Skip to main content

foxtive_worker/
batch_processor.rs

1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::{Mutex, Notify};
4use tracing::{debug, error, info};
5
6use crate::batch::{BatchConfig, BatchHandler, BatchStatus, MessageBatch, ReceivedBatchMessage};
7use crate::error::{WorkerError, WorkerResult};
8use crate::message::ReceivedMessage;
9
10/// Internal message wrapper for the batch queue
11struct QueuedMessage {
12    message: ReceivedMessage<serde_json::Value>,
13}
14
15/// Batch processor that assembles messages into batches and processes them
16pub struct BatchProcessor {
17    /// The handler that processes completed batches
18    handler: Arc<dyn BatchHandler>,
19
20    /// Configuration for batching behavior
21    config: BatchConfig,
22
23    /// Queue of messages waiting to be batched
24    queue: Arc<Mutex<Vec<QueuedMessage>>>,
25
26    /// Signal when new messages arrive
27    notify: Arc<Notify>,
28
29    /// Signal for shutdown
30    shutdown_notify: Arc<Notify>,
31
32    /// Handle to the background processing task
33    _task_handle: Option<tokio::task::JoinHandle<()>>,
34}
35
36impl BatchProcessor {
37    /// Create a new batch processor
38    pub fn new(handler: Arc<dyn BatchHandler>, config: BatchConfig) -> Self {
39        Self {
40            handler,
41            config,
42            queue: Arc::new(Mutex::new(Vec::new())),
43            notify: Arc::new(Notify::new()),
44            shutdown_notify: Arc::new(Notify::new()),
45            _task_handle: None,
46        }
47    }
48
49    /// Start the batch processor background task
50    pub async fn start(&mut self) -> WorkerResult<()> {
51        info!(
52            "Starting batch processor with batch_size={}, flush_interval={:?}",
53            self.config.batch_size, self.config.flush_interval
54        );
55
56        let queue = self.queue.clone();
57        let notify = self.notify.clone();
58        let shutdown_notify = self.shutdown_notify.clone();
59        let handler = self.handler.clone();
60        let config = self.config.clone();
61
62        let task_handle = tokio::spawn(async move {
63            Self::processing_loop(queue.clone(), notify, shutdown_notify, handler, config).await;
64        });
65
66        self._task_handle = Some(task_handle);
67
68        Ok(())
69    }
70
71    /// Add a message to the batch queue
72    pub async fn enqueue(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
73        let mut queue = self.queue.lock().await;
74
75        let queued_msg = QueuedMessage { message };
76
77        queue.push(queued_msg);
78
79        // Notify the processing loop that new messages are available
80        self.notify.notify_one();
81
82        debug!("Message enqueued, queue size: {}", queue.len());
83
84        Ok(())
85    }
86
87    /// Shutdown the batch processor gracefully
88    pub async fn shutdown(&self) -> WorkerResult<()> {
89        info!("Shutting down batch processor...");
90
91        // Signal shutdown
92        self.shutdown_notify.notify_one();
93
94        // Process any remaining messages in the queue
95        self.flush_remaining().await?;
96
97        Ok(())
98    }
99
100    /// Flush any remaining messages in the queue
101    async fn flush_remaining(&self) -> WorkerResult<()> {
102        let mut queue = self.queue.lock().await;
103
104        if !queue.is_empty() {
105            let count = queue.len();
106            info!("Flushing {} remaining messages before shutdown", count);
107
108            // Create a batch from remaining messages
109            let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue
110                .drain(..)
111                .enumerate()
112                .map(|(idx, qm)| ReceivedBatchMessage {
113                    message: qm.message.message,
114                    batch_index: idx,
115                })
116                .collect();
117
118            drop(queue); // Release lock before processing
119
120            if !batch_messages.is_empty() {
121                let batch_id = format!("flush-{}", uuid::Uuid::new_v4());
122                let batch = MessageBatch::new(batch_id, batch_messages);
123
124                match self.process_batch_with_retry(&batch).await {
125                    Ok(_) => {
126                        info!("Successfully flushed {} messages", count);
127                    }
128                    Err(e) => {
129                        error!("Failed to flush remaining messages: {:?}", e);
130                    }
131                }
132            }
133        }
134
135        Ok(())
136    }
137
138    /// Main processing loop
139    async fn processing_loop(
140        queue: Arc<Mutex<Vec<QueuedMessage>>>,
141        notify: Arc<Notify>,
142        shutdown_notify: Arc<Notify>,
143        handler: Arc<dyn BatchHandler>,
144        config: BatchConfig,
145    ) {
146        let mut last_flush = Instant::now();
147
148        loop {
149            tokio::select! {
150                // Wait for new messages or shutdown signal
151                _ = notify.notified() => {
152                    // Check if we have enough messages to form a batch
153                    let queue_len = queue.lock().await.len();
154
155                    if queue_len >= config.batch_size {
156                        // Process a full batch
157                        if let Err(e) = Self::process_full_batch(&queue, &handler, &config).await {
158                            error!("Failed to process batch: {:?}", e);
159                        }
160                        last_flush = Instant::now();
161                    }
162                }
163
164                // Periodic flush based on timeout
165                _ = tokio::time::sleep(config.flush_interval) => {
166                    if !config.wait_for_full_batch {
167                        let elapsed = last_flush.elapsed();
168
169                        if elapsed >= config.flush_interval {
170                            debug!("Flush interval reached, checking for partial batch");
171
172                            if let Err(e) = Self::flush_partial_batch(&queue, &handler, &config, BatchStatus::TimeoutFlush).await {
173                                error!("Failed to flush partial batch: {:?}", e);
174                            }
175                            last_flush = Instant::now();
176                        }
177                    }
178                }
179
180                // Shutdown signal
181                _ = shutdown_notify.notified() => {
182                    info!("Batch processor received shutdown signal");
183                    break;
184                }
185            }
186        }
187    }
188
189    /// Process a full batch when queue reaches batch_size
190    async fn process_full_batch(
191        queue: &Mutex<Vec<QueuedMessage>>,
192        handler: &Arc<dyn BatchHandler>,
193        config: &BatchConfig,
194    ) -> WorkerResult<()> {
195        let mut queue_guard = queue.lock().await;
196
197        if queue_guard.len() < config.batch_size {
198            return Ok(());
199        }
200
201        // Extract batch_size messages
202        let batch_data: Vec<QueuedMessage> = queue_guard.drain(..config.batch_size).collect();
203        drop(queue_guard); // Release lock
204
205        // Convert to batch messages
206        let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
207            .into_iter()
208            .enumerate()
209            .map(|(idx, qm)| ReceivedBatchMessage {
210                message: qm.message.message,
211                batch_index: idx,
212            })
213            .collect();
214
215        let batch_id = format!("batch-{}", uuid::Uuid::new_v4());
216        let batch = MessageBatch::new(batch_id, batch_messages);
217
218        info!(
219            "Processing full batch {} with {} messages",
220            batch.id,
221            batch.len()
222        );
223
224        Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
225    }
226
227    /// Flush a partial batch (timeout or shutdown)
228    async fn flush_partial_batch(
229        queue: &Arc<Mutex<Vec<QueuedMessage>>>,
230        handler: &Arc<dyn BatchHandler>,
231        config: &BatchConfig,
232        status: BatchStatus,
233    ) -> WorkerResult<()> {
234        let mut queue_guard = queue.lock().await;
235
236        if queue_guard.is_empty() {
237            return Ok(());
238        }
239
240        // Extract all messages
241        let batch_data: Vec<QueuedMessage> = queue_guard.drain(..).collect();
242        drop(queue_guard);
243
244        let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
245            .into_iter()
246            .enumerate()
247            .map(|(idx, qm)| ReceivedBatchMessage {
248                message: qm.message.message,
249                batch_index: idx,
250            })
251            .collect();
252
253        let batch_id = format!("partial-{}", uuid::Uuid::new_v4());
254        let mut batch = MessageBatch::new(batch_id, batch_messages);
255        batch.metadata.status = status.clone();
256
257        info!(
258            "Flushing partial batch {} with {} messages (reason: {:?})",
259            batch.id,
260            batch.len(),
261            status
262        );
263
264        Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
265    }
266
267    /// Process a batch with timeout protection
268    async fn process_batch_with_timeout(
269        batch: &MessageBatch<serde_json::Value>,
270        handler: &Arc<dyn BatchHandler>,
271        timeout: Duration,
272    ) -> WorkerResult<()> {
273        match tokio::time::timeout(timeout, handler.process_batch(batch.clone())).await {
274            Ok(result) => result,
275            Err(_) => Err(WorkerError::ProcessingFailed(format!(
276                "Batch {} processing timed out after {:?}",
277                batch.id, timeout
278            ))),
279        }
280    }
281
282    /// Process a batch with retry logic
283    async fn process_batch_with_retry(
284        &self,
285        batch: &MessageBatch<serde_json::Value>,
286    ) -> WorkerResult<()> {
287        // For now, just process once. Could add retry logic here later.
288        Self::process_batch_with_timeout(batch, &self.handler, self.config.processing_timeout).await
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::message::{AckHandle, Message, MessageMetadata};
296    use async_trait::async_trait;
297
298    #[derive(Debug)]
299    struct MockAckHandle;
300
301    #[async_trait]
302    impl AckHandle for MockAckHandle {
303        async fn ack(&self) -> WorkerResult<()> {
304            Ok(())
305        }
306        async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
307            Ok(())
308        }
309    }
310
311    struct TestBatchHandler;
312
313    #[async_trait]
314    impl BatchHandler for TestBatchHandler {
315        async fn process_batch(&self, batch: MessageBatch<serde_json::Value>) -> WorkerResult<()> {
316            println!("Processed batch {} with {} messages", batch.id, batch.len());
317            Ok(())
318        }
319    }
320
321    #[tokio::test]
322    async fn test_batch_processor_creation() {
323        let handler = Arc::new(TestBatchHandler);
324        let config = BatchConfig::default();
325        let processor = BatchProcessor::new(handler, config);
326
327        assert_eq!(processor.config.batch_size, 50);
328    }
329
330    #[tokio::test]
331    async fn test_enqueue_message() {
332        let handler = Arc::new(TestBatchHandler);
333        let config = BatchConfig::default();
334        let processor = BatchProcessor::new(handler, config);
335
336        let message = ReceivedMessage::new(
337            Message {
338                id: "test-1".to_string(),
339                payload: serde_json::json!({"test": "data"}),
340                metadata: MessageMetadata::new("test-queue"),
341            },
342            Arc::new(MockAckHandle),
343        );
344
345        processor.enqueue(message).await.unwrap();
346
347        let queue_len = processor.queue.lock().await.len();
348        assert_eq!(queue_len, 1);
349    }
350}