arkflow_plugin/processor/
batch.rs1use arkflow_core::processor::{register_processor_builder, Processor, ProcessorBuilder};
6use arkflow_core::{Content, Error, MessageBatch};
7use async_trait::async_trait;
8use datafusion::arrow;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::{Mutex, RwLock};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BatchProcessorConfig {
16 pub count: usize,
18 pub timeout_ms: u64,
20 pub data_type: String,
22}
23
24pub struct BatchProcessor {
26 config: BatchProcessorConfig,
27 batch: Arc<RwLock<Vec<MessageBatch>>>,
28 last_batch_time: Arc<Mutex<std::time::Instant>>,
29}
30
31impl BatchProcessor {
32 pub fn new(config: BatchProcessorConfig) -> Result<Self, Error> {
34 Ok(Self {
35 config: config.clone(),
36 batch: Arc::new(RwLock::new(Vec::with_capacity(config.count))),
37 last_batch_time: Arc::new(Mutex::new(std::time::Instant::now())),
38 })
39 }
40
41 async fn should_flush(&self) -> bool {
43 let batch = self.batch.read().await;
45 if batch.len() >= self.config.count {
46 return true;
47 }
48 let last_batch_time = self.last_batch_time.lock().await;
49 if !batch.is_empty()
51 && last_batch_time.elapsed().as_millis() >= self.config.timeout_ms as u128
52 {
53 return true;
54 }
55
56 false
57 }
58
59 async fn flush(&self) -> Result<Vec<MessageBatch>, Error> {
61 let mut batch = self.batch.write().await;
62
63 if batch.is_empty() {
64 return Ok(vec![]);
65 }
66
67 let new_batch = match self.config.data_type.as_str() {
69 "arrow" => {
70 let mut combined_content = Vec::new();
71
72 for msg in batch.iter() {
73 if let Content::Arrow(v) = &msg.content {
74 combined_content.push(v.clone());
75 }
76 }
77 let schema = combined_content[0].schema();
78 let batch = arrow::compute::concat_batches(&schema, &combined_content)
79 .map_err(|e| Error::Process(format!("Merge batches failed: {}", e)))?;
80 Ok(vec![MessageBatch::new_arrow(batch)])
81 }
82 "binary" => {
83 let mut combined_content = Vec::new();
84
85 for msg in batch.iter() {
86 if let Content::Binary(v) = &msg.content {
87 combined_content.extend(v.clone());
88 }
89 }
90 Ok(vec![MessageBatch::new_binary(combined_content)])
91 }
92 _ => Err(Error::Process("Invalid data type".to_string())),
93 };
94
95 batch.clear();
96 let mut last_batch_time = self.last_batch_time.lock().await;
97
98 *last_batch_time = std::time::Instant::now();
99
100 new_batch
101 }
102}
103
104#[async_trait]
105impl Processor for BatchProcessor {
106 async fn process(&self, msg: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
107 match &msg.content {
108 Content::Arrow(_) => {
109 if self.config.data_type != "arrow" {
110 return Err(Error::Process("Invalid data type".to_string()));
111 }
112 }
113 Content::Binary(_) => {
114 if self.config.data_type != "binary" {
115 return Err(Error::Process("Invalid data type".to_string()));
116 }
117 }
118 }
119
120 {
121 let mut batch = self.batch.write().await;
122
123 batch.push(msg);
125 }
126
127 if self.should_flush().await {
129 self.flush().await
130 } else {
131 Ok(vec![])
133 }
134 }
135
136 async fn close(&self) -> Result<(), Error> {
137 let mut batch = self.batch.write().await;
138
139 batch.clear();
140 Ok(())
141 }
142}
143
144pub(crate) struct BatchProcessorBuilder;
145impl ProcessorBuilder for BatchProcessorBuilder {
146 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
147 if config.is_none() {
148 return Err(Error::Config(
149 "Batch processor configuration is missing".to_string(),
150 ));
151 }
152 let config: BatchProcessorConfig = serde_json::from_value(config.clone().unwrap())?;
153 Ok(Arc::new(BatchProcessor::new(config)?))
154 }
155}
156
157pub fn init() {
158 register_processor_builder("batch", Arc::new(BatchProcessorBuilder));
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use std::time::Duration;
165 use tokio::time::sleep;
166
167 fn create_test_config(count: usize, timeout_ms: u64, data_type: &str) -> BatchProcessorConfig {
169 BatchProcessorConfig {
170 count,
171 timeout_ms,
172 data_type: data_type.to_string(),
173 }
174 }
175
176 #[tokio::test]
177 async fn test_batch_size_control() -> Result<(), Error> {
178 let config = create_test_config(2, 1000, "binary");
180 let processor = BatchProcessor::new(config)?;
181
182 let msg1 = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); let result1 = processor.process(msg1).await?;
185 assert!(result1.is_empty(), "First message should not trigger flush");
186
187 let msg2 = MessageBatch::new_binary(vec![vec![4u8, 5u8, 6u8]]); let result2 = processor.process(msg2).await?;
190
191 assert_eq!(result2.len(), 1, "Should return one combined batch");
192 if let Content::Binary(data) = &result2[0].content {
193 assert_eq!(
194 data,
195 &vec![vec![1u8, 2u8, 3u8], vec![4u8, 5u8, 6u8]],
196 "Combined binary data should match"
197 );
198 } else {
199 panic!("Expected binary content");
200 }
201
202 Ok(())
203 }
204
205 #[tokio::test]
206 async fn test_timeout_flush() -> Result<(), Error> {
207 let config = create_test_config(5, 100, "binary");
209 let processor = BatchProcessor::new(config)?;
210
211 let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); let result1 = processor.process(msg).await?;
214 assert!(result1.is_empty(), "First message should not trigger flush");
215
216 sleep(Duration::from_millis(150)).await;
218
219 let msg2 = MessageBatch::new_binary(vec![vec![4u8, 5u8, 6u8]]); let result2 = processor.process(msg2).await?;
222
223 assert_eq!(result2.len(), 1, "Should return one combined batch");
224 if let Content::Binary(data) = &result2[0].content {
225 assert_eq!(
226 data,
227 &vec![vec![1u8, 2u8, 3u8], vec![4u8, 5u8, 6u8]],
228 "Timeout flush should contain both messages"
229 );
230 }
231
232 Ok(())
233 }
234
235 #[tokio::test]
236 async fn test_invalid_data_type() -> Result<(), Error> {
237 let config = create_test_config(2, 1000, "arrow");
239 let processor = BatchProcessor::new(config)?;
240
241 let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); let result = processor.process(msg).await;
244
245 assert!(result.is_err(), "Should return error for invalid data type");
246 assert!(
247 matches!(result, Err(Error::Process(_))),
248 "Should be processing error"
249 );
250
251 Ok(())
252 }
253
254 #[tokio::test]
255 async fn test_close() -> Result<(), Error> {
256 let config = create_test_config(2, 1000, "binary");
258 let processor = BatchProcessor::new(config)?;
259
260 let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); processor.process(msg).await?;
263
264 processor.close().await?;
266
267 let batch = processor.batch.read().await;
269 assert!(batch.is_empty(), "Batch should be empty after close");
270
271 Ok(())
272 }
273}