use crate::transport::types::SequenceNumber;
use bytes::Bytes;
use dashmap::DashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::mpsc;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamMessage {
Data(Bytes),
Ack(SequenceNumber),
Close,
}
pub struct StreamDemultiplexer {
streams: DashMap<u32, mpsc::Sender<StreamMessage>>,
control_tx: mpsc::Sender<Bytes>,
next_stream_id: AtomicU32,
}
pub struct StreamHandle {
pub stream_id: u32,
pub rx: mpsc::Receiver<StreamMessage>,
}
impl StreamDemultiplexer {
pub fn new(control_buffer: usize) -> (Self, mpsc::Receiver<Bytes>) {
let (control_tx, control_rx) = mpsc::channel(control_buffer);
let mux = Self {
streams: DashMap::new(),
control_tx,
next_stream_id: AtomicU32::new(2), };
(mux, control_rx)
}
pub fn open_stream(&self, buffer_size: usize) -> StreamHandle {
let stream_id = self.next_stream_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(buffer_size);
self.streams.insert(stream_id, tx);
StreamHandle { stream_id, rx }
}
pub fn register_stream(&self, stream_id: u32, buffer_size: usize) -> StreamHandle {
let (tx, rx) = mpsc::channel(buffer_size);
self.streams.insert(stream_id, tx);
let _ = self
.next_stream_id
.fetch_max(stream_id + 1, Ordering::Relaxed);
StreamHandle { stream_id, rx }
}
pub fn close_stream(&self, stream_id: u32) {
self.streams.remove(&stream_id);
}
pub fn route_data(&self, stream_id: u32, payload: Bytes) -> bool {
if stream_id == 0 {
return self.control_tx.try_send(payload).is_ok();
}
if let Some(sender) = self.streams.get(&stream_id) {
sender.try_send(StreamMessage::Data(payload)).is_ok()
} else {
log::warn!(
"StreamDemultiplexer: dropping data for unknown stream_id={}",
stream_id
);
false
}
}
pub async fn route_data_async(&self, stream_id: u32, payload: Bytes) -> bool {
if stream_id == 0 {
return self.control_tx.send(payload).await.is_ok();
}
if let Some(sender) = self.streams.get(&stream_id) {
sender.send(StreamMessage::Data(payload)).await.is_ok()
} else {
log::warn!(
"StreamDemultiplexer: dropping data for unknown stream_id={}",
stream_id
);
false
}
}
pub fn route_ack(&self, stream_id: u32, seq: SequenceNumber) -> bool {
if stream_id == 0 {
return false;
}
if let Some(sender) = self.streams.get(&stream_id) {
sender.try_send(StreamMessage::Ack(seq)).is_ok()
} else {
false
}
}
pub fn route_close(&self, stream_id: u32) -> bool {
if stream_id == 0 {
return false;
}
if let Some(sender) = self.streams.get(&stream_id) {
sender.try_send(StreamMessage::Close).is_ok()
} else {
false
}
}
pub async fn route_ack_async(&self, stream_id: u32, seq: SequenceNumber) -> bool {
if stream_id == 0 {
return false;
}
if let Some(sender) = self.streams.get(&stream_id) {
sender.send(StreamMessage::Ack(seq)).await.is_ok()
} else {
log::warn!(
"StreamDemultiplexer: dropping ACK for unknown stream_id={}",
stream_id
);
false
}
}
pub async fn route_close_async(&self, stream_id: u32) -> bool {
if stream_id == 0 {
return false;
}
if let Some(sender) = self.streams.get(&stream_id) {
sender.send(StreamMessage::Close).await.is_ok()
} else {
log::warn!(
"StreamDemultiplexer: dropping CLOSE for unknown stream_id={}",
stream_id
);
false
}
}
pub fn active_stream_count(&self) -> usize {
self.streams.len()
}
pub fn has_stream(&self, stream_id: u32) -> bool {
self.streams.contains_key(&stream_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_demux_open_and_route() {
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let handle = demux.open_stream(16);
let sid = handle.stream_id;
let mut rx = handle.rx;
assert!(demux.has_stream(sid));
assert_eq!(demux.active_stream_count(), 1);
let data = Bytes::from_static(b"hello stream");
assert!(demux.route_data(sid, data.clone()));
let received = rx.recv().await.unwrap();
assert_eq!(received, StreamMessage::Data(data));
}
#[tokio::test]
async fn test_demux_control_channel() {
let (demux, mut ctrl_rx) = StreamDemultiplexer::new(16);
let data = Bytes::from_static(b"control msg");
assert!(demux.route_data(0, data.clone()));
let received = ctrl_rx.recv().await.unwrap();
assert_eq!(received, data);
}
#[tokio::test]
async fn test_demux_unknown_stream() {
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let data = Bytes::from_static(b"lost");
assert!(!demux.route_data(999, data));
}
#[tokio::test]
async fn test_demux_close_stream() {
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let handle = demux.open_stream(16);
let sid = handle.stream_id;
assert!(demux.has_stream(sid));
demux.close_stream(sid);
assert!(!demux.has_stream(sid));
assert_eq!(demux.active_stream_count(), 0);
}
#[tokio::test]
async fn test_demux_multiple_streams() {
let (demux, _ctrl_rx) = StreamDemultiplexer::new(16);
let h1 = demux.open_stream(16);
let h2 = demux.open_stream(16);
let h3 = demux.open_stream(16);
assert_ne!(h1.stream_id, h2.stream_id);
assert_ne!(h2.stream_id, h3.stream_id);
assert_eq!(demux.active_stream_count(), 3);
}
}