use bytes::Bytes;
use dashmap::DashMap;
use event_listener::EventListener;
use futures::Stream;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use super::StreamId;
use super::streaming_buffer::LockFreeStreamBuffer;
pub const STREAM_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamError {
Cancelled,
NotFound,
InvalidFragment { message: String },
InactivityTimeout,
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamError::Cancelled => write!(f, "stream was cancelled"),
StreamError::NotFound => write!(f, "stream not found in registry"),
StreamError::InvalidFragment { message } => {
write!(f, "invalid fragment: {}", message)
}
StreamError::InactivityTimeout => {
write!(f, "no fragments received within inactivity timeout")
}
}
}
}
impl std::error::Error for StreamError {}
struct SyncState {
cancelled: bool,
wakers: Vec<Waker>,
}
impl SyncState {
fn new() -> Self {
Self {
cancelled: false,
wakers: Vec::new(),
}
}
fn wake_all(&mut self) {
for waker in self.wakers.drain(..) {
waker.wake();
}
}
}
#[derive(Clone)]
pub struct StreamHandle {
buffer: Arc<LockFreeStreamBuffer>,
sync: Arc<parking_lot::RwLock<SyncState>>,
stream_id: StreamId,
total_bytes: u64,
}
impl StreamHandle {
pub fn new(stream_id: StreamId, total_bytes: u64) -> Self {
Self {
buffer: Arc::new(LockFreeStreamBuffer::new(total_bytes)),
sync: Arc::new(parking_lot::RwLock::new(SyncState::new())),
stream_id,
total_bytes,
}
}
#[allow(dead_code)]
pub(crate) fn stream_id(&self) -> StreamId {
self.stream_id
}
pub fn total_bytes(&self) -> u64 {
self.total_bytes
}
pub fn is_complete(&self) -> bool {
self.buffer.is_complete()
}
pub fn received_fragments(&self) -> usize {
self.buffer.inserted_count()
}
pub fn total_fragments(&self) -> usize {
self.buffer.total_fragments()
}
pub fn stream(&self) -> StreamingInboundStream {
StreamingInboundStream {
handle: self.clone(),
next_fragment: 1,
bytes_read: 0,
auto_reclaim: false,
listener: None,
}
}
pub fn stream_with_reclaim(&self) -> StreamingInboundStream {
StreamingInboundStream {
handle: self.clone(),
next_fragment: 1,
bytes_read: 0,
auto_reclaim: true,
listener: None,
}
}
pub fn fork(&self) -> Self {
let already_cancelled = self.sync.read().cancelled;
let mut sync = SyncState::new();
sync.cancelled = already_cancelled;
Self {
buffer: self.buffer.clone(),
sync: Arc::new(parking_lot::RwLock::new(sync)),
stream_id: self.stream_id,
total_bytes: self.total_bytes,
}
}
pub(crate) fn push_fragment(
&self,
fragment_number: u32,
data: Bytes,
) -> Result<bool, StreamError> {
if self.sync.read().cancelled {
return Err(StreamError::Cancelled);
}
match self.buffer.insert(fragment_number, data) {
Ok(inserted) => {
if inserted {
self.sync.write().wake_all();
}
Ok(inserted)
}
Err(e) => Err(StreamError::InvalidFragment {
message: e.to_string(),
}),
}
}
pub(crate) fn cancel(&self) {
let mut sync = self.sync.write();
sync.cancelled = true;
sync.wake_all();
drop(sync);
self.buffer.notifier().notify(usize::MAX);
}
pub fn try_assemble(&self) -> Option<Vec<u8>> {
self.buffer.assemble()
}
pub async fn assemble(&self) -> Result<Vec<u8>, StreamError> {
loop {
if self.sync.read().cancelled {
return Err(StreamError::Cancelled);
}
if self.buffer.is_complete() {
if let Some(data) = self.buffer.assemble() {
return Ok(data);
}
tracing::debug!(
stream_id = %self.stream_id.0,
total_bytes = self.total_bytes,
received_fragments = self.buffer.inserted_count(),
total_fragments = self.buffer.total_fragments(),
"Stream marked complete but assembly insufficient, waiting for overflow fragment"
);
}
match tokio::time::timeout(STREAM_INACTIVITY_TIMEOUT, self.buffer.notifier().listen())
.await
{
Ok(()) => { }
Err(_) => {
if self.sync.read().cancelled {
return Err(StreamError::Cancelled);
}
if let Some(data) = self.buffer.assemble() {
return Ok(data);
}
tracing::warn!(
stream_id = %self.stream_id.0,
total_bytes = self.total_bytes,
received_fragments = self.buffer.inserted_count(),
total_fragments = self.buffer.total_fragments(),
timeout_secs = STREAM_INACTIVITY_TIMEOUT.as_secs(),
"Stream assembly timed out — no fragments received within inactivity window"
);
return Err(StreamError::InactivityTimeout);
}
}
}
}
}
impl std::fmt::Debug for StreamHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamHandle")
.field("stream_id", &self.stream_id)
.field("total_bytes", &self.total_bytes)
.field("received", &self.buffer.inserted_count())
.field("total_fragments", &self.buffer.total_fragments())
.field("complete", &self.buffer.is_complete())
.field("cancelled", &self.sync.read().cancelled)
.finish()
}
}
pub struct StreamingInboundStream {
handle: StreamHandle,
next_fragment: u32,
bytes_read: u64,
auto_reclaim: bool,
listener: Option<Pin<Box<EventListener>>>,
}
impl StreamingInboundStream {
#[allow(dead_code)]
pub fn next_fragment_number(&self) -> u32 {
self.next_fragment
}
#[allow(dead_code)]
pub fn bytes_read(&self) -> u64 {
self.bytes_read
}
#[allow(dead_code)]
pub(crate) fn stream_id(&self) -> StreamId {
self.handle.stream_id
}
#[allow(dead_code)]
pub fn is_auto_reclaim(&self) -> bool {
self.auto_reclaim
}
fn try_get_fragment(&self, idx: u32) -> Option<Bytes> {
if self.auto_reclaim {
self.handle.buffer.take(idx)
} else {
self.handle.buffer.get(idx).cloned()
}
}
}
impl Stream for StreamingInboundStream {
type Item = Result<Bytes, StreamError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let next_idx = self.next_fragment;
if self.handle.sync.read().cancelled {
return Poll::Ready(Some(Err(StreamError::Cancelled)));
}
let total_fragments = self.handle.buffer.total_fragments() as u32;
if next_idx > total_fragments {
return Poll::Ready(None); }
if self.handle.buffer.is_complete() && self.handle.buffer.get(next_idx).is_none() {
return Poll::Ready(None);
}
if let Some(data) = self.try_get_fragment(next_idx) {
self.listener = None;
self.next_fragment = next_idx + 1;
self.bytes_read += data.len() as u64;
return Poll::Ready(Some(Ok(data)));
}
if self.listener.is_none() {
self.listener = Some(Box::pin(self.handle.buffer.notifier().listen()));
}
if self.handle.sync.read().cancelled {
self.listener = None;
return Poll::Ready(Some(Err(StreamError::Cancelled)));
}
if let Some(data) = self.try_get_fragment(next_idx) {
self.listener = None;
self.next_fragment = next_idx + 1;
self.bytes_read += data.len() as u64;
return Poll::Ready(Some(Ok(data)));
}
if let Some(listener) = self.listener.as_mut() {
match listener.as_mut().poll(cx) {
Poll::Ready(()) => {
self.listener = None;
if self.handle.sync.read().cancelled {
return Poll::Ready(Some(Err(StreamError::Cancelled)));
}
if let Some(data) = self.try_get_fragment(next_idx) {
self.next_fragment = next_idx + 1;
self.bytes_read += data.len() as u64;
return Poll::Ready(Some(Ok(data)));
}
self.listener = Some(Box::pin(self.handle.buffer.notifier().listen()));
if let Some(data) = self.try_get_fragment(next_idx) {
self.listener = None;
self.next_fragment = next_idx + 1;
self.bytes_read += data.len() as u64;
return Poll::Ready(Some(Ok(data)));
}
if let Some(new_listener) = self.listener.as_mut() {
match new_listener.as_mut().poll(cx) {
Poll::Ready(()) => self.listener = None,
Poll::Pending => {}
}
}
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Pending
}
}
}
pub struct StreamRegistry {
streams: DashMap<StreamId, StreamHandle>,
}
impl StreamRegistry {
pub fn new() -> Self {
Self {
streams: DashMap::new(),
}
}
pub(crate) fn register(&self, stream_id: StreamId, total_bytes: u64) -> StreamHandle {
self.streams
.entry(stream_id)
.or_insert_with(|| StreamHandle::new(stream_id, total_bytes))
.clone()
}
#[allow(dead_code)]
pub(crate) fn get(&self, stream_id: StreamId) -> Option<StreamHandle> {
self.streams.get(&stream_id).map(|r| r.clone())
}
#[allow(dead_code)]
pub(crate) fn remove(&self, stream_id: StreamId) -> Option<StreamHandle> {
self.streams.remove(&stream_id).map(|(_, h)| h)
}
pub fn stream_count(&self) -> usize {
self.streams.len()
}
pub fn cancel_all(&self) {
for entry in self.streams.iter() {
entry.value().cancel();
}
self.streams.clear();
}
}
impl Default for StreamRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GlobalExecutor;
use futures::StreamExt;
fn make_stream_id() -> StreamId {
StreamId::next()
}
#[test]
fn test_stream_handle_creation() {
let id = make_stream_id();
let handle = StreamHandle::new(id, 1000);
assert_eq!(handle.stream_id(), id);
assert_eq!(handle.total_bytes(), 1000);
assert!(!handle.is_complete());
}
#[test]
fn test_stream_handle_push_fragment() {
let handle = StreamHandle::new(make_stream_id(), 100);
let result = handle.push_fragment(1, Bytes::from_static(b"hello"));
assert!(result.is_ok());
assert!(result.unwrap());
let result = handle.push_fragment(1, Bytes::from_static(b"world"));
assert!(result.is_ok());
assert!(!result.unwrap()); }
#[test]
fn test_stream_handle_invalid_fragment() {
let handle = StreamHandle::new(make_stream_id(), 100);
let result = handle.push_fragment(0, Bytes::from_static(b"hello"));
assert!(matches!(result, Err(StreamError::InvalidFragment { .. })));
let result = handle.push_fragment(3, Bytes::from_static(b"hello"));
assert!(matches!(result, Err(StreamError::InvalidFragment { .. })));
}
#[test]
fn test_stream_handle_cancel() {
let handle = StreamHandle::new(make_stream_id(), 100);
handle.cancel();
let result = handle.push_fragment(1, Bytes::from_static(b"hello"));
assert!(matches!(result, Err(StreamError::Cancelled)));
}
#[test]
fn test_stream_handle_fork() {
let handle = StreamHandle::new(make_stream_id(), 100);
handle
.push_fragment(1, Bytes::from_static(b"hello"))
.unwrap();
let forked = handle.fork();
assert!(handle.is_complete());
assert!(forked.is_complete());
assert_eq!(handle.try_assemble(), forked.try_assemble());
}
#[test]
fn test_fork_independent_cancellation() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = FRAGMENT_PAYLOAD_SIZE as u64;
let data = vec![42u8; FRAGMENT_PAYLOAD_SIZE];
let handle = StreamHandle::new(make_stream_id(), total);
handle.push_fragment(1, Bytes::from(data.clone())).unwrap();
let forked = handle.fork();
handle.cancel();
assert!(handle.sync.read().cancelled);
assert!(matches!(
handle.push_fragment(1, Bytes::from(vec![0u8; 10])),
Err(StreamError::Cancelled)
));
assert!(!forked.sync.read().cancelled);
assert!(forked.is_complete());
assert_eq!(forked.try_assemble(), Some(data));
}
#[tokio::test]
async fn test_fork_stream_reads_after_original_cancel() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
handle
.push_fragment(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let forked = handle.fork();
let mut stream = forked.stream();
handle.cancel();
let chunk1 = stream.next().await;
assert!(chunk1.is_some());
assert_eq!(chunk1.unwrap().unwrap().len(), FRAGMENT_PAYLOAD_SIZE);
let chunk2 = stream.next().await;
assert!(chunk2.is_some());
assert_eq!(chunk2.unwrap().unwrap().len(), FRAGMENT_PAYLOAD_SIZE);
let chunk3 = stream.next().await;
assert!(chunk3.is_none());
}
#[tokio::test]
async fn test_fork_incremental_wakeup() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let forked = handle.fork();
let mut stream = forked.stream();
let chunk1 = stream.next().await;
assert!(chunk1.is_some());
assert_eq!(chunk1.unwrap().unwrap(), vec![1u8; FRAGMENT_PAYLOAD_SIZE]);
let handle_clone = handle.clone();
tokio::spawn(async move {
tokio::task::yield_now().await;
handle_clone
.push_fragment(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
});
let chunk2 = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next()).await;
assert!(
chunk2.is_ok(),
"fork should wake up when fragment arrives via original handle"
);
let chunk2 = chunk2.unwrap();
assert!(chunk2.is_some());
assert_eq!(chunk2.unwrap().unwrap(), vec![2u8; FRAGMENT_PAYLOAD_SIZE]);
handle
.push_fragment(3, Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let chunk3 = stream.next().await;
assert!(chunk3.is_some());
assert_eq!(chunk3.unwrap().unwrap(), vec![3u8; FRAGMENT_PAYLOAD_SIZE]);
let end = stream.next().await;
assert!(end.is_none());
}
#[tokio::test]
async fn test_streaming_inbound_stream_basic() {
let handle = StreamHandle::new(make_stream_id(), 15);
handle
.push_fragment(1, Bytes::from_static(b"hello"))
.unwrap();
let mut stream = handle.stream();
let chunk = stream.next().await;
assert!(chunk.is_some());
assert_eq!(chunk.unwrap().unwrap(), Bytes::from_static(b"hello"));
let chunk = stream.next().await;
assert!(chunk.is_none());
}
#[tokio::test]
async fn test_streaming_inbound_stream_incremental() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
let mut stream = handle.stream();
let handle_clone = handle.clone();
let producer = GlobalExecutor::spawn(async move {
for i in 1..=3 {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
handle_clone
.push_fragment(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
});
let mut count = 0;
while let Some(result) = stream.next().await {
count += 1;
let bytes = result.unwrap();
assert_eq!(bytes.len(), FRAGMENT_PAYLOAD_SIZE);
}
producer.await.unwrap();
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_streaming_cancelled() {
let handle = StreamHandle::new(make_stream_id(), 100);
let mut stream = handle.stream();
handle.cancel();
let result = stream.next().await;
assert!(matches!(result, Some(Err(StreamError::Cancelled))));
}
#[tokio::test]
async fn test_stream_registry_register_and_get() {
let registry = StreamRegistry::new();
let id = make_stream_id();
let handle = registry.register(id, 1000);
assert_eq!(handle.stream_id(), id);
let retrieved = registry.get(id);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().stream_id(), id);
}
#[tokio::test]
async fn test_stream_registry_register_existing() {
let registry = StreamRegistry::new();
let id = make_stream_id();
let handle1 = registry.register(id, 1000);
let handle2 = registry.register(id, 2000);
assert_eq!(handle1.total_bytes(), handle2.total_bytes());
}
#[tokio::test]
async fn test_stream_registry_remove() {
let registry = StreamRegistry::new();
let id = make_stream_id();
registry.register(id, 1000);
assert_eq!(registry.stream_count(), 1);
let removed = registry.remove(id);
assert!(removed.is_some());
assert_eq!(registry.stream_count(), 0);
assert!(registry.get(id).is_none());
}
#[tokio::test]
async fn test_stream_handle_assemble_async() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
let handle_clone = handle.clone();
let producer = GlobalExecutor::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
handle_clone
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
handle_clone
.push_fragment(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
});
let result = handle.assemble().await;
producer.await.unwrap();
assert!(result.is_ok());
let data = result.unwrap();
assert_eq!(data.len(), total as usize);
}
#[tokio::test]
async fn test_assemble_inactivity_timeout() {
let handle = StreamHandle::new(make_stream_id(), 1000);
handle
.push_fragment(1, Bytes::from(vec![0u8; 100]))
.unwrap();
tokio::time::pause();
let result = handle.assemble().await;
assert!(
matches!(result, Err(StreamError::InactivityTimeout)),
"Expected InactivityTimeout, got: {:?}",
result
);
}
#[tokio::test]
async fn test_assemble_timeout_resets_per_fragment() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
tokio::time::pause();
let handle_clone = handle.clone();
let producer = GlobalExecutor::spawn(async move {
for i in 1..=3u32 {
tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
handle_clone
.push_fragment(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
});
let result = handle.assemble().await;
producer.await.unwrap();
assert!(
result.is_ok(),
"Slow-but-active stream should succeed: {:?}",
result
);
assert_eq!(result.unwrap().len(), total as usize);
}
#[tokio::test]
async fn test_stalled_stream_killed_within_inactivity_timeout() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 5) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
tokio::time::pause();
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
handle
.push_fragment(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let start = tokio::time::Instant::now();
let result = handle.assemble().await;
assert!(
matches!(result, Err(StreamError::InactivityTimeout)),
"Stalled stream should be killed by inactivity timeout, got: {:?}",
result
);
let elapsed = start.elapsed();
assert!(
elapsed < std::time::Duration::from_secs(10),
"Inactivity timeout should fire within ~5s, took {:?}",
elapsed
);
}
#[tokio::test]
async fn test_multiple_independent_streams() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let mut stream1 = handle.stream();
let mut stream2 = handle.stream();
let chunk1 = stream1.next().await.unwrap().unwrap();
let chunk2 = stream2.next().await.unwrap().unwrap();
assert_eq!(chunk1, chunk2);
assert_eq!(stream1.next_fragment_number(), 2);
assert_eq!(stream2.next_fragment_number(), 2);
handle
.push_fragment(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let chunk1 = stream1.next().await.unwrap().unwrap();
let chunk2 = stream2.next().await.unwrap().unwrap();
assert_eq!(chunk1, chunk2);
assert!(stream1.next().await.is_none());
assert!(stream2.next().await.is_none());
}
#[test]
fn test_zero_byte_stream() {
let handle = StreamHandle::new(make_stream_id(), 0);
assert_eq!(handle.total_bytes(), 0);
assert_eq!(handle.total_fragments(), 0);
assert!(handle.is_complete()); assert_eq!(handle.try_assemble(), Some(vec![]));
}
#[tokio::test]
async fn test_zero_byte_stream_streaming() {
let handle = StreamHandle::new(make_stream_id(), 0);
let mut stream = handle.stream();
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_assemble_cancelled_stream() {
let handle = StreamHandle::new(make_stream_id(), 100);
handle.cancel();
let result = handle.assemble().await;
assert!(matches!(result, Err(StreamError::Cancelled)));
}
#[tokio::test]
async fn test_cancel_during_assemble() {
let handle = StreamHandle::new(make_stream_id(), 100);
let handle_clone = handle.clone();
let assemble_task = GlobalExecutor::spawn(async move { handle_clone.assemble().await });
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
handle.cancel();
let result = tokio::time::timeout(tokio::time::Duration::from_millis(100), assemble_task)
.await
.expect("timeout")
.expect("join");
assert!(matches!(result, Err(StreamError::Cancelled)));
}
#[tokio::test]
async fn test_poll_after_stream_exhausted() {
let handle = StreamHandle::new(make_stream_id(), 5);
handle
.push_fragment(1, Bytes::from_static(b"hello"))
.unwrap();
let mut stream = handle.stream();
let chunk = stream.next().await;
assert!(chunk.is_some());
assert!(stream.next().await.is_none());
assert!(stream.next().await.is_none());
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_out_of_order_fragments_stream_waits() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
handle
.push_fragment(3, Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let mut stream = handle.stream();
let handle_clone = handle.clone();
let read_task = GlobalExecutor::spawn(async move {
let mut s = handle_clone.stream();
s.next().await
});
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
assert!(!read_task.is_finished());
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let result = tokio::time::timeout(tokio::time::Duration::from_millis(100), read_task)
.await
.expect("timeout")
.expect("join");
assert!(result.is_some());
assert!(result.unwrap().is_ok());
let chunk = stream.next().await;
assert!(chunk.is_some());
}
#[tokio::test]
async fn test_cancel_while_streaming() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let mut stream = handle.stream();
let chunk = stream.next().await;
assert!(chunk.is_some());
assert!(chunk.unwrap().is_ok());
let handle_clone = handle.clone();
let cancel_task = GlobalExecutor::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
handle_clone.cancel();
});
let result = stream.next().await;
assert!(matches!(result, Some(Err(StreamError::Cancelled))));
cancel_task.await.unwrap();
}
#[tokio::test]
async fn test_bytes_read_tracking() {
let handle = StreamHandle::new(make_stream_id(), 15);
handle
.push_fragment(1, Bytes::from_static(b"hello world!!!"))
.unwrap();
let mut stream = handle.stream();
assert_eq!(stream.bytes_read(), 0);
let chunk = stream.next().await.unwrap().unwrap();
assert_eq!(stream.bytes_read(), chunk.len() as u64);
}
#[tokio::test]
async fn test_registry_cancel_all() {
let registry = StreamRegistry::new();
let id1 = make_stream_id();
let id2 = make_stream_id();
let handle1 = registry.register(id1, 100);
let handle2 = registry.register(id2, 200);
registry.cancel_all();
let result1 = handle1.push_fragment(1, Bytes::from_static(b"test"));
let result2 = handle2.push_fragment(1, Bytes::from_static(b"test"));
assert!(matches!(result1, Err(StreamError::Cancelled)));
assert!(matches!(result2, Err(StreamError::Cancelled)));
assert_eq!(registry.stream_count(), 0);
}
#[tokio::test]
async fn test_registry_cleanup_on_normal_completion() {
let registry = StreamRegistry::new();
let id1 = make_stream_id();
let id2 = make_stream_id();
let id3 = make_stream_id();
let handle1 = registry.register(id1, 100);
let _handle2 = registry.register(id2, 200);
let _handle3 = registry.register(id3, 300);
assert_eq!(registry.stream_count(), 3);
handle1
.push_fragment(1, Bytes::from_static(b"complete data here!"))
.unwrap();
assert!(handle1.is_complete());
let removed = registry.remove(id1);
assert!(removed.is_some());
assert_eq!(registry.stream_count(), 2);
let removed = registry.remove(id2);
assert!(removed.is_some());
assert_eq!(registry.stream_count(), 1);
let removed = registry.remove(id3);
assert!(removed.is_some());
assert_eq!(registry.stream_count(), 0);
assert!(registry.get(id1).is_none());
assert!(registry.get(id2).is_none());
assert!(registry.get(id3).is_none());
}
#[test]
fn test_registry_get_nonexistent() {
let registry = StreamRegistry::new();
let id = make_stream_id();
let result = registry.get(id);
assert!(result.is_none());
}
#[test]
fn test_registry_remove_nonexistent() {
let registry = StreamRegistry::new();
let id = make_stream_id();
let result = registry.remove(id);
assert!(result.is_none());
}
#[tokio::test]
async fn test_concurrent_push_and_stream() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
use std::sync::atomic::{AtomicUsize, Ordering};
let total = (FRAGMENT_PAYLOAD_SIZE * 10) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
let fragments_received = Arc::new(AtomicUsize::new(0));
let handle_producer = handle.clone();
let producer = GlobalExecutor::spawn(async move {
for i in 1..=10 {
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
handle_producer
.push_fragment(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
});
let mut consumers = Vec::new();
for _ in 0..3 {
let h = handle.clone();
let counter = Arc::clone(&fragments_received);
consumers.push(GlobalExecutor::spawn(async move {
let mut stream = h.stream();
let mut local_count = 0;
while let Some(result) = stream.next().await {
result.unwrap();
local_count += 1;
}
counter.fetch_add(local_count, Ordering::SeqCst);
local_count
}));
}
producer.await.unwrap();
for consumer in consumers {
let count = consumer.await.unwrap();
assert_eq!(count, 10);
}
}
#[test]
fn test_stream_handle_debug() {
let handle = StreamHandle::new(make_stream_id(), 100);
handle
.push_fragment(1, Bytes::from_static(b"test"))
.unwrap();
let debug_str = format!("{:?}", handle);
assert!(debug_str.contains("StreamHandle"));
assert!(debug_str.contains("total_bytes"));
assert!(debug_str.contains("complete"));
}
#[tokio::test]
async fn test_try_assemble_before_complete() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
assert!(handle.try_assemble().is_none());
handle
.push_fragment(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let assembled = handle.try_assemble();
assert!(assembled.is_some());
assert_eq!(assembled.unwrap().len(), total as usize);
}
#[test]
fn test_stream_error_display() {
let cancelled = StreamError::Cancelled;
assert_eq!(format!("{}", cancelled), "stream was cancelled");
let not_found = StreamError::NotFound;
assert_eq!(format!("{}", not_found), "stream not found in registry");
let invalid = StreamError::InvalidFragment {
message: "test error".into(),
};
assert_eq!(format!("{}", invalid), "invalid fragment: test error");
let timeout = StreamError::InactivityTimeout;
assert_eq!(
format!("{}", timeout),
"no fragments received within inactivity timeout"
);
}
#[tokio::test]
async fn test_stream_with_reclaim_basic() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
for i in 1..=3 {
handle
.push_fragment(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
assert_eq!(handle.buffer.inserted_count(), 3);
let mut stream = handle.stream_with_reclaim();
assert!(stream.is_auto_reclaim());
let chunk = stream.next().await.unwrap().unwrap();
assert_eq!(chunk[0], 1);
assert_eq!(handle.buffer.inserted_count(), 2);
let _ = stream.next().await.unwrap().unwrap();
let _ = stream.next().await.unwrap().unwrap();
assert_eq!(handle.buffer.inserted_count(), 0);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_stream_without_reclaim_preserves_data() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
for i in 1..=2 {
handle
.push_fragment(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
let mut stream = handle.stream();
assert!(!stream.is_auto_reclaim());
let _ = stream.next().await.unwrap().unwrap();
let _ = stream.next().await.unwrap().unwrap();
assert_eq!(handle.buffer.inserted_count(), 2);
}
#[tokio::test]
async fn test_reclaim_incremental_with_delayed_fragments() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
let mut stream = handle.stream_with_reclaim();
handle
.push_fragment(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let chunk = stream.next().await.unwrap().unwrap();
assert_eq!(chunk[0], 1);
assert_eq!(handle.buffer.inserted_count(), 0);
handle
.push_fragment(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
handle
.push_fragment(3, Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
assert_eq!(handle.buffer.inserted_count(), 2);
let _ = stream.next().await.unwrap().unwrap();
assert_eq!(handle.buffer.inserted_count(), 1);
let _ = stream.next().await.unwrap().unwrap();
assert_eq!(handle.buffer.inserted_count(), 0);
}
#[tokio::test]
async fn test_reclaim_vs_fork_conflict() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
for i in 1..=2 {
handle
.push_fragment(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
let mut reclaim_stream = handle.stream_with_reclaim();
let _ = reclaim_stream.next().await.unwrap().unwrap();
let forked = handle.fork();
let _forked_stream = forked.stream();
assert!(handle.buffer.get(2).is_some());
}
#[tokio::test]
async fn test_stream_memory_efficiency() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let num_fragments = 20u32;
let total = (FRAGMENT_PAYLOAD_SIZE * num_fragments as usize) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
let handle_clone = handle.clone();
for i in 1..=num_fragments {
handle
.push_fragment(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
assert_eq!(handle_clone.buffer.inserted_count(), num_fragments as usize);
let mut stream = handle_clone.stream_with_reclaim();
for _ in 0..10 {
let _ = stream.next().await.unwrap().unwrap();
}
assert_eq!(handle.buffer.inserted_count(), 10);
for _ in 0..10 {
let _ = stream.next().await.unwrap().unwrap();
}
assert_eq!(handle.buffer.inserted_count(), 0);
}
#[tokio::test]
async fn test_assemble_waits_for_overflow_fragment() {
use super::super::streaming_buffer::FRAGMENT_PAYLOAD_SIZE;
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let handle = StreamHandle::new(make_stream_id(), total);
let reduced_payload = FRAGMENT_PAYLOAD_SIZE / 2;
let frag1_data = Bytes::from(vec![1u8; reduced_payload]);
let remaining = total as usize - reduced_payload;
let frag2_data = Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]);
let frag3_data = Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]);
let overflow_size = remaining - 2 * FRAGMENT_PAYLOAD_SIZE;
let frag4_data = Bytes::from(vec![4u8; overflow_size]);
handle.push_fragment(1, frag1_data).unwrap();
handle.push_fragment(2, frag2_data).unwrap();
handle.push_fragment(3, frag3_data).unwrap();
assert!(handle.is_complete());
assert!(handle.try_assemble().is_none());
let handle_clone = handle.clone();
let assemble_task = GlobalExecutor::spawn(async move { handle_clone.assemble().await });
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
assert!(!assemble_task.is_finished());
handle.push_fragment(4, frag4_data).unwrap();
let result = tokio::time::timeout(tokio::time::Duration::from_millis(200), assemble_task)
.await
.expect("timeout waiting for assemble")
.expect("join error");
assert!(result.is_ok(), "assemble should succeed: {:?}", result);
let data = result.unwrap();
assert_eq!(data.len(), total as usize);
}
}