Skip to main content

ai_lib_rust/batch/
collector.rs

1//! Batch collector.
2
3use std::collections::VecDeque;
4use std::sync::{Arc, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone)]
8pub struct BatchConfig { pub max_batch_size: usize, pub max_wait_time: Duration, pub auto_flush: bool }
9impl Default for BatchConfig { fn default() -> Self { Self { max_batch_size: 10, max_wait_time: Duration::from_secs(5), auto_flush: true } } }
10impl BatchConfig {
11    pub fn new() -> Self { Self::default() }
12    pub fn with_max_batch_size(mut self, s: usize) -> Self { self.max_batch_size = s; self }
13    pub fn with_auto_flush(mut self, a: bool) -> Self { self.auto_flush = a; self }
14}
15
16#[derive(Debug, Clone)]
17pub struct BatchItem<T> { pub data: T, pub added_at: Instant, pub request_id: Option<String>, pub priority: i32 }
18impl<T> BatchItem<T> {
19    pub fn new(data: T) -> Self { Self { data, added_at: Instant::now(), request_id: None, priority: 0 } }
20    pub fn with_request_id(mut self, id: impl Into<String>) -> Self { self.request_id = Some(id.into()); self }
21    pub fn with_priority(mut self, p: i32) -> Self { self.priority = p; self }
22}
23
24pub struct BatchCollector<T> { config: BatchConfig, items: Arc<RwLock<VecDeque<BatchItem<T>>>>, batch_start: Arc<RwLock<Option<Instant>>> }
25
26impl<T: Clone> BatchCollector<T> {
27    pub fn new(config: BatchConfig) -> Self { Self { config, items: Arc::new(RwLock::new(VecDeque::new())), batch_start: Arc::new(RwLock::new(None)) } }
28
29    pub fn add(&self, item: BatchItem<T>) -> BatchAddResult {
30        let mut items = self.items.write().unwrap();
31        let mut start = self.batch_start.write().unwrap();
32        if items.is_empty() { *start = Some(Instant::now()); }
33        items.push_back(item);
34        let count = items.len();
35        if self.config.auto_flush && count >= self.config.max_batch_size { BatchAddResult::ShouldFlush { count } } else { BatchAddResult::Added { count } }
36    }
37
38    pub fn add_data(&self, data: T) -> BatchAddResult { self.add(BatchItem::new(data)) }
39
40    pub fn should_flush(&self) -> bool {
41        let items = self.items.read().unwrap();
42        let start = self.batch_start.read().unwrap();
43        if items.is_empty() { return false; }
44        if items.len() >= self.config.max_batch_size { return true; }
45        if let Some(s) = *start { if s.elapsed() >= self.config.max_wait_time { return true; } }
46        false
47    }
48
49    pub fn drain(&self) -> Vec<BatchItem<T>> {
50        let mut items = self.items.write().unwrap();
51        let mut start = self.batch_start.write().unwrap();
52        *start = None;
53        items.drain(..).collect()
54    }
55
56    pub fn len(&self) -> usize { self.items.read().unwrap().len() }
57    pub fn is_empty(&self) -> bool { self.len() == 0 }
58    pub fn clear(&self) { self.items.write().unwrap().clear(); *self.batch_start.write().unwrap() = None; }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum BatchAddResult { Added { count: usize }, ShouldFlush { count: usize } }
63impl BatchAddResult { pub fn should_flush(&self) -> bool { matches!(self, BatchAddResult::ShouldFlush { .. }) } pub fn count(&self) -> usize { match self { BatchAddResult::Added { count } | BatchAddResult::ShouldFlush { count } => *count } } }
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    #[test]
70    fn test_batch_config_defaults() {
71        let config = BatchConfig::default();
72        assert_eq!(config.max_batch_size, 10);
73        assert_eq!(config.max_wait_time, Duration::from_secs(5));
74        assert!(config.auto_flush);
75    }
76
77    #[test]
78    fn test_batch_config_builder() {
79        let config = BatchConfig::new()
80            .with_max_batch_size(5)
81            .with_auto_flush(false);
82        assert_eq!(config.max_batch_size, 5);
83        assert!(!config.auto_flush);
84    }
85
86    #[test]
87    fn test_batch_item_creation() {
88        let item = BatchItem::new("test data")
89            .with_request_id("req-001")
90            .with_priority(10);
91        assert_eq!(item.data, "test data");
92        assert_eq!(item.request_id, Some("req-001".to_string()));
93        assert_eq!(item.priority, 10);
94    }
95
96    #[test]
97    fn test_batch_collector_empty() {
98        let config = BatchConfig::new().with_max_batch_size(5);
99        let collector: BatchCollector<String> = BatchCollector::new(config);
100        assert!(collector.is_empty());
101        assert_eq!(collector.len(), 0);
102        assert!(!collector.should_flush());
103    }
104
105    #[test]
106    fn test_batch_collector_add_data() {
107        let config = BatchConfig::new().with_max_batch_size(5);
108        let collector: BatchCollector<String> = BatchCollector::new(config);
109        
110        let result = collector.add_data("item1".to_string());
111        assert_eq!(result, BatchAddResult::Added { count: 1 });
112        assert_eq!(collector.len(), 1);
113        assert!(!collector.is_empty());
114    }
115
116    #[test]
117    fn test_batch_collector_add_item() {
118        let config = BatchConfig::new().with_max_batch_size(5);
119        let collector: BatchCollector<String> = BatchCollector::new(config);
120        
121        let item = BatchItem::new("item1".to_string()).with_priority(5);
122        let result = collector.add(item);
123        assert_eq!(result.count(), 1);
124    }
125
126    #[test]
127    fn test_batch_collector_auto_flush() {
128        let config = BatchConfig::new().with_max_batch_size(3).with_auto_flush(true);
129        let collector: BatchCollector<i32> = BatchCollector::new(config);
130        
131        // Add items below threshold
132        assert!(!collector.add_data(1).should_flush());
133        assert!(!collector.add_data(2).should_flush());
134        
135        // Third item should trigger flush
136        let result = collector.add_data(3);
137        assert!(result.should_flush());
138        assert_eq!(result.count(), 3);
139    }
140
141    #[test]
142    fn test_batch_collector_no_auto_flush() {
143        let config = BatchConfig::new().with_max_batch_size(3).with_auto_flush(false);
144        let collector: BatchCollector<i32> = BatchCollector::new(config);
145        
146        collector.add_data(1);
147        collector.add_data(2);
148        let result = collector.add_data(3);
149        
150        // Should not report ShouldFlush when auto_flush is disabled
151        assert!(!result.should_flush());
152        // But should_flush() method checks size
153        assert!(collector.should_flush());
154    }
155
156    #[test]
157    fn test_batch_collector_drain() {
158        let config = BatchConfig::new().with_max_batch_size(10);
159        let collector: BatchCollector<String> = BatchCollector::new(config);
160        
161        collector.add_data("a".to_string());
162        collector.add_data("b".to_string());
163        collector.add_data("c".to_string());
164        
165        let items = collector.drain();
166        assert_eq!(items.len(), 3);
167        assert_eq!(items[0].data, "a");
168        assert_eq!(items[1].data, "b");
169        assert_eq!(items[2].data, "c");
170        
171        // Collector should be empty after drain
172        assert!(collector.is_empty());
173    }
174
175    #[test]
176    fn test_batch_collector_clear() {
177        let config = BatchConfig::new().with_max_batch_size(10);
178        let collector: BatchCollector<i32> = BatchCollector::new(config);
179        
180        collector.add_data(1);
181        collector.add_data(2);
182        assert_eq!(collector.len(), 2);
183        
184        collector.clear();
185        assert!(collector.is_empty());
186    }
187
188    #[test]
189    fn test_batch_add_result_methods() {
190        let added = BatchAddResult::Added { count: 5 };
191        let should_flush = BatchAddResult::ShouldFlush { count: 10 };
192        
193        assert!(!added.should_flush());
194        assert_eq!(added.count(), 5);
195        
196        assert!(should_flush.should_flush());
197        assert_eq!(should_flush.count(), 10);
198    }
199
200    #[test]
201    fn test_batch_collector_thread_safe() {
202        use std::sync::Arc;
203        use std::thread;
204        
205        let config = BatchConfig::new().with_max_batch_size(100);
206        let collector: Arc<BatchCollector<i32>> = Arc::new(BatchCollector::new(config));
207        
208        let mut handles = vec![];
209        for i in 0..10 {
210            let c = Arc::clone(&collector);
211            handles.push(thread::spawn(move || {
212                for j in 0..10 {
213                    c.add_data(i * 10 + j);
214                }
215            }));
216        }
217        
218        for h in handles {
219            h.join().unwrap();
220        }
221        
222        assert_eq!(collector.len(), 100);
223    }
224}