penguin_mux/lib.rs
1//! Multiplexing streamed data and datagrams over a single WebSocket
2//! connection.
3//!
4//! This is not a general-purpose WebSocket multiplexing library.
5//! It is tailored to the needs of `penguin`.
6//
7// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later
8#![deny(rust_2018_idioms, missing_docs, missing_debug_implementations)]
9#![deny(clippy::pedantic, clippy::cargo, clippy::nursery, clippy::unwrap_used)]
10#![allow(clippy::multiple_crate_versions)]
11
12pub mod config;
13#[cfg(feature = "deadlock-detection")]
14pub mod deadlock_detection;
15mod dupe;
16pub mod frame;
17mod loom;
18mod proto_version;
19mod stream;
20mod task;
21#[cfg(test)]
22mod tests;
23pub mod timing;
24pub mod ws;
25
26use crate::frame::{BindPayload, BindType, FinalizedFrame, Frame};
27use crate::loom::{Arc, AtomicBool, AtomicU32, AtomicWaker, Mutex, Ordering, RwLock};
28use crate::task::{Task, TaskData};
29use crate::ws::WebSocket;
30use bytes::Bytes;
31use rand::distr::uniform::SampleUniform;
32use std::future::poll_fn;
33use std::hash::{BuildHasher, Hash};
34use std::time::Instant;
35use thiserror::Error;
36use tokio::sync::mpsc::error::TrySendError;
37use tokio::sync::{mpsc, oneshot, watch};
38use tokio::task::JoinSet;
39use tracing::{error, trace, warn};
40
41#[cfg(feature = "nohash")]
42use nohash_hasher::IntMap;
43#[cfg(not(feature = "nohash"))]
44use std::collections::HashMap as IntMap;
45
46pub use crate::dupe::Dupe;
47pub use crate::proto_version::{PROTOCOL_VERSION, PROTOCOL_VERSION_NUMBER};
48pub use crate::stream::MuxStream;
49
50/// Multiplexor error
51#[derive(Debug, Error)]
52#[non_exhaustive]
53pub enum Error {
54 /// Requester exited before receiving the stream
55 /// (i.e. the `Receiver` was dropped before the task could send the stream)
56 #[error("requester exited before receiving the stream")]
57 SendStreamToClient,
58 /// The multiplexor is closed
59 #[error("mux is already closed")]
60 Closed,
61 /// The peer does not support the requested operation
62 #[error("peer does not support requested operation")]
63 PeerUnsupportedOperation,
64 /// This `Multiplexor` is not configured for this operation
65 #[error("unsupported operation")]
66 UnsupportedOperation,
67 /// Peer rejected the flow ID selection
68 #[error("peer rejected flow ID selection")]
69 FlowIdRejected,
70 /// Keepalive timeout: no pong received within the allowed time
71 #[error("`Pong` response not received before the configured timeout")]
72 KeepaliveTimeout,
73
74 /// WebSocket errors
75 #[error("unspecified `WebSocket` error: {0}")]
76 WebSocket(Box<dyn std::error::Error + Send>),
77
78 // These are the ones that shouldn't normally happen
79 /// A `Datagram` frame with a target host longer than 255 octets
80 #[error("datagram target host longer than 255 octets")]
81 DatagramHostTooLong,
82 /// Received an invalid frame
83 #[error("invalid frame: {0}")]
84 InvalidFrame(#[from] frame::Error),
85 /// The peer sent a `Text` message
86 /// "The client and server MUST NOT use other WebSocket data frame types"
87 #[error("received `Text` message")]
88 TextMessage,
89 /// A `Acknowledge` frame that does not match any pending [`Connect`](frame::OpCode::Connect) request
90 #[error("bogus `Acknowledge` frame")]
91 ConnAckGone,
92 /// An internal channel closed
93 #[error("internal channel `{0}` closed unexpectedly")]
94 ChannelClosed(&'static str),
95}
96
97/// A variant of [`std::result::Result`] with [`enum@Error`] as the error type.
98pub type Result<T> = std::result::Result<T, Error>;
99
100/// A multiplexor over a `WebSocket` connection.
101#[derive(Debug)]
102pub struct Multiplexor {
103 /// Open stream channels: `flow_id` -> `FlowSlot`
104 flows: Arc<RwLock<IntMap<u32, FlowSlot>>>,
105 /// Where tasks queue frames to be sent
106 tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
107 /// We only use this to inform the task that the multiplexor is closed
108 /// and it should stop processing.
109 dropped_ports_tx: mpsc::UnboundedSender<u32>,
110 /// Channel of received datagram frames for processing.
111 datagram_rx: Mutex<mpsc::Receiver<Datagram>>,
112 /// Channel for a `Multiplexor` to receive newly
113 /// established streams after the peer requests one.
114 con_recv_stream_rx: Mutex<mpsc::Receiver<MuxStream>>,
115 /// Channel for `Bnd` requests.
116 bnd_request_rx: Option<Mutex<mpsc::Receiver<BindRequest<'static>>>>,
117 /// Number of retries to find a suitable flow ID
118 /// See [`config::Options`] for more details.
119 max_flow_id_retries: usize,
120 /// Number of `StreamFrame`s to buffer in `MuxStream`'s channels before blocking
121 /// See [`config::Options`] for more details.
122 rwnd: u32,
123}
124
125impl Multiplexor {
126 /// Create a new `Multiplexor`.
127 ///
128 /// # Arguments
129 ///
130 /// * `ws`: The `WebSocket` connection to multiplex over.
131 ///
132 /// * `options`: Multiplexor options. See [`config::Options`] for more details.
133 /// If `None`, the default options will be used.
134 ///
135 /// * `task_joinset`: A [`JoinSet`] to spawn the multiplexor task into so
136 /// that the caller can notice if the task exits. If it is `None`, the
137 /// task will be spawned by `tokio::spawn` and errors will be logged.
138 #[tracing::instrument(skip_all, level = "debug")]
139 pub fn new<S: WebSocket>(
140 ws: S,
141 options: Option<config::Options>,
142 task_joinset: Option<&mut JoinSet<Result<()>>>,
143 ) -> Self {
144 let options = options.unwrap_or_default();
145 let (datagram_tx, datagram_rx) = mpsc::channel(options.datagram_buffer_size);
146 let (con_recv_stream_tx, con_recv_stream_rx) = mpsc::channel(options.stream_buffer_size);
147 // This one is unbounded because the protocol provides its own flow control for `Push` frames
148 // and other frame types are to be immediately processed without any backpressure,
149 // so they are ok to be unbounded channels.
150 let (tx_frame_tx, tx_frame_rx) = mpsc::unbounded_channel();
151 // This one cannot be bounded because it needs to be used in Drop
152 let (dropped_ports_tx, dropped_ports_rx) = mpsc::unbounded_channel();
153 let (last_pong_timestamp_tx, last_pong_timestamp_rx) = watch::channel(Instant::now());
154
155 let (bnd_request_tx, bnd_request_rx) = if options.bind_buffer_size > 0 {
156 let (tx, rx) = mpsc::channel(options.bind_buffer_size);
157 (Some(tx), Some(rx))
158 } else {
159 (None, None)
160 };
161 let flows = Arc::new(RwLock::new(IntMap::default()));
162
163 let mux = Self {
164 tx_frame_tx: tx_frame_tx.dupe(),
165 flows: flows.dupe(),
166 dropped_ports_tx: dropped_ports_tx.dupe(),
167 datagram_rx: Mutex::new(datagram_rx),
168 con_recv_stream_rx: Mutex::new(con_recv_stream_rx),
169 bnd_request_rx: bnd_request_rx.map(Mutex::new),
170 max_flow_id_retries: options.max_flow_id_retries,
171 rwnd: options.rwnd,
172 };
173 let taskdata = TaskData {
174 task: Task {
175 ws: Mutex::new(ws),
176 tx_frame_tx,
177 flows,
178 dropped_ports_tx,
179 con_recv_stream_tx,
180 last_pong_timestamp_tx,
181 default_rwnd_threshold: options.default_rwnd_threshold,
182 rwnd: options.rwnd,
183 datagram_tx,
184 bnd_request_tx,
185 keepalive_interval: options.keepalive_interval,
186 keepalive_timeout: options.keepalive_timeout,
187 },
188 dropped_ports_rx,
189 tx_frame_rx,
190 last_pong_timestamp_rx,
191 };
192 taskdata.spawn(task_joinset);
193 mux
194 }
195
196 /// Request a channel for `host` and `port`.
197 ///
198 /// # Arguments
199 /// * `host`: The host to forward to. While the current implementation
200 /// supports a domain of arbitrary length, Section 3.2.2 of
201 /// [RFC 3986](https://www.rfc-editor.org/rfc/rfc3986#section-3.2.2)
202 /// specifies that the host component of a URI is limited to 255 octets.
203 /// * `port`: The port to forward to.
204 ///
205 /// # Cancel safety
206 /// This function is not cancel safe. If the task is cancelled while waiting
207 /// for the channel to be established, that channel may be established but
208 /// inaccessible through normal means. Subsequent calls to this function
209 /// will result in a new channel being established.
210 ///
211 /// # Errors
212 /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
213 /// - Returns [`Error::FlowIdRejected`] if a flow ID could not be allocated
214 /// after `max_flow_id_retries` attempts.
215 #[tracing::instrument(skip(self), level = "debug")]
216 pub async fn new_stream_channel(&self, host: &[u8], port: u16) -> Result<MuxStream> {
217 let mut retries_left = self.max_flow_id_retries;
218 // Normally this should terminate in one loop
219 while retries_left > 0 {
220 retries_left -= 1;
221 let (stream_tx, stream_rx) = oneshot::channel();
222 let flow_id = {
223 let mut streams = self.flows.write();
224 // Allocate a new port
225 let flow_id = u32::next_available_key(&*streams);
226 trace!("flow_id = {flow_id:08x}");
227 streams.insert(flow_id, FlowSlot::Requested(stream_tx));
228 flow_id
229 };
230 trace!("sending `Connect`");
231 self.tx_frame_tx
232 .send(Frame::new_connect(host, port, flow_id, self.rwnd).finalize())
233 .or(Err(Error::Closed))?;
234 trace!("sending stream to user");
235 let stream = stream_rx
236 .await
237 // Happens if the task exits before sending the stream,
238 // thus `Closed` is the correct error
239 .or(Err(Error::Closed))?;
240 if let Some(s) = stream {
241 return Ok(s);
242 }
243 // For testing purposes. Make sure the previous flow ID is gone
244 debug_assert!(!self.flows.read().contains_key(&flow_id));
245 }
246 Err(Error::FlowIdRejected)
247 }
248
249 /// Accept a new stream channel from the remote peer.
250 ///
251 /// # Errors
252 /// - Returns [`Error::Closed`] if the connection is closed.
253 ///
254 /// # Cancel Safety
255 /// This function is cancel safe. If the task is cancelled while waiting
256 /// for a new connection, it is guaranteed that no connected stream will
257 /// be lost.
258 #[tracing::instrument(skip(self), level = "debug")]
259 pub async fn accept_stream_channel(&self) -> Result<MuxStream> {
260 poll_fn(|cx| self.con_recv_stream_rx.lock().poll_recv(cx))
261 .await
262 .ok_or(Error::Closed)
263 }
264
265 /// Get the next available datagram.
266 ///
267 /// # Errors
268 /// - Returns [`Error::Closed`] if the connection is closed.
269 ///
270 /// # Cancel Safety
271 /// This function is cancel safe. If the task is cancelled while waiting
272 /// for a datagram, it is guaranteed that no datagram will be lost.
273 #[tracing::instrument(skip(self), level = "debug")]
274 #[inline]
275 pub async fn get_datagram(&self) -> Result<Datagram> {
276 poll_fn(|cx| self.datagram_rx.lock().poll_recv(cx))
277 .await
278 .ok_or(Error::Closed)
279 }
280
281 /// Send a datagram
282 ///
283 /// # Errors
284 /// - Returns [`Error::DatagramHostTooLong`] if the destination host is
285 /// longer than 255 octets.
286 /// - Returns [`Error::Closed`] if the Multiplexor is already closed.
287 ///
288 /// # Cancel Safety
289 /// This function is cancel safe. If the task is cancelled, it is
290 /// guaranteed that the datagram has not been sent.
291 #[tracing::instrument(skip(self), level = "debug")]
292 #[inline]
293 pub async fn send_datagram(&self, datagram: Datagram) -> Result<()> {
294 if datagram.target_host.len() > 255 {
295 return Err(Error::DatagramHostTooLong);
296 }
297 let frame = Frame::new_datagram_owned(
298 datagram.flow_id,
299 datagram.target_host,
300 datagram.target_port,
301 datagram.data,
302 );
303 self.tx_frame_tx
304 .send(frame.finalize())
305 .or(Err(Error::Closed))?;
306 Ok(())
307 }
308
309 /// Request a `Bind` for `host` and `port`.
310 ///
311 /// # Arguments
312 /// * `host`: The local address or host to bind to. Hostname resolution might
313 /// not be supported by the remote peer.
314 /// * `port`: The local port to bind to.
315 ///
316 /// # Cancel Safety
317 /// This function is not cancel safe. If the task is cancelled while waiting
318 /// for the peer to reply, the user will not be able to receive whether the
319 /// peer accepted the bind request.
320 ///
321 /// # Errors
322 /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
323 #[tracing::instrument(skip(self), level = "debug")]
324 pub async fn request_bind(&self, host: &[u8], port: u16, bind_type: BindType) -> Result<bool> {
325 let (result_tx, result_rx) = oneshot::channel();
326 let flow_id = {
327 let mut streams = self.flows.write();
328 // Allocate a new port
329 let flow_id = u32::next_available_key(&*streams);
330 trace!("flow_id = {flow_id:08x}");
331 streams.insert(flow_id, FlowSlot::BindRequested(result_tx));
332 flow_id
333 };
334 let bnd_frame = Frame::new_bind(flow_id, bind_type, host, port).finalize();
335 self.tx_frame_tx.send(bnd_frame).or(Err(Error::Closed))?;
336 let result = result_rx.await.or(Err(Error::Closed))?;
337 Ok(result)
338 }
339
340 /// Accept a `Bind` request from the remote peer.
341 ///
342 /// # Cancel Safety
343 /// This function is cancel safe. If the task is cancelled while waiting
344 /// for a `Bind` request, it is guaranteed that no request will be lost.
345 ///
346 /// # Errors
347 /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
348 /// - Returns [`Error::UnsupportedOperation`] if the `Multiplexor` was not
349 /// configured to allow `Bind` requests.
350 #[tracing::instrument(skip(self), level = "debug")]
351 pub async fn next_bind_request(&self) -> Result<BindRequest<'static>> {
352 if let Some(rx) = self.bnd_request_rx.as_ref() {
353 poll_fn(|cx| rx.lock().poll_recv(cx))
354 .await
355 .ok_or(Error::Closed)
356 } else {
357 Err(Error::UnsupportedOperation)
358 }
359 }
360}
361
362impl Drop for Multiplexor {
363 fn drop(&mut self) {
364 if self.dropped_ports_tx.send(0).is_err() {
365 error!("Failed to inform task of dropped multiplexor");
366 }
367 }
368}
369
370#[derive(Debug)]
371struct EstablishedStreamData {
372 /// Channel for sending data to `MuxStream`'s `AsyncRead`
373 /// If `None`, we have received `Finish` from the peer but we can possibly still send data.
374 sender: Option<mpsc::Sender<Bytes>>,
375 /// Whether writes should succeed.
376 /// There are two cases for `true`:
377 /// 1. `Finish` has been sent.
378 /// 2. The stream has been removed from `inner.streams`.
379 // In general, our `Atomic*` types don't need more than `Relaxed` ordering
380 // because we are not protecting memory accesses, but rather counting the
381 // frames we have sent and received.
382 finish_sent: Arc<AtomicBool>,
383 /// Number of `Push` frames we are allowed to send before waiting for a `Acknowledge` frame.
384 psh_send_remaining: Arc<AtomicU32>,
385 /// Waker to wake up the task that sends frames because their `psh_send_remaining`
386 /// has increased.
387 writer_waker: Arc<AtomicWaker>,
388}
389
390impl EstablishedStreamData {
391 /// Process a `Finish` frame from the peer and thus disallowing further `AsyncRead` operations
392 /// Returns the sender if it was not already taken.
393 #[inline]
394 const fn disallow_read(&mut self) -> Option<mpsc::Sender<Bytes>> {
395 self.sender.take()
396 }
397
398 /// Process a `Acknowledge` frame from the peer
399 #[inline]
400 fn acknowledge(&self, acknowledged: u32) {
401 // Atomic ordering: as long as the value is incremented atomically,
402 // whether a writer sees the new value or the old value is not
403 // important. If it sees the old value and decides to return
404 // `Poll::Pending`, it will be woken up by the `Waker` anyway.
405 self.psh_send_remaining
406 .fetch_add(acknowledged, Ordering::Relaxed);
407 // Wake up the writer if it is waiting for `Acknowledge`
408 self.writer_waker.wake();
409 }
410
411 /// Disallow any `AsyncWrite` operations.
412 /// Note that this should not be used from inside the `MuxStream` itself
413 #[inline]
414 fn disallow_write(&self) -> bool {
415 // Atomic ordering:
416 // Load part:
417 // If the user calls `poll_shutdown`, but we see `true` here,
418 // the other end will receive a bogus `Reset` frame, which is fine.
419 // Store part:
420 // We need to make sure the writer can see the new value
421 // before we call `wake()`.
422 let old = self.finish_sent.swap(true, Ordering::AcqRel);
423 // If there is a writer waiting for `Acknowledge`, wake it up because it will never receive one.
424 // Waking it here and the user should receive a `BrokenPipe` error.
425 self.writer_waker.wake();
426 old
427 }
428}
429
430#[derive(Debug)]
431enum FlowSlot {
432 /// A `Connect` frame was sent and waiting for the peer to `Acknowledge`.
433 Requested(oneshot::Sender<Option<MuxStream>>),
434 /// The stream is established.
435 Established(EstablishedStreamData),
436 /// A `Bind` request was sent and waiting for the peer to `Acknowledge` or `Reset`.
437 BindRequested(oneshot::Sender<bool>),
438}
439
440impl FlowSlot {
441 /// Take the sender and set the slot to `Established`.
442 /// Returns `None` if the slot is already established.
443 #[inline]
444 fn establish(
445 &mut self,
446 data: EstablishedStreamData,
447 ) -> Option<oneshot::Sender<Option<MuxStream>>> {
448 // Make sure it is not replaced in the error case
449 if matches!(self, Self::Established(_) | Self::BindRequested(_)) {
450 error!("establishing an established or invalid slot");
451 return None;
452 }
453 let sender = match std::mem::replace(self, Self::Established(data)) {
454 Self::Requested(sender) => sender,
455 Self::Established(_) | Self::BindRequested(_) => unreachable!(),
456 };
457 Some(sender)
458 }
459
460 /// If the slot is established, send data. Otherwise, return `None`.
461 #[inline]
462 fn dispatch(&self, data: Bytes) -> Option<std::result::Result<(), TrySendError<()>>> {
463 if let Self::Established(stream_data) = self {
464 let r = stream_data
465 .sender
466 .as_ref()
467 .map(|sender| sender.try_send(data))?
468 .map_err(|e| match e {
469 TrySendError::Full(_) => TrySendError::Full(()),
470 TrySendError::Closed(_) => TrySendError::Closed(()),
471 });
472 Some(r)
473 } else {
474 None
475 }
476 }
477}
478
479/// Datagram frame data
480#[derive(Clone, Debug)]
481pub struct Datagram {
482 /// Flow ID
483 pub flow_id: u32,
484 /// Target host
485 pub target_host: Bytes,
486 /// Target port
487 pub target_port: u16,
488 /// Data
489 pub data: Bytes,
490}
491
492/// A `Bind` request that the user can respond to
493#[derive(Debug)]
494pub struct BindRequest<'data> {
495 /// Flow ID
496 flow_id: u32,
497 /// Bind payload
498 payload: BindPayload<'data>,
499 /// Place to respond to the bind request
500 tx_frame_tx: mpsc::UnboundedSender<FinalizedFrame>,
501}
502
503impl BindRequest<'_> {
504 /// Get the flow ID of the bind request
505 #[inline]
506 pub const fn flow_id(&self) -> u32 {
507 self.flow_id
508 }
509
510 /// Get the bind type of the bind request
511 #[inline]
512 pub const fn bind_type(&self) -> BindType {
513 self.payload.bind_type
514 }
515
516 /// Get the host of the bind request
517 #[inline]
518 pub fn host(&self) -> &[u8] {
519 self.payload.target_host.as_ref()
520 }
521
522 /// Get the port of the bind request
523 #[inline]
524 pub const fn port(&self) -> u16 {
525 self.payload.target_port
526 }
527
528 /// Accept or reject the bind request
529 ///
530 /// # Errors
531 /// - Returns [`Error::Closed`] if the `Multiplexor` is already closed.
532 #[tracing::instrument(skip(self), level = "debug")]
533 pub fn reply(&self, accepted: bool) -> Result<()> {
534 if accepted {
535 self.tx_frame_tx
536 .send(Frame::new_finish(self.flow_id).finalize())
537 } else {
538 self.tx_frame_tx
539 .send(Frame::new_reset(self.flow_id).finalize())
540 }
541 .or(Err(Error::Closed))
542 }
543}
544
545impl Drop for BindRequest<'_> {
546 /// Dropping a `BindRequest` will reject the request
547 fn drop(&mut self) {
548 self.reply(false).ok();
549 }
550}
551
552/// Randomly generate a new number
553pub trait IntKey: Eq + Hash + Copy + SampleUniform + PartialOrd {
554 /// The minimum value of the key
555 const MIN: Self;
556 /// The maximum value of the key
557 const MAX: Self;
558
559 /// Generate a new key that is not in the map
560 #[inline]
561 #[must_use]
562 fn next_available_key<V, S: BuildHasher>(map: &std::collections::HashMap<Self, V, S>) -> Self {
563 loop {
564 let i = rand::random_range(Self::MIN..Self::MAX);
565 if !map.contains_key(&i) {
566 break i;
567 }
568 }
569 }
570}
571
572macro_rules! impl_int_key {
573 ($($t:ty),*) => {
574 $(
575 impl IntKey for $t {
576 // 0 is for special use
577 const MIN : Self = 1;
578 const MAX : Self = Self::MAX;
579 }
580 )*
581 };
582}
583
584impl_int_key!(u8, u16, u32, u64, u128, usize);