1use std::sync::Arc;
2use std::time::{Duration, Instant};
3use tokio::sync::{Mutex, Notify};
4use tracing::{debug, error, info};
5
6use crate::batch::{BatchConfig, BatchHandler, BatchStatus, MessageBatch, ReceivedBatchMessage};
7use crate::error::{WorkerError, WorkerResult};
8use crate::message::ReceivedMessage;
9
10struct QueuedMessage {
12 message: ReceivedMessage<serde_json::Value>,
13}
14
15pub struct BatchProcessor {
17 handler: Arc<dyn BatchHandler>,
19
20 config: BatchConfig,
22
23 queue: Arc<Mutex<Vec<QueuedMessage>>>,
25
26 notify: Arc<Notify>,
28
29 shutdown_notify: Arc<Notify>,
31
32 _task_handle: Option<tokio::task::JoinHandle<()>>,
34}
35
36impl BatchProcessor {
37 pub fn new(handler: Arc<dyn BatchHandler>, config: BatchConfig) -> Self {
39 Self {
40 handler,
41 config,
42 queue: Arc::new(Mutex::new(Vec::new())),
43 notify: Arc::new(Notify::new()),
44 shutdown_notify: Arc::new(Notify::new()),
45 _task_handle: None,
46 }
47 }
48
49 pub async fn start(&mut self) -> WorkerResult<()> {
51 info!(
52 "Starting batch processor with batch_size={}, flush_interval={:?}",
53 self.config.batch_size, self.config.flush_interval
54 );
55
56 let queue = self.queue.clone();
57 let notify = self.notify.clone();
58 let shutdown_notify = self.shutdown_notify.clone();
59 let handler = self.handler.clone();
60 let config = self.config.clone();
61
62 let task_handle = tokio::spawn(async move {
63 Self::processing_loop(queue.clone(), notify, shutdown_notify, handler, config).await;
64 });
65
66 self._task_handle = Some(task_handle);
67
68 Ok(())
69 }
70
71 pub async fn enqueue(&self, message: ReceivedMessage<serde_json::Value>) -> WorkerResult<()> {
73 let mut queue = self.queue.lock().await;
74
75 let queued_msg = QueuedMessage { message };
76
77 queue.push(queued_msg);
78
79 self.notify.notify_one();
81
82 debug!("Message enqueued, queue size: {}", queue.len());
83
84 Ok(())
85 }
86
87 pub async fn shutdown(&self) -> WorkerResult<()> {
89 info!("Shutting down batch processor...");
90
91 self.shutdown_notify.notify_one();
93
94 self.flush_remaining().await?;
96
97 Ok(())
98 }
99
100 async fn flush_remaining(&self) -> WorkerResult<()> {
102 let mut queue = self.queue.lock().await;
103
104 if !queue.is_empty() {
105 let count = queue.len();
106 info!("Flushing {} remaining messages before shutdown", count);
107
108 let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue
110 .drain(..)
111 .enumerate()
112 .map(|(idx, qm)| ReceivedBatchMessage {
113 message: qm.message.message,
114 batch_index: idx,
115 })
116 .collect();
117
118 drop(queue); if !batch_messages.is_empty() {
121 let batch_id = format!("flush-{}", uuid::Uuid::new_v4());
122 let batch = MessageBatch::new(batch_id, batch_messages);
123
124 match self.process_batch_with_retry(&batch).await {
125 Ok(_) => {
126 info!("Successfully flushed {} messages", count);
127 }
128 Err(e) => {
129 error!("Failed to flush remaining messages: {:?}", e);
130 }
131 }
132 }
133 }
134
135 Ok(())
136 }
137
138 async fn processing_loop(
140 queue: Arc<Mutex<Vec<QueuedMessage>>>,
141 notify: Arc<Notify>,
142 shutdown_notify: Arc<Notify>,
143 handler: Arc<dyn BatchHandler>,
144 config: BatchConfig,
145 ) {
146 let mut last_flush = Instant::now();
147
148 loop {
149 tokio::select! {
150 _ = notify.notified() => {
152 let queue_len = queue.lock().await.len();
154
155 if queue_len >= config.batch_size {
156 if let Err(e) = Self::process_full_batch(&queue, &handler, &config).await {
158 error!("Failed to process batch: {:?}", e);
159 }
160 last_flush = Instant::now();
161 }
162 }
163
164 _ = tokio::time::sleep(config.flush_interval) => {
166 if !config.wait_for_full_batch {
167 let elapsed = last_flush.elapsed();
168
169 if elapsed >= config.flush_interval {
170 debug!("Flush interval reached, checking for partial batch");
171
172 if let Err(e) = Self::flush_partial_batch(&queue, &handler, &config, BatchStatus::TimeoutFlush).await {
173 error!("Failed to flush partial batch: {:?}", e);
174 }
175 last_flush = Instant::now();
176 }
177 }
178 }
179
180 _ = shutdown_notify.notified() => {
182 info!("Batch processor received shutdown signal");
183 break;
184 }
185 }
186 }
187 }
188
189 async fn process_full_batch(
191 queue: &Mutex<Vec<QueuedMessage>>,
192 handler: &Arc<dyn BatchHandler>,
193 config: &BatchConfig,
194 ) -> WorkerResult<()> {
195 let mut queue_guard = queue.lock().await;
196
197 if queue_guard.len() < config.batch_size {
198 return Ok(());
199 }
200
201 let batch_data: Vec<QueuedMessage> = queue_guard.drain(..config.batch_size).collect();
203 drop(queue_guard); let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
207 .into_iter()
208 .enumerate()
209 .map(|(idx, qm)| ReceivedBatchMessage {
210 message: qm.message.message,
211 batch_index: idx,
212 })
213 .collect();
214
215 let batch_id = format!("batch-{}", uuid::Uuid::new_v4());
216 let batch = MessageBatch::new(batch_id, batch_messages);
217
218 info!(
219 "Processing full batch {} with {} messages",
220 batch.id,
221 batch.len()
222 );
223
224 Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
225 }
226
227 async fn flush_partial_batch(
229 queue: &Arc<Mutex<Vec<QueuedMessage>>>,
230 handler: &Arc<dyn BatchHandler>,
231 config: &BatchConfig,
232 status: BatchStatus,
233 ) -> WorkerResult<()> {
234 let mut queue_guard = queue.lock().await;
235
236 if queue_guard.is_empty() {
237 return Ok(());
238 }
239
240 let batch_data: Vec<QueuedMessage> = queue_guard.drain(..).collect();
242 drop(queue_guard);
243
244 let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = batch_data
245 .into_iter()
246 .enumerate()
247 .map(|(idx, qm)| ReceivedBatchMessage {
248 message: qm.message.message,
249 batch_index: idx,
250 })
251 .collect();
252
253 let batch_id = format!("partial-{}", uuid::Uuid::new_v4());
254 let mut batch = MessageBatch::new(batch_id, batch_messages);
255 batch.metadata.status = status.clone();
256
257 info!(
258 "Flushing partial batch {} with {} messages (reason: {:?})",
259 batch.id,
260 batch.len(),
261 status
262 );
263
264 Self::process_batch_with_timeout(&batch, handler, config.processing_timeout).await
265 }
266
267 async fn process_batch_with_timeout(
269 batch: &MessageBatch<serde_json::Value>,
270 handler: &Arc<dyn BatchHandler>,
271 timeout: Duration,
272 ) -> WorkerResult<()> {
273 match tokio::time::timeout(timeout, handler.process_batch(batch.clone())).await {
274 Ok(result) => result,
275 Err(_) => Err(WorkerError::ProcessingFailed(format!(
276 "Batch {} processing timed out after {:?}",
277 batch.id, timeout
278 ))),
279 }
280 }
281
282 async fn process_batch_with_retry(
284 &self,
285 batch: &MessageBatch<serde_json::Value>,
286 ) -> WorkerResult<()> {
287 Self::process_batch_with_timeout(batch, &self.handler, self.config.processing_timeout).await
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::message::{AckHandle, Message, MessageMetadata};
296 use async_trait::async_trait;
297
298 #[derive(Debug)]
299 struct MockAckHandle;
300
301 #[async_trait]
302 impl AckHandle for MockAckHandle {
303 async fn ack(&self) -> WorkerResult<()> {
304 Ok(())
305 }
306 async fn nack(&self, _requeue: bool) -> WorkerResult<()> {
307 Ok(())
308 }
309 }
310
311 struct TestBatchHandler;
312
313 #[async_trait]
314 impl BatchHandler for TestBatchHandler {
315 async fn process_batch(&self, batch: MessageBatch<serde_json::Value>) -> WorkerResult<()> {
316 println!("Processed batch {} with {} messages", batch.id, batch.len());
317 Ok(())
318 }
319 }
320
321 #[tokio::test]
322 async fn test_batch_processor_creation() {
323 let handler = Arc::new(TestBatchHandler);
324 let config = BatchConfig::default();
325 let processor = BatchProcessor::new(handler, config);
326
327 assert_eq!(processor.config.batch_size, 50);
328 }
329
330 #[tokio::test]
331 async fn test_enqueue_message() {
332 let handler = Arc::new(TestBatchHandler);
333 let config = BatchConfig::default();
334 let processor = BatchProcessor::new(handler, config);
335
336 let message = ReceivedMessage::new(
337 Message {
338 id: "test-1".to_string(),
339 payload: serde_json::json!({"test": "data"}),
340 metadata: MessageMetadata::new("test-queue"),
341 },
342 Arc::new(MockAckHandle),
343 );
344
345 processor.enqueue(message).await.unwrap();
346
347 let queue_len = processor.queue.lock().await.len();
348 assert_eq!(queue_len, 1);
349 }
350}