use futures::Stream;
use std::pin::Pin;
use std::time::Duration;
use tokio::time::Instant;
#[ derive( Debug, Clone ) ]
pub struct BufferConfig
{
pub min_buffer_size : usize,
pub max_buffer_time : Duration,
pub flush_on_newline : bool,
}
impl Default for BufferConfig
{
fn default() -> Self
{
Self
{
min_buffer_size : 50,
max_buffer_time : Duration::from_millis( 100 ),
flush_on_newline : true,
}
}
}
impl BufferConfig
{
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
#[ must_use ]
pub fn with_min_buffer_size( mut self, size : usize ) -> Self
{
self.min_buffer_size = size;
self
}
#[ must_use ]
pub fn with_max_buffer_time( mut self, duration : Duration ) -> Self
{
self.max_buffer_time = duration;
self
}
#[ must_use ]
pub fn with_flush_on_newline( mut self, enabled : bool ) -> Self
{
self.flush_on_newline = enabled;
self
}
}
#[ derive( Debug ) ]
pub struct BufferedStream< S >
where
S : Stream< Item = String > + Unpin,
{
inner : S,
config : BufferConfig,
buffer : String,
last_flush : Instant,
}
impl< S > BufferedStream< S >
where
S : Stream< Item = String > + Unpin,
{
pub fn new( stream : S, config : BufferConfig ) -> Self
{
Self
{
inner : stream,
config,
buffer : String::new(),
last_flush : Instant::now(),
}
}
fn should_flush( &self ) -> bool
{
if self.buffer.len() >= self.config.min_buffer_size
{
return true;
}
if self.last_flush.elapsed() >= self.config.max_buffer_time
{
return true;
}
if self.config.flush_on_newline && self.buffer.contains( '\n' )
{
return true;
}
false
}
fn flush( &mut self ) -> Option< String >
{
if self.buffer.is_empty()
{
return None;
}
let content = self.buffer.clone();
self.buffer.clear();
self.last_flush = Instant::now();
Some( content )
}
}
impl< S > Stream for BufferedStream< S >
where
S : Stream< Item = String > + Unpin,
{
type Item = String;
fn poll_next(
mut self : Pin< &mut Self >,
cx : &mut std::task::Context< '_ >,
) -> std::task::Poll< Option< Self::Item > >
{
use std::task::Poll;
loop
{
match Pin::new( &mut self.inner ).poll_next( cx )
{
Poll::Ready( Some( chunk ) ) =>
{
self.buffer.push_str( &chunk );
if self.should_flush()
{
if let Some( content ) = self.flush()
{
return Poll::Ready( Some( content ) );
}
}
continue;
}
Poll::Ready( None ) =>
{
return Poll::Ready( self.flush() );
}
Poll::Pending =>
{
if self.should_flush()
{
if let Some( content ) = self.flush()
{
return Poll::Ready( Some( content ) );
}
}
return Poll::Pending;
}
}
}
}
}
pub trait BufferedStreamExt : Stream< Item = String > + Sized + Unpin
{
fn buffered( self, config : BufferConfig ) -> BufferedStream< Self >
{
BufferedStream::new( self, config )
}
fn buffered_default( self ) -> BufferedStream< Self >
{
BufferedStream::new( self, BufferConfig::default() )
}
}
impl< S > BufferedStreamExt for S where S : Stream< Item = String > + Unpin {}
#[ cfg( test ) ]
mod tests
{
use super::*;
use tokio_stream;
use tokio_stream::StreamExt;
#[ tokio::test ]
async fn test_buffer_config_creation()
{
let config = BufferConfig::new();
assert_eq!( config.min_buffer_size, 50 );
assert_eq!( config.max_buffer_time, Duration::from_millis( 100 ) );
assert!( config.flush_on_newline );
}
#[ tokio::test ]
async fn test_buffer_config_builder()
{
let config = BufferConfig::new()
.with_min_buffer_size( 100 )
.with_max_buffer_time( Duration::from_millis( 200 ) )
.with_flush_on_newline( false );
assert_eq!( config.min_buffer_size, 100 );
assert_eq!( config.max_buffer_time, Duration::from_millis( 200 ) );
assert!( !config.flush_on_newline );
}
#[ tokio::test ]
async fn test_buffered_stream_basic()
{
let items = vec![ "a".to_string(), "b".to_string(), "c".to_string() ];
let stream = tokio_stream::iter( items );
let config = BufferConfig::new().with_min_buffer_size( 2 );
let mut buffered = stream.buffered( config );
let mut results = vec![];
while let Some( chunk ) = buffered.next().await
{
results.push( chunk );
}
assert!( !results.is_empty() );
}
#[ tokio::test ]
async fn test_buffered_stream_flush_on_newline()
{
let items = vec![ "hello".to_string(), "\n".to_string(), "world".to_string() ];
let stream = tokio_stream::iter( items );
let config = BufferConfig::new()
.with_min_buffer_size( 100 ) .with_flush_on_newline( true );
let mut buffered = stream.buffered( config );
let mut results = vec![];
while let Some( chunk ) = buffered.next().await
{
results.push( chunk );
}
assert!( results.len() >= 2 );
}
#[ tokio::test ]
async fn test_buffered_stream_size_threshold()
{
let items = vec![ "x".repeat( 60 ) ];
let stream = tokio_stream::iter( items );
let config = BufferConfig::new().with_min_buffer_size( 50 );
let mut buffered = stream.buffered( config );
let result = buffered.next().await;
assert!( result.is_some() );
assert_eq!( result.unwrap().len(), 60 );
}
}