Skip to main content

foxtive_worker/
batch_processor.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::{Mutex, Notify};
4use tracing::{debug, info, error};
5
6use crate::batch::{BatchConfig, BatchHandler, BatchStatus, MessageBatch, ReceivedBatchMessage};
7use crate::error::{WorkerError, WorkerResult};
8use crate::message::ReceivedMessage;
9
10/// Internal message wrapper for the batch queue
11struct QueuedMessage {
12    message: ReceivedMessage<serde_json::Value>,
13}
14
15/// Batch processor that assembles messages into batches and processes them
16pub struct BatchProcessor {
17    /// The handler that processes completed batches
18    handler: Arc<dyn BatchHandler>,
19    
20    /// Configuration for batching behavior
21    config: BatchConfig,
22    
23    /// Queue of messages waiting to be batched
24    queue: Arc<Mutex<Vec<QueuedMessage>>>,
25    
26    /// Signal when new messages arrive
27    notify: Arc<Notify>,
28    
29    /// Signal for shutdown
30    shutdown_notify: Arc<Notify>,
31    
32    /// Handle to the background processing task
33    _task_handle: Option<tokio::task::JoinHandle<()>>,
34}
35
36impl BatchProcessor {
37    /// Create a new batch processor
38    pub fn new(handler: Arc<dyn BatchHandler>, config: BatchConfig) -> Self {
39        Self {
40            handler,
41            config,
42            queue: Arc::new(Mutex::new(Vec::new())),
43            notify: Arc::new(Notify::new()),
44            shutdown_notify: Arc::new(Notify::new()),
45            _task_handle: None,
46        }
47    }
48
49    /// Start the batch processor background task
50    pub async fn start(&mut self) -> WorkerResult<()> {
51        info!(
52            "Starting batch processor with batch_size={}, flush_interval={:?}",
53            self.config.batch_size,
54            self.config.flush_interval
55        );
56
57        let queue = self.queue.clone();
58        let notify = self.notify.clone();
59        let shutdown_notify = self.shutdown_notify.clone();
60        let handler = self.handler.clone();
61        let config = self.config.clone();
62
63        let task_handle = tokio::spawn(async move {
64            Self::processing_loop(queue.clone(), notify, shutdown_notify, handler, config).await;
65        });
66
67        self._task_handle = Some(task_handle);
68        
69        Ok(())
70    }
71
72    /// Add a message to the batch queue
73    pub async fn enqueue(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
74        let mut queue = self.queue.lock().await;
75        
76        let queued_msg = QueuedMessage {
77            message,
78        };
79        
80        queue.push(queued_msg);
81        
82        // Notify the processing loop that new messages are available
83        self.notify.notify_one();
84        
85        debug!("Message enqueued, queue size: {}", queue.len());
86        
87        Ok(())
88    }
89
90    /// Shutdown the batch processor gracefully
91    pub async fn shutdown(&self) -> WorkerResult<()> {
92        info!("Shutting down batch processor...");
93        
94        // Signal shutdown
95        self.shutdown_notify.notify_one();
96        
97        // Process any remaining messages in the queue
98        self.flush_remaining().await?;
99        
100        Ok(())
101    }
102
103    /// Flush any remaining messages in the queue
104    async fn flush_remaining(&self) -> WorkerResult<()> {
105        let mut queue = self.queue.lock().await;
106        
107        if !queue.is_empty() {
108            let count = queue.len();
109            info!("Flushing {} remaining messages before shutdown", count);
110            
111            // Create a batch from remaining messages
112            let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue
113                .drain(..)
114                .enumerate()
115                .map(|(idx, qm)| ReceivedBatchMessage {
116                    message: qm.message.message,
117                    batch_index: idx,
118                })
119                .collect();
120            
121            drop(queue); // Release lock before processing
122            
123            if !batch_messages.is_empty() {
124                let batch_id = format!("flush-{}", uuid::Uuid::new_v4());
125                let batch = MessageBatch::new(batch_id, batch_messages);
126                
127                match self.process_batch_with_retry(&batch).await {
128                    Ok(_) => {
129                        info!("Successfully flushed {} messages", count);
130                    }
131                    Err(e) => {
132                        error!("Failed to flush remaining messages: {:?}", e);
133                    }
134                }
135            }
136        }
137        
138        Ok(())
139    }
140
141    /// Main processing loop
142    async fn processing_loop(
143        queue: Arc<Mutex<Vec<QueuedMessage>>>,
144        notify: Arc<Notify>,
145        shutdown_notify: Arc<Notify>,
146        handler: Arc<dyn BatchHandler>,
147        config: BatchConfig,
148    ) {
149        let mut last_flush = Instant::now();
150        
151        loop {
152            tokio::select! {
153                // Wait for new messages or shutdown signal
154                _ = notify.notified() => {
155                    // Check if we have enough messages to form a batch
156                    let queue_len = queue.lock().await.len();
157                    
158                    if queue_len >= config.batch_size {
159                        // Process a full batch
160                        if let Err(e) = Self::process_full_batch(&queue, &handler, &config).await {
161                            error!("Failed to process batch: {:?}", e);
162                        }
163                        last_flush = Instant::now();
164                    }
165                }
166                
167                // Periodic flush based on timeout
168                _ = tokio::time::sleep(config.flush_interval) => {
169                    if !config.wait_for_full_batch {
170                        let elapsed = last_flush.elapsed();
171                        
172                        if elapsed >= config.flush_interval {
173                            debug!("Flush interval reached, checking for partial batch");
174                            
175                            if let Err(e) = Self::flush_partial_batch(&queue, &handler, &config, BatchStatus::TimeoutFlush).await {
176                                error!("Failed to flush partial batch: {:?}", e);
177                            }
178                            last_flush = Instant::now();
179                        }
180                    }
181                }
182                
183                // Shutdown signal
184                _ = shutdown_notify.notified() => {
185                    info!("Batch processor received shutdown signal");
186                    break;
187                }
188            }
189        }
190    }
191
192    /// Process a full batch when queue reaches batch_size
193    async fn process_full_batch(
194        queue: &Mutex<Vec<QueuedMessage>>,
195        handler: &Arc<dyn BatchHandler>,
196        config: &BatchConfig,
197    ) -> WorkerResult<()> {
198        let mut queue_guard = queue.lock().await;
199        
200        if queue_guard.len() < config.batch_size {
201            return Ok(());
202        }
203        
204        // Extract batch_size messages
205        let batch_data: Vec<QueuedMessage> = queue_guard.drain(..config.batch_size).collect();
206        drop(queue_guard); // Release lock
207        
208        // Convert to batch messages
209        let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
210            .into_iter()
211            .enumerate()
212            .map(|(idx, qm)| ReceivedBatchMessage {
213                message: qm.message.message,
214                batch_index: idx,
215            })
216            .collect();
217        
218        let batch_id = format!("batch-{}", uuid::Uuid::new_v4());
219        let batch = MessageBatch::new(batch_id, batch_messages);
220        
221        info!("Processing full batch {} with {} messages", batch.id, batch.len());
222        
223        Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
224    }
225
226    /// Flush a partial batch (timeout or shutdown)
227    async fn flush_partial_batch(
228        queue: &Arc<Mutex<Vec<QueuedMessage>>>,
229        handler: &Arc<dyn BatchHandler>,
230        config: &BatchConfig,
231        status: BatchStatus,
232    ) -> WorkerResult<()> {
233        let mut queue_guard = queue.lock().await;
234        
235        if queue_guard.is_empty() {
236            return Ok(());
237        }
238        
239        // Extract all messages
240        let batch_data: Vec<QueuedMessage> = queue_guard.drain(..).collect();
241        drop(queue_guard);
242        
243        let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
244            .into_iter()
245            .enumerate()
246            .map(|(idx, qm)| ReceivedBatchMessage {
247                message: qm.message.message,
248                batch_index: idx,
249            })
250            .collect();
251        
252        let batch_id = format!("partial-{}", uuid::Uuid::new_v4());
253        let mut batch = MessageBatch::new(batch_id, batch_messages);
254        batch.metadata.status = status.clone();
255        
256        info!(
257            "Flushing partial batch {} with {} messages (reason: {:?})",
258            batch.id,
259            batch.len(),
260            status
261        );
262        
263        Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
264    }
265
266    /// Process a batch with timeout protection
267    async fn process_batch_with_timeout(
268        batch: &MessageBatch<serde_json::Value>,
269        handler: &Arc<dyn BatchHandler>,
270        timeout: Duration,
271    ) -> WorkerResult<()> {
272        match tokio::time::timeout(timeout, handler.process_batch(batch.clone())).await {
273            Ok(result) => result,
274            Err(_) => {
275                Err(WorkerError::ProcessingFailed(format!(
276                    "Batch {} processing timed out after {:?}",
277                    batch.id,
278                    timeout
279                )))
280            }
281        }
282    }
283
284    /// Process a batch with retry logic
285    async fn process_batch_with_retry(
286        &self,
287        batch: &MessageBatch<serde_json::Value>,
288    ) -> WorkerResult<()> {
289        // For now, just process once. Could add retry logic here later.
290        Self::process_batch_with_timeout(batch, &self.handler, self.config.processing_timeout).await
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::message::{AckHandle, MessageMetadata, Message};
298    use async_trait::async_trait;
299
300    #[derive(Debug)]
301    struct MockAckHandle;
302
303    #[async_trait]
304    impl AckHandle for MockAckHandle {
305        async fn ack(&self) -> WorkerResult<()> { Ok(()) }
306        async fn nack(&self, _requeue: bool) -> WorkerResult<()> { Ok(()) }
307    }
308
309    struct TestBatchHandler;
310
311    #[async_trait]
312    impl BatchHandler for TestBatchHandler {
313        async fn process_batch(&self, batch: MessageBatch<serde_json::Value>) -> WorkerResult<()> {
314            println!("Processed batch {} with {} messages", batch.id, batch.len());
315            Ok(())
316        }
317    }
318
319    #[tokio::test]
320    async fn test_batch_processor_creation() {
321        let handler = Arc::new(TestBatchHandler);
322        let config = BatchConfig::default();
323        let processor = BatchProcessor::new(handler, config);
324        
325        assert_eq!(processor.config.batch_size, 50);
326    }
327
328    #[tokio::test]
329    async fn test_enqueue_message() {
330        let handler = Arc::new(TestBatchHandler);
331        let config = BatchConfig::default();
332        let processor = BatchProcessor::new(handler, config);
333        
334        let message = ReceivedMessage::new(
335            Message {
336                id: "test-1".to_string(),
337                payload: serde_json::json!({"test": "data"}),
338                metadata: MessageMetadata::new("test-queue"),
339            },
340            Arc::new(MockAckHandle),
341        );
342        
343        processor.enqueue(message).await.unwrap();
344        
345        let queue_len = processor.queue.lock().await.len();
346        assert_eq!(queue_len, 1);
347    }
348}