1use crate::{
10 error::{RabbitError, Result},
11 metrics::{MetricsTimer, RustRabbitMetrics},
12 publisher::{PublishOptions, Publisher},
13};
14use serde::Serialize;
15use std::time::{Duration, Instant};
16use tokio::sync::mpsc;
17use tokio::time::{interval, Instant as TokioInstant};
18use tracing::{debug, error, info};
19
20#[derive(Debug, Clone)]
22pub struct BatchConfig {
23 pub max_batch_size: usize,
25 pub max_batch_timeout: Duration,
27 pub buffer_size: usize,
29 pub flush_on_full: bool,
31}
32
33impl Default for BatchConfig {
34 fn default() -> Self {
35 Self {
36 max_batch_size: 100,
37 max_batch_timeout: Duration::from_millis(100),
38 buffer_size: 1000,
39 flush_on_full: true,
40 }
41 }
42}
43
44#[derive(Debug)]
46struct BatchMessage {
47 queue_name: String,
48 payload: Vec<u8>,
49 options: Option<PublishOptions>,
50 timestamp: Instant,
51}
52
53#[derive(Debug)]
55pub struct MessageBatcher {
56 config: BatchConfig,
57 #[allow(dead_code)]
58 publisher: Publisher,
59 sender: mpsc::Sender<BatchMessage>,
60 metrics: Option<RustRabbitMetrics>,
61}
62
63impl MessageBatcher {
64 pub async fn new(publisher: Publisher, config: BatchConfig) -> Result<Self> {
66 let (sender, receiver) = mpsc::channel(config.buffer_size);
67
68 let batcher = Self {
69 config: config.clone(),
70 publisher: publisher.clone(),
71 sender,
72 metrics: None,
73 };
74
75 let batch_processor = BatchProcessor::new(publisher, receiver, config, None);
77
78 tokio::spawn(async move {
79 if let Err(e) = batch_processor.run().await {
80 error!("Batch processor error: {}", e);
81 }
82 });
83
84 Ok(batcher)
85 }
86
87 pub async fn with_metrics(
89 publisher: Publisher,
90 config: BatchConfig,
91 metrics: RustRabbitMetrics,
92 ) -> Result<Self> {
93 let (sender, receiver) = mpsc::channel(config.buffer_size);
94
95 let batcher = Self {
96 config: config.clone(),
97 publisher: publisher.clone(),
98 sender,
99 metrics: Some(metrics.clone()),
100 };
101
102 let batch_processor = BatchProcessor::new(publisher, receiver, config, Some(metrics));
104
105 tokio::spawn(async move {
106 if let Err(e) = batch_processor.run().await {
107 error!("Batch processor error: {}", e);
108 }
109 });
110
111 Ok(batcher)
112 }
113
114 pub async fn queue_message<T>(
116 &self,
117 queue_name: &str,
118 message: &T,
119 options: Option<PublishOptions>,
120 ) -> Result<()>
121 where
122 T: Serialize,
123 {
124 let payload = serde_json::to_vec(message)
125 .map_err(|e| RabbitError::SerializationError(e.to_string()))?;
126
127 let batch_message = BatchMessage {
128 queue_name: queue_name.to_string(),
129 payload,
130 options,
131 timestamp: Instant::now(),
132 };
133
134 self.sender
135 .send(batch_message)
136 .await
137 .map_err(|_| RabbitError::ChannelError("Batch queue is closed".to_string()))?;
138
139 if let Some(metrics) = &self.metrics {
141 metrics.record_message_published(queue_name, "", "batch_queued");
142 }
143
144 Ok(())
145 }
146
147 pub fn queue_len(&self) -> usize {
149 self.config
151 .buffer_size
152 .saturating_sub(self.sender.capacity())
153 }
154
155 pub fn is_nearly_full(&self) -> bool {
157 let remaining_capacity = self.sender.capacity();
158 let usage_percentage =
159 (self.config.buffer_size - remaining_capacity) * 100 / self.config.buffer_size;
160 usage_percentage > 80
161 }
162}
163
164struct BatchProcessor {
166 publisher: Publisher,
167 receiver: mpsc::Receiver<BatchMessage>,
168 config: BatchConfig,
169 metrics: Option<RustRabbitMetrics>,
170 current_batch: Vec<BatchMessage>,
171 last_flush: TokioInstant,
172}
173
174impl BatchProcessor {
175 fn new(
176 publisher: Publisher,
177 receiver: mpsc::Receiver<BatchMessage>,
178 config: BatchConfig,
179 metrics: Option<RustRabbitMetrics>,
180 ) -> Self {
181 Self {
182 publisher,
183 receiver,
184 config: config.clone(),
185 metrics,
186 current_batch: Vec::with_capacity(config.max_batch_size),
187 last_flush: TokioInstant::now(),
188 }
189 }
190
191 async fn run(mut self) -> Result<()> {
192 let mut flush_interval = interval(self.config.max_batch_timeout);
193
194 info!("Batch processor started with config: {:?}", self.config);
195
196 loop {
197 tokio::select! {
198 message = self.receiver.recv() => {
200 match message {
201 Some(msg) => {
202 self.add_to_batch(msg).await?;
203 }
204 None => {
205 info!("Batch processor channel closed, flushing remaining messages");
207 self.flush_batch().await?;
208 break;
209 }
210 }
211 }
212
213 _ = flush_interval.tick() => {
215 if self.should_flush() {
216 self.flush_batch().await?;
217 }
218 }
219 }
220 }
221
222 Ok(())
223 }
224
225 async fn add_to_batch(&mut self, message: BatchMessage) -> Result<()> {
226 self.current_batch.push(message);
227
228 if self.should_flush() {
230 self.flush_batch().await?;
231 }
232
233 Ok(())
234 }
235
236 fn should_flush(&self) -> bool {
237 if self.current_batch.is_empty() {
238 return false;
239 }
240
241 if self.current_batch.len() >= self.config.max_batch_size {
243 return true;
244 }
245
246 let oldest_message_time = self
248 .current_batch
249 .first()
250 .map(|msg| msg.timestamp)
251 .unwrap_or_else(Instant::now);
252
253 let elapsed = oldest_message_time.elapsed();
254 elapsed >= self.config.max_batch_timeout
255 }
256
257 async fn flush_batch(&mut self) -> Result<()> {
258 if self.current_batch.is_empty() {
259 return Ok(());
260 }
261
262 let batch_size = self.current_batch.len();
263 let timer = MetricsTimer::new();
264
265 debug!("Flushing batch of {} messages", batch_size);
266
267 let mut queue_batches: std::collections::HashMap<String, Vec<&BatchMessage>> =
269 std::collections::HashMap::new();
270
271 for message in &self.current_batch {
272 queue_batches
273 .entry(message.queue_name.clone())
274 .or_default()
275 .push(message);
276 }
277
278 let mut total_published = 0;
280 let mut total_errors = 0;
281
282 for (queue_name, messages) in &queue_batches {
283 match self.publish_queue_batch(queue_name, messages.clone()).await {
284 Ok(count) => total_published += count,
285 Err(e) => {
286 error!("Failed to publish batch for queue {}: {}", queue_name, e);
287 total_errors += messages.len();
288 }
289 }
290 }
291
292 if let Some(metrics) = &self.metrics {
294 let duration = timer.elapsed();
295
296 for (queue_name, messages) in &queue_batches {
298 for _ in messages {
299 metrics.record_message_published(queue_name, "", "batch_sent");
300 }
301 }
302
303 metrics.record_publish_duration("", "batch", duration);
305 }
306
307 info!(
308 "Batch flush completed: {} published, {} errors, took {:?}",
309 total_published,
310 total_errors,
311 timer.elapsed()
312 );
313
314 self.current_batch.clear();
316 self.last_flush = TokioInstant::now();
317
318 Ok(())
319 }
320
321 async fn publish_queue_batch(
322 &self,
323 queue_name: &str,
324 messages: Vec<&BatchMessage>,
325 ) -> Result<usize> {
326 if messages.is_empty() {
327 return Ok(0);
328 }
329
330 let mut published_count = 0;
333
334 for message in messages {
335 let payload_str = String::from_utf8(message.payload.clone())
339 .map_err(|e| RabbitError::SerializationError(e.to_string()))?;
340
341 let json_value: serde_json::Value = serde_json::from_str(&payload_str)
342 .map_err(|e| RabbitError::SerializationError(e.to_string()))?;
343
344 match self
345 .publisher
346 .publish_to_queue(queue_name, &json_value, message.options.clone())
347 .await
348 {
349 Ok(_) => published_count += 1,
350 Err(e) => {
351 error!("Failed to publish message in batch: {}", e);
352 return Err(e);
353 }
354 }
355 }
356
357 Ok(published_count)
358 }
359}
360
361#[derive(Debug)]
363pub struct BatchConfigBuilder {
364 config: BatchConfig,
365}
366
367impl BatchConfigBuilder {
368 pub fn new() -> Self {
370 Self {
371 config: BatchConfig::default(),
372 }
373 }
374
375 pub fn max_batch_size(mut self, size: usize) -> Self {
377 self.config.max_batch_size = size;
378 self
379 }
380
381 pub fn max_batch_timeout(mut self, timeout: Duration) -> Self {
383 self.config.max_batch_timeout = timeout;
384 self
385 }
386
387 pub fn buffer_size(mut self, size: usize) -> Self {
389 self.config.buffer_size = size;
390 self
391 }
392
393 pub fn flush_on_full(mut self, flush: bool) -> Self {
395 self.config.flush_on_full = flush;
396 self
397 }
398
399 pub fn build(self) -> BatchConfig {
401 self.config
402 }
403}
404
405impl Default for BatchConfigBuilder {
406 fn default() -> Self {
407 Self::new()
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_batch_config_builder() {
417 let config = BatchConfigBuilder::new()
418 .max_batch_size(50)
419 .max_batch_timeout(Duration::from_millis(200))
420 .buffer_size(500)
421 .flush_on_full(false)
422 .build();
423
424 assert_eq!(config.max_batch_size, 50);
425 assert_eq!(config.max_batch_timeout, Duration::from_millis(200));
426 assert_eq!(config.buffer_size, 500);
427 assert!(!config.flush_on_full);
428 }
429
430 #[test]
431 fn test_batch_config_default() {
432 let config = BatchConfig::default();
433
434 assert_eq!(config.max_batch_size, 100);
435 assert_eq!(config.max_batch_timeout, Duration::from_millis(100));
436 assert_eq!(config.buffer_size, 1000);
437 assert!(config.flush_on_full);
438 }
439}