use std::time::Duration;
use tokio::sync::mpsc;
#[derive(Debug, Clone)]
pub struct AccumulatorConfig {
pub channel_capacity: usize,
pub max_items: usize,
pub max_bytes: usize,
pub max_wait: Duration,
}
impl Default for AccumulatorConfig {
fn default() -> Self {
Self {
channel_capacity: 10_000,
max_items: 100,
max_bytes: 1024 * 1024, max_wait: Duration::from_millis(10),
}
}
}
#[derive(Clone)]
pub struct BatchAccumulator<T> {
tx: mpsc::Sender<(T, usize)>, }
pub struct BatchDrainer<T> {
rx: mpsc::Receiver<(T, usize)>,
config: AccumulatorConfig,
buffer: Vec<T>,
buffer_bytes: usize,
}
#[derive(Debug, thiserror::Error)]
#[error("accumulator full — backpressure active ({capacity} items buffered)")]
pub struct AccumulatorFull {
pub capacity: usize,
}
impl<T: Send + 'static> BatchAccumulator<T> {
#[must_use]
pub fn new(config: AccumulatorConfig) -> (Self, BatchDrainer<T>) {
let (tx, rx) = mpsc::channel(config.channel_capacity);
let drainer = BatchDrainer {
rx,
buffer: Vec::with_capacity(config.max_items),
buffer_bytes: 0,
config: config.clone(),
};
(Self { tx }, drainer)
}
pub async fn push(&self, item: T, byte_size: usize) -> Result<(), AccumulatorFull> {
self.tx
.try_send((item, byte_size))
.map_err(|_| AccumulatorFull {
capacity: self.tx.capacity(),
})
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.tx.is_closed()
}
}
impl<T> BatchDrainer<T> {
pub async fn next_batch(&mut self) -> Vec<T> {
if self.threshold_met() {
return self.take_buffer();
}
loop {
let timeout = tokio::time::sleep(self.config.max_wait);
tokio::select! {
biased;
() = timeout => {
if self.buffer.is_empty() {
continue;
}
return self.take_buffer();
}
item = self.rx.recv() => {
match item {
Some((val, size)) => {
self.buffer_bytes += size;
self.buffer.push(val);
if self.threshold_met() {
return self.take_buffer();
}
}
None => {
return self.take_buffer();
}
}
}
}
}
}
pub fn drain_remaining(&mut self) -> Vec<T> {
while let Ok((val, size)) = self.rx.try_recv() {
self.buffer_bytes += size;
self.buffer.push(val);
}
self.take_buffer()
}
fn threshold_met(&self) -> bool {
self.buffer.len() >= self.config.max_items || self.buffer_bytes >= self.config.max_bytes
}
fn take_buffer(&mut self) -> Vec<T> {
self.buffer_bytes = 0;
std::mem::take(&mut self.buffer)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_drain_on_item_count() {
let config = AccumulatorConfig {
channel_capacity: 100,
max_items: 5,
max_bytes: usize::MAX,
max_wait: Duration::from_secs(60), };
let (acc, mut drainer) = BatchAccumulator::new(config);
for i in 0..5 {
acc.push(i, 1).await.unwrap();
}
let batch = drainer.next_batch().await;
assert_eq!(batch.len(), 5);
assert_eq!(batch, vec![0, 1, 2, 3, 4]);
}
#[tokio::test]
async fn test_drain_on_byte_threshold() {
let config = AccumulatorConfig {
channel_capacity: 100,
max_items: 1000, max_bytes: 10, max_wait: Duration::from_secs(60),
};
let (acc, mut drainer) = BatchAccumulator::new(config);
for i in 0..4 {
acc.push(i, 3).await.unwrap();
}
let batch = drainer.next_batch().await;
assert_eq!(batch.len(), 4);
}
#[tokio::test]
async fn test_drain_on_time_threshold() {
let config = AccumulatorConfig {
channel_capacity: 100,
max_items: 1000,
max_bytes: usize::MAX,
max_wait: Duration::from_millis(50), };
let (acc, mut drainer) = BatchAccumulator::new(config);
acc.push(1, 1).await.unwrap();
acc.push(2, 1).await.unwrap();
let batch = drainer.next_batch().await;
assert_eq!(batch.len(), 2);
}
#[tokio::test]
async fn test_backpressure_when_full() {
let config = AccumulatorConfig {
channel_capacity: 3,
max_items: 100,
max_bytes: usize::MAX,
max_wait: Duration::from_secs(60),
};
let (acc, _drainer) = BatchAccumulator::<i32>::new(config);
acc.push(1, 1).await.unwrap();
acc.push(2, 1).await.unwrap();
acc.push(3, 1).await.unwrap();
let result = acc.push(4, 1).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_shutdown_drains_remaining() {
let config = AccumulatorConfig {
channel_capacity: 100,
max_items: 1000,
max_bytes: usize::MAX,
max_wait: Duration::from_secs(60),
};
let (acc, mut drainer) = BatchAccumulator::new(config);
acc.push(10, 1).await.unwrap();
acc.push(20, 1).await.unwrap();
drop(acc);
let batch = drainer.next_batch().await;
assert_eq!(batch, vec![10, 20]);
let batch = drainer.next_batch().await;
assert!(batch.is_empty());
}
#[tokio::test]
async fn test_multiple_batches() {
let config = AccumulatorConfig {
channel_capacity: 100,
max_items: 3,
max_bytes: usize::MAX,
max_wait: Duration::from_secs(60),
};
let (acc, mut drainer) = BatchAccumulator::new(config);
for i in 0..7 {
acc.push(i, 1).await.unwrap();
}
drop(acc);
let b1 = drainer.next_batch().await;
assert_eq!(b1.len(), 3);
let b2 = drainer.next_batch().await;
assert_eq!(b2.len(), 3);
let b3 = drainer.next_batch().await;
assert_eq!(b3.len(), 1);
let b4 = drainer.next_batch().await;
assert!(b4.is_empty()); }
#[tokio::test]
async fn test_push_handle_is_clone() {
let config = AccumulatorConfig::default();
let (acc, mut drainer) = BatchAccumulator::new(config);
let acc2 = acc.clone();
acc.push(1, 1).await.unwrap();
acc2.push(2, 1).await.unwrap();
drop(acc);
drop(acc2);
let batch = drainer.next_batch().await;
assert_eq!(batch.len(), 2);
}
#[tokio::test]
async fn test_drain_remaining_on_shutdown() {
let config = AccumulatorConfig {
channel_capacity: 100,
max_items: 1000,
max_bytes: usize::MAX,
max_wait: Duration::from_secs(60),
};
let (acc, mut drainer) = BatchAccumulator::new(config);
acc.push(1, 1).await.unwrap();
acc.push(2, 1).await.unwrap();
acc.push(3, 1).await.unwrap();
drop(acc);
let remaining = drainer.drain_remaining();
assert_eq!(remaining, vec![1, 2, 3]);
}
#[tokio::test]
async fn test_empty_drain_returns_empty() {
let config = AccumulatorConfig::default();
let (_acc, mut drainer) = BatchAccumulator::<i32>::new(config);
let remaining = drainer.drain_remaining();
assert!(remaining.is_empty());
}
}