1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::{Mutex, Notify};
4use tracing::{debug, info, error};
5
6use crate::batch::{BatchConfig, BatchHandler, BatchStatus, MessageBatch, ReceivedBatchMessage};
7use crate::error::{WorkerError, WorkerResult};
8use crate::message::ReceivedMessage;
9
10struct QueuedMessage {
12 message: ReceivedMessage<serde_json::Value>,
13}
14
15pub struct BatchProcessor {
17 handler: Arc<dyn BatchHandler>,
19
20 config: BatchConfig,
22
23 queue: Arc<Mutex<Vec<QueuedMessage>>>,
25
26 notify: Arc<Notify>,
28
29 shutdown_notify: Arc<Notify>,
31
32 _task_handle: Option<tokio::task::JoinHandle<()>>,
34}
35
36impl BatchProcessor {
37 pub fn new(handler: Arc<dyn BatchHandler>, config: BatchConfig) -> Self {
39 Self {
40 handler,
41 config,
42 queue: Arc::new(Mutex::new(Vec::new())),
43 notify: Arc::new(Notify::new()),
44 shutdown_notify: Arc::new(Notify::new()),
45 _task_handle: None,
46 }
47 }
48
49 pub async fn start(&mut self) -> WorkerResult<()> {
51 info!(
52 "Starting batch processor with batch_size={}, flush_interval={:?}",
53 self.config.batch_size,
54 self.config.flush_interval
55 );
56
57 let queue = self.queue.clone();
58 let notify = self.notify.clone();
59 let shutdown_notify = self.shutdown_notify.clone();
60 let handler = self.handler.clone();
61 let config = self.config.clone();
62
63 let task_handle = tokio::spawn(async move {
64 Self::processing_loop(queue.clone(), notify, shutdown_notify, handler, config).await;
65 });
66
67 self._task_handle = Some(task_handle);
68
69 Ok(())
70 }
71
72 pub async fn enqueue(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
74 let mut queue = self.queue.lock().await;
75
76 let queued_msg = QueuedMessage {
77 message,
78 };
79
80 queue.push(queued_msg);
81
82 self.notify.notify_one();
84
85 debug!("Message enqueued, queue size: {}", queue.len());
86
87 Ok(())
88 }
89
90 pub async fn shutdown(&self) -> WorkerResult<()> {
92 info!("Shutting down batch processor...");
93
94 self.shutdown_notify.notify_one();
96
97 self.flush_remaining().await?;
99
100 Ok(())
101 }
102
103 async fn flush_remaining(&self) -> WorkerResult<()> {
105 let mut queue = self.queue.lock().await;
106
107 if !queue.is_empty() {
108 let count = queue.len();
109 info!("Flushing {} remaining messages before shutdown", count);
110
111 let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue
113 .drain(..)
114 .enumerate()
115 .map(|(idx, qm)| ReceivedBatchMessage {
116 message: qm.message.message,
117 batch_index: idx,
118 })
119 .collect();
120
121 drop(queue); if !batch_messages.is_empty() {
124 let batch_id = format!("flush-{}", uuid::Uuid::new_v4());
125 let batch = MessageBatch::new(batch_id, batch_messages);
126
127 match self.process_batch_with_retry(&batch).await {
128 Ok(_) => {
129 info!("Successfully flushed {} messages", count);
130 }
131 Err(e) => {
132 error!("Failed to flush remaining messages: {:?}", e);
133 }
134 }
135 }
136 }
137
138 Ok(())
139 }
140
141 async fn processing_loop(
143 queue: Arc<Mutex<Vec<QueuedMessage>>>,
144 notify: Arc<Notify>,
145 shutdown_notify: Arc<Notify>,
146 handler: Arc<dyn BatchHandler>,
147 config: BatchConfig,
148 ) {
149 let mut last_flush = Instant::now();
150
151 loop {
152 tokio::select! {
153 _ = notify.notified() => {
155 let queue_len = queue.lock().await.len();
157
158 if queue_len >= config.batch_size {
159 if let Err(e) = Self::process_full_batch(&queue, &handler, &config).await {
161 error!("Failed to process batch: {:?}", e);
162 }
163 last_flush = Instant::now();
164 }
165 }
166
167 _ = tokio::time::sleep(config.flush_interval) => {
169 if !config.wait_for_full_batch {
170 let elapsed = last_flush.elapsed();
171
172 if elapsed >= config.flush_interval {
173 debug!("Flush interval reached, checking for partial batch");
174
175 if let Err(e) = Self::flush_partial_batch(&queue, &handler, &config, BatchStatus::TimeoutFlush).await {
176 error!("Failed to flush partial batch: {:?}", e);
177 }
178 last_flush = Instant::now();
179 }
180 }
181 }
182
183 _ = shutdown_notify.notified() => {
185 info!("Batch processor received shutdown signal");
186 break;
187 }
188 }
189 }
190 }
191
192 async fn process_full_batch(
194 queue: &Mutex<Vec<QueuedMessage>>,
195 handler: &Arc<dyn BatchHandler>,
196 config: &BatchConfig,
197 ) -> WorkerResult<()> {
198 let mut queue_guard = queue.lock().await;
199
200 if queue_guard.len() < config.batch_size {
201 return Ok(());
202 }
203
204 let batch_data: Vec<QueuedMessage> = queue_guard.drain(..config.batch_size).collect();
206 drop(queue_guard); let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
210 .into_iter()
211 .enumerate()
212 .map(|(idx, qm)| ReceivedBatchMessage {
213 message: qm.message.message,
214 batch_index: idx,
215 })
216 .collect();
217
218 let batch_id = format!("batch-{}", uuid::Uuid::new_v4());
219 let batch = MessageBatch::new(batch_id, batch_messages);
220
221 info!("Processing full batch {} with {} messages", batch.id, batch.len());
222
223 Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
224 }
225
226 async fn flush_partial_batch(
228 queue: &Arc<Mutex<Vec<QueuedMessage>>>,
229 handler: &Arc<dyn BatchHandler>,
230 config: &BatchConfig,
231 status: BatchStatus,
232 ) -> WorkerResult<()> {
233 let mut queue_guard = queue.lock().await;
234
235 if queue_guard.is_empty() {
236 return Ok(());
237 }
238
239 let batch_data: Vec<QueuedMessage> = queue_guard.drain(..).collect();
241 drop(queue_guard);
242
243 let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
244 .into_iter()
245 .enumerate()
246 .map(|(idx, qm)| ReceivedBatchMessage {
247 message: qm.message.message,
248 batch_index: idx,
249 })
250 .collect();
251
252 let batch_id = format!("partial-{}", uuid::Uuid::new_v4());
253 let mut batch = MessageBatch::new(batch_id, batch_messages);
254 batch.metadata.status = status.clone();
255
256 info!(
257 "Flushing partial batch {} with {} messages (reason: {:?})",
258 batch.id,
259 batch.len(),
260 status
261 );
262
263 Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
264 }
265
266 async fn process_batch_with_timeout(
268 batch: &MessageBatch<serde_json::Value>,
269 handler: &Arc<dyn BatchHandler>,
270 timeout: Duration,
271 ) -> WorkerResult<()> {
272 match tokio::time::timeout(timeout, handler.process_batch(batch.clone())).await {
273 Ok(result) => result,
274 Err(_) => {
275 Err(WorkerError::ProcessingFailed(format!(
276 "Batch {} processing timed out after {:?}",
277 batch.id,
278 timeout
279 )))
280 }
281 }
282 }
283
284 async fn process_batch_with_retry(
286 &self,
287 batch: &MessageBatch<serde_json::Value>,
288 ) -> WorkerResult<()> {
289 Self::process_batch_with_timeout(batch, &self.handler, self.config.processing_timeout).await
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::message::{AckHandle, MessageMetadata, Message};
298 use async_trait::async_trait;
299
300 #[derive(Debug)]
301 struct MockAckHandle;
302
303 #[async_trait]
304 impl AckHandle for MockAckHandle {
305 async fn ack(&self) -> WorkerResult<()> { Ok(()) }
306 async fn nack(&self, _requeue: bool) -> WorkerResult<()> { Ok(()) }
307 }
308
309 struct TestBatchHandler;
310
311 #[async_trait]
312 impl BatchHandler for TestBatchHandler {
313 async fn process_batch(&self, batch: MessageBatch<serde_json::Value>) -> WorkerResult<()> {
314 println!("Processed batch {} with {} messages", batch.id, batch.len());
315 Ok(())
316 }
317 }
318
319 #[tokio::test]
320 async fn test_batch_processor_creation() {
321 let handler = Arc::new(TestBatchHandler);
322 let config = BatchConfig::default();
323 let processor = BatchProcessor::new(handler, config);
324
325 assert_eq!(processor.config.batch_size, 50);
326 }
327
328 #[tokio::test]
329 async fn test_enqueue_message() {
330 let handler = Arc::new(TestBatchHandler);
331 let config = BatchConfig::default();
332 let processor = BatchProcessor::new(handler, config);
333
334 let message = ReceivedMessage::new(
335 Message {
336 id: "test-1".to_string(),
337 payload: serde_json::json!({"test": "data"}),
338 metadata: MessageMetadata::new("test-queue"),
339 },
340 Arc::new(MockAckHandle),
341 );
342
343 processor.enqueue(message).await.unwrap();
344
345 let queue_len = processor.queue.lock().await.len();
346 assert_eq!(queue_len, 1);
347 }
348}