use bytes::Bytes;
use event_listener::Event;
use std::ptr;
use std::sync::atomic::{AtomicPtr, AtomicU32, Ordering};
use crate::simulation::{RealTime, TimeSource};
use crate::transport::packet_data;
#[cfg(feature = "bench")]
pub const FRAGMENT_PAYLOAD_SIZE: usize = packet_data::MAX_DATA_SIZE - 41;
#[cfg(not(feature = "bench"))]
pub(crate) const FRAGMENT_PAYLOAD_SIZE: usize = packet_data::MAX_DATA_SIZE - 41;
pub struct LockFreeStreamBuffer<T: TimeSource = RealTime> {
fragments: Box<[AtomicPtr<Bytes>]>,
total_size: u64,
total_fragments: u32,
min_complete_fragments: u32,
contiguous_fragments: AtomicU32,
consumed_frontier: AtomicU32,
data_available: Event,
#[allow(dead_code)]
time_source: T,
}
impl LockFreeStreamBuffer<RealTime> {
pub fn new(total_size: u64) -> Self {
Self::new_with_time_source(total_size, RealTime::new())
}
}
impl<T: TimeSource> LockFreeStreamBuffer<T> {
pub fn new_with_time_source(total_size: u64, time_source: T) -> Self {
let base_fragments = Self::base_fragment_count(total_size);
let num_fragments = Self::calculate_fragment_count(total_size);
let fragments: Vec<AtomicPtr<Bytes>> = (0..num_fragments)
.map(|_| AtomicPtr::new(ptr::null_mut()))
.collect();
Self {
fragments: fragments.into_boxed_slice(),
total_size,
total_fragments: num_fragments as u32,
min_complete_fragments: base_fragments as u32,
contiguous_fragments: AtomicU32::new(0),
consumed_frontier: AtomicU32::new(0),
data_available: Event::new(),
time_source,
}
}
fn base_fragment_count(total_size: u64) -> usize {
if total_size == 0 {
return 0;
}
(total_size as usize).div_ceil(FRAGMENT_PAYLOAD_SIZE)
}
fn calculate_fragment_count(total_size: u64) -> usize {
let base = Self::base_fragment_count(total_size);
if base == 0 {
return 0;
}
base + 1
}
pub fn insert(&self, fragment_number: u32, data: Bytes) -> Result<bool, InsertError> {
if fragment_number == 0 || fragment_number > self.total_fragments {
return Err(InsertError::InvalidNumber {
number: fragment_number,
max: self.total_fragments,
});
}
if fragment_number <= self.consumed_frontier.load(Ordering::Acquire) {
return Err(InsertError::AlreadyConsumed {
number: fragment_number,
});
}
let idx = (fragment_number - 1) as usize;
let boxed = Box::new(data);
let new_ptr = Box::into_raw(boxed);
match self.fragments[idx].compare_exchange(
ptr::null_mut(),
new_ptr,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
self.advance_frontier();
self.data_available.notify(usize::MAX);
Ok(true)
}
Err(_) => {
unsafe {
drop(Box::from_raw(new_ptr));
}
Ok(false) }
}
}
fn advance_frontier(&self) {
loop {
let current = self.contiguous_fragments.load(Ordering::Acquire);
let next = current + 1;
if next > self.total_fragments {
return; }
let idx = (next - 1) as usize;
if self.fragments[idx].load(Ordering::Acquire).is_null() {
return; }
match self.contiguous_fragments.compare_exchange_weak(
current,
next,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => continue, Err(_) => continue, }
}
}
pub fn is_complete(&self) -> bool {
self.contiguous_fragments.load(Ordering::Acquire) >= self.min_complete_fragments
}
pub fn inserted_count(&self) -> usize {
self.fragments
.iter()
.filter(|s| !s.load(Ordering::Acquire).is_null())
.count()
}
pub fn total_fragments(&self) -> usize {
self.total_fragments as usize
}
pub fn total_bytes(&self) -> u64 {
self.total_size
}
pub fn highest_contiguous(&self) -> u32 {
self.contiguous_fragments.load(Ordering::Acquire)
}
pub fn consumed_frontier(&self) -> u32 {
self.consumed_frontier.load(Ordering::Acquire)
}
pub fn get(&self, fragment_number: u32) -> Option<&Bytes> {
if fragment_number == 0 || fragment_number > self.total_fragments {
return None;
}
let idx = (fragment_number - 1) as usize;
let ptr = self.fragments[idx].load(Ordering::Acquire);
if ptr.is_null() {
return None;
}
Some(unsafe { &*ptr })
}
pub fn take(&self, fragment_number: u32) -> Option<Bytes> {
if fragment_number == 0 || fragment_number > self.total_fragments {
return None;
}
let idx = (fragment_number - 1) as usize;
let ptr = self.fragments[idx].swap(ptr::null_mut(), Ordering::AcqRel);
if ptr.is_null() {
return None;
}
let boxed = unsafe { Box::from_raw(ptr) };
Some(*boxed)
}
pub fn mark_consumed(&self, up_to_index: u32) -> usize {
assert!(
up_to_index <= self.total_fragments,
"mark_consumed index {} exceeds total_fragments {}",
up_to_index,
self.total_fragments
);
let mut freed_count = 0;
let current_consumed = self.consumed_frontier.load(Ordering::Acquire);
for idx in current_consumed..up_to_index {
let slot_idx = idx as usize;
let ptr = self.fragments[slot_idx].swap(ptr::null_mut(), Ordering::AcqRel);
if !ptr.is_null() {
unsafe {
drop(Box::from_raw(ptr));
}
freed_count += 1;
}
}
self.consumed_frontier
.fetch_max(up_to_index, Ordering::AcqRel);
freed_count
}
pub fn notifier(&self) -> &Event {
&self.data_available
}
pub fn assemble(&self) -> Option<Vec<u8>> {
if !self.is_complete() {
return None;
}
if self.consumed_frontier.load(Ordering::Acquire) > 0 {
return None; }
let target_size = self.total_size as usize;
let mut result = Vec::with_capacity(target_size);
for slot in self.fragments.iter() {
let ptr = slot.load(Ordering::Acquire);
if ptr.is_null() {
break;
}
let data = unsafe { &*ptr };
result.extend_from_slice(data);
if result.len() >= target_size {
break;
}
}
result.truncate(target_size);
if result.len() == target_size {
Some(result)
} else {
None }
}
pub fn iter(&self) -> impl Iterator<Item = Option<&Bytes>> {
self.fragments.iter().map(|slot| {
let ptr = slot.load(Ordering::Acquire);
if ptr.is_null() {
None
} else {
Some(unsafe { &*ptr })
}
})
}
pub fn iter_contiguous(&self) -> impl Iterator<Item = &Bytes> {
self.fragments.iter().map_while(|slot| {
let ptr = slot.load(Ordering::Acquire);
if ptr.is_null() {
None
} else {
Some(unsafe { &*ptr })
}
})
}
#[allow(dead_code)]
pub fn collect_contiguous(&self) -> Vec<u8> {
let mut result = Vec::with_capacity(self.total_size as usize);
for data in self.iter_contiguous() {
result.extend_from_slice(data);
}
result
}
}
impl<T: TimeSource> Drop for LockFreeStreamBuffer<T> {
fn drop(&mut self) {
for slot in self.fragments.iter() {
let ptr = slot.load(Ordering::Acquire);
if !ptr.is_null() {
unsafe {
drop(Box::from_raw(ptr));
}
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InsertError {
InvalidNumber {
number: u32,
max: u32,
},
AlreadyConsumed {
number: u32,
},
}
impl std::fmt::Display for InsertError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InsertError::InvalidNumber { number, max } => {
write!(
f,
"fragment number {} is out of bounds (max: {})",
number, max
)
}
InsertError::AlreadyConsumed { number } => {
write!(f, "fragment number {} has already been consumed", number)
}
}
}
}
impl std::error::Error for InsertError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::GlobalExecutor;
use crate::simulation::VirtualTime;
use std::time::Duration;
#[test]
fn test_new_buffer_empty() {
let buffer = LockFreeStreamBuffer::new(0);
assert_eq!(buffer.total_fragments(), 0);
assert!(buffer.is_complete()); }
#[test]
fn test_new_buffer_single_fragment() {
let buffer = LockFreeStreamBuffer::new(100);
assert_eq!(buffer.total_fragments(), 2);
assert!(!buffer.is_complete());
}
#[test]
fn test_new_buffer_multiple_fragments() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
assert_eq!(buffer.total_fragments(), 4); }
#[test]
fn test_new_buffer_partial_last_fragment() {
let total = (FRAGMENT_PAYLOAD_SIZE * 2 + FRAGMENT_PAYLOAD_SIZE / 2) as u64;
let buffer = LockFreeStreamBuffer::new(total);
assert_eq!(buffer.total_fragments(), 4); }
#[test]
fn test_insert_single_fragment() {
let buffer = LockFreeStreamBuffer::new(100);
let data = Bytes::from_static(b"hello");
let result = buffer.insert(1, data.clone());
assert!(result.is_ok());
assert!(result.unwrap());
assert_eq!(buffer.inserted_count(), 1);
assert!(buffer.is_complete());
assert_eq!(buffer.get(1), Some(&data));
}
#[test]
fn test_insert_duplicate_is_noop() {
let buffer = LockFreeStreamBuffer::new(100);
let data1 = Bytes::from_static(b"first");
let data2 = Bytes::from_static(b"second");
assert!(buffer.insert(1, data1.clone()).unwrap());
assert!(!buffer.insert(1, data2).unwrap());
assert_eq!(buffer.get(1), Some(&data1));
}
#[test]
fn test_insert_invalid_number_zero() {
let buffer = LockFreeStreamBuffer::new(100);
let data = Bytes::from_static(b"hello");
let result = buffer.insert(0, data);
assert!(matches!(result, Err(InsertError::InvalidNumber { .. })));
}
#[test]
fn test_insert_invalid_number_too_large() {
let buffer = LockFreeStreamBuffer::new(100);
let data = Bytes::from_static(b"hello");
let result = buffer.insert(3, data);
assert!(matches!(result, Err(InsertError::InvalidNumber { .. })));
}
#[test]
fn test_insert_out_of_order() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
let frag1 = Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]);
let frag2 = Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]);
let frag3 = Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]);
assert!(buffer.insert(3, frag3.clone()).unwrap());
assert!(!buffer.is_complete());
assert_eq!(buffer.highest_contiguous(), 0);
assert!(buffer.insert(1, frag1.clone()).unwrap());
assert!(!buffer.is_complete());
assert_eq!(buffer.highest_contiguous(), 1);
assert!(buffer.insert(2, frag2.clone()).unwrap());
assert!(buffer.is_complete());
assert_eq!(buffer.highest_contiguous(), 3);
}
#[test]
fn test_assemble_complete() {
let buffer = LockFreeStreamBuffer::new(6);
buffer.insert(1, Bytes::from_static(b"hello ")).unwrap();
let assembled = buffer.assemble();
assert!(assembled.is_some());
assert_eq!(assembled.unwrap(), b"hello ");
}
#[test]
fn test_assemble_incomplete() {
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
assert!(buffer.assemble().is_none());
}
#[test]
fn test_assemble_truncates_to_total_bytes() {
let buffer = LockFreeStreamBuffer::new(10);
buffer
.insert(1, Bytes::from_static(b"hello world plus extra"))
.unwrap();
let assembled = buffer.assemble().unwrap();
assert_eq!(assembled.len(), 10);
assert_eq!(assembled, b"hello worl");
}
#[test]
fn test_contiguous_iter() {
let total = (FRAGMENT_PAYLOAD_SIZE * 4) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(4, Bytes::from(vec![4u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let contiguous: Vec<_> = buffer.iter_contiguous().collect();
assert_eq!(contiguous.len(), 2); }
#[test]
fn test_concurrent_inserts() {
use std::sync::Arc;
use std::thread;
let total = (FRAGMENT_PAYLOAD_SIZE * 100) as u64;
let buffer = Arc::new(LockFreeStreamBuffer::new(total));
let handles: Vec<_> = (1..=100)
.map(|i| {
let buffer = Arc::clone(&buffer);
thread::spawn(move || {
let data = Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]);
buffer.insert(i as u32, data)
})
})
.collect();
for handle in handles {
assert!(handle.join().unwrap().unwrap());
}
assert!(buffer.is_complete());
assert_eq!(buffer.inserted_count(), 100);
}
#[test]
fn test_concurrent_duplicate_inserts() {
use std::sync::Arc;
use std::thread;
let buffer = Arc::new(LockFreeStreamBuffer::new(100));
let handles: Vec<_> = (0..10)
.map(|i| {
let buffer = Arc::clone(&buffer);
thread::spawn(move || {
let data = Bytes::from(vec![i as u8; 50]);
buffer.insert(1, data)
})
})
.collect();
let results: Vec<_> = handles
.into_iter()
.map(|h| h.join().unwrap().unwrap())
.collect();
let success_count = results.iter().filter(|&&r| r).count();
assert_eq!(success_count, 1);
assert_eq!(buffer.inserted_count(), 1);
}
#[test]
fn test_single_byte_stream() {
let buffer = LockFreeStreamBuffer::new(1);
assert_eq!(buffer.total_fragments(), 2); assert_eq!(buffer.total_bytes(), 1);
buffer.insert(1, Bytes::from_static(b"X")).unwrap();
assert!(buffer.is_complete());
let assembled = buffer.assemble().unwrap();
assert_eq!(assembled, b"X");
}
#[test]
fn test_exact_fragment_boundary() {
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let buffer = LockFreeStreamBuffer::new(total);
assert_eq!(buffer.total_fragments(), 3);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let assembled = buffer.assemble().unwrap();
assert_eq!(assembled.len(), total as usize);
}
#[test]
fn test_iter_returns_all_fragments() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
for i in 1..=3 {
buffer
.insert(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
let fragments: Vec<_> = buffer.iter().collect();
assert_eq!(fragments.len(), 4); assert!(fragments[0].is_some());
assert!(fragments[1].is_some());
assert!(fragments[2].is_some());
assert!(fragments[3].is_none()); }
#[test]
fn test_iter_with_gaps() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(3, Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let fragments: Vec<_> = buffer.iter().collect();
assert_eq!(fragments.len(), 4); assert!(fragments[0].is_some());
assert!(fragments[1].is_none()); assert!(fragments[2].is_some());
assert!(fragments[3].is_none()); }
#[test]
fn test_get_missing_fragment_returns_none() {
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
assert!(buffer.get(1).is_some());
assert!(buffer.get(2).is_none()); }
#[test]
fn test_get_out_of_bounds_returns_none() {
let buffer = LockFreeStreamBuffer::new(100);
assert!(buffer.get(0).is_none()); assert!(buffer.get(2).is_none()); assert!(buffer.get(100).is_none()); }
#[test]
fn test_highest_contiguous_advances_after_gap_fill() {
let total = (FRAGMENT_PAYLOAD_SIZE * 4) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(3, Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(4, Bytes::from(vec![4u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
assert_eq!(buffer.highest_contiguous(), 1);
buffer
.insert(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
assert_eq!(buffer.highest_contiguous(), 4);
assert!(buffer.is_complete());
}
#[test]
fn test_collect_contiguous_with_gap() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(3, Bytes::from(vec![3u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
let contiguous = buffer.collect_contiguous();
assert_eq!(contiguous.len(), FRAGMENT_PAYLOAD_SIZE);
assert!(contiguous.iter().all(|&b| b == 1));
}
#[tokio::test]
async fn test_notifier_is_called_on_insert() {
use std::sync::Arc;
use tokio::sync::Barrier;
let time_source = VirtualTime::new();
let buffer = Arc::new(LockFreeStreamBuffer::new_with_time_source(
100,
time_source.clone(),
));
let buffer_clone = Arc::clone(&buffer);
let barrier = Arc::new(Barrier::new(2));
let barrier_clone = Arc::clone(&barrier);
let waiter = GlobalExecutor::spawn(async move {
barrier_clone.wait().await;
buffer_clone.notifier().listen().await;
true
});
barrier.wait().await;
time_source.advance(Duration::from_millis(10));
buffer.insert(1, Bytes::from_static(b"data")).unwrap();
let timeout_future = time_source.timeout(Duration::from_millis(100), waiter);
let result = timeout_future.await;
assert!(result.is_some());
assert!(result.unwrap().unwrap());
}
#[test]
fn test_mark_consumed_basic() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
for i in 1..=3 {
buffer
.insert(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
assert_eq!(buffer.inserted_count(), 3);
assert_eq!(buffer.consumed_frontier(), 0);
let freed = buffer.mark_consumed(1);
assert_eq!(freed, 1);
assert_eq!(buffer.consumed_frontier(), 1);
assert_eq!(buffer.inserted_count(), 2); assert!(buffer.get(1).is_none()); assert!(buffer.get(2).is_some());
let freed = buffer.mark_consumed(3);
assert_eq!(freed, 2);
assert_eq!(buffer.consumed_frontier(), 3);
assert_eq!(buffer.inserted_count(), 0);
}
#[test]
fn test_mark_consumed_idempotent() {
let buffer = LockFreeStreamBuffer::new(100);
buffer.insert(1, Bytes::from_static(b"hello")).unwrap();
let freed1 = buffer.mark_consumed(1);
assert_eq!(freed1, 1);
let freed2 = buffer.mark_consumed(1);
assert_eq!(freed2, 0);
assert_eq!(buffer.consumed_frontier(), 1);
}
#[test]
fn test_insert_after_consumed_fails() {
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer.mark_consumed(1);
let result = buffer.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]));
assert!(matches!(
result,
Err(InsertError::AlreadyConsumed { number: 1 })
));
let result = buffer.insert(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]));
assert!(result.is_ok());
}
#[test]
fn test_take_fragment() {
let buffer = LockFreeStreamBuffer::new(100);
let data = Bytes::from_static(b"hello");
buffer.insert(1, data.clone()).unwrap();
let taken = buffer.take(1);
assert!(taken.is_some());
assert_eq!(taken.unwrap(), data);
assert!(buffer.get(1).is_none());
assert_eq!(buffer.inserted_count(), 0);
assert!(buffer.take(1).is_none());
}
#[test]
fn test_take_invalid_indices() {
let buffer = LockFreeStreamBuffer::new(100);
assert!(buffer.take(0).is_none()); assert!(buffer.take(2).is_none()); }
#[test]
fn test_assemble_fails_after_consumption() {
let total = (FRAGMENT_PAYLOAD_SIZE * 2) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer
.insert(1, Bytes::from(vec![1u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
buffer
.insert(2, Bytes::from(vec![2u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
assert!(buffer.assemble().is_some());
buffer.mark_consumed(1);
assert!(buffer.assemble().is_none());
}
#[test]
fn test_iter_shows_consumed_as_none() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
for i in 1..=3 {
buffer
.insert(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
buffer.mark_consumed(2);
let fragments: Vec<_> = buffer.iter().collect();
assert!(fragments[0].is_none()); assert!(fragments[1].is_none()); assert!(fragments[2].is_some()); }
#[test]
fn test_iter_contiguous_stops_at_consumed() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
for i in 1..=3 {
buffer
.insert(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
buffer.mark_consumed(1);
let contiguous: Vec<_> = buffer.iter_contiguous().collect();
assert_eq!(contiguous.len(), 0);
}
#[test]
fn test_consumed_frontier_accessor() {
let total = (FRAGMENT_PAYLOAD_SIZE * 5) as u64;
let buffer = LockFreeStreamBuffer::new(total);
assert_eq!(buffer.consumed_frontier(), 0);
for i in 1..=5 {
buffer
.insert(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
buffer.mark_consumed(2);
assert_eq!(buffer.consumed_frontier(), 2);
buffer.mark_consumed(4);
assert_eq!(buffer.consumed_frontier(), 4);
buffer.mark_consumed(3);
assert_eq!(buffer.consumed_frontier(), 4); }
#[test]
#[should_panic(expected = "mark_consumed index 5 exceeds total_fragments 4")]
fn test_mark_consumed_panics_on_invalid_index() {
let total = (FRAGMENT_PAYLOAD_SIZE * 3) as u64;
let buffer = LockFreeStreamBuffer::new(total);
buffer.mark_consumed(5); }
#[test]
fn test_concurrent_mark_consumed() {
use std::sync::Arc;
use std::thread;
let total = (FRAGMENT_PAYLOAD_SIZE * 100) as u64;
let buffer = Arc::new(LockFreeStreamBuffer::new(total));
for i in 1..=100 {
buffer
.insert(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
let handles: Vec<_> = (0..4)
.map(|t| {
let buffer = Arc::clone(&buffer);
thread::spawn(move || {
let _start = t * 25 + 1;
let end = (t + 1) * 25;
buffer.mark_consumed(end as u32)
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(buffer.consumed_frontier(), 100);
assert_eq!(buffer.inserted_count(), 0);
}
#[test]
fn test_memory_reclamation_flow() {
let total = (FRAGMENT_PAYLOAD_SIZE * 5) as u64;
let buffer = LockFreeStreamBuffer::new(total);
for i in 1..=5 {
buffer
.insert(i, Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]))
.unwrap();
}
for i in 1..=5 {
let data = buffer.get(i);
assert!(data.is_some());
assert_eq!(data.unwrap()[0], i as u8);
let freed = buffer.mark_consumed(i);
assert_eq!(freed, 1);
assert!(buffer.get(i).is_none());
assert_eq!(buffer.consumed_frontier(), i);
}
assert_eq!(buffer.inserted_count(), 0);
assert_eq!(buffer.consumed_frontier(), 5);
}
#[test]
fn test_concurrent_insert_and_consume() {
use std::sync::Arc;
use std::thread;
let total = (FRAGMENT_PAYLOAD_SIZE * 10) as u64;
let buffer = Arc::new(LockFreeStreamBuffer::new(total));
let buffer_producer = Arc::clone(&buffer);
let producer = thread::spawn(move || {
for i in 1..=10 {
let data = Bytes::from(vec![i as u8; FRAGMENT_PAYLOAD_SIZE]);
buffer_producer.insert(i, data).unwrap();
thread::sleep(std::time::Duration::from_micros(100));
}
});
let buffer_consumer = Arc::clone(&buffer);
let consumer = thread::spawn(move || {
let mut consumed = 0;
while consumed < 10 {
let next = consumed + 1;
if buffer_consumer.get(next as u32).is_some() {
buffer_consumer.mark_consumed(next as u32);
consumed = next;
} else {
thread::yield_now();
}
}
});
producer.join().unwrap();
consumer.join().unwrap();
assert_eq!(buffer.consumed_frontier(), 10);
assert_eq!(buffer.inserted_count(), 0);
}
}