Skip to main content

penguin_mux/
lib.rs

1//! Multiplexing streamed data and datagrams over a single WebSocket
2//! connection.
3//!
4//! This is not a general-purpose WebSocket multiplexing library.
5//! It is tailored to the needs of `penguin`.
6//
7// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later
8#![deny(rust_2018_idioms, missing_docs, missing_debug_implementations)]
9#![deny(clippy::pedantic, clippy::cargo, clippy::nursery, clippy::unwrap_used)]
10#![allow(clippy::multiple_crate_versions)]
11
12pub mod config;
13#[cfg(feature = "deadlock-detection")]
14pub mod deadlock_detection;
15mod dupe;
16pub mod frame;
17mod loom;
18mod proto_version;
19mod stream;
20mod task;
21#[cfg(test)]
22mod tests;
23pub mod timing;
24pub mod ws;
25
26use crate::frame::{BindPayload, BindType, FinalizedFrame, Frame};
27use crate::loom::{Arc, AtomicBool, AtomicU32, AtomicWaker, Mutex, Ordering, RwLock};
28use crate::task::{Task, TaskData};
29use crate::ws::WebSocket;
30use bytes::Bytes;
31use rand::distr::uniform::SampleUniform;
32use std::future::poll_fn;
33use std::hash::{BuildHasher, Hash};
34use std::time::Instant;
35use thiserror::Error;
36use tokio::sync::mpsc::error::TrySendError;
37use tokio::sync::{mpsc, oneshot, watch};
38use tokio::task::JoinSet;
39use tracing::{error, trace, warn};
40
41#[cfg(feature = "nohash")]
42use nohash_hasher::IntMap;
43#[cfg(not(feature = "nohash"))]
44use std::collections::HashMap as IntMap;
45
46pub use crate::dupe::Dupe;
47pub use crate::proto_version::{PROTOCOL_VERSION, PROTOCOL_VERSION_NUMBER};
48pub use crate::stream::MuxStream;
49
50/// Multiplexor error
51#[derive(Debug, Error)]
52#[non_exhaustive]
53pub enum Error {
54    /// Requester exited before receiving the stream
55    /// (i.e. the `Receiver` was dropped before the task could send the stream)
56    #[error("requester exited before receiving the stream")]
57    SendStreamToClient,
58    /// The multiplexor is closed
59    #[error("mux is already closed")]
60    Closed,
61    /// The peer does not support the requested operation
62    #[error("peer does not support requested operation")]
63    PeerUnsupportedOperation,
64    /// This `Multiplexor` is not configured for this operation
65    #[error("unsupported operation")]
66    UnsupportedOperation,
67    /// Peer rejected the flow ID selection
68    #[error("peer rejected flow ID selection")]
69    FlowIdRejected,
70    /// Keepalive timeout: no pong received within the allowed time
71    #[error("`Pong` response not received before the configured timeout")]
72    KeepaliveTimeout,
73
74    /// WebSocket errors
75    #[error("unspecified `WebSocket` error: {0}")]
76    WebSocket(Box<dyn std::error::Error + Send>),
77
78    // These are the ones that shouldn't normally happen
79    /// A `Datagram` frame with a target host longer than 255 octets
80    #[error("datagram target host longer than 255 octets")]
81    DatagramHostTooLong,
82    /// Received an invalid frame
83    #[error("invalid frame: {0}")]
84    InvalidFrame(#[from] frame::Error),
85    /// The peer sent a `Text` message
86    /// "The client and server MUST NOT use other WebSocket data frame types"
87    #[error("received `Text` message")]
88    TextMessage,
89    /// A `Acknowledge` frame that does not match any pending [`Connect`](frame::OpCode::Connect) request
90    #[error("bogus `Acknowledge` frame")]
91    ConnAckGone,
92    /// An internal channel closed
93    #[error("internal channel `{0}` closed unexpectedly")]
94    ChannelClosed(&'static str),
95}
96
97/// A variant of [`std::result::Result`] with [`enum@Error`] as the error type.
98pub type Result<T> = std::result::Result<T, Error>;
99
100/// A multiplexor over a `WebSocket` connection.
101#[derive(Debug)]
102pub struct Multiplexor {
103    /// Open stream channels: `flow_id` -> `FlowSlot`
104    flows: Arc<RwLock<IntMap<u32, FlowSlot>>>,
105    /// Where tasks queue frames to be sent
106    tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
107    /// We only use this to inform the task that the multiplexor is closed
108    /// and it should stop processing.
109    dropped_ports_tx: mpsc::UnboundedSender<u32>,
110    /// Channel of received datagram frames for processing.
111    datagram_rx: Mutex<mpsc::Receiver<Datagram>>,
112    /// Channel for a `Multiplexor` to receive newly
113    /// established streams after the peer requests one.
114    con_recv_stream_rx: Mutex<mpsc::Receiver<MuxStream>>,
115    /// Channel for `Bnd` requests.
116    bnd_request_rx: Option<Mutex<mpsc::Receiver<BindRequest<'static>>>>,
117    /// Number of retries to find a suitable flow ID
118    /// See [`config::Options`] for more details.
119    max_flow_id_retries: usize,
120    /// Number of `StreamFrame`s to buffer in `MuxStream`'s channels before blocking
121    /// See [`config::Options`] for more details.
122    rwnd: u32,
123}
124
125impl Multiplexor {
126    /// Create a new `Multiplexor`.
127    ///
128    /// # Arguments
129    ///
130    /// * `ws`: The `WebSocket` connection to multiplex over.
131    ///
132    /// * `options`: Multiplexor options. See [`config::Options`] for more details.
133    ///   If `None`, the default options will be used.
134    ///
135    /// * `task_joinset`: A [`JoinSet`] to spawn the multiplexor task into so
136    ///   that the caller can notice if the task exits. If it is `None`, the
137    ///   task will be spawned by `tokio::spawn` and errors will be logged.
138    #[tracing::instrument(skip_all, level = "debug")]
139    pub fn new<S: WebSocket>(
140        ws: S,
141        options: Option<config::Options>,
142        task_joinset: Option<&mut JoinSet<Result<()>>>,
143    ) -> Self {
144        let options = options.unwrap_or_default();
145        let (datagram_tx, datagram_rx) = mpsc::channel(options.datagram_buffer_size);
146        let (con_recv_stream_tx, con_recv_stream_rx) = mpsc::channel(options.stream_buffer_size);
147        // This one is unbounded because the protocol provides its own flow control for `Push` frames
148        // and other frame types are to be immediately processed without any backpressure,
149        // so they are ok to be unbounded channels.
150        let (tx_frame_tx, tx_frame_rx) = mpsc::unbounded_channel();
151        // This one cannot be bounded because it needs to be used in Drop
152        let (dropped_ports_tx, dropped_ports_rx) = mpsc::unbounded_channel();
153        let (last_pong_timestamp_tx, last_pong_timestamp_rx) = watch::channel(Instant::now());
154
155        let (bnd_request_tx, bnd_request_rx) = if options.bind_buffer_size > 0 {
156            let (tx, rx) = mpsc::channel(options.bind_buffer_size);
157            (Some(tx), Some(rx))
158        } else {
159            (None, None)
160        };
161        let flows = Arc::new(RwLock::new(IntMap::default()));
162
163        let mux = Self {
164            tx_frame_tx: tx_frame_tx.dupe(),
165            flows: flows.dupe(),
166            dropped_ports_tx: dropped_ports_tx.dupe(),
167            datagram_rx: Mutex::new(datagram_rx),
168            con_recv_stream_rx: Mutex::new(con_recv_stream_rx),
169            bnd_request_rx: bnd_request_rx.map(Mutex::new),
170            max_flow_id_retries: options.max_flow_id_retries,
171            rwnd: options.rwnd,
172        };
173        let taskdata = TaskData {
174            task: Task {
175                ws: Mutex::new(ws),
176                tx_frame_tx,
177                flows,
178                dropped_ports_tx,
179                con_recv_stream_tx,
180                last_pong_timestamp_tx,
181                default_rwnd_threshold: options.default_rwnd_threshold,
182                rwnd: options.rwnd,
183                datagram_tx,
184                bnd_request_tx,
185                keepalive_interval: options.keepalive_interval,
186                keepalive_timeout: options.keepalive_timeout,
187            },
188            dropped_ports_rx,
189            tx_frame_rx,
190            last_pong_timestamp_rx,
191        };
192        taskdata.spawn(task_joinset);
193        mux
194    }
195
196    /// Request a channel for `host` and `port`.
197    ///
198    /// # Arguments
199    /// * `host`: The host to forward to. While the current implementation
200    ///   supports a domain of arbitrary length, Section 3.2.2 of
201    ///   [RFC 3986](https://www.rfc-editor.org/rfc/rfc3986#section-3.2.2)
202    ///   specifies that the host component of a URI is limited to 255 octets.
203    /// * `port`: The port to forward to.
204    ///
205    /// # Cancel safety
206    /// This function is not cancel safe. If the task is cancelled while waiting
207    /// for the channel to be established, that channel may be established but
208    /// inaccessible through normal means. Subsequent calls to this function
209    /// will result in a new channel being established.
210    ///
211    /// # Errors
212    /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
213    /// - Returns [`Error::FlowIdRejected`] if a flow ID could not be allocated
214    ///   after `max_flow_id_retries` attempts.
215    #[tracing::instrument(skip(self), level = "debug")]
216    pub async fn new_stream_channel(&self, host: &[u8], port: u16) -> Result<MuxStream> {
217        let mut retries_left = self.max_flow_id_retries;
218        // Normally this should terminate in one loop
219        while retries_left > 0 {
220            retries_left -= 1;
221            let (stream_tx, stream_rx) = oneshot::channel();
222            let flow_id = {
223                let mut streams = self.flows.write();
224                // Allocate a new port
225                let flow_id = u32::next_available_key(&*streams);
226                trace!("flow_id = {flow_id:08x}");
227                streams.insert(flow_id, FlowSlot::Requested(stream_tx));
228                flow_id
229            };
230            trace!("sending `Connect`");
231            self.tx_frame_tx
232                .send(Frame::new_connect(host, port, flow_id, self.rwnd).finalize())
233                .or(Err(Error::Closed))?;
234            trace!("sending stream to user");
235            let stream = stream_rx
236                .await
237                // Happens if the task exits before sending the stream,
238                // thus `Closed` is the correct error
239                .or(Err(Error::Closed))?;
240            if let Some(s) = stream {
241                return Ok(s);
242            }
243            // For testing purposes. Make sure the previous flow ID is gone
244            debug_assert!(!self.flows.read().contains_key(&flow_id));
245        }
246        Err(Error::FlowIdRejected)
247    }
248
249    /// Accept a new stream channel from the remote peer.
250    ///
251    /// # Errors
252    /// - Returns [`Error::Closed`] if the connection is closed.
253    ///
254    /// # Cancel Safety
255    /// This function is cancel safe. If the task is cancelled while waiting
256    /// for a new connection, it is guaranteed that no connected stream will
257    /// be lost.
258    #[tracing::instrument(skip(self), level = "debug")]
259    pub async fn accept_stream_channel(&self) -> Result<MuxStream> {
260        poll_fn(|cx| self.con_recv_stream_rx.lock().poll_recv(cx))
261            .await
262            .ok_or(Error::Closed)
263    }
264
265    /// Get the next available datagram.
266    ///
267    /// # Errors
268    /// - Returns [`Error::Closed`] if the connection is closed.
269    ///
270    /// # Cancel Safety
271    /// This function is cancel safe. If the task is cancelled while waiting
272    /// for a datagram, it is guaranteed that no datagram will be lost.
273    #[tracing::instrument(skip(self), level = "debug")]
274    #[inline]
275    pub async fn get_datagram(&self) -> Result<Datagram> {
276        poll_fn(|cx| self.datagram_rx.lock().poll_recv(cx))
277            .await
278            .ok_or(Error::Closed)
279    }
280
281    /// Send a datagram
282    ///
283    /// # Errors
284    /// - Returns [`Error::DatagramHostTooLong`] if the destination host is
285    /// longer than 255 octets.
286    /// - Returns [`Error::Closed`] if the Multiplexor is already closed.
287    ///
288    /// # Cancel Safety
289    /// This function is cancel safe. If the task is cancelled, it is
290    /// guaranteed that the datagram has not been sent.
291    #[tracing::instrument(skip(self), level = "debug")]
292    #[inline]
293    pub async fn send_datagram(&self, datagram: Datagram) -> Result<()> {
294        if datagram.target_host.len() > 255 {
295            return Err(Error::DatagramHostTooLong);
296        }
297        let frame = Frame::new_datagram_owned(
298            datagram.flow_id,
299            datagram.target_host,
300            datagram.target_port,
301            datagram.data,
302        );
303        self.tx_frame_tx
304            .send(frame.finalize())
305            .or(Err(Error::Closed))?;
306        Ok(())
307    }
308
309    /// Request a `Bind` for `host` and `port`.
310    ///
311    /// # Arguments
312    /// * `host`: The local address or host to bind to. Hostname resolution might
313    /// not be supported by the remote peer.
314    /// * `port`: The local port to bind to.
315    ///
316    /// # Cancel Safety
317    /// This function is not cancel safe. If the task is cancelled while waiting
318    /// for the peer to reply, the user will not be able to receive whether the
319    /// peer accepted the bind request.
320    ///
321    /// # Errors
322    /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
323    #[tracing::instrument(skip(self), level = "debug")]
324    pub async fn request_bind(&self, host: &[u8], port: u16, bind_type: BindType) -> Result<bool> {
325        let (result_tx, result_rx) = oneshot::channel();
326        let flow_id = {
327            let mut streams = self.flows.write();
328            // Allocate a new port
329            let flow_id = u32::next_available_key(&*streams);
330            trace!("flow_id = {flow_id:08x}");
331            streams.insert(flow_id, FlowSlot::BindRequested(result_tx));
332            flow_id
333        };
334        let bnd_frame = Frame::new_bind(flow_id, bind_type, host, port).finalize();
335        self.tx_frame_tx.send(bnd_frame).or(Err(Error::Closed))?;
336        let result = result_rx.await.or(Err(Error::Closed))?;
337        Ok(result)
338    }
339
340    /// Accept a `Bind` request from the remote peer.
341    ///
342    /// # Cancel Safety
343    /// This function is cancel safe. If the task is cancelled while waiting
344    /// for a `Bind` request, it is guaranteed that no request will be lost.
345    ///
346    /// # Errors
347    /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
348    /// - Returns [`Error::UnsupportedOperation`] if the `Multiplexor` was not
349    ///  configured to allow `Bind` requests.
350    #[tracing::instrument(skip(self), level = "debug")]
351    pub async fn next_bind_request(&self) -> Result<BindRequest<'static>> {
352        if let Some(rx) = self.bnd_request_rx.as_ref() {
353            poll_fn(|cx| rx.lock().poll_recv(cx))
354                .await
355                .ok_or(Error::Closed)
356        } else {
357            Err(Error::UnsupportedOperation)
358        }
359    }
360}
361
362impl Drop for Multiplexor {
363    fn drop(&mut self) {
364        if self.dropped_ports_tx.send(0).is_err() {
365            error!("Failed to inform task of dropped multiplexor");
366        }
367    }
368}
369
370#[derive(Debug)]
371struct EstablishedStreamData {
372    /// Channel for sending data to `MuxStream`'s `AsyncRead`
373    /// If `None`, we have received `Finish` from the peer but we can possibly still send data.
374    sender: Option<mpsc::Sender<Bytes>>,
375    /// Whether writes should succeed.
376    /// There are two cases for `true`:
377    /// 1. `Finish` has been sent.
378    /// 2. The stream has been removed from `inner.streams`.
379    // In general, our `Atomic*` types don't need more than `Relaxed` ordering
380    // because we are not protecting memory accesses, but rather counting the
381    // frames we have sent and received.
382    finish_sent: Arc<AtomicBool>,
383    /// Number of `Push` frames we are allowed to send before waiting for a `Acknowledge` frame.
384    psh_send_remaining: Arc<AtomicU32>,
385    /// Waker to wake up the task that sends frames because their `psh_send_remaining`
386    /// has increased.
387    writer_waker: Arc<AtomicWaker>,
388}
389
390impl EstablishedStreamData {
391    /// Process a `Finish` frame from the peer and thus disallowing further `AsyncRead` operations
392    /// Returns the sender if it was not already taken.
393    #[inline]
394    const fn disallow_read(&mut self) -> Option<mpsc::Sender<Bytes>> {
395        self.sender.take()
396    }
397
398    /// Process a `Acknowledge` frame from the peer
399    #[inline]
400    fn acknowledge(&self, acknowledged: u32) {
401        // Atomic ordering: as long as the value is incremented atomically,
402        // whether a writer sees the new value or the old value is not
403        // important. If it sees the old value and decides to return
404        // `Poll::Pending`, it will be woken up by the `Waker` anyway.
405        self.psh_send_remaining
406            .fetch_add(acknowledged, Ordering::Relaxed);
407        // Wake up the writer if it is waiting for `Acknowledge`
408        self.writer_waker.wake();
409    }
410
411    /// Disallow any `AsyncWrite` operations.
412    /// Note that this should not be used from inside the `MuxStream` itself
413    #[inline]
414    fn disallow_write(&self) -> bool {
415        // Atomic ordering:
416        // Load part:
417        // If the user calls `poll_shutdown`, but we see `true` here,
418        // the other end will receive a bogus `Reset` frame, which is fine.
419        // Store part:
420        // We need to make sure the writer can see the new value
421        // before we call `wake()`.
422        let old = self.finish_sent.swap(true, Ordering::AcqRel);
423        // If there is a writer waiting for `Acknowledge`, wake it up because it will never receive one.
424        // Waking it here and the user should receive a `BrokenPipe` error.
425        self.writer_waker.wake();
426        old
427    }
428}
429
430#[derive(Debug)]
431enum FlowSlot {
432    /// A `Connect` frame was sent and waiting for the peer to `Acknowledge`.
433    Requested(oneshot::Sender<Option<MuxStream>>),
434    /// The stream is established.
435    Established(EstablishedStreamData),
436    /// A `Bind` request was sent and waiting for the peer to `Acknowledge` or `Reset`.
437    BindRequested(oneshot::Sender<bool>),
438}
439
440impl FlowSlot {
441    /// Take the sender and set the slot to `Established`.
442    /// Returns `None` if the slot is already established.
443    #[inline]
444    fn establish(
445        &mut self,
446        data: EstablishedStreamData,
447    ) -> Option<oneshot::Sender<Option<MuxStream>>> {
448        // Make sure it is not replaced in the error case
449        if matches!(self, Self::Established(_) | Self::BindRequested(_)) {
450            error!("establishing an established or invalid slot");
451            return None;
452        }
453        let sender = match std::mem::replace(self, Self::Established(data)) {
454            Self::Requested(sender) => sender,
455            Self::Established(_) | Self::BindRequested(_) => unreachable!(),
456        };
457        Some(sender)
458    }
459
460    /// If the slot is established, send data. Otherwise, return `None`.
461    #[inline]
462    fn dispatch(&self, data: Bytes) -> Option<std::result::Result<(), TrySendError<()>>> {
463        if let Self::Established(stream_data) = self {
464            let r = stream_data
465                .sender
466                .as_ref()
467                .map(|sender| sender.try_send(data))?
468                .map_err(|e| match e {
469                    TrySendError::Full(_) => TrySendError::Full(()),
470                    TrySendError::Closed(_) => TrySendError::Closed(()),
471                });
472            Some(r)
473        } else {
474            None
475        }
476    }
477}
478
479/// Datagram frame data
480#[derive(Clone, Debug)]
481pub struct Datagram {
482    /// Flow ID
483    pub flow_id: u32,
484    /// Target host
485    pub target_host: Bytes,
486    /// Target port
487    pub target_port: u16,
488    /// Data
489    pub data: Bytes,
490}
491
492/// A `Bind` request that the user can respond to
493#[derive(Debug)]
494pub struct BindRequest<'data> {
495    /// Flow ID
496    flow_id: u32,
497    /// Bind payload
498    payload: BindPayload<'data>,
499    /// Place to respond to the bind request
500    tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
501}
502
503impl BindRequest<'_> {
504    /// Get the flow ID of the bind request
505    #[inline]
506    pub const fn flow_id(&self) -> u32 {
507        self.flow_id
508    }
509
510    /// Get the bind type of the bind request
511    #[inline]
512    pub const fn bind_type(&self) -> BindType {
513        self.payload.bind_type
514    }
515
516    /// Get the host of the bind request
517    #[inline]
518    pub fn host(&self) -> &[u8] {
519        self.payload.target_host.as_ref()
520    }
521
522    /// Get the port of the bind request
523    #[inline]
524    pub const fn port(&self) -> u16 {
525        self.payload.target_port
526    }
527
528    /// Accept or reject the bind request
529    ///
530    /// # Errors
531    /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
532    #[tracing::instrument(skip(self), level = "debug")]
533    pub fn reply(&self, accepted: bool) -> Result<()> {
534        if accepted {
535            self.tx_frame_tx
536                .send(Frame::new_finish(self.flow_id).finalize())
537        } else {
538            self.tx_frame_tx
539                .send(Frame::new_reset(self.flow_id).finalize())
540        }
541        .or(Err(Error::Closed))
542    }
543}
544
545impl Drop for BindRequest<'_> {
546    /// Dropping a `BindRequest` will reject the request
547    fn drop(&mut self) {
548        self.reply(false).ok();
549    }
550}
551
552/// Randomly generate a new number
553pub trait IntKey: Eq + Hash + Copy + SampleUniform + PartialOrd {
554    /// The minimum value of the key
555    const MIN: Self;
556    /// The maximum value of the key
557    const MAX: Self;
558
559    /// Generate a new key that is not in the map
560    #[inline]
561    #[must_use]
562    fn next_available_key<V, S: BuildHasher>(map: &std::collections::HashMap<Self, V, S>) -> Self {
563        loop {
564            let i = rand::random_range(Self::MIN..Self::MAX);
565            if !map.contains_key(&i) {
566                break i;
567            }
568        }
569    }
570}
571
572macro_rules! impl_int_key {
573    ($($t:ty),*) => {
574        $(
575            impl IntKey for $t {
576                // 0 is for special use
577                const MIN : Self = 1;
578                const MAX : Self = Self::MAX;
579            }
580        )*
581    };
582}
583
584impl_int_key!(u8, u16, u32, u64, u128, usize);