use std::collections::VecDeque;
use std::pin::Pin;
use std::time::Duration;
use futures::Stream;
pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
pub struct StreamSender<T: Send + 'static> {
inner: tokio::sync::mpsc::Sender<T>,
}
impl<T: Send + 'static> StreamSender<T> {
pub fn new(inner: tokio::sync::mpsc::Sender<T>) -> Self {
Self { inner }
}
pub async fn send(&self, item: T) -> Result<(), StreamSendError> {
self.inner
.send(item)
.await
.map_err(|_| StreamSendError::ConsumerDropped)
}
pub fn try_send(&self, item: T) -> Result<(), StreamSendError> {
self.inner.try_send(item).map_err(|e| match e {
tokio::sync::mpsc::error::TrySendError::Full(_) => StreamSendError::Full,
tokio::sync::mpsc::error::TrySendError::Closed(_) => StreamSendError::ConsumerDropped,
})
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
#[derive(Debug)]
pub enum StreamSendError {
ConsumerDropped,
Full,
}
impl std::fmt::Display for StreamSendError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ConsumerDropped => write!(f, "stream consumer dropped"),
Self::Full => write!(f, "stream buffer full"),
}
}
}
impl std::error::Error for StreamSendError {}
pub struct StreamReceiver<T: Send + 'static> {
inner: tokio::sync::mpsc::Receiver<T>,
}
impl<T: Send + 'static> StreamReceiver<T> {
pub fn new(inner: tokio::sync::mpsc::Receiver<T>) -> Self {
Self { inner }
}
pub async fn recv(&mut self) -> Option<T> {
self.inner.recv().await
}
pub fn into_stream(self) -> BoxStream<T> {
Box::pin(tokio_stream::wrappers::ReceiverStream::new(self.inner))
}
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_items: usize,
pub max_delay: Duration,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_items: 64,
max_delay: Duration::from_millis(5),
}
}
}
impl BatchConfig {
pub fn new(max_items: usize, max_delay: Duration) -> Self {
Self {
max_items: max_items.max(1),
max_delay,
}
}
}
pub struct BatchWriter<T: Send + 'static> {
sender: tokio::sync::mpsc::Sender<Vec<T>>,
config: BatchConfig,
buffer: Vec<T>,
flush_deadline: Option<tokio::time::Instant>,
}
impl<T: Send + 'static> BatchWriter<T> {
pub fn new(sender: tokio::sync::mpsc::Sender<Vec<T>>, config: BatchConfig) -> Self {
let cap = config.max_items;
Self {
sender,
config,
buffer: Vec::with_capacity(cap),
flush_deadline: None,
}
}
pub async fn push(&mut self, item: T) -> Result<(), StreamSendError> {
self.buffer.push(item);
if self.flush_deadline.is_none() {
self.flush_deadline = Some(tokio::time::Instant::now() + self.config.max_delay);
}
if self.buffer.len() >= self.config.max_items {
self.flush().await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<(), StreamSendError> {
if self.buffer.is_empty() {
return Ok(());
}
let batch = std::mem::take(&mut self.buffer);
self.buffer = Vec::with_capacity(self.config.max_items);
self.flush_deadline = None;
self.sender
.send(batch)
.await
.map_err(|_| StreamSendError::ConsumerDropped)
}
pub async fn check_deadline(&mut self) -> Result<(), StreamSendError> {
if let Some(deadline) = self.flush_deadline {
if tokio::time::Instant::now() >= deadline {
self.flush().await?;
}
}
Ok(())
}
pub fn buffered_count(&self) -> usize {
self.buffer.len()
}
pub fn max_delay(&self) -> Duration {
self.config.max_delay
}
}
pub struct BatchReader<T: Send + 'static> {
receiver: tokio::sync::mpsc::Receiver<Vec<T>>,
current_batch: VecDeque<T>,
}
impl<T: Send + 'static> BatchReader<T> {
pub fn new(receiver: tokio::sync::mpsc::Receiver<Vec<T>>) -> Self {
Self {
receiver,
current_batch: VecDeque::new(),
}
}
pub async fn recv(&mut self) -> Option<T> {
loop {
if let Some(item) = self.current_batch.pop_front() {
return Some(item);
}
match self.receiver.recv().await {
Some(batch) => {
self.current_batch = VecDeque::from(batch);
}
None => return None,
}
}
}
pub fn into_stream(self) -> BoxStream<T> {
Box::pin(futures::stream::unfold(self, |mut reader| async move {
reader.recv().await.map(|item| (item, reader))
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_stream::StreamExt;
#[tokio::test]
async fn test_batch_config_default() {
let config = BatchConfig::default();
assert_eq!(config.max_items, 64);
assert_eq!(config.max_delay, Duration::from_millis(5));
}
#[tokio::test]
async fn test_batch_config_clamps_to_one() {
let config = BatchConfig::new(0, Duration::from_millis(1));
assert_eq!(config.max_items, 1);
}
#[tokio::test]
async fn test_batch_writer_reader_roundtrip() {
let (tx, rx) = tokio::sync::mpsc::channel(16);
let mut writer = BatchWriter::new(tx, BatchConfig::new(3, Duration::from_secs(10)));
let mut reader = BatchReader::new(rx);
writer.push(1).await.unwrap();
writer.push(2).await.unwrap();
writer.push(3).await.unwrap();
assert_eq!(reader.recv().await, Some(1));
assert_eq!(reader.recv().await, Some(2));
assert_eq!(reader.recv().await, Some(3));
}
#[tokio::test]
async fn test_batch_writer_flush_explicit() {
let (tx, rx) = tokio::sync::mpsc::channel(16);
let mut writer = BatchWriter::new(tx, BatchConfig::new(100, Duration::from_secs(10)));
let mut reader = BatchReader::new(rx);
writer.push(42).await.unwrap();
assert_eq!(writer.buffered_count(), 1);
writer.flush().await.unwrap();
assert_eq!(writer.buffered_count(), 0);
assert_eq!(reader.recv().await, Some(42));
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn test_batch_writer_flush_on_deadline() {
let (tx, rx) = tokio::sync::mpsc::channel(16);
let mut writer = BatchWriter::new(tx, BatchConfig::new(100, Duration::from_millis(10)));
let mut reader = BatchReader::new(rx);
writer.push(42).await.unwrap();
assert_eq!(writer.buffered_count(), 1);
tokio::time::sleep(Duration::from_millis(20)).await;
writer.check_deadline().await.unwrap();
assert_eq!(reader.recv().await, Some(42));
}
#[tokio::test]
async fn test_batch_reader_as_stream() {
let (tx, rx) = tokio::sync::mpsc::channel(16);
let mut writer = BatchWriter::new(tx, BatchConfig::new(2, Duration::from_secs(10)));
writer.push(10).await.unwrap();
writer.push(20).await.unwrap(); writer.push(30).await.unwrap();
writer.push(40).await.unwrap(); drop(writer);
let reader = BatchReader::new(rx);
let items: Vec<i32> = reader.into_stream().collect().await;
assert_eq!(items, vec![10, 20, 30, 40]);
}
#[tokio::test]
async fn test_batch_empty_flush_is_noop() {
let (tx, _rx) = tokio::sync::mpsc::channel::<Vec<i32>>(16);
let mut writer = BatchWriter::new(tx, BatchConfig::default());
writer.flush().await.unwrap();
assert_eq!(writer.buffered_count(), 0);
}
#[tokio::test]
async fn test_batch_writer_consumer_dropped() {
let (tx, rx) = tokio::sync::mpsc::channel(1);
let mut writer = BatchWriter::new(tx, BatchConfig::new(2, Duration::from_secs(10)));
drop(rx);
writer.push(1).await.unwrap(); let err = writer.push(2).await; assert!(err.is_err());
}
#[tokio::test]
async fn test_batch_reader_empty_batch_skipped() {
let (tx, rx) = tokio::sync::mpsc::channel(16);
tx.send(vec![]).await.unwrap();
tx.send(vec![1, 2]).await.unwrap();
drop(tx);
let mut reader = BatchReader::new(rx);
assert_eq!(reader.recv().await, Some(1));
assert_eq!(reader.recv().await, Some(2));
assert_eq!(reader.recv().await, None); }
}