use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::{Error, Message, processor::Processor};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchProcessorConfig {
pub count: usize,
pub timeout_ms: u64,
pub condition: Option<String>,
}
pub struct BatchProcessor {
config: BatchProcessorConfig,
batch: Arc<Mutex<Vec<Message>>>,
last_batch_time: Arc<Mutex<std::time::Instant>>,
}
impl BatchProcessor {
pub fn new(config: &BatchProcessorConfig) -> Result<Self, Error> {
Ok(Self {
config: config.clone(),
batch: Arc::new(Mutex::new(Vec::with_capacity(config.count))),
last_batch_time: Arc::new(Mutex::new(std::time::Instant::now())),
})
}
async fn should_flush(&self) -> bool {
let batch = self.batch.lock().await;
if batch.len() >= self.config.count {
return true;
}
let last_batch_time = self.last_batch_time.lock().await;
if !batch.is_empty() && last_batch_time.elapsed().as_millis() >= self.config.timeout_ms as u128 {
return true;
}
false
}
async fn flush(&self) -> Result<Vec<Message>, Error> {
let mut batch = self.batch.lock().await;
if batch.is_empty() {
return Ok(vec![]);
}
let mut combined_content = Vec::new();
for msg in batch.iter() {
combined_content.extend_from_slice(msg.content());
combined_content.push(b'\n'); }
let mut batch_msg = Message::new(combined_content);
let metadata = batch_msg.metadata_mut();
metadata.set("batch_size", &batch.len().to_string());
let result = vec![batch_msg];
batch.clear();
let mut last_batch_time = self.last_batch_time.lock().await;
*last_batch_time = std::time::Instant::now();
Ok(result)
}
}
#[async_trait]
impl Processor for BatchProcessor {
async fn process(&self, msg: Message) -> Result<Vec<Message>, Error> {
let mut batch = self.batch.lock().await;
batch.push(msg);
if self.should_flush().await {
self.flush().await
} else {
Ok(vec![])
}
}
async fn close(&self) -> Result<(), Error> {
let mut batch = self.batch.lock().await;
batch.clear();
Ok(())
}
}