Skip to main content

phantom_protocol/transport/
multiplexer.rs

1//! Stream Demultiplexer
2//!
3//! Routes incoming packets to their target streams based on `stream_id`
4//! extracted from `PhantomPacket` headers. Replaces the old smoltcp-based
5//! multiplexer with a lightweight, zero-copy routing table.
6
7use crate::transport::types::SequenceNumber;
8use bytes::Bytes;
9use dashmap::DashMap;
10use std::sync::atomic::{AtomicU32, Ordering};
11use tokio::sync::mpsc;
12
13/// Messages routed to a stream by the demultiplexer.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum StreamMessage {
16    /// Normal data payload
17    Data(Bytes),
18    /// Acknowledgment of a specific sequence number
19    Ack(SequenceNumber),
20    /// Stream closure signal
21    Close,
22}
23
24/// A lightweight stream demultiplexer that routes packets to registered streams.
25///
26/// # Design
27///
28/// ```text
29///   UDP Socket → StreamDemultiplexer → Stream[0] (reliable)
30///                                    → Stream[1] (reliable)
31///                                    → Stream[2] (unreliable)
32///                                    → control channel
33/// ```
34///
35/// Each stream is identified by a `u32` stream ID extracted from the packet header.
36/// Unrecognized stream IDs are dropped (with a log warning).
37pub struct StreamDemultiplexer {
38    /// Active stream senders: stream_id → sender channel
39    streams: DashMap<u32, mpsc::Sender<StreamMessage>>,
40    /// Control channel for session-level messages (stream_id = 0)
41    control_tx: mpsc::Sender<Bytes>,
42    /// Next stream ID to allocate
43    next_stream_id: AtomicU32,
44}
45
46/// Handle returned when a stream is registered with the demultiplexer.
47pub struct StreamHandle {
48    /// The stream ID assigned to this stream
49    pub stream_id: u32,
50    /// Receiver end for incoming packets
51    pub rx: mpsc::Receiver<StreamMessage>,
52}
53
54impl StreamDemultiplexer {
55    /// Create a new demultiplexer with a control channel.
56    ///
57    /// The control channel (stream_id = 0) receives session-level packets
58    /// such as keepalives, migration signals, and stream management.
59    pub fn new(control_buffer: usize) -> (Self, mpsc::Receiver<Bytes>) {
60        let (control_tx, control_rx) = mpsc::channel(control_buffer);
61        let mux = Self {
62            streams: DashMap::new(),
63            control_tx,
64            next_stream_id: AtomicU32::new(2), // 0 = control, 1 = raw-app session channel
65        };
66        (mux, control_rx)
67    }
68
69    /// Register a new stream and get back a handle with the assigned ID.
70    ///
71    /// `buffer_size` controls the depth of the per-stream receive buffer.
72    pub fn open_stream(&self, buffer_size: usize) -> StreamHandle {
73        let stream_id = self.next_stream_id.fetch_add(1, Ordering::Relaxed);
74        let (tx, rx) = mpsc::channel(buffer_size);
75        self.streams.insert(stream_id, tx);
76        StreamHandle { stream_id, rx }
77    }
78
79    /// Register a stream with a specific ID (e.g., for accepting remote-initiated streams).
80    pub fn register_stream(&self, stream_id: u32, buffer_size: usize) -> StreamHandle {
81        let (tx, rx) = mpsc::channel(buffer_size);
82        self.streams.insert(stream_id, tx);
83        // Update next_stream_id if necessary to avoid collisions
84        let _ = self
85            .next_stream_id
86            .fetch_max(stream_id + 1, Ordering::Relaxed);
87        StreamHandle { stream_id, rx }
88    }
89
90    /// Remove a stream from the routing table.
91    pub fn close_stream(&self, stream_id: u32) {
92        self.streams.remove(&stream_id);
93    }
94
95    /// Route data payload to the appropriate stream.
96    ///
97    /// Returns `true` if the packet was successfully delivered,
98    /// `false` if the stream was not found or the buffer was full.
99    pub fn route_data(&self, stream_id: u32, payload: Bytes) -> bool {
100        if stream_id == 0 {
101            // Control channel
102            return self.control_tx.try_send(payload).is_ok();
103        }
104
105        if let Some(sender) = self.streams.get(&stream_id) {
106            sender.try_send(StreamMessage::Data(payload)).is_ok()
107        } else {
108            log::warn!(
109                "StreamDemultiplexer: dropping data for unknown stream_id={}",
110                stream_id
111            );
112            false
113        }
114    }
115
116    /// Route data asynchronously (waits if buffer is full).
117    pub async fn route_data_async(&self, stream_id: u32, payload: Bytes) -> bool {
118        if stream_id == 0 {
119            return self.control_tx.send(payload).await.is_ok();
120        }
121
122        if let Some(sender) = self.streams.get(&stream_id) {
123            sender.send(StreamMessage::Data(payload)).await.is_ok()
124        } else {
125            log::warn!(
126                "StreamDemultiplexer: dropping data for unknown stream_id={}",
127                stream_id
128            );
129            false
130        }
131    }
132
133    /// Route an ACK signal to a stream **without blocking**. Returns
134    /// `false` if the stream is unknown or its buffer is full — the recv pump
135    /// uses this on its never-block path, where a vestigial/absent stream
136    /// consumer must not stall inbound ACK/control processing.
137    pub fn route_ack(&self, stream_id: u32, seq: SequenceNumber) -> bool {
138        if stream_id == 0 {
139            return false;
140        }
141        if let Some(sender) = self.streams.get(&stream_id) {
142            sender.try_send(StreamMessage::Ack(seq)).is_ok()
143        } else {
144            false
145        }
146    }
147
148    /// Route a stream-closure signal **without blocking** (see [`Self::route_ack`]).
149    pub fn route_close(&self, stream_id: u32) -> bool {
150        if stream_id == 0 {
151            return false;
152        }
153        if let Some(sender) = self.streams.get(&stream_id) {
154            sender.try_send(StreamMessage::Close).is_ok()
155        } else {
156            false
157        }
158    }
159
160    /// Route an ACK signal to a stream asynchronously.
161    pub async fn route_ack_async(&self, stream_id: u32, seq: SequenceNumber) -> bool {
162        if stream_id == 0 {
163            return false;
164        }
165
166        if let Some(sender) = self.streams.get(&stream_id) {
167            sender.send(StreamMessage::Ack(seq)).await.is_ok()
168        } else {
169            log::warn!(
170                "StreamDemultiplexer: dropping ACK for unknown stream_id={}",
171                stream_id
172            );
173            false
174        }
175    }
176
177    /// Route a stream closure signal asynchronously.
178    pub async fn route_close_async(&self, stream_id: u32) -> bool {
179        if stream_id == 0 {
180            return false;
181        }
182
183        if let Some(sender) = self.streams.get(&stream_id) {
184            sender.send(StreamMessage::Close).await.is_ok()
185        } else {
186            log::warn!(
187                "StreamDemultiplexer: dropping CLOSE for unknown stream_id={}",
188                stream_id
189            );
190            false
191        }
192    }
193
194    /// Number of active streams (excluding control channel).
195    pub fn active_stream_count(&self) -> usize {
196        self.streams.len()
197    }
198
199    /// Check if a stream is registered.
200    pub fn has_stream(&self, stream_id: u32) -> bool {
201        self.streams.contains_key(&stream_id)
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[tokio::test]
210    async fn test_demux_open_and_route() {
211        let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
212
213        let handle = demux.open_stream(16);
214        let sid = handle.stream_id;
215        let mut rx = handle.rx;
216
217        assert!(demux.has_stream(sid));
218        assert_eq!(demux.active_stream_count(), 1);
219
220        // Route a packet
221        let data = Bytes::from_static(b"hello stream");
222        assert!(demux.route_data(sid, data.clone()));
223
224        // Receive it
225        let received = rx.recv().await.unwrap();
226        assert_eq!(received, StreamMessage::Data(data));
227    }
228
229    #[tokio::test]
230    async fn test_demux_control_channel() {
231        let (demux, mut ctrl_rx) = StreamDemultiplexer::new(16);
232
233        let data = Bytes::from_static(b"control msg");
234        assert!(demux.route_data(0, data.clone()));
235
236        let received = ctrl_rx.recv().await.unwrap();
237        assert_eq!(received, data);
238    }
239
240    #[tokio::test]
241    async fn test_demux_unknown_stream() {
242        let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
243
244        // Route to non-existent stream
245        let data = Bytes::from_static(b"lost");
246        assert!(!demux.route_data(999, data));
247    }
248
249    #[tokio::test]
250    async fn test_demux_close_stream() {
251        let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
252
253        let handle = demux.open_stream(16);
254        let sid = handle.stream_id;
255        assert!(demux.has_stream(sid));
256
257        demux.close_stream(sid);
258        assert!(!demux.has_stream(sid));
259        assert_eq!(demux.active_stream_count(), 0);
260    }
261
262    #[tokio::test]
263    async fn test_demux_multiple_streams() {
264        let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
265
266        let h1 = demux.open_stream(16);
267        let h2 = demux.open_stream(16);
268        let h3 = demux.open_stream(16);
269
270        assert_ne!(h1.stream_id, h2.stream_id);
271        assert_ne!(h2.stream_id, h3.stream_id);
272        assert_eq!(demux.active_stream_count(), 3);
273    }
274}