use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::RwLock;
pub const DEFAULT_STREAM_MAX_SIZE: usize = 10 * 1024 * 1024;
#[derive(Clone)]
pub struct BoundedStream {
inner: Arc<RwLock<BoundedStreamInner>>,
}
struct BoundedStreamInner {
buffer: VecDeque<u8>,
max_size: usize,
total_written: u64,
bytes_evicted: u64,
closed: bool,
}
impl BoundedStream {
pub fn new(max_size: usize) -> Self {
Self {
inner: Arc::new(RwLock::new(BoundedStreamInner {
buffer: VecDeque::with_capacity(max_size.min(8192)), max_size,
total_written: 0,
bytes_evicted: 0,
closed: false,
})),
}
}
pub fn default_size() -> Self {
Self::new(DEFAULT_STREAM_MAX_SIZE)
}
pub async fn write(&self, data: &[u8]) {
let mut inner = self.inner.write().await;
if inner.closed {
return;
}
inner.total_written += data.len() as u64;
if data.len() >= inner.max_size {
let start = data.len() - inner.max_size;
inner.bytes_evicted += inner.buffer.len() as u64 + start as u64;
inner.buffer.clear();
inner.buffer.extend(&data[start..]);
return;
}
let needed = data.len();
let available = inner.max_size.saturating_sub(inner.buffer.len());
if needed > available {
let to_evict = needed - available;
let actual_evict = to_evict.min(inner.buffer.len());
inner.buffer.drain(..actual_evict);
inner.bytes_evicted += actual_evict as u64;
}
inner.buffer.extend(data);
}
pub async fn read(&self) -> Vec<u8> {
let inner = self.inner.read().await;
inner.buffer.iter().copied().collect()
}
pub async fn read_string(&self) -> String {
let data = self.read().await;
String::from_utf8_lossy(&data).into_owned()
}
pub async fn close(&self) {
let mut inner = self.inner.write().await;
inner.closed = true;
}
pub async fn is_closed(&self) -> bool {
let inner = self.inner.read().await;
inner.closed
}
pub async fn len(&self) -> usize {
let inner = self.inner.read().await;
inner.buffer.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
pub async fn stats(&self) -> StreamStats {
let inner = self.inner.read().await;
StreamStats {
current_size: inner.buffer.len(),
max_size: inner.max_size,
total_written: inner.total_written,
bytes_evicted: inner.bytes_evicted,
closed: inner.closed,
}
}
}
impl std::fmt::Debug for BoundedStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BoundedStream")
.field("inner", &"<locked>")
.finish()
}
}
#[derive(Debug, Clone)]
pub struct StreamStats {
pub current_size: usize,
pub max_size: usize,
pub total_written: u64,
pub bytes_evicted: u64,
pub closed: bool,
}
pub async fn drain_to_stream<R>(mut reader: R, stream: Arc<BoundedStream>)
where
R: tokio::io::AsyncRead + Unpin,
{
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 8192];
loop {
match reader.read(&mut buf).await {
Ok(0) => break, Ok(n) => stream.write(&buf[..n]).await,
Err(e) => {
tracing::warn!("drain_to_stream read error: {}", e);
break;
}
}
}
stream.close().await;
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_write_read() {
let stream = BoundedStream::new(100);
stream.write(b"hello").await;
assert_eq!(stream.read().await, b"hello");
}
#[tokio::test]
async fn test_multiple_writes() {
let stream = BoundedStream::new(100);
stream.write(b"hello ").await;
stream.write(b"world").await;
assert_eq!(stream.read().await, b"hello world");
}
#[tokio::test]
async fn test_eviction_on_overflow() {
let stream = BoundedStream::new(10);
stream.write(b"12345").await;
stream.write(b"67890").await;
assert_eq!(stream.len().await, 10);
stream.write(b"ABCDE").await;
assert_eq!(stream.read().await, b"67890ABCDE");
let stats = stream.stats().await;
assert_eq!(stats.bytes_evicted, 5);
assert_eq!(stats.total_written, 15);
}
#[tokio::test]
async fn test_large_write_exceeds_buffer() {
let stream = BoundedStream::new(10);
stream.write(b"0123456789ABCDEFGHIJ").await;
assert_eq!(stream.read().await, b"ABCDEFGHIJ");
}
#[tokio::test]
async fn test_close_prevents_writes() {
let stream = BoundedStream::new(100);
stream.write(b"before").await;
stream.close().await;
stream.write(b"after").await;
assert_eq!(stream.read().await, b"before");
}
#[tokio::test]
async fn test_read_string() {
let stream = BoundedStream::new(100);
stream.write(b"hello world").await;
assert_eq!(stream.read_string().await, "hello world");
}
#[tokio::test]
async fn test_concurrent_writes() {
use std::sync::Arc;
let stream = Arc::new(BoundedStream::new(1000));
let handles: Vec<_> = (0..10)
.map(|i| {
let s = stream.clone();
tokio::spawn(async move {
for j in 0..10 {
s.write(format!("[{}-{}]", i, j).as_bytes()).await;
}
})
})
.collect();
for h in handles {
h.await.expect("task should not panic");
}
let data = stream.read().await;
assert!(!data.is_empty());
}
#[tokio::test]
async fn test_stats() {
let stream = BoundedStream::new(10);
stream.write(b"1234567890").await;
let stats = stream.stats().await;
assert_eq!(stats.current_size, 10);
assert_eq!(stats.max_size, 10);
assert_eq!(stats.total_written, 10);
assert_eq!(stats.bytes_evicted, 0);
assert!(!stats.closed);
}
#[tokio::test]
async fn test_empty_stream() {
let stream = BoundedStream::new(100);
assert!(stream.is_empty().await);
assert_eq!(stream.len().await, 0);
assert_eq!(stream.read().await, Vec::<u8>::new());
}
#[tokio::test]
async fn test_drain_to_stream() {
use std::io::Cursor;
let data = b"test data from reader";
let cursor = Cursor::new(data.to_vec());
let stream = Arc::new(BoundedStream::new(100));
drain_to_stream(cursor, stream.clone()).await;
assert_eq!(stream.read().await, data);
assert!(stream.is_closed().await);
}
#[tokio::test]
async fn test_default_size() {
let stream = BoundedStream::default_size();
let stats = stream.stats().await;
assert_eq!(stats.max_size, DEFAULT_STREAM_MAX_SIZE);
}
}