arkflow_plugin/processor/
batch.rs

1//! Batch Processor Components
2//!
3//! Batch multiple messages into one or more messages
4
5use arkflow_core::processor::{register_processor_builder, Processor, ProcessorBuilder};
6use arkflow_core::{Content, Error, MessageBatch};
7use async_trait::async_trait;
8use datafusion::arrow;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::{Mutex, RwLock};
12
13/// Batch processor configuration
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BatchProcessorConfig {
16    /// Batch size
17    pub count: usize,
18    /// Batch timeout (ms)
19    pub timeout_ms: u64,
20    /// Batch data type
21    pub data_type: String,
22}
23
24/// Batch Processor Components
25pub struct BatchProcessor {
26    config: BatchProcessorConfig,
27    batch: Arc<RwLock<Vec<MessageBatch>>>,
28    last_batch_time: Arc<Mutex<std::time::Instant>>,
29}
30
31impl BatchProcessor {
32    /// Create a new batch processor component
33    pub fn new(config: BatchProcessorConfig) -> Result<Self, Error> {
34        Ok(Self {
35            config: config.clone(),
36            batch: Arc::new(RwLock::new(Vec::with_capacity(config.count))),
37            last_batch_time: Arc::new(Mutex::new(std::time::Instant::now())),
38        })
39    }
40
41    /// Check if the batch should be refreshed
42    async fn should_flush(&self) -> bool {
43        // 如果批处理已满,则刷新
44        let batch = self.batch.read().await;
45        if batch.len() >= self.config.count {
46            return true;
47        }
48        let last_batch_time = self.last_batch_time.lock().await;
49        // 如果超过超时时间且批处理不为空,则刷新
50        if !batch.is_empty()
51            && last_batch_time.elapsed().as_millis() >= self.config.timeout_ms as u128
52        {
53            return true;
54        }
55
56        false
57    }
58
59    /// Refresh the batch
60    async fn flush(&self) -> Result<Vec<MessageBatch>, Error> {
61        let mut batch = self.batch.write().await;
62
63        if batch.is_empty() {
64            return Ok(vec![]);
65        }
66
67        // Create a new batch message
68        let new_batch = match self.config.data_type.as_str() {
69            "arrow" => {
70                let mut combined_content = Vec::new();
71
72                for msg in batch.iter() {
73                    if let Content::Arrow(v) = &msg.content {
74                        combined_content.push(v.clone());
75                    }
76                }
77                let schema = combined_content[0].schema();
78                let batch = arrow::compute::concat_batches(&schema, &combined_content)
79                    .map_err(|e| Error::Process(format!("Merge batches failed: {}", e)))?;
80                Ok(vec![MessageBatch::new_arrow(batch)])
81            }
82            "binary" => {
83                let mut combined_content = Vec::new();
84
85                for msg in batch.iter() {
86                    if let Content::Binary(v) = &msg.content {
87                        combined_content.extend(v.clone());
88                    }
89                }
90                Ok(vec![MessageBatch::new_binary(combined_content)])
91            }
92            _ => Err(Error::Process("Invalid data type".to_string())),
93        };
94
95        batch.clear();
96        let mut last_batch_time = self.last_batch_time.lock().await;
97
98        *last_batch_time = std::time::Instant::now();
99
100        new_batch
101    }
102}
103
104#[async_trait]
105impl Processor for BatchProcessor {
106    async fn process(&self, msg: MessageBatch) -> Result<Vec<MessageBatch>, Error> {
107        match &msg.content {
108            Content::Arrow(_) => {
109                if self.config.data_type != "arrow" {
110                    return Err(Error::Process("Invalid data type".to_string()));
111                }
112            }
113            Content::Binary(_) => {
114                if self.config.data_type != "binary" {
115                    return Err(Error::Process("Invalid data type".to_string()));
116                }
117            }
118        }
119
120        {
121            let mut batch = self.batch.write().await;
122
123            // Add messages to a batch
124            batch.push(msg);
125        }
126
127        // Check if the batch should be refreshed
128        if self.should_flush().await {
129            self.flush().await
130        } else {
131            // If it is not refreshed, an empty result is returned
132            Ok(vec![])
133        }
134    }
135
136    async fn close(&self) -> Result<(), Error> {
137        let mut batch = self.batch.write().await;
138
139        batch.clear();
140        Ok(())
141    }
142}
143
144pub(crate) struct BatchProcessorBuilder;
145impl ProcessorBuilder for BatchProcessorBuilder {
146    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Processor>, Error> {
147        if config.is_none() {
148            return Err(Error::Config(
149                "Batch processor configuration is missing".to_string(),
150            ));
151        }
152        let config: BatchProcessorConfig = serde_json::from_value(config.clone().unwrap())?;
153        Ok(Arc::new(BatchProcessor::new(config)?))
154    }
155}
156
157pub fn init() {
158    register_processor_builder("batch", Arc::new(BatchProcessorBuilder));
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use std::time::Duration;
165    use tokio::time::sleep;
166
167    // Helper function to create test configuration
168    fn create_test_config(count: usize, timeout_ms: u64, data_type: &str) -> BatchProcessorConfig {
169        BatchProcessorConfig {
170            count,
171            timeout_ms,
172            data_type: data_type.to_string(),
173        }
174    }
175
176    #[tokio::test]
177    async fn test_batch_size_control() -> Result<(), Error> {
178        // Test batch size control with binary data
179        let config = create_test_config(2, 1000, "binary");
180        let processor = BatchProcessor::new(config)?;
181
182        // Process first message
183        let msg1 = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); // 使用u8类型
184        let result1 = processor.process(msg1).await?;
185        assert!(result1.is_empty(), "First message should not trigger flush");
186
187        // Process second message - should trigger flush due to batch size
188        let msg2 = MessageBatch::new_binary(vec![vec![4u8, 5u8, 6u8]]); // 使用u8类型
189        let result2 = processor.process(msg2).await?;
190
191        assert_eq!(result2.len(), 1, "Should return one combined batch");
192        if let Content::Binary(data) = &result2[0].content {
193            assert_eq!(
194                data,
195                &vec![vec![1u8, 2u8, 3u8], vec![4u8, 5u8, 6u8]],
196                "Combined binary data should match"
197            );
198        } else {
199            panic!("Expected binary content");
200        }
201
202        Ok(())
203    }
204
205    #[tokio::test]
206    async fn test_timeout_flush() -> Result<(), Error> {
207        // Test timeout-based flush with short timeout
208        let config = create_test_config(5, 100, "binary");
209        let processor = BatchProcessor::new(config)?;
210
211        // Process one message and wait for timeout
212        let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); // 使用u8类型
213        let result1 = processor.process(msg).await?;
214        assert!(result1.is_empty(), "First message should not trigger flush");
215
216        // Wait for timeout
217        sleep(Duration::from_millis(150)).await;
218
219        // Process another message - should trigger flush due to timeout
220        let msg2 = MessageBatch::new_binary(vec![vec![4u8, 5u8, 6u8]]); // 使用u8类型
221        let result2 = processor.process(msg2).await?;
222
223        assert_eq!(result2.len(), 1, "Should return one combined batch");
224        if let Content::Binary(data) = &result2[0].content {
225            assert_eq!(
226                data,
227                &vec![vec![1u8, 2u8, 3u8], vec![4u8, 5u8, 6u8]],
228                "Timeout flush should contain both messages"
229            );
230        }
231
232        Ok(())
233    }
234
235    #[tokio::test]
236    async fn test_invalid_data_type() -> Result<(), Error> {
237        // Test error handling for mismatched data types
238        let config = create_test_config(2, 1000, "arrow");
239        let processor = BatchProcessor::new(config)?;
240
241        // Try to process binary message with arrow configuration
242        let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); // 使用u8类型
243        let result = processor.process(msg).await;
244
245        assert!(result.is_err(), "Should return error for invalid data type");
246        assert!(
247            matches!(result, Err(Error::Process(_))),
248            "Should be processing error"
249        );
250
251        Ok(())
252    }
253
254    #[tokio::test]
255    async fn test_close() -> Result<(), Error> {
256        // Test processor cleanup
257        let config = create_test_config(2, 1000, "binary");
258        let processor = BatchProcessor::new(config)?;
259
260        // Add a message to the batch
261        let msg = MessageBatch::new_binary(vec![vec![1u8, 2u8, 3u8]]); // 使用u8类型
262        processor.process(msg).await?;
263
264        // Close the processor
265        processor.close().await?;
266
267        // Verify batch is cleared
268        let batch = processor.batch.read().await;
269        assert!(batch.is_empty(), "Batch should be empty after close");
270
271        Ok(())
272    }
273}