stream_multiplexer/
lib.rs

1//! This crate provides natural backpressure to classes of streams.
2//!
3//! Streams are gathered into 'channels' that can be polled via `recv()`. Channels are indpendent
4//! of each other and have their own backpressure.
5//!
6//! ## Example
7//!
8//! With a TCP server you may have two different classes of connections: Authenticated and
9//! Unauthenticated. By grouping each class of connection into it's own channel, you can favor the
10//! Authenticated connections over the Unauthenticated. This would provide a better experience for
11//! those that have been able to authenticate.
12//!
13//! ## Code Example
14//!
15//! ```
16//! use futures_util::stream::StreamExt;
17//! use tokio_util::compat::*;
18//!
19//! smol::block_on(async move {
20//!     const CHANNEL_ONE: usize = 1;
21//!     const CHANNEL_TWO: usize = 2;
22//!
23//!     // Initialize a multiplexer
24//!     let mut multiplexer = stream_multiplexer::Multiplexer::new();
25//!
26//!     // Set up the recognized channels
27//!     multiplexer.add_channel(CHANNEL_ONE);
28//!     multiplexer.add_channel(CHANNEL_TWO);
29//!
30//!     // Bind to a random port on localhost
31//!     let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
32//!     let local_addr = listener.local_addr().unwrap();
33//!
34//!     // Set up a task to add incoming connections into multiplexer
35//!     let mut incoming_multiplexer = multiplexer.clone();
36//!     smol::Task::spawn(async move {
37//!         for stream in listener.incoming() {
38//!             match stream {
39//!                 Ok(stream) => {
40//!                     let stream = async_io::Async::new(stream).unwrap();
41//!                     let codec = tokio_util::codec::LinesCodec::new();
42//!                     let framed = tokio_util::codec::Framed::new(stream.compat(), codec);
43//!                     let (sink, stream) = framed.split();
44//!                     let _stream_id = incoming_multiplexer.add_stream_pair(sink, stream, CHANNEL_ONE);
45//!                 }
46//!                 Err(_) => unimplemented!()
47//!             }
48//!         }
49//!     }).detach();
50//!
51//!     // test clients to put into channels
52//!     let mut client_1 = std::net::TcpStream::connect(local_addr).unwrap();
53//!     let mut client_2 = std::net::TcpStream::connect(local_addr).unwrap();
54//!
55//!     let mut multiplexer_ch_1 = multiplexer.clone();
56//!
57//!     // Simple server that echos the data back to the stream and moves the stream to channel 2.
58//!     smol::Task::spawn(async move {
59//!         while let Ok((stream_id, message)) = multiplexer_ch_1.recv(CHANNEL_ONE).await {
60//!             match message {
61//!                 Some(Ok(data)) => {
62//!                     // echo the data back and move it to channel 2
63//!                     multiplexer_ch_1.send(vec![stream_id], data);
64//!                     multiplexer_ch_1.change_stream_channel(stream_id, CHANNEL_TWO);
65//!                 }
66//!                 Some(Err(err)) => {
67//!                     // the stream had an error
68//!                 }
69//!                 None => {
70//!                     // stream_id has been dropped
71//!                 }
72//!             }
73//!         }
74//!     }).detach();
75//! });
76//! ```
77
78#![forbid(unsafe_code)]
79#![warn(
80    missing_docs,
81    missing_debug_implementations,
82    missing_copy_implementations,
83    trivial_casts,
84    trivial_numeric_casts,
85    unreachable_pub,
86    unsafe_code,
87    unstable_features,
88    unused_import_braces,
89    unused_qualifications,
90    rust_2018_idioms
91)]
92
93mod channel;
94
95mod error;
96mod stream_dropper;
97
98use crate::stream_dropper::*;
99use channel::Channel;
100pub use error::MultiplexerError;
101
102use async_mutex::Mutex;
103use dashmap::DashMap;
104use futures_util::future::FutureExt;
105use futures_util::sink::{Sink, SinkExt};
106use futures_util::stream::{FuturesUnordered, StreamExt, TryStream};
107use parking_lot::RwLock;
108use sharded_slab::Slab;
109
110use std::sync::Arc;
111
112/// A value returned by `Multiplexer` when a stream pair is added.
113pub type StreamId = usize;
114
115/// Used when registering channels with `Multiplexer`.
116pub type ChannelId = usize;
117
118struct ChannelChange<St> {
119    pub(crate) next_channel_id: ChannelId,
120    pub(crate) stream_id: StreamId,
121    pub(crate) stream: St,
122}
123
124type AddStreamToChannelTx<St> = async_channel::Sender<StreamDropper<St>>;
125
126///
127#[derive(Debug)]
128pub struct Multiplexer<St, Si>
129where
130    St: 'static,
131{
132    stream_controls: Arc<DashMap<StreamId, StreamDropControl>>,
133    sinks: Arc<RwLock<Slab<Arc<Mutex<Si>>>>>,
134    channels: Arc<DashMap<ChannelId, (AddStreamToChannelTx<St>, Mutex<Channel<St>>)>>,
135    channel_change_rx: async_channel::Receiver<ChannelChange<St>>,
136    channel_change_tx: async_channel::Sender<ChannelChange<St>>,
137}
138
139impl<St, Si> Clone for Multiplexer<St, Si> {
140    fn clone(&self) -> Self {
141        let stream_controls = Arc::clone(&self.stream_controls);
142        let sinks = Arc::clone(&self.sinks);
143        let channels = Arc::clone(&self.channels);
144        let channel_change_rx = self.channel_change_rx.clone();
145        let channel_change_tx = self.channel_change_tx.clone();
146
147        Self {
148            stream_controls,
149            sinks,
150            channels,
151            channel_change_rx,
152            channel_change_tx,
153        }
154    }
155}
156
157impl<St, Si> Multiplexer<St, Si>
158where
159    St: TryStream + Unpin,
160    St::Ok: Clone,
161    Si: Sink<St::Ok> + Unpin,
162    Si::Error: std::fmt::Debug,
163{
164    /// Creates a Multiplexer
165    pub fn new() -> Self {
166        let (channel_change_tx, channel_change_rx) = async_channel::unbounded();
167
168        Self {
169            channel_change_rx,
170            channel_change_tx,
171            channels: Arc::new(DashMap::new()),
172            sinks: Arc::new(RwLock::new(Slab::new())),
173            stream_controls: Arc::new(DashMap::new()),
174        }
175    }
176
177    /// Adds a channel id to the internal table used for validation checks.
178    ///
179    /// Returns an error if `channel` already exists.
180    pub fn add_channel(
181        &mut self,
182        channel_id: ChannelId,
183    ) -> Result<(), MultiplexerError<Si::Error>> {
184        // Ensure that the channel does not already exist
185        if self.has_channel(channel_id) {
186            return Err(MultiplexerError::DuplicateChannel(channel_id));
187        }
188
189        // Create the channel and store it
190        let (add_tx, channel) = Channel::new(channel_id);
191        let channel = Mutex::new(channel);
192        self.channels.insert(channel_id, (add_tx, channel));
193
194        Ok(())
195    }
196
197    /// Returns true if the channel id exists.
198    #[inline]
199    pub fn has_channel(&self, channel: ChannelId) -> bool {
200        self.channels.contains_key(&channel)
201    }
202
203    /// Removes a channel.
204    ///
205    /// Returns `true` if found, `false` otherwise.
206    pub fn remove_channel(&mut self, channel: ChannelId) -> bool {
207        // FIXME: What do we do with the sockets in that channel?
208        self.channels.remove(&channel)
209    }
210
211    /// Sends `data` to a set of `stream_ids`, waiting for them to be sent.
212    pub async fn send(
213        &self,
214        stream_ids: impl IntoIterator<Item = StreamId>,
215        data: St::Ok,
216    ) -> Vec<Result<StreamId, MultiplexerError<Si::Error>>> {
217        let futures: FuturesUnordered<_> = stream_ids
218            .into_iter()
219            .map(|stream_id| {
220                // Clone the data before moving it into the async block.
221                let data = data.clone();
222
223                // Async block is the return value.
224                async move {
225                    let sink = {
226                        // Fetch the write-half of the stream from the slab
227                        let slab_read_guard = self.sinks.read(); // read guard
228                        let sink = slab_read_guard.get(stream_id); // shard guard
229                        sink.map(|s| Arc::clone(&s))
230                    };
231
232                    match sink {
233                        Some(sink) => {
234                            // Sending data via the stream
235                            sink.lock()
236                                .await
237                                .send(data)
238                                .await
239                                // to match the return type
240                                .map(|()| stream_id)
241                                .map_err(|err| MultiplexerError::SendError(stream_id, err))
242                        }
243                        None => {
244                            // It's possible for the stream to not be in the slab, should be a rare case.
245                            Err(MultiplexerError::UnknownStream(stream_id))
246                        }
247                    }
248                }
249            })
250            .collect();
251
252        // Waiting for all of the sends to complete
253        futures.collect().await
254    }
255
256    /// Adds `stream` to the `channel` and stores `sink`.
257    ///
258    /// Returns a `StreamId` that represents the pair. It can be used in functions such as `send()` or `change_stream_channel()`.
259    pub fn add_stream_pair(
260        &mut self,
261        sink: Si,
262        stream: St,
263        channel_id: ChannelId,
264    ) -> Result<StreamId, MultiplexerError<Si::Error>> {
265        // Check that we have the channel before we commit the stream to the slab.
266        let channel = self
267            .channels
268            .get(&channel_id)
269            .ok_or_else(|| MultiplexerError::UnknownChannel(channel_id))?;
270
271        // Add the write-half of the stream to the slab, can't return the stream if there isn't room.
272        let stream_id = self
273            .sinks
274            .write()
275            .insert(Arc::new(Mutex::new(sink)))
276            .ok_or_else(move || MultiplexerError::ChannelFull(channel_id))?;
277
278        // wrap the stream in a dropper so that it can be ejected
279        let (stream_control, stream_dropper) =
280            StreamDropControl::wrap(stream_id, stream, self.channel_change_tx.clone());
281        self.stream_controls.insert(stream_id, stream_control);
282
283        // Add the read-half of the stream to the channel
284        channel
285            .0
286            .try_send(stream_dropper)
287            .map_err(|_err| MultiplexerError::ChannelAdd(stream_id, channel_id))?;
288
289        Ok(stream_id)
290    }
291
292    /// Signals to the stream that it should move to a given channel.
293    ///
294    /// The channel change is not instantaneous. Calling `.recv()` on the stream's current channel
295    /// may result in that channel receiving more of the stream's data.
296    pub fn change_stream_channel(
297        &self,
298        stream_id: StreamId,
299        channel_id: ChannelId,
300    ) -> Result<(), MultiplexerError<Si::Error>> {
301        // Ensure that the channel exists
302        if !self.has_channel(channel_id) {
303            return Err(MultiplexerError::UnknownChannel(channel_id));
304        }
305
306        let control = self
307            .stream_controls
308            .get(&stream_id)
309            .ok_or_else(|| MultiplexerError::UnknownStream(stream_id))?;
310
311        control.change_channel(channel_id);
312
313        Ok(())
314    }
315
316    /// Removes the stream from the multiplexer
317    pub fn remove_stream(&mut self, stream_id: StreamId) -> bool {
318        log::trace!("Attempting to remove stream {}", stream_id);
319
320        // Removing it first from the sinks because the sinks are checked when changing channels
321        let res = self.sinks.write().take(stream_id).is_some();
322        log::trace!("Stream {} exists", stream_id);
323
324        // If the stream is changing channels, it may not have a control and will be dropped in process_add_channel
325        if let Some(control) = self.stream_controls.get(&stream_id) {
326            log::trace!("Stream {} is getting dropped()", stream_id);
327            control.drop_stream();
328        }
329
330        res
331    }
332
333    /// Receives the next packet available from a channel:
334    ///
335    /// Returns a tuple of the stream's ID and it's data or an error.
336    pub async fn recv(
337        &mut self,
338        channel_id: ChannelId,
339    ) -> Result<(StreamId, Option<St::Item>), MultiplexerError<Si::Error>> {
340        log::debug!("recv({})", channel_id);
341        match self.channels.get(&channel_id) {
342            Some(channel_guard) => {
343                log::debug!("recv({}) before loop {{}}", channel_id);
344                let mut channel = channel_guard.value().1.lock().await;
345                loop {
346                    log::debug!("recv({}) awaiting the select", channel_id);
347                    futures_util::select! {
348                        channel_next = channel.next().fuse() => {
349                            return Ok(channel_next);
350                        }
351                        add_res = self.channel_change_rx.next().fuse() => {
352                            log::debug!("recv({}) channel_change has message.", channel_id);
353                            if let Some(add_res) = add_res {
354                                // If the stream fails to get added to this channel, bail
355                                self.process_add_channel(add_res)?;
356                            }
357                        }
358                    }
359                }
360            }
361            None => Err(MultiplexerError::UnknownChannel(channel_id)),
362        }
363    }
364
365    fn process_add_channel(
366        &mut self,
367        channel_change: ChannelChange<St>,
368    ) -> Result<(), MultiplexerError<Si::Error>> {
369        let ChannelChange {
370            next_channel_id,
371            stream_id,
372            stream,
373        } = channel_change;
374
375        // Check that we have the channel before we commit the stream to the slab.
376        let channel = self
377            .channels
378            .get(&next_channel_id)
379            .ok_or_else(|| MultiplexerError::ChannelAdd(stream_id, next_channel_id))?;
380
381        // If the sink has been dropped, don't add this half to a channel
382        if self.sinks.read().contains(stream_id) {
383            // Wrap the stream in another drop control
384            let (stream_control, stream_dropper) =
385                StreamDropControl::wrap(stream_id, stream, self.channel_change_tx.clone());
386            self.stream_controls.insert(stream_id, stream_control);
387
388            // FIXME: If a channel is removed while a stream is being transitioned to it why are we
389            // returning the error to some other channel's next() call?
390            channel
391                .0
392                .try_send(stream_dropper)
393                .map_err(|_err| MultiplexerError::ChannelAdd(stream_id, next_channel_id))?;
394        }
395
396        Ok(())
397    }
398}