use arkflow_core::processor::{register_processor_builder, Processor, ProcessorBuilder};
use arkflow_core::{Content, Error, MessageBatch};
use async_trait::async_trait;
use datafusion::arrow;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchProcessorConfig {
pub count: usize,
pub timeout_ms: u64,
pub data_type: String,
}
pub struct BatchProcessor {
config: BatchProcessorConfig,
batch: Arc<RwLock<Vec<MessageBatch>>>,
last_batch_time: Arc<Mutex<std::time::Instant>>,
}
impl BatchProcessor {
pub fn new(config: BatchProcessorConfig) -> Result<Self, Error> {
Ok(Self {
config: config.clone(),
batch: Arc::new(RwLock::new(Vec::with_capacity(config.count))),
last_batch_time: Arc::new(Mutex::new(std::time::Instant::now())),
})
}
async fn should_flush(&self) -> bool {
let batch = self.batch.read().await;
if batch.len() >= self.config.count {
return true;
}
let last_batch_time = self.last_batch_time.lock().await;
if !batch.is_empty()
&& last_batch_time.elapsed().as_millis() >= self.config.timeout_ms as u128
{
return true;
}
false
}
async fn flush(&self) -> Result<Vec<MessageBatch>, Error> {
let mut batch = self.batch.write().await;
if batch.is_empty() {
return Ok(vec![]);
}
let new_batch = match self.config.data_type.as_str() {
"arrow" => {
let mut combined_content = Vec::new();
for msg in batch.iter() {
if let Content::Arrow(v) = &msg.content {
combined_content.push(v.clone());
}
}
let schema = combined_content[0].schema();
let batch = arrow::compute::concat_batches(&schema, &combined_content)
.map_err(|e| Error::Process(format!("Merge batches failed: {}", e)))?;
Ok(vec![MessageBatch::new_arrow(batch)])
}
"binary" => {
let mut combined_content = Vec::new();
for msg in batch.iter() {
if let Content::Binary(v) = &msg.content {
combined_content.extend(v.clone());
}
}
Ok(vec![MessageBatch::new_binary(combined_content)])
}
_ => Err(Error::Process("Invalid data type".to_string())),
};
batch.clear();
let mut last_batch_time = self.last_batch_time.lock().await;
*last_batch_time = std::time::Instant::now();
new_batch
}
}
#[async_trait]
impl Processor for BatchProcessor {
async fn process(&self, msg: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
match &msg.content {
Content::Arrow(_) => {
if self.config.data_type != "arrow" {
return Err(Error::Process("Invalid data type".to_string()));
}
}
Content::Binary(_) => {
if self.config.data_type != "binary" {
return Err(Error::Process("Invalid data type".to_string()));
}
}
}
{
let mut batch = self.batch.write().await;
batch.push(msg);
}
if self.should_flush().await {
self.flush().await
} else {
Ok(vec![])
}
}
async fn close(&self) -> Result<(), Error> {
let mut batch = self.batch.write().await;
batch.clear();
Ok(())
}
}
pub(crate) struct BatchProcessorBuilder;
impl ProcessorBuilder for BatchProcessorBuilder {
fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
if config.is_none() {
return Err(Error::Config(
"Batch processor configuration is missing".to_string(),
));
}
let config: BatchProcessorConfig = serde_json::from_value(config.clone().unwrap())?;
Ok(Arc::new(BatchProcessor::new(config)?))
}
}
pub fn init() {
register_processor_builder("batch", Arc::new(BatchProcessorBuilder));
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
fn create_test_config(count: usize, timeout_ms: u64, data_type: &str) -> BatchProcessorConfig {
BatchProcessorConfig {
count,
timeout_ms,
data_type: data_type.to_string(),
}
}
#[tokio::test]
async fn test_batch_size_control() -> Result<(), Error> {
let config = create_test_config(2, 1000, "binary");
let processor = BatchProcessor::new(config)?;
let msg1 = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); let result1 = processor.process(msg1).await?;
assert!(result1.is_empty(), "First message should not trigger flush");
let msg2 = MessageBatch::new_binary(vec![vec![4u8, 5u8, 6u8]]); let result2 = processor.process(msg2).await?;
assert_eq!(result2.len(), 1, "Should return one combined batch");
if let Content::Binary(data) = &result2[0].content {
assert_eq!(
data,
&vec![vec![1u8, 2u8, 3u8], vec![4u8, 5u8, 6u8]],
"Combined binary data should match"
);
} else {
panic!("Expected binary content");
}
Ok(())
}
#[tokio::test]
async fn test_timeout_flush() -> Result<(), Error> {
let config = create_test_config(5, 100, "binary");
let processor = BatchProcessor::new(config)?;
let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); let result1 = processor.process(msg).await?;
assert!(result1.is_empty(), "First message should not trigger flush");
sleep(Duration::from_millis(150)).await;
let msg2 = MessageBatch::new_binary(vec![vec![4u8, 5u8, 6u8]]); let result2 = processor.process(msg2).await?;
assert_eq!(result2.len(), 1, "Should return one combined batch");
if let Content::Binary(data) = &result2[0].content {
assert_eq!(
data,
&vec![vec![1u8, 2u8, 3u8], vec![4u8, 5u8, 6u8]],
"Timeout flush should contain both messages"
);
}
Ok(())
}
#[tokio::test]
async fn test_invalid_data_type() -> Result<(), Error> {
let config = create_test_config(2, 1000, "arrow");
let processor = BatchProcessor::new(config)?;
let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); let result = processor.process(msg).await;
assert!(result.is_err(), "Should return error for invalid data type");
assert!(
matches!(result, Err(Error::Process(_))),
"Should be processing error"
);
Ok(())
}
#[tokio::test]
async fn test_close() -> Result<(), Error> {
let config = create_test_config(2, 1000, "binary");
let processor = BatchProcessor::new(config)?;
let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); processor.process(msg).await?;
processor.close().await?;
let batch = processor.batch.read().await;
assert!(batch.is_empty(), "Batch should be empty after close");
Ok(())
}
}