Skip to main content

foxtive_worker/middleware/
batch.rs

1use std::sync::Arc;
2use std::time::Instant;
3use tokio::sync::{Mutex, Notify};
4use tracing::{debug, error, info};
5
6use crate::batch::{BatchConfig, BatchHandler, BatchStatus, MessageBatch, ReceivedBatchMessage};
7use crate::error::{WorkerError, WorkerResult};
8use crate::message::ReceivedMessage;
9use crate::middleware::{MessageHandler, Middleware};
10
11/// Internal queued message for batch assembly
12struct QueuedMessage {
13    received_message: ReceivedMessage<serde_json::Value>,
14}
15
16/// Middleware that automatically batches messages before processing.
17///
18/// This middleware intercepts individual messages and assembles them into batches.
19/// When a batch is complete (by size or timeout), it processes all messages together
20/// using a BatchHandler, then acknowledges each message individually.
21///
22/// # Example
23/// ```rust,no_run
24/// use foxtive_worker::middleware::batch::BatchMiddleware;
25/// use foxtive_worker::{BatchConfig, BatchHandler, ReceivedBatchMessage, MessageBatch};
26/// use foxtive_worker::error::WorkerResult;
27/// use std::sync::Arc;
28///
29/// // Create a batch handler
30/// struct MyBatchHandler;
31///
32/// #[async_trait::async_trait]
33/// impl BatchHandler for MyBatchHandler {
34///     async fn process_batch(&self, _batch: MessageBatch<serde_json::Value>) -> WorkerResult<()> {
35///         Ok(())
36///     }
37/// }
38///
39/// // Configure batching
40/// let config = BatchConfig::default()
41///     .with_batch_size(10)
42///     .with_flush_interval(std::time::Duration::from_secs(5));
43///
44/// // Create middleware
45/// let batch_middleware = BatchMiddleware::new(Arc::new(MyBatchHandler), config);
46/// ```
47pub struct BatchMiddleware {
48    /// The handler that processes completed batches
49    handler: Arc<dyn BatchHandler>,
50
51    /// Configuration for batching behavior
52    config: BatchConfig,
53
54    /// Queue of messages waiting to be batched
55    queue: Arc<Mutex<Vec<QueuedMessage>>>,
56
57    /// Signal when new messages arrive
58    notify: Arc<Notify>,
59
60    /// Background task handle
61    _task_handle: Option<tokio::task::JoinHandle<()>>,
62}
63
64impl BatchMiddleware {
65    /// Create a new batch middleware
66    pub fn new(handler: Arc<dyn BatchHandler>, config: BatchConfig) -> Self {
67        Self {
68            handler,
69            config,
70            queue: Arc::new(Mutex::new(Vec::new())),
71            notify: Arc::new(Notify::new()),
72            _task_handle: None,
73        }
74    }
75
76    /// Start the background batch processing task
77    pub async fn start(&mut self) -> WorkerResult<()> {
78        info!(
79            "Starting batch middleware with batch_size={}, flush_interval={:?}",
80            self.config.batch_size, self.config.flush_interval
81        );
82
83        let queue = self.queue.clone();
84        let notify = self.notify.clone();
85        let handler = self.handler.clone();
86        let config = self.config.clone();
87
88        let task_handle = tokio::spawn(async move {
89            Self::processing_loop(queue, notify, handler, config).await;
90        });
91
92        self._task_handle = Some(task_handle);
93
94        Ok(())
95    }
96
97    /// Add a message to the batch queue
98    async fn enqueue_message(
99        &self,
100        message: ReceivedMessage<serde_json::Value>,
101    ) -> Result<(), WorkerError> {
102        let mut queue = self.queue.lock().await;
103
104        let queued_msg = QueuedMessage {
105            received_message: message,
106        };
107
108        queue.push(queued_msg);
109
110        // Notify the processing loop
111        self.notify.notify_one();
112
113        debug!("Message enqueued for batching, queue size: {}", queue.len());
114
115        Ok(())
116    }
117
118    /// Main processing loop
119    async fn processing_loop(
120        queue: Arc<Mutex<Vec<QueuedMessage>>>,
121        notify: Arc<Notify>,
122        handler: Arc<dyn BatchHandler>,
123        config: BatchConfig,
124    ) {
125        let mut last_flush = Instant::now();
126
127        loop {
128            tokio::select! {
129                // Wait for new messages
130                _ = notify.notified() => {
131                    let queue_len = queue.lock().await.len();
132
133                    if queue_len >= config.batch_size {
134                        // Process a full batch
135                        if let Err(e) = Self::process_full_batch(&queue, &handler, &config).await {
136                            error!("Failed to process batch: {:?}", e);
137                        }
138                        last_flush = Instant::now();
139                    }
140                }
141
142                // Periodic flush based on timeout
143                _ = tokio::time::sleep(config.flush_interval) => {
144                    if !config.wait_for_full_batch {
145                        let elapsed = last_flush.elapsed();
146
147                        if elapsed >= config.flush_interval {
148                            debug!("Flush interval reached, checking for partial batch");
149
150                            if let Err(e) = Self::flush_partial_batch(&queue, &handler, &config, BatchStatus::TimeoutFlush).await {
151                                error!("Failed to flush partial batch: {:?}", e);
152                            }
153                            last_flush = Instant::now();
154                        }
155                    }
156                }
157            }
158        }
159    }
160
161    /// Process a full batch when queue reaches batch_size
162    async fn process_full_batch(
163        queue: &Mutex<Vec<QueuedMessage>>,
164        handler: &Arc<dyn BatchHandler>,
165        config: &BatchConfig,
166    ) -> WorkerResult<()> {
167        let mut queue_guard = queue.lock().await;
168
169        if queue_guard.len() < config.batch_size {
170            return Ok(());
171        }
172
173        // Extract messages for the batch
174        let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue_guard
175            .drain(..config.batch_size)
176            .enumerate()
177            .map(|(idx, qm)| ReceivedBatchMessage {
178                message: qm.received_message.message,
179                batch_index: idx,
180            })
181            .collect();
182
183        // Store received messages for acknowledgment after processing
184        let received_messages: Vec<ReceivedMessage<serde_json::Value>> = queue_guard
185            .iter()
186            .take(config.batch_size)
187            .map(|qm| qm.received_message.clone())
188            .collect();
189
190        drop(queue_guard); // Release lock before processing
191
192        if batch_messages.is_empty() {
193            return Ok(());
194        }
195
196        let batch_id = format!("batch-{}", uuid::Uuid::new_v4());
197        let mut batch = MessageBatch::new(batch_id.clone(), batch_messages);
198        batch.metadata.status = BatchStatus::Ready;
199
200        info!(
201            "Processing batch {} with {} messages",
202            batch_id,
203            batch.len()
204        );
205
206        // Process the batch
207        match handler.process_batch(batch).await {
208            Ok(_) => {
209                info!("Batch {} processed successfully", batch_id);
210
211                // Acknowledge all messages in the batch
212                for received_msg in received_messages {
213                    if let Err(e) = received_msg.ack().await {
214                        error!("Failed to acknowledge message in batch: {:?}", e);
215                    }
216                }
217
218                Ok(())
219            }
220            Err(e) => {
221                error!("Batch {} processing failed: {:?}", batch_id, e);
222
223                // Nack all messages in the batch
224                for received_msg in received_messages {
225                    if let Err(e) = received_msg.nack(true).await {
226                        error!("Failed to nack message in batch: {:?}", e);
227                    }
228                }
229
230                Err(e)
231            }
232        }
233    }
234
235    /// Flush a partial batch (timeout or shutdown)
236    async fn flush_partial_batch(
237        queue: &Mutex<Vec<QueuedMessage>>,
238        handler: &Arc<dyn BatchHandler>,
239        _config: &BatchConfig,
240        status: BatchStatus,
241    ) -> WorkerResult<()> {
242        let mut queue_guard = queue.lock().await;
243
244        if queue_guard.is_empty() {
245            return Ok(());
246        }
247
248        let count = queue_guard.len();
249        debug!("Flushing partial batch with {} messages", count);
250
251        // Extract all remaining messages
252        let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue_guard
253            .drain(..)
254            .enumerate()
255            .map(|(idx, qm)| ReceivedBatchMessage {
256                message: qm.received_message.message,
257                batch_index: idx,
258            })
259            .collect();
260
261        let received_messages: Vec<ReceivedMessage<serde_json::Value>> = queue_guard
262            .iter()
263            .map(|qm| qm.received_message.clone())
264            .collect();
265
266        drop(queue_guard);
267
268        if batch_messages.is_empty() {
269            return Ok(());
270        }
271
272        let batch_id = format!("partial-{}", uuid::Uuid::new_v4());
273        let mut batch = MessageBatch::new(batch_id.clone(), batch_messages);
274        batch.metadata.status = status.clone();
275
276        info!(
277            "Processing partial batch {} with {} messages (status: {:?})",
278            batch_id,
279            batch.len(),
280            status
281        );
282
283        // Process the batch
284        match handler.process_batch(batch).await {
285            Ok(_) => {
286                info!("Partial batch {} processed successfully", batch_id);
287
288                for received_msg in received_messages {
289                    if let Err(e) = received_msg.ack().await {
290                        error!("Failed to acknowledge message in partial batch: {:?}", e);
291                    }
292                }
293
294                Ok(())
295            }
296            Err(e) => {
297                error!("Partial batch {} processing failed: {:?}", batch_id, e);
298
299                for received_msg in received_messages {
300                    if let Err(e) = received_msg.nack(true).await {
301                        error!("Failed to nack message in partial batch: {:?}", e);
302                    }
303                }
304
305                Err(e)
306            }
307        }
308    }
309}
310
311#[async_trait::async_trait]
312impl Middleware for BatchMiddleware {
313    fn name(&self) -> &str {
314        "BatchMiddleware"
315    }
316
317    async fn handle(
318        &self,
319        message: ReceivedMessage<serde_json::Value>,
320        _next: Box<dyn MessageHandler>,
321    ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
322        // Enqueue the message for batching instead of processing immediately
323        // The 'next' handler is not used because we're intercepting for batching
324        self.enqueue_message(message).await?;
325        Ok(crate::middleware::MiddlewareResult::Acknowledged)
326    }
327}