stream_multiplexer/lib.rs
1//! This crate provides natural backpressure to classes of streams.
2//!
3//! Streams are gathered into 'channels' that can be polled via `recv()`. Channels are indpendent
4//! of each other and have their own backpressure.
5//!
6//! ## Example
7//!
8//! With a TCP server you may have two different classes of connections: Authenticated and
9//! Unauthenticated. By grouping each class of connection into it's own channel, you can favor the
10//! Authenticated connections over the Unauthenticated. This would provide a better experience for
11//! those that have been able to authenticate.
12//!
13//! ## Code Example
14//!
15//! ```
16//! use futures_util::stream::StreamExt;
17//! use tokio_util::compat::*;
18//!
19//! smol::block_on(async move {
20//! const CHANNEL_ONE: usize = 1;
21//! const CHANNEL_TWO: usize = 2;
22//!
23//! // Initialize a multiplexer
24//! let mut multiplexer = stream_multiplexer::Multiplexer::new();
25//!
26//! // Set up the recognized channels
27//! multiplexer.add_channel(CHANNEL_ONE);
28//! multiplexer.add_channel(CHANNEL_TWO);
29//!
30//! // Bind to a random port on localhost
31//! let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
32//! let local_addr = listener.local_addr().unwrap();
33//!
34//! // Set up a task to add incoming connections into multiplexer
35//! let mut incoming_multiplexer = multiplexer.clone();
36//! smol::Task::spawn(async move {
37//! for stream in listener.incoming() {
38//! match stream {
39//! Ok(stream) => {
40//! let stream = async_io::Async::new(stream).unwrap();
41//! let codec = tokio_util::codec::LinesCodec::new();
42//! let framed = tokio_util::codec::Framed::new(stream.compat(), codec);
43//! let (sink, stream) = framed.split();
44//! let _stream_id = incoming_multiplexer.add_stream_pair(sink, stream, CHANNEL_ONE);
45//! }
46//! Err(_) => unimplemented!()
47//! }
48//! }
49//! }).detach();
50//!
51//! // test clients to put into channels
52//! let mut client_1 = std::net::TcpStream::connect(local_addr).unwrap();
53//! let mut client_2 = std::net::TcpStream::connect(local_addr).unwrap();
54//!
55//! let mut multiplexer_ch_1 = multiplexer.clone();
56//!
57//! // Simple server that echos the data back to the stream and moves the stream to channel 2.
58//! smol::Task::spawn(async move {
59//! while let Ok((stream_id, message)) = multiplexer_ch_1.recv(CHANNEL_ONE).await {
60//! match message {
61//! Some(Ok(data)) => {
62//! // echo the data back and move it to channel 2
63//! multiplexer_ch_1.send(vec![stream_id], data);
64//! multiplexer_ch_1.change_stream_channel(stream_id, CHANNEL_TWO);
65//! }
66//! Some(Err(err)) => {
67//! // the stream had an error
68//! }
69//! None => {
70//! // stream_id has been dropped
71//! }
72//! }
73//! }
74//! }).detach();
75//! });
76//! ```
77
78#![forbid(unsafe_code)]
79#![warn(
80 missing_docs,
81 missing_debug_implementations,
82 missing_copy_implementations,
83 trivial_casts,
84 trivial_numeric_casts,
85 unreachable_pub,
86 unsafe_code,
87 unstable_features,
88 unused_import_braces,
89 unused_qualifications,
90 rust_2018_idioms
91)]
92
93mod channel;
94
95mod error;
96mod stream_dropper;
97
98use crate::stream_dropper::*;
99use channel::Channel;
100pub use error::MultiplexerError;
101
102use async_mutex::Mutex;
103use dashmap::DashMap;
104use futures_util::future::FutureExt;
105use futures_util::sink::{Sink, SinkExt};
106use futures_util::stream::{FuturesUnordered, StreamExt, TryStream};
107use parking_lot::RwLock;
108use sharded_slab::Slab;
109
110use std::sync::Arc;
111
112/// A value returned by `Multiplexer` when a stream pair is added.
113pub type StreamId = usize;
114
115/// Used when registering channels with `Multiplexer`.
116pub type ChannelId = usize;
117
118struct ChannelChange<St> {
119 pub(crate) next_channel_id: ChannelId,
120 pub(crate) stream_id: StreamId,
121 pub(crate) stream: St,
122}
123
124type AddStreamToChannelTx<St> = async_channel::Sender<StreamDropper<St>>;
125
126///
127#[derive(Debug)]
128pub struct Multiplexer<St, Si>
129where
130 St: 'static,
131{
132 stream_controls: Arc<DashMap<StreamId, StreamDropControl>>,
133 sinks: Arc<RwLock<Slab<Arc<Mutex<Si>>>>>,
134 channels: Arc<DashMap<ChannelId, (AddStreamToChannelTx<St>, Mutex<Channel<St>>)>>,
135 channel_change_rx: async_channel::Receiver<ChannelChange<St>>,
136 channel_change_tx: async_channel::Sender<ChannelChange<St>>,
137}
138
139impl<St, Si> Clone for Multiplexer<St, Si> {
140 fn clone(&self) -> Self {
141 let stream_controls = Arc::clone(&self.stream_controls);
142 let sinks = Arc::clone(&self.sinks);
143 let channels = Arc::clone(&self.channels);
144 let channel_change_rx = self.channel_change_rx.clone();
145 let channel_change_tx = self.channel_change_tx.clone();
146
147 Self {
148 stream_controls,
149 sinks,
150 channels,
151 channel_change_rx,
152 channel_change_tx,
153 }
154 }
155}
156
157impl<St, Si> Multiplexer<St, Si>
158where
159 St: TryStream + Unpin,
160 St::Ok: Clone,
161 Si: Sink<St::Ok> + Unpin,
162 Si::Error: std::fmt::Debug,
163{
164 /// Creates a Multiplexer
165 pub fn new() -> Self {
166 let (channel_change_tx, channel_change_rx) = async_channel::unbounded();
167
168 Self {
169 channel_change_rx,
170 channel_change_tx,
171 channels: Arc::new(DashMap::new()),
172 sinks: Arc::new(RwLock::new(Slab::new())),
173 stream_controls: Arc::new(DashMap::new()),
174 }
175 }
176
177 /// Adds a channel id to the internal table used for validation checks.
178 ///
179 /// Returns an error if `channel` already exists.
180 pub fn add_channel(
181 &mut self,
182 channel_id: ChannelId,
183 ) -> Result<(), MultiplexerError<Si::Error>> {
184 // Ensure that the channel does not already exist
185 if self.has_channel(channel_id) {
186 return Err(MultiplexerError::DuplicateChannel(channel_id));
187 }
188
189 // Create the channel and store it
190 let (add_tx, channel) = Channel::new(channel_id);
191 let channel = Mutex::new(channel);
192 self.channels.insert(channel_id, (add_tx, channel));
193
194 Ok(())
195 }
196
197 /// Returns true if the channel id exists.
198 #[inline]
199 pub fn has_channel(&self, channel: ChannelId) -> bool {
200 self.channels.contains_key(&channel)
201 }
202
203 /// Removes a channel.
204 ///
205 /// Returns `true` if found, `false` otherwise.
206 pub fn remove_channel(&mut self, channel: ChannelId) -> bool {
207 // FIXME: What do we do with the sockets in that channel?
208 self.channels.remove(&channel)
209 }
210
211 /// Sends `data` to a set of `stream_ids`, waiting for them to be sent.
212 pub async fn send(
213 &self,
214 stream_ids: impl IntoIterator<Item = StreamId>,
215 data: St::Ok,
216 ) -> Vec<Result<StreamId, MultiplexerError<Si::Error>>> {
217 let futures: FuturesUnordered<_> = stream_ids
218 .into_iter()
219 .map(|stream_id| {
220 // Clone the data before moving it into the async block.
221 let data = data.clone();
222
223 // Async block is the return value.
224 async move {
225 let sink = {
226 // Fetch the write-half of the stream from the slab
227 let slab_read_guard = self.sinks.read(); // read guard
228 let sink = slab_read_guard.get(stream_id); // shard guard
229 sink.map(|s| Arc::clone(&s))
230 };
231
232 match sink {
233 Some(sink) => {
234 // Sending data via the stream
235 sink.lock()
236 .await
237 .send(data)
238 .await
239 // to match the return type
240 .map(|()| stream_id)
241 .map_err(|err| MultiplexerError::SendError(stream_id, err))
242 }
243 None => {
244 // It's possible for the stream to not be in the slab, should be a rare case.
245 Err(MultiplexerError::UnknownStream(stream_id))
246 }
247 }
248 }
249 })
250 .collect();
251
252 // Waiting for all of the sends to complete
253 futures.collect().await
254 }
255
256 /// Adds `stream` to the `channel` and stores `sink`.
257 ///
258 /// Returns a `StreamId` that represents the pair. It can be used in functions such as `send()` or `change_stream_channel()`.
259 pub fn add_stream_pair(
260 &mut self,
261 sink: Si,
262 stream: St,
263 channel_id: ChannelId,
264 ) -> Result<StreamId, MultiplexerError<Si::Error>> {
265 // Check that we have the channel before we commit the stream to the slab.
266 let channel = self
267 .channels
268 .get(&channel_id)
269 .ok_or_else(|| MultiplexerError::UnknownChannel(channel_id))?;
270
271 // Add the write-half of the stream to the slab, can't return the stream if there isn't room.
272 let stream_id = self
273 .sinks
274 .write()
275 .insert(Arc::new(Mutex::new(sink)))
276 .ok_or_else(move || MultiplexerError::ChannelFull(channel_id))?;
277
278 // wrap the stream in a dropper so that it can be ejected
279 let (stream_control, stream_dropper) =
280 StreamDropControl::wrap(stream_id, stream, self.channel_change_tx.clone());
281 self.stream_controls.insert(stream_id, stream_control);
282
283 // Add the read-half of the stream to the channel
284 channel
285 .0
286 .try_send(stream_dropper)
287 .map_err(|_err| MultiplexerError::ChannelAdd(stream_id, channel_id))?;
288
289 Ok(stream_id)
290 }
291
292 /// Signals to the stream that it should move to a given channel.
293 ///
294 /// The channel change is not instantaneous. Calling `.recv()` on the stream's current channel
295 /// may result in that channel receiving more of the stream's data.
296 pub fn change_stream_channel(
297 &self,
298 stream_id: StreamId,
299 channel_id: ChannelId,
300 ) -> Result<(), MultiplexerError<Si::Error>> {
301 // Ensure that the channel exists
302 if !self.has_channel(channel_id) {
303 return Err(MultiplexerError::UnknownChannel(channel_id));
304 }
305
306 let control = self
307 .stream_controls
308 .get(&stream_id)
309 .ok_or_else(|| MultiplexerError::UnknownStream(stream_id))?;
310
311 control.change_channel(channel_id);
312
313 Ok(())
314 }
315
316 /// Removes the stream from the multiplexer
317 pub fn remove_stream(&mut self, stream_id: StreamId) -> bool {
318 log::trace!("Attempting to remove stream {}", stream_id);
319
320 // Removing it first from the sinks because the sinks are checked when changing channels
321 let res = self.sinks.write().take(stream_id).is_some();
322 log::trace!("Stream {} exists", stream_id);
323
324 // If the stream is changing channels, it may not have a control and will be dropped in process_add_channel
325 if let Some(control) = self.stream_controls.get(&stream_id) {
326 log::trace!("Stream {} is getting dropped()", stream_id);
327 control.drop_stream();
328 }
329
330 res
331 }
332
333 /// Receives the next packet available from a channel:
334 ///
335 /// Returns a tuple of the stream's ID and it's data or an error.
336 pub async fn recv(
337 &mut self,
338 channel_id: ChannelId,
339 ) -> Result<(StreamId, Option<St::Item>), MultiplexerError<Si::Error>> {
340 log::debug!("recv({})", channel_id);
341 match self.channels.get(&channel_id) {
342 Some(channel_guard) => {
343 log::debug!("recv({}) before loop {{}}", channel_id);
344 let mut channel = channel_guard.value().1.lock().await;
345 loop {
346 log::debug!("recv({}) awaiting the select", channel_id);
347 futures_util::select! {
348 channel_next = channel.next().fuse() => {
349 return Ok(channel_next);
350 }
351 add_res = self.channel_change_rx.next().fuse() => {
352 log::debug!("recv({}) channel_change has message.", channel_id);
353 if let Some(add_res) = add_res {
354 // If the stream fails to get added to this channel, bail
355 self.process_add_channel(add_res)?;
356 }
357 }
358 }
359 }
360 }
361 None => Err(MultiplexerError::UnknownChannel(channel_id)),
362 }
363 }
364
365 fn process_add_channel(
366 &mut self,
367 channel_change: ChannelChange<St>,
368 ) -> Result<(), MultiplexerError<Si::Error>> {
369 let ChannelChange {
370 next_channel_id,
371 stream_id,
372 stream,
373 } = channel_change;
374
375 // Check that we have the channel before we commit the stream to the slab.
376 let channel = self
377 .channels
378 .get(&next_channel_id)
379 .ok_or_else(|| MultiplexerError::ChannelAdd(stream_id, next_channel_id))?;
380
381 // If the sink has been dropped, don't add this half to a channel
382 if self.sinks.read().contains(stream_id) {
383 // Wrap the stream in another drop control
384 let (stream_control, stream_dropper) =
385 StreamDropControl::wrap(stream_id, stream, self.channel_change_tx.clone());
386 self.stream_controls.insert(stream_id, stream_control);
387
388 // FIXME: If a channel is removed while a stream is being transitioned to it why are we
389 // returning the error to some other channel's next() call?
390 channel
391 .0
392 .try_send(stream_dropper)
393 .map_err(|_err| MultiplexerError::ChannelAdd(stream_id, next_channel_id))?;
394 }
395
396 Ok(())
397 }
398}