1use crate::{Frame, ProtocolError};
4use bytes::Bytes;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU32, Ordering};
7use std::sync::Arc;
8use tokio::sync::{mpsc, Mutex};
9use uuid::Uuid;
10
11pub struct StreamMultiplexer {
13 next_stream_id: AtomicU32,
15 streams: Arc<Mutex<HashMap<u32, StreamInfo>>>,
17 frame_sender: mpsc::UnboundedSender<Frame>,
19 frame_receiver: Arc<Mutex<mpsc::UnboundedReceiver<Frame>>>,
21 flow_control_config: FlowControlConfig,
23}
24
25#[derive(Debug, Clone)]
27pub struct FlowControlConfig {
28 pub initial_window_size: u32,
30 pub max_window_size: u32,
32 pub connection_window_size: u32,
34}
35
36#[derive(Debug)]
38struct StreamInfo {
39 state: StreamState,
41 frame_sender: mpsc::UnboundedSender<Frame>,
43 next_sequence: u32,
45 request_id: Option<Uuid>,
47 flow_control: FlowControlState,
49}
50
51#[derive(Debug)]
53struct FlowControlState {
54 send_window: u32,
56 recv_window: u32,
58 initial_window_size: u32,
60 bytes_in_flight: u32,
62 bytes_buffered: u32,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum StreamState {
69 Open,
71 HalfClosed,
73 Closed,
75}
76
77pub struct StreamHandle {
79 stream_id: u32,
81 frame_receiver: mpsc::UnboundedReceiver<Frame>,
83 multiplexer: Arc<StreamMultiplexer>,
85 next_sequence: AtomicU32,
87 state: StreamState,
89}
90
91impl Default for FlowControlConfig {
92 fn default() -> Self {
93 Self {
94 initial_window_size: 65536, max_window_size: 1048576, connection_window_size: 1048576, }
98 }
99}
100
101impl FlowControlState {
102 fn new(initial_window_size: u32) -> Self {
103 Self {
104 send_window: initial_window_size,
105 recv_window: initial_window_size,
106 initial_window_size,
107 bytes_in_flight: 0,
108 bytes_buffered: 0,
109 }
110 }
111
112 fn can_send(&self, size: u32) -> bool {
114 self.send_window >= size && self.bytes_in_flight + size <= self.initial_window_size
115 }
116
117 fn consume_send_credits(&mut self, size: u32) -> Result<(), ProtocolError> {
119 if !self.can_send(size) {
120 return Err(ProtocolError::FlowControlViolation);
121 }
122
123 self.send_window -= size;
124 self.bytes_in_flight += size;
125 Ok(())
126 }
127
128 fn add_recv_credits(&mut self, size: u32) {
130 self.recv_window += size;
131 self.bytes_buffered = self.bytes_buffered.saturating_sub(size);
132 }
133
134 fn consume_recv_credits(&mut self, size: u32) -> Result<(), ProtocolError> {
136 if self.recv_window < size {
137 return Err(ProtocolError::FlowControlViolation);
138 }
139
140 self.recv_window -= size;
141 self.bytes_buffered += size;
142 Ok(())
143 }
144
145 fn update_send_window(&mut self, delta: u32) {
147 self.send_window += delta;
148 self.bytes_in_flight = self.bytes_in_flight.saturating_sub(delta);
149 }
150}
151
152impl StreamMultiplexer {
153 pub fn new() -> Self {
155 Self::with_config(FlowControlConfig::default())
156 }
157
158 pub fn with_config(config: FlowControlConfig) -> Self {
160 let (frame_sender, frame_receiver) = mpsc::unbounded_channel();
161
162 Self {
163 next_stream_id: AtomicU32::new(1),
164 streams: Arc::new(Mutex::new(HashMap::new())),
165 frame_sender,
166 frame_receiver: Arc::new(Mutex::new(frame_receiver)),
167 flow_control_config: config,
168 }
169 }
170
171 pub async fn create_stream(&self, request_id: Option<Uuid>) -> Result<StreamHandle, ProtocolError> {
173 let stream_id = self.next_stream_id.fetch_add(1, Ordering::SeqCst);
174 let (frame_sender, frame_receiver) = mpsc::unbounded_channel();
175
176 let stream_info = StreamInfo {
177 state: StreamState::Open,
178 frame_sender,
179 next_sequence: 0,
180 request_id,
181 flow_control: FlowControlState::new(self.flow_control_config.initial_window_size),
182 };
183
184 {
185 let mut streams = self.streams.lock().await;
186 streams.insert(stream_id, stream_info);
187 }
188
189 Ok(StreamHandle {
190 stream_id,
191 frame_receiver,
192 multiplexer: Arc::new(self.clone()),
193 next_sequence: AtomicU32::new(0),
194 state: StreamState::Open,
195 })
196 }
197
198 pub async fn route_frame(&self, frame: Frame) -> Result<(), ProtocolError> {
200 let stream_id = frame.stream_id;
201
202 let mut streams = self.streams.lock().await;
203
204 if let Some(stream_info) = streams.get_mut(&stream_id) {
205 if frame.sequence != stream_info.next_sequence {
207 return Err(ProtocolError::InvalidFrame);
208 }
209
210 stream_info.next_sequence += 1;
211
212 if frame.is_end_stream() {
214 stream_info.state = StreamState::Closed;
215 }
216
217 if let Err(_) = stream_info.frame_sender.send(frame) {
219 streams.remove(&stream_id);
221 }
222 } else {
223 return Err(ProtocolError::InvalidStreamId(stream_id));
225 }
226
227 Ok(())
228 }
229
230 pub async fn close_stream(&self, stream_id: u32) -> Result<(), ProtocolError> {
232 let mut streams = self.streams.lock().await;
233
234 if let Some(stream_info) = streams.get_mut(&stream_id) {
235 stream_info.state = StreamState::Closed;
236 streams.remove(&stream_id);
237 Ok(())
238 } else {
239 Err(ProtocolError::InvalidStreamId(stream_id))
240 }
241 }
242
243 pub async fn stream_count(&self) -> usize {
245 let streams = self.streams.lock().await;
246 streams.len()
247 }
248
249 pub async fn stream_state(&self, stream_id: u32) -> Option<StreamState> {
251 let streams = self.streams.lock().await;
252 streams.get(&stream_id).map(|info| info.state)
253 }
254
255 pub async fn process_frames(&self) -> Result<(), ProtocolError> {
257 let mut receiver = self.frame_receiver.lock().await;
258
259 while let Some(frame) = receiver.recv().await {
260 self.route_frame(frame).await?;
261 }
262
263 Ok(())
264 }
265
266 pub fn send_frame(&self, frame: Frame) -> Result<(), ProtocolError> {
268 self.frame_sender.send(frame)
269 .map_err(|_| ProtocolError::StreamClosed)
270 }
271
272 pub async fn can_send_data(&self, stream_id: u32, size: u32) -> Result<bool, ProtocolError> {
274 let streams = self.streams.lock().await;
275
276 if let Some(stream_info) = streams.get(&stream_id) {
277 Ok(stream_info.flow_control.can_send(size))
278 } else {
279 Err(ProtocolError::InvalidStreamId(stream_id))
280 }
281 }
282
283 pub async fn update_window(&self, stream_id: u32, delta: u32) -> Result<(), ProtocolError> {
285 let mut streams = self.streams.lock().await;
286
287 if let Some(stream_info) = streams.get_mut(&stream_id) {
288 stream_info.flow_control.update_send_window(delta);
289 Ok(())
290 } else {
291 Err(ProtocolError::InvalidStreamId(stream_id))
292 }
293 }
294
295 pub async fn process_received_data(&self, stream_id: u32, size: u32) -> Result<(), ProtocolError> {
297 let mut streams = self.streams.lock().await;
298
299 if let Some(stream_info) = streams.get_mut(&stream_id) {
300 stream_info.flow_control.consume_recv_credits(size)?;
301 Ok(())
302 } else {
303 Err(ProtocolError::InvalidStreamId(stream_id))
304 }
305 }
306
307 pub async fn ack_processed_data(&self, stream_id: u32, size: u32) -> Result<(), ProtocolError> {
309 let mut streams = self.streams.lock().await;
310
311 if let Some(stream_info) = streams.get_mut(&stream_id) {
312 stream_info.flow_control.add_recv_credits(size);
313 Ok(())
314 } else {
315 Err(ProtocolError::InvalidStreamId(stream_id))
316 }
317 }
318}
319
320impl Clone for StreamMultiplexer {
321 fn clone(&self) -> Self {
322 Self {
323 next_stream_id: AtomicU32::new(self.next_stream_id.load(Ordering::SeqCst)),
324 streams: Arc::clone(&self.streams),
325 frame_sender: self.frame_sender.clone(),
326 frame_receiver: Arc::clone(&self.frame_receiver),
327 flow_control_config: self.flow_control_config.clone(),
328 }
329 }
330}
331
332impl Default for StreamMultiplexer {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338impl StreamHandle {
339 pub fn stream_id(&self) -> u32 {
341 self.stream_id
342 }
343
344 pub async fn send_data(&mut self, payload: Bytes) -> Result<(), ProtocolError> {
346 if self.state == StreamState::Closed {
347 return Err(ProtocolError::StreamClosed);
348 }
349
350 let payload_size = payload.len() as u32;
351
352 if !self.multiplexer.can_send_data(self.stream_id, payload_size).await? {
354 return Err(ProtocolError::FlowControlViolation);
355 }
356
357 let sequence = self.next_sequence.fetch_add(1, Ordering::SeqCst);
358 let frame = Frame::data(self.stream_id, sequence, payload);
359
360 {
362 let mut streams = self.multiplexer.streams.lock().await;
363 if let Some(stream_info) = streams.get_mut(&self.stream_id) {
364 stream_info.flow_control.consume_send_credits(payload_size)?;
365 }
366 }
367
368 self.multiplexer.send_frame(frame)
369 }
370
371 pub async fn send_end_stream(&mut self) -> Result<(), ProtocolError> {
373 if self.state == StreamState::Closed {
374 return Err(ProtocolError::StreamClosed);
375 }
376
377 let sequence = self.next_sequence.fetch_add(1, Ordering::SeqCst);
378 let frame = Frame::end_stream(self.stream_id, sequence);
379
380 self.state = StreamState::HalfClosed;
381 self.multiplexer.send_frame(frame)
382 }
383
384 pub async fn recv_frame(&mut self) -> Option<Frame> {
386 self.frame_receiver.recv().await
387 }
388
389 pub async fn close(&mut self) -> Result<(), ProtocolError> {
391 if self.state != StreamState::Closed {
392 self.send_end_stream().await?;
393 self.state = StreamState::Closed;
394 self.multiplexer.close_stream(self.stream_id).await?;
395 }
396 Ok(())
397 }
398
399 pub fn state(&self) -> StreamState {
401 self.state
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use tokio::time::{timeout, Duration};
409
410 #[tokio::test]
411 async fn test_stream_creation() {
412 let multiplexer = StreamMultiplexer::new();
413 let stream = multiplexer.create_stream(None).await.unwrap();
414
415 assert_eq!(stream.stream_id(), 1);
416 assert_eq!(stream.state(), StreamState::Open);
417 assert_eq!(multiplexer.stream_count().await, 1);
418 }
419
420 #[tokio::test]
421 async fn test_multiple_streams() {
422 let multiplexer = StreamMultiplexer::new();
423
424 let stream1 = multiplexer.create_stream(None).await.unwrap();
425 let stream2 = multiplexer.create_stream(None).await.unwrap();
426 let stream3 = multiplexer.create_stream(None).await.unwrap();
427
428 assert_eq!(stream1.stream_id(), 1);
429 assert_eq!(stream2.stream_id(), 2);
430 assert_eq!(stream3.stream_id(), 3);
431 assert_eq!(multiplexer.stream_count().await, 3);
432 }
433
434 #[tokio::test]
435 async fn test_frame_routing() {
436 let multiplexer = StreamMultiplexer::new();
437 let mut stream = multiplexer.create_stream(None).await.unwrap();
438 let stream_id = stream.stream_id();
439
440 let frame = Frame::data(stream_id, 0, Bytes::from("test data"));
442 multiplexer.route_frame(frame.clone()).await.unwrap();
443
444 let received = timeout(Duration::from_millis(100), stream.recv_frame())
446 .await
447 .unwrap()
448 .unwrap();
449
450 assert_eq!(received.stream_id, frame.stream_id);
451 assert_eq!(received.payload, frame.payload);
452 }
453
454 #[tokio::test]
455 async fn test_stream_send_receive() {
456 let multiplexer = StreamMultiplexer::new();
457 let mut stream1 = multiplexer.create_stream(None).await.unwrap();
458 let _stream2 = multiplexer.create_stream(None).await.unwrap();
459
460 let payload = Bytes::from("hello world");
462 stream1.send_data(payload).await.unwrap();
463
464 }
467
468 #[tokio::test]
469 async fn test_stream_close() {
470 let multiplexer = StreamMultiplexer::new();
471 let mut stream = multiplexer.create_stream(None).await.unwrap();
472 let _stream_id = stream.stream_id();
473
474 assert_eq!(multiplexer.stream_count().await, 1);
475
476 stream.close().await.unwrap();
477
478 assert_eq!(stream.state(), StreamState::Closed);
479 assert_eq!(multiplexer.stream_count().await, 0);
480 }
481
482 #[tokio::test]
483 async fn test_invalid_stream_id() {
484 let multiplexer = StreamMultiplexer::new();
485
486 let frame = Frame::data(999, 0, Bytes::from("test"));
487 let result = multiplexer.route_frame(frame).await;
488
489 assert!(matches!(result, Err(ProtocolError::InvalidStreamId(999))));
490 }
491
492 #[tokio::test]
493 async fn test_sequence_number_validation() {
494 let multiplexer = StreamMultiplexer::new();
495 let stream = multiplexer.create_stream(None).await.unwrap();
496 let stream_id = stream.stream_id();
497
498 let frame = Frame::data(stream_id, 5, Bytes::from("test"));
500 let result = multiplexer.route_frame(frame).await;
501
502 assert!(matches!(result, Err(ProtocolError::InvalidFrame)));
503 }
504
505 #[tokio::test]
506 async fn test_end_stream_handling() {
507 let multiplexer = StreamMultiplexer::new();
508 let stream = multiplexer.create_stream(None).await.unwrap();
509 let stream_id = stream.stream_id();
510
511 let frame = Frame::end_stream(stream_id, 0);
513 multiplexer.route_frame(frame).await.unwrap();
514
515 assert_eq!(multiplexer.stream_state(stream_id).await, Some(StreamState::Closed));
517 }
518
519 #[tokio::test]
520 async fn test_flow_control_basic() {
521 let config = FlowControlConfig {
522 initial_window_size: 1000,
523 max_window_size: 2000,
524 connection_window_size: 5000,
525 };
526 let multiplexer = StreamMultiplexer::with_config(config);
527 let mut stream = multiplexer.create_stream(None).await.unwrap();
528 let stream_id = stream.stream_id();
529
530 assert!(multiplexer.can_send_data(stream_id, 500).await.unwrap());
532
533 let payload = Bytes::from(vec![0u8; 500]);
535 stream.send_data(payload).await.unwrap();
536
537 assert!(multiplexer.can_send_data(stream_id, 500).await.unwrap());
539
540 assert!(!multiplexer.can_send_data(stream_id, 600).await.unwrap());
542 }
543
544 #[tokio::test]
545 async fn test_flow_control_violation() {
546 let config = FlowControlConfig {
547 initial_window_size: 100,
548 max_window_size: 200,
549 connection_window_size: 500,
550 };
551 let multiplexer = StreamMultiplexer::with_config(config);
552 let mut stream = multiplexer.create_stream(None).await.unwrap();
553
554 let large_payload = Bytes::from(vec![0u8; 200]);
556 let result = stream.send_data(large_payload).await;
557
558 assert!(matches!(result, Err(ProtocolError::FlowControlViolation)));
559 }
560
561 #[tokio::test]
562 async fn test_window_update() {
563 let config = FlowControlConfig {
564 initial_window_size: 100,
565 max_window_size: 200,
566 connection_window_size: 500,
567 };
568 let multiplexer = StreamMultiplexer::with_config(config);
569 let mut stream = multiplexer.create_stream(None).await.unwrap();
570 let stream_id = stream.stream_id();
571
572 let payload = Bytes::from(vec![0u8; 100]);
574 stream.send_data(payload).await.unwrap();
575
576 assert!(!multiplexer.can_send_data(stream_id, 50).await.unwrap());
578
579 multiplexer.update_window(stream_id, 50).await.unwrap();
581
582 assert!(multiplexer.can_send_data(stream_id, 50).await.unwrap());
584 }
585
586 #[tokio::test]
587 async fn test_receive_flow_control() {
588 let multiplexer = StreamMultiplexer::new();
589 let stream = multiplexer.create_stream(None).await.unwrap();
590 let stream_id = stream.stream_id();
591
592 multiplexer.process_received_data(stream_id, 1000).await.unwrap();
594
595 multiplexer.ack_processed_data(stream_id, 500).await.unwrap();
597
598 multiplexer.process_received_data(stream_id, 500).await.unwrap();
600 }
601
602 use proptest::prelude::*;
604
605 proptest! {
606 #[test]
607 fn test_stream_id_generation_properties(
608 num_streams in 1usize..100
609 ) {
610 tokio_test::block_on(async {
611 let multiplexer = StreamMultiplexer::new();
612 let mut stream_ids = Vec::new();
613
614 for _ in 0..num_streams {
615 let stream = multiplexer.create_stream(None).await?;
616 stream_ids.push(stream.stream_id());
617 }
618
619 stream_ids.sort();
621 stream_ids.dedup();
622 prop_assert_eq!(stream_ids.len(), num_streams);
623
624 for (i, &stream_id) in stream_ids.iter().enumerate() {
626 prop_assert_eq!(stream_id, (i + 1) as u32);
627 }
628
629 Ok(())
630 })?;
631 }
632
633 #[test]
634 fn test_flow_control_invariants(
635 initial_window in 100u32..10000,
636 data_sizes in prop::collection::vec(1u32..1000, 1..20)
637 ) {
638 tokio_test::block_on(async {
639 let config = FlowControlConfig {
640 initial_window_size: initial_window,
641 max_window_size: initial_window * 2,
642 connection_window_size: initial_window * 5,
643 };
644 let multiplexer = StreamMultiplexer::with_config(config);
645 let mut stream = multiplexer.create_stream(None).await?;
646 let stream_id = stream.stream_id();
647
648 let mut total_sent = 0u32;
649
650 for &size in &data_sizes {
651 let can_send = multiplexer.can_send_data(stream_id, size).await?;
652
653 if can_send && total_sent + size <= initial_window {
654 let payload = Bytes::from(vec![0u8; size as usize]);
656 stream.send_data(payload).await?;
657 total_sent += size;
658 } else {
659 let payload = Bytes::from(vec![0u8; size as usize]);
661 let result = stream.send_data(payload).await;
662 prop_assert!(result.is_err());
663 }
664 }
665
666 prop_assert!(total_sent <= initial_window);
668
669 Ok(())
670 })?;
671 }
672
673 #[test]
674 fn test_concurrent_stream_operations(
675 num_streams in 1usize..10,
676 operations_per_stream in 1usize..10
677 ) {
678 tokio_test::block_on(async {
679 let multiplexer = StreamMultiplexer::new();
680 let mut streams = Vec::new();
681
682 for _ in 0..num_streams {
684 let stream = multiplexer.create_stream(None).await?;
685 streams.push(stream);
686 }
687
688 prop_assert_eq!(multiplexer.stream_count().await, num_streams);
689
690 for stream in &mut streams {
692 for i in 0..operations_per_stream {
693 let payload = Bytes::from(format!("data-{}", i));
694 let _ = stream.send_data(payload).await;
696 }
697 }
698
699 for stream in &mut streams {
701 stream.close().await?;
702 }
703
704 prop_assert_eq!(multiplexer.stream_count().await, 0);
705
706 Ok(())
707 })?;
708 }
709
710 #[test]
711 fn test_sequence_number_properties(
712 num_frames in 1usize..50
713 ) {
714 tokio_test::block_on(async {
715 let multiplexer = StreamMultiplexer::new();
716 let stream = multiplexer.create_stream(None).await?;
717 let stream_id = stream.stream_id();
718
719 for i in 0..num_frames {
721 let frame = Frame::data(stream_id, i as u32, Bytes::from("test"));
722 multiplexer.route_frame(frame).await?;
723 }
724
725 let wrong_frame = Frame::data(stream_id, (num_frames + 5) as u32, Bytes::from("wrong"));
727 let result = multiplexer.route_frame(wrong_frame).await;
728 prop_assert!(result.is_err());
729
730 Ok(())
731 })?;
732 }
733
734 #[test]
735 fn test_window_update_properties(
736 initial_window in 100u32..1000,
737 updates in prop::collection::vec(1u32..500, 1..10)
738 ) {
739 tokio_test::block_on(async {
740 let config = FlowControlConfig {
741 initial_window_size: initial_window,
742 max_window_size: initial_window * 10,
743 connection_window_size: initial_window * 10,
744 };
745 let multiplexer = StreamMultiplexer::with_config(config);
746 let mut stream = multiplexer.create_stream(None).await?;
747 let stream_id = stream.stream_id();
748
749 let payload = Bytes::from(vec![0u8; initial_window as usize]);
751 stream.send_data(payload).await?;
752
753 prop_assert!(!multiplexer.can_send_data(stream_id, 1).await?);
755
756 let mut total_updates = 0u32;
758 for &update in &updates {
759 multiplexer.update_window(stream_id, update).await?;
760 total_updates += update;
761
762 if total_updates > 0 {
764 prop_assert!(multiplexer.can_send_data(stream_id, 1).await?);
765 }
766 }
767
768 Ok(())
769 })?;
770 }
771 }
772}