phantom_protocol/transport/
multiplexer.rs1use crate::transport::types::SequenceNumber;
8use bytes::Bytes;
9use dashmap::DashMap;
10use std::sync::atomic::{AtomicU32, Ordering};
11use tokio::sync::mpsc;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum StreamMessage {
16 Data(Bytes),
18 Ack(SequenceNumber),
20 Close,
22}
23
24pub struct StreamDemultiplexer {
38 streams: DashMap<u32, mpsc::Sender<StreamMessage>>,
40 control_tx: mpsc::Sender<Bytes>,
42 next_stream_id: AtomicU32,
44}
45
46pub struct StreamHandle {
48 pub stream_id: u32,
50 pub rx: mpsc::Receiver<StreamMessage>,
52}
53
54impl StreamDemultiplexer {
55 pub fn new(control_buffer: usize) -> (Self, mpsc::Receiver<Bytes>) {
60 let (control_tx, control_rx) = mpsc::channel(control_buffer);
61 let mux = Self {
62 streams: DashMap::new(),
63 control_tx,
64 next_stream_id: AtomicU32::new(2), };
66 (mux, control_rx)
67 }
68
69 pub fn open_stream(&self, buffer_size: usize) -> StreamHandle {
73 let stream_id = self.next_stream_id.fetch_add(1, Ordering::Relaxed);
74 let (tx, rx) = mpsc::channel(buffer_size);
75 self.streams.insert(stream_id, tx);
76 StreamHandle { stream_id, rx }
77 }
78
79 pub fn register_stream(&self, stream_id: u32, buffer_size: usize) -> StreamHandle {
81 let (tx, rx) = mpsc::channel(buffer_size);
82 self.streams.insert(stream_id, tx);
83 let _ = self
85 .next_stream_id
86 .fetch_max(stream_id + 1, Ordering::Relaxed);
87 StreamHandle { stream_id, rx }
88 }
89
90 pub fn close_stream(&self, stream_id: u32) {
92 self.streams.remove(&stream_id);
93 }
94
95 pub fn route_data(&self, stream_id: u32, payload: Bytes) -> bool {
100 if stream_id == 0 {
101 return self.control_tx.try_send(payload).is_ok();
103 }
104
105 if let Some(sender) = self.streams.get(&stream_id) {
106 sender.try_send(StreamMessage::Data(payload)).is_ok()
107 } else {
108 log::warn!(
109 "StreamDemultiplexer: dropping data for unknown stream_id={}",
110 stream_id
111 );
112 false
113 }
114 }
115
116 pub async fn route_data_async(&self, stream_id: u32, payload: Bytes) -> bool {
118 if stream_id == 0 {
119 return self.control_tx.send(payload).await.is_ok();
120 }
121
122 if let Some(sender) = self.streams.get(&stream_id) {
123 sender.send(StreamMessage::Data(payload)).await.is_ok()
124 } else {
125 log::warn!(
126 "StreamDemultiplexer: dropping data for unknown stream_id={}",
127 stream_id
128 );
129 false
130 }
131 }
132
133 pub fn route_ack(&self, stream_id: u32, seq: SequenceNumber) -> bool {
138 if stream_id == 0 {
139 return false;
140 }
141 if let Some(sender) = self.streams.get(&stream_id) {
142 sender.try_send(StreamMessage::Ack(seq)).is_ok()
143 } else {
144 false
145 }
146 }
147
148 pub fn route_close(&self, stream_id: u32) -> bool {
150 if stream_id == 0 {
151 return false;
152 }
153 if let Some(sender) = self.streams.get(&stream_id) {
154 sender.try_send(StreamMessage::Close).is_ok()
155 } else {
156 false
157 }
158 }
159
160 pub async fn route_ack_async(&self, stream_id: u32, seq: SequenceNumber) -> bool {
162 if stream_id == 0 {
163 return false;
164 }
165
166 if let Some(sender) = self.streams.get(&stream_id) {
167 sender.send(StreamMessage::Ack(seq)).await.is_ok()
168 } else {
169 log::warn!(
170 "StreamDemultiplexer: dropping ACK for unknown stream_id={}",
171 stream_id
172 );
173 false
174 }
175 }
176
177 pub async fn route_close_async(&self, stream_id: u32) -> bool {
179 if stream_id == 0 {
180 return false;
181 }
182
183 if let Some(sender) = self.streams.get(&stream_id) {
184 sender.send(StreamMessage::Close).await.is_ok()
185 } else {
186 log::warn!(
187 "StreamDemultiplexer: dropping CLOSE for unknown stream_id={}",
188 stream_id
189 );
190 false
191 }
192 }
193
194 pub fn active_stream_count(&self) -> usize {
196 self.streams.len()
197 }
198
199 pub fn has_stream(&self, stream_id: u32) -> bool {
201 self.streams.contains_key(&stream_id)
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[tokio::test]
210 async fn test_demux_open_and_route() {
211 let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
212
213 let handle = demux.open_stream(16);
214 let sid = handle.stream_id;
215 let mut rx = handle.rx;
216
217 assert!(demux.has_stream(sid));
218 assert_eq!(demux.active_stream_count(), 1);
219
220 let data = Bytes::from_static(b"hello stream");
222 assert!(demux.route_data(sid, data.clone()));
223
224 let received = rx.recv().await.unwrap();
226 assert_eq!(received, StreamMessage::Data(data));
227 }
228
229 #[tokio::test]
230 async fn test_demux_control_channel() {
231 let (demux, mut ctrl_rx) = StreamDemultiplexer::new(16);
232
233 let data = Bytes::from_static(b"control msg");
234 assert!(demux.route_data(0, data.clone()));
235
236 let received = ctrl_rx.recv().await.unwrap();
237 assert_eq!(received, data);
238 }
239
240 #[tokio::test]
241 async fn test_demux_unknown_stream() {
242 let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
243
244 let data = Bytes::from_static(b"lost");
246 assert!(!demux.route_data(999, data));
247 }
248
249 #[tokio::test]
250 async fn test_demux_close_stream() {
251 let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
252
253 let handle = demux.open_stream(16);
254 let sid = handle.stream_id;
255 assert!(demux.has_stream(sid));
256
257 demux.close_stream(sid);
258 assert!(!demux.has_stream(sid));
259 assert_eq!(demux.active_stream_count(), 0);
260 }
261
262 #[tokio::test]
263 async fn test_demux_multiple_streams() {
264 let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
265
266 let h1 = demux.open_stream(16);
267 let h2 = demux.open_stream(16);
268 let h3 = demux.open_stream(16);
269
270 assert_ne!(h1.stream_id, h2.stream_id);
271 assert_ne!(h2.stream_id, h3.stream_id);
272 assert_eq!(demux.active_stream_count(), 3);
273 }
274}