celers_protocol/
batch.rs

1//! Batch message processing utilities
2//!
3//! This module provides utilities for efficient batch processing of messages.
4
5use crate::{Message, ValidationError};
6use std::collections::HashMap;
7
8/// Batch of messages for efficient processing
9#[derive(Debug, Clone)]
10pub struct MessageBatch {
11    messages: Vec<Message>,
12    max_size: usize,
13}
14
15impl MessageBatch {
16    /// Create a new message batch with default max size (100)
17    pub fn new() -> Self {
18        Self {
19            messages: Vec::new(),
20            max_size: 100,
21        }
22    }
23
24    /// Create a new message batch with specified max size
25    pub fn with_capacity(max_size: usize) -> Self {
26        Self {
27            messages: Vec::with_capacity(max_size),
28            max_size,
29        }
30    }
31
32    /// Add a message to the batch
33    ///
34    /// Returns `true` if the message was added, `false` if the batch is full
35    pub fn push(&mut self, message: Message) -> bool {
36        if self.messages.len() < self.max_size {
37            self.messages.push(message);
38            true
39        } else {
40            false
41        }
42    }
43
44    /// Get the number of messages in the batch
45    #[inline]
46    pub fn len(&self) -> usize {
47        self.messages.len()
48    }
49
50    /// Check if the batch is empty
51    #[inline]
52    pub fn is_empty(&self) -> bool {
53        self.messages.is_empty()
54    }
55
56    /// Check if the batch is full
57    #[inline]
58    pub fn is_full(&self) -> bool {
59        self.messages.len() >= self.max_size
60    }
61
62    /// Get the messages in the batch
63    pub fn messages(&self) -> &[Message] {
64        &self.messages
65    }
66
67    /// Take all messages from the batch, leaving it empty
68    pub fn drain(&mut self) -> Vec<Message> {
69        std::mem::take(&mut self.messages)
70    }
71
72    /// Validate all messages in the batch
73    pub fn validate(&self) -> Result<(), ValidationError> {
74        for msg in &self.messages {
75            msg.validate()?;
76        }
77        Ok(())
78    }
79
80    /// Split the batch into smaller batches of the specified size
81    pub fn split(self, chunk_size: usize) -> Vec<MessageBatch> {
82        self.messages
83            .chunks(chunk_size)
84            .map(|chunk| {
85                let mut batch = MessageBatch::with_capacity(chunk_size);
86                for msg in chunk {
87                    batch.push(msg.clone());
88                }
89                batch
90            })
91            .collect()
92    }
93
94    /// Merge another batch into this one
95    ///
96    /// Returns the messages that didn't fit if the combined size exceeds max_size
97    pub fn merge(&mut self, other: MessageBatch) -> Vec<Message> {
98        let mut overflow = Vec::new();
99        for msg in other.messages {
100            if !self.push(msg.clone()) {
101                overflow.push(msg);
102            }
103        }
104        overflow
105    }
106}
107
108impl Default for MessageBatch {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl FromIterator<Message> for MessageBatch {
115    fn from_iter<T: IntoIterator<Item = Message>>(iter: T) -> Self {
116        let messages: Vec<_> = iter.into_iter().collect();
117        let max_size = messages.len().max(100);
118        Self { messages, max_size }
119    }
120}
121
122impl IntoIterator for MessageBatch {
123    type Item = Message;
124    type IntoIter = std::vec::IntoIter<Message>;
125
126    #[inline]
127    fn into_iter(self) -> Self::IntoIter {
128        self.messages.into_iter()
129    }
130}
131
132impl<'a> IntoIterator for &'a MessageBatch {
133    type Item = &'a Message;
134    type IntoIter = std::slice::Iter<'a, Message>;
135
136    #[inline]
137    fn into_iter(self) -> Self::IntoIter {
138        self.messages.iter()
139    }
140}
141
142impl<'a> IntoIterator for &'a mut MessageBatch {
143    type Item = &'a mut Message;
144    type IntoIter = std::slice::IterMut<'a, Message>;
145
146    #[inline]
147    fn into_iter(self) -> Self::IntoIter {
148        self.messages.iter_mut()
149    }
150}
151
152impl std::ops::Index<usize> for MessageBatch {
153    type Output = Message;
154
155    #[inline]
156    fn index(&self, index: usize) -> &Self::Output {
157        &self.messages[index]
158    }
159}
160
161impl std::ops::IndexMut<usize> for MessageBatch {
162    #[inline]
163    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
164        &mut self.messages[index]
165    }
166}
167
168impl Extend<Message> for MessageBatch {
169    fn extend<T: IntoIterator<Item = Message>>(&mut self, iter: T) {
170        for msg in iter {
171            if !self.push(msg) {
172                break; // Stop if batch is full
173            }
174        }
175    }
176}
177
178impl AsRef<[Message]> for MessageBatch {
179    #[inline]
180    fn as_ref(&self) -> &[Message] {
181        &self.messages
182    }
183}
184
185impl AsMut<[Message]> for MessageBatch {
186    #[inline]
187    fn as_mut(&mut self) -> &mut [Message] {
188        &mut self.messages
189    }
190}
191
192/// Batch processor for processing messages in groups
193pub struct BatchProcessor {
194    batch_size: usize,
195    timeout_ms: u64,
196}
197
198impl BatchProcessor {
199    /// Create a new batch processor with default settings
200    pub fn new() -> Self {
201        Self {
202            batch_size: 100,
203            timeout_ms: 1000,
204        }
205    }
206
207    /// Set the batch size
208    #[must_use]
209    pub fn with_batch_size(mut self, size: usize) -> Self {
210        self.batch_size = size;
211        self
212    }
213
214    /// Set the timeout in milliseconds
215    #[must_use]
216    pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
217        self.timeout_ms = timeout_ms;
218        self
219    }
220
221    /// Create batches from a vector of messages
222    pub fn create_batches(&self, messages: Vec<Message>) -> Vec<MessageBatch> {
223        messages
224            .chunks(self.batch_size)
225            .map(|chunk| {
226                let mut batch = MessageBatch::with_capacity(self.batch_size);
227                for msg in chunk {
228                    batch.push(msg.clone());
229                }
230                batch
231            })
232            .collect()
233    }
234
235    /// Process messages in batches with a callback function
236    pub fn process<F>(&self, messages: Vec<Message>, mut callback: F) -> Result<(), String>
237    where
238        F: FnMut(&[Message]) -> Result<(), String>,
239    {
240        for batch in self.create_batches(messages) {
241            callback(batch.messages())?;
242        }
243        Ok(())
244    }
245}
246
247impl Default for BatchProcessor {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253/// Statistics for batch processing
254#[derive(Debug, Clone, Default)]
255pub struct BatchStats {
256    /// Total number of messages processed
257    pub total_messages: usize,
258    /// Number of batches processed
259    pub total_batches: usize,
260    /// Number of successful messages
261    pub successful: usize,
262    /// Number of failed messages
263    pub failed: usize,
264}
265
266impl BatchStats {
267    /// Create new batch statistics
268    pub fn new() -> Self {
269        Self::default()
270    }
271
272    /// Record a batch result
273    pub fn record_batch(&mut self, batch_size: usize, successes: usize, failures: usize) {
274        self.total_batches += 1;
275        self.total_messages += batch_size;
276        self.successful += successes;
277        self.failed += failures;
278    }
279
280    /// Get the success rate as a percentage
281    pub fn success_rate(&self) -> f64 {
282        if self.total_messages == 0 {
283            0.0
284        } else {
285            (self.successful as f64 / self.total_messages as f64) * 100.0
286        }
287    }
288
289    /// Get the average batch size
290    pub fn average_batch_size(&self) -> f64 {
291        if self.total_batches == 0 {
292            0.0
293        } else {
294            self.total_messages as f64 / self.total_batches as f64
295        }
296    }
297}
298
299/// Group messages by a key function
300pub fn group_by<F, K>(messages: Vec<Message>, key_fn: F) -> HashMap<K, Vec<Message>>
301where
302    F: Fn(&Message) -> K,
303    K: Eq + std::hash::Hash,
304{
305    let mut groups = HashMap::new();
306    for msg in messages {
307        let key = key_fn(&msg);
308        groups.entry(key).or_insert_with(Vec::new).push(msg);
309    }
310    groups
311}
312
313/// Partition messages into two groups based on a predicate
314pub fn partition<F>(messages: Vec<Message>, predicate: F) -> (Vec<Message>, Vec<Message>)
315where
316    F: Fn(&Message) -> bool,
317{
318    let mut true_group = Vec::new();
319    let mut false_group = Vec::new();
320
321    for msg in messages {
322        if predicate(&msg) {
323            true_group.push(msg);
324        } else {
325            false_group.push(msg);
326        }
327    }
328
329    (true_group, false_group)
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate::builder::MessageBuilder;
336
337    fn create_test_message(task: &str) -> Message {
338        MessageBuilder::new(task).build().unwrap()
339    }
340
341    #[test]
342    fn test_message_batch_new() {
343        let batch = MessageBatch::new();
344        assert_eq!(batch.len(), 0);
345        assert!(batch.is_empty());
346        assert!(!batch.is_full());
347    }
348
349    #[test]
350    fn test_message_batch_push() {
351        let mut batch = MessageBatch::with_capacity(2);
352        assert!(batch.push(create_test_message("task1")));
353        assert!(batch.push(create_test_message("task2")));
354        assert!(!batch.push(create_test_message("task3"))); // Full
355
356        assert_eq!(batch.len(), 2);
357        assert!(batch.is_full());
358    }
359
360    #[test]
361    fn test_message_batch_drain() {
362        let mut batch = MessageBatch::new();
363        batch.push(create_test_message("task1"));
364        batch.push(create_test_message("task2"));
365
366        let messages = batch.drain();
367        assert_eq!(messages.len(), 2);
368        assert!(batch.is_empty());
369    }
370
371    #[test]
372    fn test_message_batch_validate() {
373        let mut batch = MessageBatch::new();
374        batch.push(create_test_message("task1"));
375        batch.push(create_test_message("task2"));
376
377        assert!(batch.validate().is_ok());
378    }
379
380    #[test]
381    fn test_message_batch_split() {
382        let mut batch = MessageBatch::new();
383        for i in 0..10 {
384            batch.push(create_test_message(&format!("task{}", i)));
385        }
386
387        let batches = batch.split(3);
388        assert_eq!(batches.len(), 4); // 10 messages / 3 = 4 batches (3, 3, 3, 1)
389        assert_eq!(batches[0].len(), 3);
390        assert_eq!(batches[1].len(), 3);
391        assert_eq!(batches[2].len(), 3);
392        assert_eq!(batches[3].len(), 1);
393    }
394
395    #[test]
396    fn test_message_batch_merge() {
397        let mut batch1 = MessageBatch::with_capacity(5);
398        batch1.push(create_test_message("task1"));
399        batch1.push(create_test_message("task2"));
400
401        let mut batch2 = MessageBatch::new();
402        batch2.push(create_test_message("task3"));
403        batch2.push(create_test_message("task4"));
404
405        let overflow = batch1.merge(batch2);
406        assert_eq!(batch1.len(), 4);
407        assert!(overflow.is_empty());
408    }
409
410    #[test]
411    fn test_batch_processor_create_batches() {
412        let processor = BatchProcessor::new().with_batch_size(3);
413        let messages = vec![
414            create_test_message("task1"),
415            create_test_message("task2"),
416            create_test_message("task3"),
417            create_test_message("task4"),
418            create_test_message("task5"),
419        ];
420
421        let batches = processor.create_batches(messages);
422        assert_eq!(batches.len(), 2);
423        assert_eq!(batches[0].len(), 3);
424        assert_eq!(batches[1].len(), 2);
425    }
426
427    #[test]
428    fn test_batch_processor_process() {
429        let processor = BatchProcessor::new().with_batch_size(2);
430        let messages = vec![
431            create_test_message("task1"),
432            create_test_message("task2"),
433            create_test_message("task3"),
434        ];
435
436        let mut count = 0;
437        let result = processor.process(messages, |batch| {
438            count += batch.len();
439            Ok(())
440        });
441
442        assert!(result.is_ok());
443        assert_eq!(count, 3);
444    }
445
446    #[test]
447    fn test_batch_stats() {
448        let mut stats = BatchStats::new();
449        stats.record_batch(10, 8, 2);
450        stats.record_batch(10, 9, 1);
451
452        assert_eq!(stats.total_batches, 2);
453        assert_eq!(stats.total_messages, 20);
454        assert_eq!(stats.successful, 17);
455        assert_eq!(stats.failed, 3);
456        assert_eq!(stats.success_rate(), 85.0);
457        assert_eq!(stats.average_batch_size(), 10.0);
458    }
459
460    #[test]
461    fn test_group_by() {
462        let messages = vec![
463            create_test_message("tasks.add"),
464            create_test_message("tasks.subtract"),
465            create_test_message("tasks.add"),
466            create_test_message("email.send"),
467        ];
468
469        let groups = group_by(messages, |msg| msg.headers.task.clone());
470        assert_eq!(groups.len(), 3);
471        assert_eq!(groups.get("tasks.add").unwrap().len(), 2);
472        assert_eq!(groups.get("tasks.subtract").unwrap().len(), 1);
473        assert_eq!(groups.get("email.send").unwrap().len(), 1);
474    }
475
476    #[test]
477    fn test_partition() {
478        let messages = vec![
479            create_test_message("tasks.add"),
480            create_test_message("email.send"),
481            create_test_message("tasks.subtract"),
482        ];
483
484        let (task_messages, other_messages) =
485            partition(messages, |msg| msg.headers.task.starts_with("tasks."));
486
487        assert_eq!(task_messages.len(), 2);
488        assert_eq!(other_messages.len(), 1);
489    }
490
491    #[test]
492    fn test_message_batch_into_iterator() {
493        let mut batch = MessageBatch::new();
494        batch.push(create_test_message("task1"));
495        batch.push(create_test_message("task2"));
496        batch.push(create_test_message("task3"));
497
498        let mut count = 0;
499        for msg in batch {
500            assert!(!msg.headers.task.is_empty());
501            count += 1;
502        }
503        assert_eq!(count, 3);
504    }
505
506    #[test]
507    fn test_message_batch_into_iterator_ref() {
508        let mut batch = MessageBatch::new();
509        batch.push(create_test_message("task1"));
510        batch.push(create_test_message("task2"));
511
512        let mut count = 0;
513        for msg in &batch {
514            assert!(!msg.headers.task.is_empty());
515            count += 1;
516        }
517        assert_eq!(count, 2);
518        assert_eq!(batch.len(), 2); // Batch still exists
519    }
520
521    #[test]
522    fn test_message_batch_into_iterator_mut() {
523        let mut batch = MessageBatch::new();
524        batch.push(create_test_message("task1"));
525        batch.push(create_test_message("task2"));
526
527        for msg in &mut batch {
528            msg.headers.retries = Some(1);
529        }
530
531        for msg in &batch {
532            assert_eq!(msg.headers.retries, Some(1));
533        }
534    }
535
536    #[test]
537    fn test_message_batch_index() {
538        let mut batch = MessageBatch::new();
539        batch.push(create_test_message("task1"));
540        batch.push(create_test_message("task2"));
541        batch.push(create_test_message("task3"));
542
543        assert_eq!(batch[0].headers.task, "task1");
544        assert_eq!(batch[1].headers.task, "task2");
545        assert_eq!(batch[2].headers.task, "task3");
546    }
547
548    #[test]
549    fn test_message_batch_index_mut() {
550        let mut batch = MessageBatch::new();
551        batch.push(create_test_message("task1"));
552        batch.push(create_test_message("task2"));
553
554        batch[0].headers.retries = Some(5);
555        batch[1].headers.retries = Some(10);
556
557        assert_eq!(batch[0].headers.retries, Some(5));
558        assert_eq!(batch[1].headers.retries, Some(10));
559    }
560
561    #[test]
562    fn test_message_batch_extend() {
563        let mut batch = MessageBatch::with_capacity(5);
564        batch.push(create_test_message("task1"));
565
566        let new_messages = vec![create_test_message("task2"), create_test_message("task3")];
567
568        batch.extend(new_messages);
569        assert_eq!(batch.len(), 3);
570    }
571
572    #[test]
573    fn test_message_batch_extend_with_capacity_limit() {
574        let mut batch = MessageBatch::with_capacity(3);
575        batch.push(create_test_message("task1"));
576
577        let new_messages = vec![
578            create_test_message("task2"),
579            create_test_message("task3"),
580            create_test_message("task4"), // This should not be added (over capacity)
581        ];
582
583        batch.extend(new_messages);
584        assert_eq!(batch.len(), 3);
585        assert!(batch.is_full());
586    }
587
588    #[test]
589    fn test_message_batch_as_ref() {
590        let mut batch = MessageBatch::new();
591        batch.push(create_test_message("task1"));
592        batch.push(create_test_message("task2"));
593
594        let slice: &[Message] = batch.as_ref();
595        assert_eq!(slice.len(), 2);
596        assert_eq!(slice[0].headers.task, "task1");
597    }
598
599    #[test]
600    fn test_message_batch_as_mut() {
601        let mut batch = MessageBatch::new();
602        batch.push(create_test_message("task1"));
603        batch.push(create_test_message("task2"));
604
605        let slice: &mut [Message] = batch.as_mut();
606        slice[0].headers.retries = Some(99);
607
608        assert_eq!(batch[0].headers.retries, Some(99));
609    }
610
611    #[test]
612    fn test_message_batch_iterator_chain() {
613        let messages = vec![
614            create_test_message("task1"),
615            create_test_message("task2"),
616            create_test_message("task3"),
617            create_test_message("task4"),
618        ];
619
620        let batch: MessageBatch = messages.into_iter().collect();
621        assert_eq!(batch.len(), 4);
622
623        let task_names: Vec<String> = batch
624            .into_iter()
625            .map(|msg| msg.headers.task.clone())
626            .collect();
627
628        assert_eq!(task_names, vec!["task1", "task2", "task3", "task4"]);
629    }
630}