netbeam/
multiplex.rs

1//! # Network Stream Multiplexing
2//!
3//! Provides a robust multiplexing system for network streams, allowing multiple logical connections
4//! to share a single underlying network connection. This module implements connection multiplexing
5//! with automatic stream management and bi-directional communication support.
6//!
7//! ## Features
8//!
9//! * Dynamic stream multiplexing over a single connection
10//! * Automatic stream ID generation and management
11//! * Bi-directional communication channels
12//! * Thread-safe subscription management
13//! * Pre-action and post-action hooks for stream lifecycle
14//! * Support for custom connection key types
15//! * Reliable ordered message delivery
16//! * Automatic cleanup of dropped streams
17//!
18//! ## Important Notes
19//!
20//! * Streams are automatically cleaned up when dropped
21//! * All operations are thread-safe and async-aware
22//! * Messages within each multiplexed stream maintain order
23//! * Custom key types must implement MultiplexedConnKey trait
24//! * Pre-action signals ensure proper stream initialization
25//! * Post-action signals handle graceful stream closure
26//!
27//! ## Related Components
28//!
29//! * `reliable_conn`: Underlying reliable stream implementation
30//! * `sync::subscription`: Stream subscription management
31//! * `sync::network_application`: Network action handling
32//! * `sync::RelativeNodeType`: Node type identification
33//!
34
35use crate::reliable_conn::{ReliableOrderedStreamToTarget, ReliableOrderedStreamToTargetExt};
36use crate::sync::network_application::{PostActionChannel, PreActionChannel, INITIAL_CAPACITY};
37use crate::sync::subscription::{
38    close_sequence_for_multiplexed_bistream, Subscribable, SubscriptionBiStream,
39};
40use crate::sync::{RelativeNodeType, SymmetricConvID};
41use anyhow::Error;
42use async_trait::async_trait;
43use citadel_io::tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
44use citadel_io::tokio::sync::Mutex;
45use citadel_io::RwLock;
46use serde::de::DeserializeOwned;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use std::fmt::Debug;
50use std::hash::Hash;
51use std::ops::Deref;
52use std::sync::atomic::{AtomicU64, Ordering};
53use std::sync::Arc;
54
55/// A trait representing a multiplexed connection key.
56pub trait MultiplexedConnKey:
57    Debug + Eq + Hash + Copy + Send + Sync + Serialize + DeserializeOwned + IDGen<Self>
58{
59}
60impl<T: Debug + Eq + Hash + Copy + Send + Sync + Serialize + DeserializeOwned + IDGen<Self>>
61    MultiplexedConnKey for T
62{
63}
64
65/// A trait for generating IDs for multiplexed connections.
66pub trait IDGen<Key: MultiplexedConnKey> {
67    /// The type of container used to generate IDs.
68    type Container: Send + Sync;
69    /// Generates a new container for ID generation.
70    fn generate_container() -> Self::Container;
71    /// Generates the next ID in the sequence.
72    fn generate_next(container: &Self::Container) -> Self;
73    /// Gets the proposed next ID in the sequence.
74    fn get_proposed_next(container: &Self::Container) -> Key;
75}
76
77impl IDGen<SymmetricConvID> for SymmetricConvID {
78    type Container = Arc<AtomicU64>;
79
80    fn generate_container() -> Self::Container {
81        Arc::new(AtomicU64::new(0))
82    }
83
84    fn generate_next(container: &Self::Container) -> SymmetricConvID {
85        (1 + container.fetch_add(1, Ordering::Relaxed)).into()
86    }
87
88    fn get_proposed_next(container: &Self::Container) -> SymmetricConvID {
89        (1 + container.load(Ordering::Relaxed)).into()
90    }
91}
92
93/// A multiplexed connection.
94pub struct MultiplexedConn<K: MultiplexedConnKey = SymmetricConvID> {
95    inner: Arc<MultiplexedConnInner<K>>,
96}
97
98/// The inner implementation of a multiplexed connection.
99pub struct MultiplexedConnInner<K: MultiplexedConnKey> {
100    /// The underlying reliable connection.
101    pub(crate) conn: Arc<dyn ReliableOrderedStreamToTarget>,
102    /// A map of subscribers.
103    subscribers: RwLock<HashMap<K, MemorySender>>,
104    /// A channel for pre-action signals.
105    pre_open_container: PreActionChannel<K>,
106    /// A channel for post-action signals.
107    post_close_container: PostActionChannel<K>,
108    /// The ID generator.
109    id_gen: K::Container,
110    /// The current latest subscribed ID.
111    current_latest_subscribed: K::Container,
112    /// The node type.
113    node_type: RelativeNodeType,
114}
115
116/// A memory sender.
117pub struct MemorySender {
118    /// The sender.
119    tx: UnboundedSender<Vec<u8>>,
120    /// The pre-reserved receiver.
121    pre_reserved_rx: Option<UnboundedReceiver<Vec<u8>>>,
122}
123
124impl Deref for MemorySender {
125    type Target = UnboundedSender<Vec<u8>>;
126
127    fn deref(&self) -> &Self::Target {
128        &self.tx
129    }
130}
131
132impl<K: MultiplexedConnKey> Deref for MultiplexedConn<K> {
133    type Target = MultiplexedConnInner<K>;
134
135    fn deref(&self) -> &Self::Target {
136        self.inner.deref()
137    }
138}
139
140/// A multiplexed packet.
141#[derive(Serialize, Deserialize)]
142#[serde(bound = "")]
143pub(crate) enum MultiplexedPacket<K: MultiplexedConnKey> {
144    /// An application layer packet.
145    ApplicationLayer { id: K, payload: Vec<u8> },
146    /// A post-drop packet.
147    PostDrop { id: K },
148    /// A pre-create packet.
149    PreCreate { id: K },
150    /// A greeter packet.
151    Greeter,
152}
153
154impl<K: MultiplexedConnKey> MultiplexedConn<K> {
155    /// Creates a new multiplexed connection.
156    pub fn new<T: ReliableOrderedStreamToTarget + 'static>(
157        node_type: RelativeNodeType,
158        conn: T,
159    ) -> Self {
160        let id_gen = K::generate_container();
161        let ids: Vec<K> = (0..INITIAL_CAPACITY)
162            .map(|_| <K as IDGen<K>>::generate_next(&id_gen))
163            .collect();
164        // the next two lines will generate a list of pre-established bistreams
165        let post_close_container = PostActionChannel::new(&ids);
166        let mut subscribers = HashMap::new();
167
168        for id in ids {
169            let (tx, pre_reserved_rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
170            subscribers.insert(
171                id,
172                MemorySender {
173                    tx,
174                    pre_reserved_rx: Some(pre_reserved_rx),
175                },
176            );
177        }
178
179        let current_latest_subscribed = K::generate_container();
180
181        Self {
182            inner: Arc::new(MultiplexedConnInner {
183                conn: Arc::new(conn),
184                subscribers: RwLock::new(subscribers),
185                pre_open_container: PreActionChannel::new(),
186                post_close_container,
187                current_latest_subscribed,
188                id_gen,
189                node_type,
190            }),
191        }
192    }
193}
194
195impl<K: MultiplexedConnKey> Clone for MultiplexedConn<K> {
196    fn clone(&self) -> Self {
197        Self {
198            inner: self.inner.clone(),
199        }
200    }
201}
202
203/// A multiplexed subscription.
204pub struct MultiplexedSubscription<'a, K: MultiplexedConnKey = SymmetricConvID> {
205    /// A reference to the multiplexed connection.
206    ptr: &'a MultiplexedConn<K>,
207    /// The receiver.
208    receiver: Option<Mutex<UnboundedReceiver<Vec<u8>>>>,
209    /// The ID.
210    id: K,
211}
212
213impl<K: MultiplexedConnKey> SubscriptionBiStream for MultiplexedSubscription<'_, K> {
214    type Conn = Arc<dyn ReliableOrderedStreamToTarget + 'static>;
215    type ID = K;
216
217    fn conn(&self) -> &Self::Conn {
218        &self.ptr.conn
219    }
220
221    fn receiver(&self) -> &Mutex<UnboundedReceiver<Vec<u8>>> {
222        self.receiver.as_ref().unwrap()
223    }
224
225    fn id(&self) -> Self::ID {
226        self.id
227    }
228
229    fn node_type(&self) -> RelativeNodeType {
230        self.ptr.node_type
231    }
232}
233
234impl<K: MultiplexedConnKey> From<MultiplexedSubscription<'_, K>>
235    for OwnedMultiplexedSubscription<K>
236{
237    fn from(mut this: MultiplexedSubscription<'_, K>) -> Self {
238        let ret = Self {
239            ptr: this.ptr.clone(),
240            receiver: this.receiver.take().unwrap(),
241            id: this.id,
242        };
243
244        // prevent destructor from running
245        std::mem::forget(this);
246        ret
247    }
248}
249
250/// An owned multiplexed subscription.
251pub struct OwnedMultiplexedSubscription<K: MultiplexedConnKey + 'static = SymmetricConvID> {
252    /// The multiplexed connection.
253    ptr: MultiplexedConn<K>,
254    /// The receiver.
255    receiver: Mutex<UnboundedReceiver<Vec<u8>>>,
256    /// The ID.
257    id: K,
258}
259
260impl<K: MultiplexedConnKey> SubscriptionBiStream for OwnedMultiplexedSubscription<K> {
261    type Conn = Arc<dyn ReliableOrderedStreamToTarget + 'static>;
262    type ID = K;
263
264    fn conn(&self) -> &Self::Conn {
265        &self.ptr.conn
266    }
267
268    fn receiver(&self) -> &Mutex<UnboundedReceiver<Vec<u8>>> {
269        &self.receiver
270    }
271
272    fn id(&self) -> Self::ID {
273        self.id
274    }
275
276    fn node_type(&self) -> RelativeNodeType {
277        self.ptr.node_type
278    }
279}
280
281#[async_trait]
282impl<K: MultiplexedConnKey + 'static> Subscribable for MultiplexedConn<K> {
283    type ID = K;
284    type UnderlyingConn = Arc<dyn ReliableOrderedStreamToTarget + 'static>;
285    type SubscriptionType = OwnedMultiplexedSubscription<K>;
286    //type BorrowedSubscriptionType<'a> = MultiplexedSubscription<'a, K>;
287    type BorrowedSubscriptionType = OwnedMultiplexedSubscription<K>;
288
289    fn underlying_conn(&self) -> &Self::UnderlyingConn {
290        &self.conn
291    }
292
293    fn subscriptions(&self) -> &RwLock<HashMap<Self::ID, MemorySender>> {
294        &self.subscribers
295    }
296
297    fn post_close_container(&self) -> &PostActionChannel<Self::ID> {
298        &self.post_close_container
299    }
300
301    fn pre_action_container(&self) -> &PreActionChannel<Self::ID> {
302        &self.pre_open_container
303    }
304
305    async fn recv_post_close_signal_from_stream(&self, id: Self::ID) -> Result<(), Error> {
306        self.post_close_container.recv(id).await
307    }
308
309    async fn send_post_close_signal(&self, id: Self::ID) -> Result<(), Error> {
310        Ok(self
311            .conn
312            .send_serialized(MultiplexedPacket::PostDrop { id })
313            .await?)
314    }
315
316    async fn send_pre_open_signal(&self, id: Self::ID) -> Result<(), Error> {
317        Ok(self
318            .conn
319            .send_serialized(MultiplexedPacket::PreCreate { id })
320            .await?)
321    }
322
323    fn node_type(&self) -> RelativeNodeType {
324        self.node_type
325    }
326
327    fn get_next_prereserved(&self) -> Option<Self::BorrowedSubscriptionType> {
328        let mut lock = self.subscribers.write();
329        let next_key = K::get_proposed_next(&self.current_latest_subscribed);
330        let pre_reserved_stream = lock.get_mut(&next_key)?;
331        let sub = MultiplexedSubscription {
332            ptr: self,
333            receiver: Some(Mutex::new(pre_reserved_stream.pre_reserved_rx.take()?)),
334            id: next_key,
335        };
336        assert_eq!(K::generate_next(&self.current_latest_subscribed), next_key);
337        Some(sub.into())
338    }
339
340    fn subscribe(&self, id: Self::ID) -> Self::BorrowedSubscriptionType {
341        let mut lock = self.subscribers.write();
342        let (tx, receiver) = unbounded_channel();
343        let sub = MultiplexedSubscription {
344            ptr: self,
345            receiver: Some(Mutex::new(receiver)),
346            id,
347        };
348        assert!(lock
349            .insert(
350                id,
351                MemorySender {
352                    tx,
353                    pre_reserved_rx: None
354                }
355            )
356            .is_none());
357        assert_eq!(K::generate_next(&self.current_latest_subscribed), id);
358        // TODO: on GAT stabalization, remove into
359        sub.into()
360    }
361
362    fn owned_subscription(&self, id: Self::ID) -> Self::SubscriptionType {
363        self.subscribe(id)
364    }
365
366    fn get_next_id(&self) -> Self::ID {
367        <K as IDGen<K>>::generate_next(&self.id_gen)
368    }
369}
370
371impl<K: MultiplexedConnKey + 'static> Drop for OwnedMultiplexedSubscription<K> {
372    fn drop(&mut self) {
373        close_sequence_for_multiplexed_bistream(self.id, self.ptr.clone())
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use crate::multiplex::OwnedMultiplexedSubscription;
380    use crate::reliable_conn::ReliableOrderedStreamToTargetExt;
381    use crate::sync::network_application::NetworkApplication;
382    use crate::sync::subscription::{Subscribable, SubscriptionBiStreamExt};
383    use crate::sync::test_utils::create_streams;
384    use crate::sync::SymmetricConvID;
385    use async_recursion::async_recursion;
386    use citadel_io::tokio;
387    use serde::{Deserialize, Serialize};
388
389    #[derive(Serialize, Deserialize)]
390    struct Packet(usize);
391
392    #[tokio::test]
393    async fn nested_multiplexed_stream() {
394        let (outer_stream_server, outer_stream_client) = create_streams().await;
395        // 50 recursions deep ....
396        nested(0, 50, outer_stream_server, outer_stream_client).await;
397    }
398
399    #[async_recursion]
400    async fn nested(
401        idx: usize,
402        max: usize,
403        server_stream: NetworkApplication,
404        client_stream: NetworkApplication,
405    ) -> (NetworkApplication, NetworkApplication) {
406        if idx == max {
407            return (server_stream, client_stream);
408        }
409
410        let (tx, rx) = citadel_io::tokio::sync::oneshot::channel::<()>();
411        let (server_stream0, client_stream0) = (server_stream.clone(), client_stream.clone());
412
413        let server = citadel_io::tokio::spawn(async move {
414            // get one substream from the input stream
415            let next_stream: OwnedMultiplexedSubscription =
416                server_stream.initiate_subscription().await.unwrap();
417            next_stream.send_serialized(Packet(idx)).await.unwrap();
418            rx.await.unwrap();
419            next_stream.multiplex::<SymmetricConvID>().await.unwrap()
420        });
421
422        let client = citadel_io::tokio::spawn(async move {
423            let next_stream: OwnedMultiplexedSubscription =
424                client_stream.initiate_subscription().await.unwrap();
425            let val = next_stream.recv_serialized::<Packet>().await.unwrap();
426            assert_eq!(val.0, idx);
427            tx.send(()).unwrap();
428            next_stream.multiplex::<SymmetricConvID>().await.unwrap()
429        });
430
431        let (tx1, rx1) = citadel_io::tokio::sync::oneshot::channel::<()>();
432
433        let server1 = citadel_io::tokio::spawn(async move {
434            // get one substream from the input stream
435            let next_stream: OwnedMultiplexedSubscription =
436                server_stream0.initiate_subscription().await.unwrap();
437            next_stream.send_serialized(Packet(idx + 10)).await.unwrap();
438            rx1.await.unwrap();
439            next_stream.multiplex::<SymmetricConvID>().await.unwrap()
440        });
441
442        let client1 = citadel_io::tokio::spawn(async move {
443            let next_stream: OwnedMultiplexedSubscription =
444                client_stream0.initiate_subscription().await.unwrap();
445            let val = next_stream.recv_serialized::<Packet>().await.unwrap();
446            assert_eq!(val.0, idx + 10);
447            tx1.send(()).unwrap();
448            next_stream.multiplex::<SymmetricConvID>().await.unwrap()
449        });
450
451        let (next_server_stream, next_client_stream, _, _) =
452            citadel_io::tokio::join!(server, client, server1, client1);
453
454        nested(
455            idx + 1,
456            max,
457            next_server_stream.unwrap(),
458            next_client_stream.unwrap(),
459        )
460        .await
461    }
462}