penguin_mux/
lib.rs

1//! Multiplexing streamed data and datagrams over a single WebSocket
2//! connection.
3//!
4//! This is not a general-purpose WebSocket multiplexing library.
5//! It is tailored to the needs of `penguin`.
6//
7// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later
8#![deny(missing_docs, missing_debug_implementations)]
9
10pub mod config;
11mod dupe;
12pub mod frame;
13mod proto_version;
14mod stream;
15mod task;
16#[cfg(test)]
17mod tests;
18pub mod timing;
19pub mod ws;
20
21use crate::frame::{BindPayload, BindType, FinalizedFrame, Frame};
22use crate::task::{Task, TaskData};
23use crate::ws::WebSocket;
24use bytes::Bytes;
25use futures_util::task::AtomicWaker;
26use parking_lot::{Mutex, RwLock};
27use rand::distr::uniform::SampleUniform;
28use std::collections::HashMap;
29use std::future::poll_fn;
30use std::hash::Hash;
31use std::sync::Arc;
32use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
33use thiserror::Error;
34use tokio::sync::mpsc::error::TrySendError;
35use tokio::sync::{mpsc, oneshot};
36use tokio::task::JoinSet;
37use tracing::{error, trace, warn};
38
39pub use crate::dupe::Dupe;
40pub use crate::proto_version::{PROTOCOL_VERSION, PROTOCOL_VERSION_NUMBER};
41pub use crate::stream::MuxStream;
42
43/// Multiplexor error
44#[derive(Debug, Error)]
45#[non_exhaustive]
46pub enum Error {
47    /// Requester exited before receiving the stream
48    /// (i.e. the `Receiver` was dropped before the task could send the stream).
49    #[error("Requester exited before receiving the stream")]
50    SendStreamToClient,
51    /// The multiplexor is closed.
52    #[error("Mux is already closed")]
53    Closed,
54    /// The peer does not support the requested operation.
55    #[error("Peer does not support requested operation")]
56    PeerUnsupportedOperation,
57    /// This `Multiplexor` is not configured for this operation.
58    #[error("Unsupported operation")]
59    UnsupportedOperation,
60    /// Peer rejected the flow ID selection.
61    #[error("Peer rejected flow ID selection")]
62    FlowIdRejected,
63
64    /// WebSocket errors
65    #[error("WebSocket Error: {0}")]
66    WebSocket(Box<dyn std::error::Error + Send>),
67
68    // These are the ones that shouldn't normally happen
69    /// A `Datagram` frame with a target host longer than 255 octets.
70    #[error("Datagram target host longer than 255 octets")]
71    DatagramHostTooLong,
72    /// Received an invalid frame.
73    #[error("Invalid frame: {0}")]
74    InvalidFrame(#[from] frame::Error),
75    /// The peer sent a `Text` message.
76    /// "The client and server MUST NOT use other WebSocket data frame types"
77    #[error("Received `Text` message")]
78    TextMessage,
79    /// A `Acknowledge` frame that does not match any pending [`Connect`](frame::OpCode::Connect) request.
80    #[error("Bogus `Acknowledge` frame")]
81    ConnAckGone,
82    /// An internal channel closed
83    #[error("Internal channel `{0}` closed")]
84    ChannelClosed(&'static str),
85}
86
87/// A variant of [`std::result::Result`] with [`enum@Error`] as the error type.
88pub type Result<T> = std::result::Result<T, Error>;
89
90/// A multiplexor over a `WebSocket` connection.
91#[derive(Debug)]
92pub struct Multiplexor {
93    /// Open stream channels: `flow_id` -> `FlowSlot`
94    flows: Arc<RwLock<HashMap<u32, FlowSlot>>>,
95    /// Where tasks queue frames to be sent
96    tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
97    /// We only use this to inform the task that the multiplexor is closed
98    /// and it should stop processing.
99    dropped_ports_tx: mpsc::UnboundedSender<u32>,
100    /// Channel of received datagram frames for processing.
101    datagram_rx: Mutex<mpsc::Receiver<Datagram>>,
102    /// Channel for a `Multiplexor` to receive newly
103    /// established streams after the peer requests one.
104    con_recv_stream_rx: Mutex<mpsc::Receiver<MuxStream>>,
105    /// Channel for `Bnd` requests.
106    bnd_request_rx: Option<Mutex<mpsc::Receiver<BindRequest<'static>>>>,
107    /// Number of retries to find a suitable flow ID
108    /// See [`config::Options`] for more details.
109    max_flow_id_retries: usize,
110    /// Number of `StreamFrame`s to buffer in `MuxStream`'s channels before blocking
111    /// See [`config::Options`] for more details.
112    rwnd: u32,
113}
114
115impl Multiplexor {
116    /// Create a new `Multiplexor`.
117    ///
118    /// # Arguments
119    ///
120    /// * `ws`: The `WebSocket` connection to multiplex over.
121    ///
122    /// * `options`: Multiplexor options. See [`config::Options`] for more details.
123    ///   If `None`, the default options will be used.
124    ///
125    /// * `task_joinset`: A [`JoinSet`] to spawn the multiplexor task into so
126    ///   that the caller can notice if the task exits. If it is `None`, the
127    ///   task will be spawned by `tokio::spawn` and errors will be logged.
128    #[tracing::instrument(skip_all, level = "debug")]
129    pub fn new<S: WebSocket>(
130        ws: S,
131        options: Option<config::Options>,
132        task_joinset: Option<&mut JoinSet<Result<()>>>,
133    ) -> Self {
134        let options = options.unwrap_or_default();
135        let (datagram_tx, datagram_rx) = mpsc::channel(options.datagram_buffer_size);
136        let (con_recv_stream_tx, con_recv_stream_rx) = mpsc::channel(options.stream_buffer_size);
137        // This one is unbounded because the protocol provides its own flow control for `Push` frames
138        // and other frame types are to be immediately processed without any backpressure,
139        // so they are ok to be unbounded channels.
140        let (tx_frame_tx, tx_frame_rx) = mpsc::unbounded_channel();
141        // This one cannot be bounded because it needs to be used in Drop
142        let (dropped_ports_tx, dropped_ports_rx) = mpsc::unbounded_channel();
143
144        let (bnd_request_tx, bnd_request_rx) = if options.bind_buffer_size > 0 {
145            let (tx, rx) = mpsc::channel(options.bind_buffer_size);
146            (Some(tx), Some(rx))
147        } else {
148            (None, None)
149        };
150        let flows = Arc::new(RwLock::new(HashMap::new()));
151
152        let mux = Self {
153            tx_frame_tx: tx_frame_tx.dupe(),
154            flows: flows.dupe(),
155            dropped_ports_tx: dropped_ports_tx.dupe(),
156            datagram_rx: Mutex::new(datagram_rx),
157            con_recv_stream_rx: Mutex::new(con_recv_stream_rx),
158            bnd_request_rx: bnd_request_rx.map(Mutex::new),
159            max_flow_id_retries: options.max_flow_id_retries,
160            rwnd: options.rwnd,
161        };
162        let taskdata = TaskData {
163            task: Task {
164                ws: Mutex::new(ws),
165                tx_frame_tx,
166                flows,
167                dropped_ports_tx,
168                con_recv_stream_tx,
169                default_rwnd_threshold: options.default_rwnd_threshold,
170                rwnd: options.rwnd,
171                datagram_tx,
172                bnd_request_tx,
173                keepalive_interval: options.keepalive_interval,
174            },
175            dropped_ports_rx,
176            tx_frame_rx,
177        };
178        taskdata.spawn(task_joinset);
179        mux
180    }
181
182    /// Request a channel for `host` and `port`.
183    ///
184    /// # Arguments
185    /// * `host`: The host to forward to. While the current implementation
186    ///   supports a domain of arbitrary length, Section 3.2.2 of
187    ///   [RFC 3986](https://www.rfc-editor.org/rfc/rfc3986#section-3.2.2)
188    ///   specifies that the host component of a URI is limited to 255 octets.
189    /// * `port`: The port to forward to.
190    ///
191    /// # Cancel safety
192    /// This function is not cancel safe. If the task is cancelled while waiting
193    /// for the channel to be established, that channel may be established but
194    /// inaccessible through normal means. Subsequent calls to this function
195    /// will result in a new channel being established.
196    #[tracing::instrument(skip(self), level = "debug")]
197    pub async fn new_stream_channel(&self, host: &[u8], port: u16) -> Result<MuxStream> {
198        let mut retries_left = self.max_flow_id_retries;
199        // Normally this should terminate in one loop
200        while retries_left > 0 {
201            retries_left -= 1;
202            let (stream_tx, stream_rx) = oneshot::channel();
203            let flow_id = {
204                let mut streams = self.flows.write();
205                // Allocate a new port
206                let flow_id = u32::next_available_key(&*streams);
207                trace!("flow_id = {flow_id:08x}");
208                streams.insert(flow_id, FlowSlot::Requested(stream_tx));
209                flow_id
210            };
211            trace!("sending `Connect`");
212            self.tx_frame_tx
213                .send(Frame::new_connect(host, port, flow_id, self.rwnd).finalize())
214                .map_err(|_| Error::Closed)?;
215            trace!("sending stream to user");
216            let stream = stream_rx
217                .await
218                // Happens if the task exits before sending the stream,
219                // thus `Closed` is the correct error
220                .map_err(|_| Error::Closed)?;
221            if let Some(s) = stream {
222                return Ok(s);
223            }
224            // For testing purposes. Make sure the previous flow ID is gone
225            debug_assert!(!self.flows.read().contains_key(&flow_id));
226        }
227        Err(Error::FlowIdRejected)
228    }
229
230    /// Accept a new stream channel from the remote peer.
231    ///
232    /// # Errors
233    /// Returns [`Error::Closed`] if the connection is closed.
234    ///
235    /// # Cancel Safety
236    /// This function is cancel safe. If the task is cancelled while waiting
237    /// for a new connection, it is guaranteed that no connected stream will
238    /// be lost.
239    #[tracing::instrument(skip(self), level = "debug")]
240    pub async fn accept_stream_channel(&self) -> Result<MuxStream> {
241        poll_fn(|cx| self.con_recv_stream_rx.lock().poll_recv(cx))
242            .await
243            .ok_or(Error::Closed)
244    }
245
246    /// Get the next available datagram.
247    ///
248    /// # Errors
249    /// Returns [`Error::Closed`] if the connection is closed.
250    ///
251    /// # Cancel Safety
252    /// This function is cancel safe. If the task is cancelled while waiting
253    /// for a datagram, it is guaranteed that no datagram will be lost.
254    #[tracing::instrument(skip(self), level = "debug")]
255    #[inline]
256    pub async fn get_datagram(&self) -> Result<Datagram> {
257        poll_fn(|cx| self.datagram_rx.lock().poll_recv(cx))
258            .await
259            .ok_or(Error::Closed)
260    }
261
262    /// Send a datagram
263    ///
264    /// # Errors
265    /// * Returns [`Error::DatagramHostTooLong`] if the destination host is
266    /// longer than 255 octets.
267    /// * Returns [`Error::Closed`] if the Multiplexor is already closed.
268    ///
269    /// # Cancel Safety
270    /// This function is cancel safe. If the task is cancelled, it is
271    /// guaranteed that the datagram has not been sent.
272    #[tracing::instrument(skip(self), level = "debug")]
273    #[inline]
274    pub async fn send_datagram(&self, datagram: Datagram) -> Result<()> {
275        if datagram.target_host.len() > 255 {
276            return Err(Error::DatagramHostTooLong);
277        }
278        let frame = Frame::new_datagram_owned(
279            datagram.flow_id,
280            datagram.target_host,
281            datagram.target_port,
282            datagram.data,
283        );
284        self.tx_frame_tx
285            .send(frame.finalize())
286            .map_err(|_| Error::Closed)?;
287        Ok(())
288    }
289
290    /// Request a `Bind` for `host` and `port`.
291    ///
292    /// # Arguments
293    /// * `host`: The local address or host to bind to. Hostname resolution might
294    /// not be supported by the remote peer.
295    /// * `port`: The local port to bind to.
296    ///
297    /// # Cancel Safety
298    /// This function is not cancel safe. If the task is cancelled while waiting
299    /// for the peer to reply, the user will not be able to receive whether the
300    /// peer accepted the bind request.
301    #[tracing::instrument(skip(self), level = "debug")]
302    #[inline]
303    pub async fn request_bind(&self, host: &[u8], port: u16, bind_type: BindType) -> Result<bool> {
304        let (result_tx, result_rx) = oneshot::channel();
305        let flow_id = {
306            let mut streams = self.flows.write();
307            // Allocate a new port
308            let flow_id = u32::next_available_key(&*streams);
309            trace!("flow_id = {flow_id:08x}");
310            streams.insert(flow_id, FlowSlot::BindRequested(result_tx));
311            flow_id
312        };
313        let bnd_frame = Frame::new_bind(flow_id, bind_type, host, port).finalize();
314        self.tx_frame_tx
315            .send(bnd_frame)
316            .map_err(|_| Error::Closed)?;
317        let result = result_rx.await.map_err(|_| Error::Closed)?;
318        Ok(result)
319    }
320
321    /// Accept a `Bind` request from the remote peer.
322    ///
323    /// # Cancel Safety
324    /// This function is cancel safe. If the task is cancelled while waiting
325    /// for a `Bind` request, it is guaranteed that no request will be lost.
326    #[tracing::instrument(skip(self), level = "debug")]
327    pub async fn next_bind_request(&self) -> Result<BindRequest<'static>> {
328        if let Some(rx) = self.bnd_request_rx.as_ref() {
329            poll_fn(|cx| rx.lock().poll_recv(cx))
330                .await
331                .ok_or(Error::Closed)
332        } else {
333            Err(Error::UnsupportedOperation)
334        }
335    }
336}
337
338impl Drop for Multiplexor {
339    fn drop(&mut self) {
340        if self.dropped_ports_tx.send(0).is_err() {
341            error!("Failed to inform task of dropped multiplexor");
342        }
343    }
344}
345
346#[derive(Debug)]
347struct EstablishedStreamData {
348    /// Channel for sending data to `MuxStream`'s `AsyncRead`
349    /// If `None`, we have received `Finish` from the peer but we can possibly still send data.
350    sender: Option<mpsc::Sender<Bytes>>,
351    /// Whether writes should succeed.
352    /// There are two cases for `true`:
353    /// 1. `Finish` has been sent.
354    /// 2. The stream has been removed from `inner.streams`.
355    // In general, our `Atomic*` types don't need more than `Relaxed` ordering
356    // because we are not protecting memory accesses, but rather counting the
357    // frames we have sent and received.
358    finish_sent: Arc<AtomicBool>,
359    /// Number of `Push` frames we are allowed to send before waiting for a `Acknowledge` frame.
360    psh_send_remaining: Arc<AtomicU32>,
361    /// Waker to wake up the task that sends frames because their `psh_send_remaining`
362    /// has increased.
363    writer_waker: Arc<AtomicWaker>,
364}
365
366impl EstablishedStreamData {
367    /// Process a `Finish` frame from the peer and thus disallowing further `AsyncRead` operations
368    /// Returns the sender if it was not already taken.
369    #[inline]
370    const fn disallow_read(&mut self) -> Option<mpsc::Sender<Bytes>> {
371        self.sender.take()
372    }
373
374    /// Process a `Acknowledge` frame from the peer
375    #[inline]
376    fn acknowledge(&self, acknowledged: u32) {
377        // Atomic ordering: as long as the value is incremented atomically,
378        // whether a writer sees the new value or the old value is not
379        // important. If it sees the old value and decides to return
380        // `Poll::Pending`, it will be woken up by the `Waker` anyway.
381        self.psh_send_remaining
382            .fetch_add(acknowledged, Ordering::Relaxed);
383        // Wake up the writer if it is waiting for `Acknowledge`
384        self.writer_waker.wake();
385    }
386
387    /// Disallow any `AsyncWrite` operations.
388    /// Note that this should not be used from inside the `MuxStream` itself
389    #[inline]
390    fn disallow_write(&self) -> bool {
391        // Atomic ordering:
392        // Load part:
393        // If the user calls `poll_shutdown`, but we see `true` here,
394        // the other end will receive a bogus `Reset` frame, which is fine.
395        // Store part:
396        // We need to make sure the writer can see the new value
397        // before we call `wake()`.
398        let old = self.finish_sent.swap(true, Ordering::AcqRel);
399        // If there is a writer waiting for `Acknowledge`, wake it up because it will never receive one.
400        // Waking it here and the user should receive a `BrokenPipe` error.
401        self.writer_waker.wake();
402        old
403    }
404}
405
406#[derive(Debug)]
407enum FlowSlot {
408    /// A `Connect` frame was sent and waiting for the peer to `Acknowledge`.
409    Requested(oneshot::Sender<Option<MuxStream>>),
410    /// The stream is established.
411    Established(EstablishedStreamData),
412    /// A `Bind` request was sent and waiting for the peer to `Acknowledge` or `Reset`.
413    BindRequested(oneshot::Sender<bool>),
414}
415
416impl FlowSlot {
417    /// Take the sender and set the slot to `Established`.
418    /// Returns `None` if the slot is already established.
419    #[inline]
420    fn establish(
421        &mut self,
422        data: EstablishedStreamData,
423    ) -> Option<oneshot::Sender<Option<MuxStream>>> {
424        // Make sure it is not replaced in the error case
425        if matches!(self, Self::Established(_) | Self::BindRequested(_)) {
426            error!("establishing an established or invalid slot");
427            return None;
428        }
429        let sender = match std::mem::replace(self, Self::Established(data)) {
430            Self::Requested(sender) => sender,
431            Self::Established(_) | Self::BindRequested(_) => unreachable!(),
432        };
433        Some(sender)
434    }
435
436    /// If the slot is established, send data. Otherwise, return `None`.
437    #[inline]
438    fn dispatch(&self, data: Bytes) -> Option<std::result::Result<(), TrySendError<()>>> {
439        if let Self::Established(stream_data) = self {
440            let r = stream_data
441                .sender
442                .as_ref()
443                .map(|sender| sender.try_send(data))?
444                .map_err(|e| match e {
445                    TrySendError::Full(_) => TrySendError::Full(()),
446                    TrySendError::Closed(_) => TrySendError::Closed(()),
447                });
448            Some(r)
449        } else {
450            None
451        }
452    }
453}
454
455/// Datagram frame data
456#[derive(Clone, Debug)]
457pub struct Datagram {
458    /// Flow ID
459    pub flow_id: u32,
460    /// Target host
461    pub target_host: Bytes,
462    /// Target port
463    pub target_port: u16,
464    /// Data
465    pub data: Bytes,
466}
467
468/// A `Bind` request that the user can respond to
469#[derive(Debug)]
470pub struct BindRequest<'data> {
471    /// Flow ID
472    flow_id: u32,
473    /// Bind payload
474    payload: BindPayload<'data>,
475    /// Place to respond to the bind request
476    tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
477}
478
479impl BindRequest<'_> {
480    /// Get the flow ID of the bind request
481    #[inline]
482    pub const fn flow_id(&self) -> u32 {
483        self.flow_id
484    }
485
486    /// Get the bind type of the bind request
487    #[inline]
488    pub const fn bind_type(&self) -> BindType {
489        self.payload.bind_type
490    }
491
492    /// Get the host of the bind request
493    #[inline]
494    pub fn host(&self) -> &[u8] {
495        self.payload.target_host.as_ref()
496    }
497
498    /// Get the port of the bind request
499    #[inline]
500    pub const fn port(&self) -> u16 {
501        self.payload.target_port
502    }
503
504    /// Accept or reject the bind request
505    #[tracing::instrument(skip(self), level = "debug")]
506    pub fn reply(&self, accepted: bool) -> Result<()> {
507        if accepted {
508            self.tx_frame_tx
509                .send(Frame::new_finish(self.flow_id).finalize())
510        } else {
511            self.tx_frame_tx
512                .send(Frame::new_reset(self.flow_id).finalize())
513        }
514        .map_err(|_| Error::Closed)
515    }
516}
517
518impl Drop for BindRequest<'_> {
519    /// Dropping a `BindRequest` will reject the request
520    fn drop(&mut self) {
521        self.reply(false).ok();
522    }
523}
524
525/// Randomly generate a new number
526pub trait IntKey: Eq + Hash + Copy + SampleUniform + PartialOrd {
527    /// The minimum value of the key
528    const MIN: Self;
529    /// The maximum value of the key
530    const MAX: Self;
531
532    /// Generate a new key that is not in the map
533    #[inline]
534    #[must_use]
535    fn next_available_key<V>(map: &HashMap<Self, V>) -> Self {
536        loop {
537            let i = rand::random_range(Self::MIN..Self::MAX);
538            if !map.contains_key(&i) {
539                break i;
540            }
541        }
542    }
543}
544
545macro_rules! impl_int_key {
546    ($($t:ty),*) => {
547        $(
548            impl IntKey for $t {
549                // 0 is for special use
550                const MIN : Self = 1;
551                const MAX : Self = Self::MAX;
552            }
553        )*
554    };
555}
556
557impl_int_key!(u8, u16, u32, u64, u128, usize);