foxtive_worker/middleware/
batch.rs1use std::sync::Arc;
2use std::time::Instant;
3use tokio::sync::{Mutex, Notify};
4use tracing::{debug, error, info};
5
6use crate::batch::{BatchConfig, BatchHandler, BatchStatus, MessageBatch, ReceivedBatchMessage};
7use crate::error::{WorkerError, WorkerResult};
8use crate::message::ReceivedMessage;
9use crate::middleware::{MessageHandler, Middleware};
10
11struct QueuedMessage {
13 received_message: ReceivedMessage<serde_json::Value>,
14}
15
16pub struct BatchMiddleware {
48 handler: Arc<dyn BatchHandler>,
50
51 config: BatchConfig,
53
54 queue: Arc<Mutex<Vec<QueuedMessage>>>,
56
57 notify: Arc<Notify>,
59
60 _task_handle: Option<tokio::task::JoinHandle<()>>,
62}
63
64impl BatchMiddleware {
65 pub fn new(handler: Arc<dyn BatchHandler>, config: BatchConfig) -> Self {
67 Self {
68 handler,
69 config,
70 queue: Arc::new(Mutex::new(Vec::new())),
71 notify: Arc::new(Notify::new()),
72 _task_handle: None,
73 }
74 }
75
76 pub async fn start(&mut self) -> WorkerResult<()> {
78 info!(
79 "Starting batch middleware with batch_size={}, flush_interval={:?}",
80 self.config.batch_size, self.config.flush_interval
81 );
82
83 let queue = self.queue.clone();
84 let notify = self.notify.clone();
85 let handler = self.handler.clone();
86 let config = self.config.clone();
87
88 let task_handle = tokio::spawn(async move {
89 Self::processing_loop(queue, notify, handler, config).await;
90 });
91
92 self._task_handle = Some(task_handle);
93
94 Ok(())
95 }
96
97 async fn enqueue_message(
99 &self,
100 message: ReceivedMessage<serde_json::Value>,
101 ) -> Result<(), WorkerError> {
102 let mut queue = self.queue.lock().await;
103
104 let queued_msg = QueuedMessage {
105 received_message: message,
106 };
107
108 queue.push(queued_msg);
109
110 self.notify.notify_one();
112
113 debug!("Message enqueued for batching, queue size: {}", queue.len());
114
115 Ok(())
116 }
117
118 async fn processing_loop(
120 queue: Arc<Mutex<Vec<QueuedMessage>>>,
121 notify: Arc<Notify>,
122 handler: Arc<dyn BatchHandler>,
123 config: BatchConfig,
124 ) {
125 let mut last_flush = Instant::now();
126
127 loop {
128 tokio::select! {
129 _ = notify.notified() => {
131 let queue_len = queue.lock().await.len();
132
133 if queue_len >= config.batch_size {
134 if let Err(e) = Self::process_full_batch(&queue, &handler, &config).await {
136 error!("Failed to process batch: {:?}", e);
137 }
138 last_flush = Instant::now();
139 }
140 }
141
142 _ = tokio::time::sleep(config.flush_interval) => {
144 if !config.wait_for_full_batch {
145 let elapsed = last_flush.elapsed();
146
147 if elapsed >= config.flush_interval {
148 debug!("Flush interval reached, checking for partial batch");
149
150 if let Err(e) = Self::flush_partial_batch(&queue, &handler, &config, BatchStatus::TimeoutFlush).await {
151 error!("Failed to flush partial batch: {:?}", e);
152 }
153 last_flush = Instant::now();
154 }
155 }
156 }
157 }
158 }
159 }
160
161 async fn process_full_batch(
163 queue: &Mutex<Vec<QueuedMessage>>,
164 handler: &Arc<dyn BatchHandler>,
165 config: &BatchConfig,
166 ) -> WorkerResult<()> {
167 let mut queue_guard = queue.lock().await;
168
169 if queue_guard.len() < config.batch_size {
170 return Ok(());
171 }
172
173 let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue_guard
175 .drain(..config.batch_size)
176 .enumerate()
177 .map(|(idx, qm)| ReceivedBatchMessage {
178 message: qm.received_message.message,
179 batch_index: idx,
180 })
181 .collect();
182
183 let received_messages: Vec<ReceivedMessage<serde_json::Value>> = queue_guard
185 .iter()
186 .take(config.batch_size)
187 .map(|qm| qm.received_message.clone())
188 .collect();
189
190 drop(queue_guard); if batch_messages.is_empty() {
193 return Ok(());
194 }
195
196 let batch_id = format!("batch-{}", uuid::Uuid::new_v4());
197 let mut batch = MessageBatch::new(batch_id.clone(), batch_messages);
198 batch.metadata.status = BatchStatus::Ready;
199
200 info!(
201 "Processing batch {} with {} messages",
202 batch_id,
203 batch.len()
204 );
205
206 match handler.process_batch(batch).await {
208 Ok(_) => {
209 info!("Batch {} processed successfully", batch_id);
210
211 for received_msg in received_messages {
213 if let Err(e) = received_msg.ack().await {
214 error!("Failed to acknowledge message in batch: {:?}", e);
215 }
216 }
217
218 Ok(())
219 }
220 Err(e) => {
221 error!("Batch {} processing failed: {:?}", batch_id, e);
222
223 for received_msg in received_messages {
225 if let Err(e) = received_msg.nack(true).await {
226 error!("Failed to nack message in batch: {:?}", e);
227 }
228 }
229
230 Err(e)
231 }
232 }
233 }
234
235 async fn flush_partial_batch(
237 queue: &Mutex<Vec<QueuedMessage>>,
238 handler: &Arc<dyn BatchHandler>,
239 _config: &BatchConfig,
240 status: BatchStatus,
241 ) -> WorkerResult<()> {
242 let mut queue_guard = queue.lock().await;
243
244 if queue_guard.is_empty() {
245 return Ok(());
246 }
247
248 let count = queue_guard.len();
249 debug!("Flushing partial batch with {} messages", count);
250
251 let batch_messages: Vec<ReceivedBatchMessage<serde_json::Value>> = queue_guard
253 .drain(..)
254 .enumerate()
255 .map(|(idx, qm)| ReceivedBatchMessage {
256 message: qm.received_message.message,
257 batch_index: idx,
258 })
259 .collect();
260
261 let received_messages: Vec<ReceivedMessage<serde_json::Value>> = queue_guard
262 .iter()
263 .map(|qm| qm.received_message.clone())
264 .collect();
265
266 drop(queue_guard);
267
268 if batch_messages.is_empty() {
269 return Ok(());
270 }
271
272 let batch_id = format!("partial-{}", uuid::Uuid::new_v4());
273 let mut batch = MessageBatch::new(batch_id.clone(), batch_messages);
274 batch.metadata.status = status.clone();
275
276 info!(
277 "Processing partial batch {} with {} messages (status: {:?})",
278 batch_id,
279 batch.len(),
280 status
281 );
282
283 match handler.process_batch(batch).await {
285 Ok(_) => {
286 info!("Partial batch {} processed successfully", batch_id);
287
288 for received_msg in received_messages {
289 if let Err(e) = received_msg.ack().await {
290 error!("Failed to acknowledge message in partial batch: {:?}", e);
291 }
292 }
293
294 Ok(())
295 }
296 Err(e) => {
297 error!("Partial batch {} processing failed: {:?}", batch_id, e);
298
299 for received_msg in received_messages {
300 if let Err(e) = received_msg.nack(true).await {
301 error!("Failed to nack message in partial batch: {:?}", e);
302 }
303 }
304
305 Err(e)
306 }
307 }
308 }
309}
310
311#[async_trait::async_trait]
312impl Middleware for BatchMiddleware {
313 fn name(&self) -> &str {
314 "BatchMiddleware"
315 }
316
317 async fn handle(
318 &self,
319 message: ReceivedMessage<serde_json::Value>,
320 _next: Box<dyn MessageHandler>,
321 ) -> Result<crate::middleware::MiddlewareResult, WorkerError> {
322 self.enqueue_message(message).await?;
325 Ok(crate::middleware::MiddlewareResult::Acknowledged)
326 }
327}