#![cfg(native)]
use super::error::SignalError;
use super::signal::Signal;
use parking_lot::Mutex;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Notify;
use tokio::time::interval;
#[derive(Debug, Clone)]
pub struct BatchConfig {
max_batch_size: usize,
flush_interval: Duration,
}
impl BatchConfig {
pub fn new() -> Self {
Self {
max_batch_size: 50,
flush_interval: Duration::from_secs(1),
}
}
pub fn with_max_batch_size(mut self, size: usize) -> Self {
self.max_batch_size = size;
self
}
pub fn with_flush_interval(mut self, interval: Duration) -> Self {
self.flush_interval = interval;
self
}
pub fn max_batch_size(&self) -> usize {
self.max_batch_size
}
pub fn flush_interval(&self) -> Duration {
self.flush_interval
}
}
impl Default for BatchConfig {
fn default() -> Self {
Self::new()
}
}
struct BatchState<T> {
items: Vec<T>,
last_flush: Instant,
}
impl<T> BatchState<T> {
fn new() -> Self {
Self {
items: Vec::new(),
last_flush: Instant::now(),
}
}
fn add(&mut self, item: T) {
self.items.push(item);
}
fn should_flush(&self, config: &BatchConfig) -> bool {
self.items.len() >= config.max_batch_size
|| self.last_flush.elapsed() >= config.flush_interval
}
fn take(&mut self) -> Vec<T> {
self.last_flush = Instant::now();
std::mem::take(&mut self.items)
}
fn is_empty(&self) -> bool {
self.items.is_empty()
}
fn len(&self) -> usize {
self.items.len()
}
}
pub struct SignalBatcher<T: Send + Sync + 'static> {
signal: Signal<Vec<T>>,
config: BatchConfig,
state: Arc<Mutex<BatchState<T>>>,
flush_notify: Arc<Notify>,
}
impl<T: Send + Sync + 'static> SignalBatcher<T> {
pub fn new(signal: Signal<Vec<T>>, config: BatchConfig) -> Self {
let batcher = Self {
signal,
config,
state: Arc::new(Mutex::new(BatchState::new())),
flush_notify: Arc::new(Notify::new()),
};
batcher.start_auto_flush();
batcher
}
pub async fn queue(&self, item: T) -> Result<(), SignalError> {
let should_flush = {
let mut state = self.state.lock();
state.add(item);
state.should_flush(&self.config)
};
if should_flush {
self.flush().await?;
}
Ok(())
}
pub async fn flush(&self) -> Result<(), SignalError> {
let items = {
let mut state = self.state.lock();
if state.is_empty() {
return Ok(());
}
state.take()
};
self.signal.send(items).await?;
self.flush_notify.notify_one();
Ok(())
}
pub fn current_batch_size(&self) -> usize {
self.state.lock().len()
}
fn start_auto_flush(&self) {
let state = Arc::clone(&self.state);
let signal = self.signal.clone();
let config = self.config.clone();
let flush_notify = Arc::clone(&self.flush_notify);
tokio::spawn(async move {
let mut ticker = interval(config.flush_interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
ticker.tick().await;
let items = {
let mut state = state.lock();
if state.is_empty() {
continue;
}
state.take()
};
if signal.send(items).await.is_ok() {
flush_notify.notify_one();
}
}
});
}
}
impl<T: Send + Sync + 'static> Clone for SignalBatcher<T> {
fn clone(&self) -> Self {
Self {
signal: self.signal.clone(),
config: self.config.clone(),
state: Arc::clone(&self.state),
flush_notify: Arc::clone(&self.flush_notify),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::signals::SignalName;
use parking_lot::Mutex as ParkingLotMutex;
use std::sync::atomic::{AtomicUsize, Ordering};
async fn poll_until<F, Fut>(
timeout: std::time::Duration,
interval: std::time::Duration,
mut condition: F,
) -> Result<(), String>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = bool>,
{
let start = std::time::Instant::now();
while start.elapsed() < timeout {
if condition().await {
return Ok(());
}
tokio::time::sleep(interval).await;
}
Err(format!("Timeout after {:?} waiting for condition", timeout))
}
#[test]
fn test_batch_config() {
let config = BatchConfig::new()
.with_max_batch_size(100)
.with_flush_interval(Duration::from_millis(500));
assert_eq!(config.max_batch_size(), 100);
assert_eq!(config.flush_interval(), Duration::from_millis(500));
}
#[test]
fn test_batch_config_default() {
let config = BatchConfig::default();
assert_eq!(config.max_batch_size(), 50);
assert_eq!(config.flush_interval(), Duration::from_secs(1));
}
#[tokio::test]
async fn test_signal_batcher_manual_flush() {
let signal = Signal::<Vec<i32>>::new(SignalName::custom("test_batch"));
let received = Arc::new(ParkingLotMutex::new(Vec::new()));
let received_clone = Arc::clone(&received);
signal.connect(move |batch| {
let received = Arc::clone(&received_clone);
async move {
received.lock().extend(batch.iter().copied());
Ok(())
}
});
let config = BatchConfig::new().with_max_batch_size(10);
let batcher = SignalBatcher::new(signal, config);
for i in 0..5 {
batcher.queue(i).await.unwrap();
}
assert_eq!(batcher.current_batch_size(), 5);
batcher.flush().await.unwrap();
let results = received.lock();
assert_eq!(results.len(), 5);
assert_eq!(*results, vec![0, 1, 2, 3, 4]);
}
#[tokio::test]
async fn test_signal_batcher_auto_flush_by_size() {
let signal = Signal::<Vec<i32>>::new(SignalName::custom("test_auto_batch"));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
signal.connect(move |batch| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(batch.len(), Ordering::SeqCst);
Ok(())
}
});
let config = BatchConfig::new().with_max_batch_size(5);
let batcher = SignalBatcher::new(signal, config);
for i in 0..5 {
batcher.queue(i).await.unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 5);
assert_eq!(batcher.current_batch_size(), 0);
}
#[tokio::test]
async fn test_signal_batcher_auto_flush_by_time() {
let signal = Signal::<Vec<i32>>::new(SignalName::custom("test_time_batch"));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
signal.connect(move |batch| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(batch.len(), Ordering::SeqCst);
Ok(())
}
});
let config = BatchConfig::new()
.with_max_batch_size(100)
.with_flush_interval(Duration::from_millis(200));
let batcher = SignalBatcher::new(signal, config);
for i in 0..3 {
batcher.queue(i).await.unwrap();
}
poll_until(
Duration::from_millis(400),
Duration::from_millis(20),
|| async { counter.load(Ordering::SeqCst) == 3 },
)
.await
.expect("Batch should be flushed within 400ms");
}
#[tokio::test]
async fn test_signal_batcher_empty_flush() {
let signal = Signal::<Vec<i32>>::new(SignalName::custom("test_empty"));
let config = BatchConfig::new();
let batcher = SignalBatcher::new(signal, config);
assert!(batcher.flush().await.is_ok());
}
}