#![allow(dead_code)]
use bytes::Bytes;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use tokio::sync::{Notify, Semaphore};
use super::StreamId;
use crate::simulation::{RealTime, TimeSource};
#[derive(Debug, Clone)]
pub struct PipedStreamConfig {
pub max_buffered_fragments: usize,
pub max_buffered_bytes: usize,
pub max_concurrent_sends: usize,
}
impl Default for PipedStreamConfig {
fn default() -> Self {
Self {
max_buffered_fragments: 100,
max_buffered_bytes: 1024 * 1024,
max_concurrent_sends: 10,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PipedStreamError {
BufferFull {
buffered_fragments: usize,
buffered_bytes: usize,
},
InvalidFragment { index: u32, total: u32 },
Cancelled,
SendFailed {
target_index: usize,
message: String,
},
AllTargetsFailed,
}
impl std::fmt::Display for PipedStreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PipedStreamError::BufferFull {
buffered_fragments,
buffered_bytes,
} => {
write!(
f,
"buffer full: {} fragments, {} bytes",
buffered_fragments, buffered_bytes
)
}
PipedStreamError::InvalidFragment { index, total } => {
write!(
f,
"invalid fragment index {} (total fragments: {})",
index, total
)
}
PipedStreamError::Cancelled => write!(f, "stream cancelled"),
PipedStreamError::SendFailed {
target_index,
message,
} => {
write!(f, "send to target {} failed: {}", target_index, message)
}
PipedStreamError::AllTargetsFailed => write!(f, "all targets failed"),
}
}
}
impl std::error::Error for PipedStreamError {}
#[derive(Debug, Clone)]
pub(crate) struct ForwardFragment {
pub(crate) stream_id: StreamId,
pub(crate) fragment_number: u32,
pub(crate) total_bytes: u64,
pub(crate) payload: Bytes,
}
pub struct PipedStream<T: TimeSource = RealTime> {
stream_id: StreamId,
total_bytes: u64,
total_fragments: u32,
next_to_forward: AtomicU32,
out_of_order: parking_lot::Mutex<BTreeMap<u32, Bytes>>,
buffered_bytes: AtomicU64,
config: PipedStreamConfig,
send_semaphores: Vec<Arc<Semaphore>>,
cancelled: std::sync::atomic::AtomicBool,
cancel_notify: Notify,
time_source: T,
}
impl PipedStream<RealTime> {
pub fn new(
stream_id: StreamId,
total_bytes: u64,
num_targets: usize,
config: PipedStreamConfig,
) -> Self {
Self::new_with_time_source(stream_id, total_bytes, num_targets, config, RealTime::new())
}
}
impl<T: TimeSource> PipedStream<T> {
pub fn new_with_time_source(
stream_id: StreamId,
total_bytes: u64,
num_targets: usize,
config: PipedStreamConfig,
time_source: T,
) -> Self {
const FRAGMENT_PAYLOAD_SIZE: usize = crate::transport::packet_data::MAX_DATA_SIZE - 40;
let total_fragments = if total_bytes == 0 {
0
} else {
(total_bytes as usize).div_ceil(FRAGMENT_PAYLOAD_SIZE) as u32
};
let send_semaphores = (0..num_targets)
.map(|_| Arc::new(Semaphore::new(config.max_concurrent_sends)))
.collect();
Self {
stream_id,
total_bytes,
total_fragments,
next_to_forward: AtomicU32::new(1), out_of_order: parking_lot::Mutex::new(BTreeMap::new()),
buffered_bytes: AtomicU64::new(0),
config,
send_semaphores,
cancelled: std::sync::atomic::AtomicBool::new(false),
cancel_notify: Notify::new(),
time_source,
}
}
pub fn stream_id(&self) -> StreamId {
self.stream_id
}
pub fn total_bytes(&self) -> u64 {
self.total_bytes
}
pub fn total_fragments(&self) -> u32 {
self.total_fragments
}
pub fn is_complete(&self) -> bool {
self.next_to_forward.load(Ordering::Acquire) > self.total_fragments
}
pub fn buffered_count(&self) -> usize {
self.out_of_order.lock().len()
}
pub fn buffered_bytes(&self) -> u64 {
self.buffered_bytes.load(Ordering::Relaxed)
}
pub fn next_expected(&self) -> u32 {
self.next_to_forward.load(Ordering::Acquire)
}
pub fn cancel(&self) {
self.cancelled
.store(true, std::sync::atomic::Ordering::Release);
self.cancel_notify.notify_waiters();
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::Acquire)
}
pub fn push_fragment(
&self,
fragment_number: u32,
payload: Bytes,
) -> Result<Vec<ForwardFragment>, PipedStreamError> {
if self.is_cancelled() {
return Err(PipedStreamError::Cancelled);
}
if fragment_number == 0 || fragment_number > self.total_fragments {
return Err(PipedStreamError::InvalidFragment {
index: fragment_number,
total: self.total_fragments,
});
}
let next = self.next_to_forward.load(Ordering::Acquire);
if fragment_number < next {
return Ok(vec![]);
}
if fragment_number == next {
return self.forward_contiguous(fragment_number, payload);
}
self.buffer_fragment(fragment_number, payload)?;
Ok(vec![])
}
fn buffer_fragment(
&self,
fragment_number: u32,
payload: Bytes,
) -> Result<(), PipedStreamError> {
let payload_len = payload.len() as u64;
let current_bytes = self.buffered_bytes.load(Ordering::Relaxed);
if current_bytes + payload_len > self.config.max_buffered_bytes as u64 {
return Err(PipedStreamError::BufferFull {
buffered_fragments: self.out_of_order.lock().len(),
buffered_bytes: current_bytes as usize,
});
}
let mut buffer = self.out_of_order.lock();
if buffer.len() >= self.config.max_buffered_fragments {
return Err(PipedStreamError::BufferFull {
buffered_fragments: buffer.len(),
buffered_bytes: current_bytes as usize,
});
}
if buffer.insert(fragment_number, payload).is_none() {
self.buffered_bytes
.fetch_add(payload_len, Ordering::Relaxed);
}
Ok(())
}
fn forward_contiguous(
&self,
fragment_number: u32,
payload: Bytes,
) -> Result<Vec<ForwardFragment>, PipedStreamError> {
let mut to_forward = Vec::new();
to_forward.push(ForwardFragment {
stream_id: self.stream_id,
fragment_number,
total_bytes: self.total_bytes,
payload,
});
let mut current = fragment_number + 1;
{
let mut buffer = self.out_of_order.lock();
while let Some(buffered_payload) = buffer.remove(¤t) {
let payload_len = buffered_payload.len() as u64;
self.buffered_bytes
.fetch_sub(payload_len, Ordering::Relaxed);
to_forward.push(ForwardFragment {
stream_id: self.stream_id,
fragment_number: current,
total_bytes: self.total_bytes,
payload: buffered_payload,
});
current += 1;
}
self.next_to_forward.store(current, Ordering::Release);
}
Ok(to_forward)
}
pub async fn acquire_send_permit(
&self,
target_index: usize,
) -> Result<tokio::sync::OwnedSemaphorePermit, PipedStreamError> {
if self.is_cancelled() {
return Err(PipedStreamError::Cancelled);
}
if target_index >= self.send_semaphores.len() {
return Err(PipedStreamError::SendFailed {
target_index,
message: "invalid target index".into(),
});
}
let semaphore = self.send_semaphores[target_index].clone();
tokio::select! {
biased;
_ = self.cancel_notify.notified() => {
Err(PipedStreamError::Cancelled)
}
result = semaphore.acquire_owned() => {
if self.is_cancelled() {
return Err(PipedStreamError::Cancelled);
}
result.map_err(|_| PipedStreamError::Cancelled)
}
}
}
pub fn available_permits(&self, target_index: usize) -> usize {
if target_index >= self.send_semaphores.len() {
return 0;
}
self.send_semaphores[target_index].available_permits()
}
}
impl<T: TimeSource> std::fmt::Debug for PipedStream<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PipedStream")
.field("stream_id", &self.stream_id)
.field("total_bytes", &self.total_bytes)
.field("total_fragments", &self.total_fragments)
.field(
"next_to_forward",
&self.next_to_forward.load(Ordering::Relaxed),
)
.field("buffered_fragments", &self.out_of_order.lock().len())
.field(
"buffered_bytes",
&self.buffered_bytes.load(Ordering::Relaxed),
)
.field("cancelled", &self.is_cancelled())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GlobalExecutor;
use crate::simulation::VirtualTime;
use std::time::Duration;
fn make_stream_id() -> StreamId {
StreamId::next()
}
#[test]
fn test_in_order_forwarding() {
let stream = PipedStream::new(
make_stream_id(),
3000, 1,
PipedStreamConfig::default(),
);
assert_eq!(stream.total_fragments(), 3);
let result = stream
.push_fragment(1, Bytes::from_static(b"fragment 1"))
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].fragment_number, 1);
assert_eq!(stream.next_expected(), 2);
let result = stream
.push_fragment(2, Bytes::from_static(b"fragment 2"))
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].fragment_number, 2);
assert_eq!(stream.next_expected(), 3);
let result = stream
.push_fragment(3, Bytes::from_static(b"fragment 3"))
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].fragment_number, 3);
assert!(stream.is_complete());
assert_eq!(stream.buffered_count(), 0);
}
#[test]
fn test_out_of_order_buffering() {
let stream = PipedStream::new(make_stream_id(), 3000, 1, PipedStreamConfig::default());
assert_eq!(stream.total_fragments(), 3);
let result = stream
.push_fragment(3, Bytes::from_static(b"fragment 3"))
.unwrap();
assert!(result.is_empty()); assert_eq!(stream.buffered_count(), 1);
assert_eq!(stream.next_expected(), 1);
let result = stream
.push_fragment(2, Bytes::from_static(b"fragment 2"))
.unwrap();
assert!(result.is_empty()); assert_eq!(stream.buffered_count(), 2);
let result = stream
.push_fragment(1, Bytes::from_static(b"fragment 1"))
.unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].fragment_number, 1);
assert_eq!(result[1].fragment_number, 2);
assert_eq!(result[2].fragment_number, 3);
assert!(stream.is_complete());
assert_eq!(stream.buffered_count(), 0);
}
#[test]
fn test_duplicate_fragments() {
let stream = PipedStream::new(make_stream_id(), 4000, 1, PipedStreamConfig::default());
let result = stream
.push_fragment(1, Bytes::from_static(b"fragment 1"))
.unwrap();
assert_eq!(result.len(), 1);
let result = stream
.push_fragment(1, Bytes::from_static(b"fragment 1 dup"))
.unwrap();
assert!(result.is_empty());
let _ = stream
.push_fragment(3, Bytes::from_static(b"fragment 3"))
.unwrap();
assert_eq!(stream.buffered_count(), 1);
let _ = stream
.push_fragment(3, Bytes::from_static(b"fragment 3 dup"))
.unwrap();
assert_eq!(stream.buffered_count(), 1);
}
#[test]
fn test_buffer_fragment_limit() {
let config = PipedStreamConfig {
max_buffered_fragments: 2,
max_buffered_bytes: 1024 * 1024,
max_concurrent_sends: 10,
};
let stream = PipedStream::new(make_stream_id(), 100_000, 1, config);
stream
.push_fragment(3, Bytes::from_static(b"frag 3"))
.unwrap();
stream
.push_fragment(4, Bytes::from_static(b"frag 4"))
.unwrap();
assert_eq!(stream.buffered_count(), 2);
let result = stream.push_fragment(5, Bytes::from_static(b"frag 5"));
assert!(matches!(result, Err(PipedStreamError::BufferFull { .. })));
}
#[test]
fn test_buffer_bytes_limit() {
let config = PipedStreamConfig {
max_buffered_fragments: 100,
max_buffered_bytes: 20, max_concurrent_sends: 10,
};
let stream = PipedStream::new(make_stream_id(), 100_000, 1, config);
stream
.push_fragment(2, Bytes::from_static(b"0123456789"))
.unwrap();
assert_eq!(stream.buffered_bytes(), 10);
let result = stream.push_fragment(3, Bytes::from_static(b"012345678901234"));
assert!(matches!(result, Err(PipedStreamError::BufferFull { .. })));
}
#[test]
fn test_invalid_fragment_index() {
let stream = PipedStream::new(make_stream_id(), 4000, 1, PipedStreamConfig::default());
let result = stream.push_fragment(0, Bytes::from_static(b"bad"));
assert!(matches!(
result,
Err(PipedStreamError::InvalidFragment { index: 0, .. })
));
let result = stream.push_fragment(100, Bytes::from_static(b"bad"));
assert!(matches!(
result,
Err(PipedStreamError::InvalidFragment { index: 100, .. })
));
}
#[test]
fn test_cancellation() {
let stream = PipedStream::new(make_stream_id(), 4000, 1, PipedStreamConfig::default());
assert!(!stream.is_cancelled());
stream.cancel();
assert!(stream.is_cancelled());
let result = stream.push_fragment(1, Bytes::from_static(b"data"));
assert!(matches!(result, Err(PipedStreamError::Cancelled)));
}
#[tokio::test]
async fn test_send_permits() {
let config = PipedStreamConfig {
max_buffered_fragments: 100,
max_buffered_bytes: 1024 * 1024,
max_concurrent_sends: 2,
};
let stream = PipedStream::new(make_stream_id(), 4000, 2, config);
assert_eq!(stream.available_permits(0), 2);
let permit1 = stream.acquire_send_permit(0).await.unwrap();
assert_eq!(stream.available_permits(0), 1);
let permit2 = stream.acquire_send_permit(0).await.unwrap();
assert_eq!(stream.available_permits(0), 0);
drop(permit1);
assert_eq!(stream.available_permits(0), 1);
assert_eq!(stream.available_permits(1), 2);
drop(permit2);
}
#[test]
fn test_buffered_bytes_tracking() {
let stream = PipedStream::new(make_stream_id(), 100_000, 1, PipedStreamConfig::default());
stream
.push_fragment(2, Bytes::from_static(b"12345"))
.unwrap(); stream
.push_fragment(3, Bytes::from_static(b"1234567890"))
.unwrap(); assert_eq!(stream.buffered_bytes(), 15);
let result = stream.push_fragment(1, Bytes::from_static(b"abc")).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(stream.buffered_bytes(), 0);
assert_eq!(stream.buffered_count(), 0);
}
#[test]
fn test_partial_cascade() {
let stream = PipedStream::new(make_stream_id(), 6500, 1, PipedStreamConfig::default());
assert_eq!(stream.total_fragments(), 6);
stream
.push_fragment(3, Bytes::from_static(b"frag 3"))
.unwrap();
stream
.push_fragment(5, Bytes::from_static(b"frag 5"))
.unwrap();
stream
.push_fragment(6, Bytes::from_static(b"frag 6"))
.unwrap();
assert_eq!(stream.buffered_count(), 3);
stream
.push_fragment(1, Bytes::from_static(b"frag 1"))
.unwrap();
let result = stream
.push_fragment(2, Bytes::from_static(b"frag 2"))
.unwrap();
assert_eq!(result.len(), 2); assert_eq!(result[0].fragment_number, 2);
assert_eq!(result[1].fragment_number, 3);
assert_eq!(stream.next_expected(), 4);
assert_eq!(stream.buffered_count(), 2);
let result = stream
.push_fragment(4, Bytes::from_static(b"frag 4"))
.unwrap();
assert_eq!(result.len(), 3); assert!(stream.is_complete());
}
#[tokio::test]
async fn test_concurrent_push_fragments() {
use std::sync::Arc;
let stream = Arc::new(PipedStream::new(
make_stream_id(),
28000, 1,
PipedStreamConfig::default(),
));
let num_fragments = stream.total_fragments();
assert!(num_fragments >= 10);
let mut handles = Vec::new();
for frag_num in 1..=num_fragments {
let stream = Arc::clone(&stream);
handles.push(GlobalExecutor::spawn(async move {
let payload = Bytes::from(format!("fragment {}", frag_num));
stream.push_fragment(frag_num, payload)
}));
}
let mut total_forwarded = 0;
for handle in handles {
if let Ok(Ok(fragments)) = handle.await {
total_forwarded += fragments.len();
}
}
assert_eq!(total_forwarded, num_fragments as usize);
assert!(stream.is_complete());
assert_eq!(stream.buffered_count(), 0);
}
#[tokio::test]
async fn test_cancel_during_permit_wait() {
use std::sync::Arc;
let config = PipedStreamConfig {
max_buffered_fragments: 100,
max_buffered_bytes: 1024 * 1024,
max_concurrent_sends: 1,
};
let time_source = VirtualTime::new();
let stream = Arc::new(PipedStream::new_with_time_source(
make_stream_id(),
4000,
1,
config,
time_source.clone(),
));
let _permit = stream.acquire_send_permit(0).await.unwrap();
assert_eq!(stream.available_permits(0), 0);
let stream_clone = Arc::clone(&stream);
let waiter =
GlobalExecutor::spawn(async move { stream_clone.acquire_send_permit(0).await });
time_source.advance(Duration::from_millis(10));
stream.cancel();
let result = tokio::select! {
waiter_result = waiter => {
Some(waiter_result.expect("task should not panic"))
}
_ = time_source.sleep(Duration::from_millis(100)) => {
None
}
};
assert!(result.is_some(), "waiter should not timeout");
assert!(matches!(result.unwrap(), Err(PipedStreamError::Cancelled)));
}
#[tokio::test]
async fn test_concurrent_push_and_cancel() {
use std::sync::Arc;
let time_source = VirtualTime::new();
let stream = Arc::new(PipedStream::new_with_time_source(
make_stream_id(),
14000, 1,
PipedStreamConfig::default(),
time_source.clone(),
));
let stream_clone = Arc::clone(&stream);
let time_source_clone = time_source.clone();
let pusher = GlobalExecutor::spawn(async move {
for i in 1..=10 {
if stream_clone.is_cancelled() {
break;
}
let _push = stream_clone.push_fragment(i, Bytes::from(format!("frag {}", i)));
time_source_clone.sleep(Duration::from_millis(1)).await;
}
});
time_source.advance(Duration::from_millis(5));
stream.cancel();
let _join = pusher.await;
assert!(stream.is_cancelled());
let result = stream.push_fragment(1, Bytes::from_static(b"data"));
assert!(matches!(result, Err(PipedStreamError::Cancelled)));
}
#[test]
fn test_buffer_full_recovery() {
let config = PipedStreamConfig {
max_buffered_fragments: 2,
max_buffered_bytes: 1024 * 1024,
max_concurrent_sends: 10,
};
let stream = PipedStream::new(make_stream_id(), 10_000, 1, config);
stream
.push_fragment(3, Bytes::from_static(b"frag 3"))
.unwrap();
stream
.push_fragment(4, Bytes::from_static(b"frag 4"))
.unwrap();
assert_eq!(stream.buffered_count(), 2);
let result = stream.push_fragment(5, Bytes::from_static(b"frag 5"));
assert!(matches!(result, Err(PipedStreamError::BufferFull { .. })));
stream
.push_fragment(1, Bytes::from_static(b"frag 1"))
.unwrap();
let result = stream
.push_fragment(2, Bytes::from_static(b"frag 2"))
.unwrap();
assert_eq!(result.len(), 3);
assert_eq!(stream.buffered_count(), 0);
stream
.push_fragment(6, Bytes::from_static(b"frag 6"))
.unwrap();
assert_eq!(stream.buffered_count(), 1);
let result = stream
.push_fragment(5, Bytes::from_static(b"frag 5"))
.unwrap();
assert_eq!(result.len(), 2); }
#[tokio::test]
async fn test_all_permits_exhausted_then_released() {
use std::sync::Arc;
let config = PipedStreamConfig {
max_buffered_fragments: 100,
max_buffered_bytes: 1024 * 1024,
max_concurrent_sends: 2,
};
let stream = Arc::new(PipedStream::new(make_stream_id(), 4000, 1, config));
let permit1 = stream.acquire_send_permit(0).await.unwrap();
let permit2 = stream.acquire_send_permit(0).await.unwrap();
assert_eq!(stream.available_permits(0), 0);
let stream_clone = Arc::clone(&stream);
let waiter =
GlobalExecutor::spawn(async move { stream_clone.acquire_send_permit(0).await });
drop(permit1);
let result = waiter.await.unwrap();
assert!(result.is_ok());
drop(permit2);
}
}