1use std::collections::HashMap;
4use std::time::Instant;
5use serde_json::Value;
6use crate::error::Result;
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum BatchStrategy {
11 FixedSize(usize),
13 MemoryBased(usize), Adaptive { initial_size: usize, max_size: usize },
17}
18
19impl Default for BatchStrategy {
20 fn default() -> Self {
21 BatchStrategy::FixedSize(1000)
22 }
23}
24
25pub struct BatchProcessor {
27 strategy: BatchStrategy,
28 current_batch_size: usize,
29 performance_history: Vec<BatchPerformance>,
30}
31
32#[derive(Debug, Clone)]
34struct BatchPerformance {
35 batch_size: usize,
36 processing_time_ms: u64,
37 rows_per_second: f64,
38}
39
40impl BatchProcessor {
41 pub fn new(strategy: BatchStrategy) -> Self {
43 let initial_size = match &strategy {
44 BatchStrategy::FixedSize(size) => *size,
45 BatchStrategy::MemoryBased(_) => 1000,
46 BatchStrategy::Adaptive { initial_size, .. } => *initial_size,
47 };
48
49 Self {
50 strategy,
51 current_batch_size: initial_size,
52 performance_history: Vec::new(),
53 }
54 }
55
56 pub fn process_batches<F>(&mut self, data: Vec<HashMap<String, Value>>, mut processor: F) -> Result<()>
58 where
59 F: FnMut(&[HashMap<String, Value>]) -> Result<()>,
60 {
61 if data.is_empty() {
62 return Ok(());
63 }
64
65 let batches = self.create_batches(&data)?;
66
67 for batch in batches {
68 let start_time = Instant::now();
69
70 processor(&batch)?;
71
72 let processing_time = start_time.elapsed().as_millis() as u64;
73 self.record_performance(batch.len(), processing_time);
74
75 if matches!(self.strategy, BatchStrategy::Adaptive { .. }) {
77 self.adjust_batch_size();
78 }
79 }
80
81 Ok(())
82 }
83
84 fn create_batches(&self, data: &[HashMap<String, Value>]) -> Result<Vec<Vec<HashMap<String, Value>>>> {
86 let batch_size = match &self.strategy {
87 BatchStrategy::FixedSize(size) => *size,
88 BatchStrategy::MemoryBased(max_bytes) => {
89 self.calculate_memory_based_batch_size(data, *max_bytes)?
90 },
91 BatchStrategy::Adaptive { .. } => self.current_batch_size,
92 };
93
94 let mut batches = Vec::new();
95 let mut current_batch = Vec::new();
96
97 for row in data {
98 current_batch.push(row.clone());
99
100 if current_batch.len() >= batch_size {
101 batches.push(current_batch);
102 current_batch = Vec::new();
103 }
104 }
105
106 if !current_batch.is_empty() {
108 batches.push(current_batch);
109 }
110
111 Ok(batches)
112 }
113
114 fn calculate_memory_based_batch_size(&self, data: &[HashMap<String, Value>], max_bytes: usize) -> Result<usize> {
116 if data.is_empty() {
117 return Ok(1000);
118 }
119
120 let sample_row = &data[0];
122 let estimated_row_size = self.estimate_row_memory_size(sample_row);
123
124 if estimated_row_size == 0 {
125 return Ok(1000);
126 }
127
128 let batch_size = max_bytes / estimated_row_size;
129 Ok(batch_size.max(1).min(10000)) }
131
132 fn estimate_row_memory_size(&self, row: &HashMap<String, Value>) -> usize {
134 let mut size = 0;
135
136 for (key, value) in row {
137 size += key.len(); size += match value {
139 Value::String(s) => s.len(),
140 Value::Number(_) => 8, Value::Bool(_) => 1,
142 Value::Array(arr) => arr.len() * 50, Value::Object(obj) => obj.len() * 100, Value::Null => 0,
145 };
146 }
147
148 size
149 }
150
151 fn record_performance(&mut self, batch_size: usize, processing_time_ms: u64) {
153 let rows_per_second = if processing_time_ms > 0 {
154 (batch_size as f64 * 1000.0) / processing_time_ms as f64
155 } else {
156 0.0
157 };
158
159 let performance = BatchPerformance {
160 batch_size,
161 processing_time_ms,
162 rows_per_second,
163 };
164
165 self.performance_history.push(performance);
166
167 if self.performance_history.len() > 100 {
169 self.performance_history.remove(0);
170 }
171 }
172
173 fn adjust_batch_size(&mut self) {
175 if let BatchStrategy::Adaptive { initial_size: _, max_size } = &self.strategy {
176 if self.performance_history.len() < 3 {
177 return;
178 }
179
180 let recent_performances = &self.performance_history[self.performance_history.len() - 3..];
182 let _avg_performance = recent_performances.iter()
183 .map(|p| p.rows_per_second)
184 .sum::<f64>() / recent_performances.len() as f64;
185
186 if recent_performances.windows(2).all(|w| w[1].rows_per_second >= w[0].rows_per_second) {
188 self.current_batch_size = (self.current_batch_size * 110 / 100).min(*max_size);
189 }
190 else if recent_performances.windows(2).all(|w| w[1].rows_per_second < w[0].rows_per_second) {
192 self.current_batch_size = (self.current_batch_size * 90 / 100).max(100);
193 }
194 }
195 }
196
197 pub fn current_batch_size(&self) -> usize {
199 self.current_batch_size
200 }
201
202 pub fn get_performance_stats(&self) -> BatchPerformanceStats {
204 if self.performance_history.is_empty() {
205 return BatchPerformanceStats::default();
206 }
207
208 let total_rows: usize = self.performance_history.iter().map(|p| p.batch_size).sum();
209 let total_time: u64 = self.performance_history.iter().map(|p| p.processing_time_ms).sum();
210 let avg_rows_per_second = self.performance_history.iter()
211 .map(|p| p.rows_per_second)
212 .sum::<f64>() / self.performance_history.len() as f64;
213
214 let max_rows_per_second = self.performance_history.iter()
215 .map(|p| p.rows_per_second)
216 .fold(0.0, f64::max);
217
218 let min_rows_per_second = self.performance_history.iter()
219 .map(|p| p.rows_per_second)
220 .fold(f64::INFINITY, f64::min);
221
222 BatchPerformanceStats {
223 total_batches: self.performance_history.len(),
224 total_rows,
225 total_processing_time_ms: total_time,
226 avg_rows_per_second,
227 max_rows_per_second,
228 min_rows_per_second,
229 current_batch_size: self.current_batch_size,
230 }
231 }
232}
233
234#[derive(Debug, Clone, Default)]
236pub struct BatchPerformanceStats {
237 pub total_batches: usize,
238 pub total_rows: usize,
239 pub total_processing_time_ms: u64,
240 pub avg_rows_per_second: f64,
241 pub max_rows_per_second: f64,
242 pub min_rows_per_second: f64,
243 pub current_batch_size: usize,
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 fn create_test_data(count: usize) -> Vec<HashMap<String, Value>> {
251 (0..count).map(|i| {
252 let mut row = HashMap::new();
253 row.insert("id".to_string(), Value::Number(serde_json::Number::from(i)));
254 row.insert("name".to_string(), Value::String(format!("User{}", i)));
255 row
256 }).collect()
257 }
258
259 #[test]
260 fn test_fixed_size_batch_processor() {
261 let mut processor = BatchProcessor::new(BatchStrategy::FixedSize(3));
262 let data = create_test_data(10);
263
264 let mut processed_batches = Vec::new();
265 let result = processor.process_batches(data, |batch| {
266 processed_batches.push(batch.len());
267 Ok(())
268 });
269
270 assert!(result.is_ok());
271 assert_eq!(processed_batches, vec![3, 3, 3, 1]); }
273
274 #[test]
275 fn test_memory_based_batch_processor() {
276 let mut processor = BatchProcessor::new(BatchStrategy::MemoryBased(1024));
277 let data = create_test_data(5);
278
279 let mut batch_count = 0;
280 let result = processor.process_batches(data, |_batch| {
281 batch_count += 1;
282 Ok(())
283 });
284
285 assert!(result.is_ok());
286 assert!(batch_count > 0);
287 }
288
289 #[test]
290 fn test_adaptive_batch_processor() {
291 let mut processor = BatchProcessor::new(BatchStrategy::Adaptive {
292 initial_size: 2,
293 max_size: 10,
294 });
295 let data = create_test_data(20);
296
297 let initial_size = processor.current_batch_size();
298
299 let result = processor.process_batches(data, |_batch| {
300 std::thread::sleep(std::time::Duration::from_millis(1));
302 Ok(())
303 });
304
305 assert!(result.is_ok());
306 assert_eq!(initial_size, 2);
307
308 let stats = processor.get_performance_stats();
309 assert!(stats.total_batches > 0);
310 assert!(stats.total_rows > 0);
311 }
312
313 #[test]
314 fn test_estimate_row_memory_size() {
315 let processor = BatchProcessor::new(BatchStrategy::default());
316
317 let mut row = HashMap::new();
318 row.insert("id".to_string(), Value::Number(serde_json::Number::from(1)));
319 row.insert("name".to_string(), Value::String("test".to_string()));
320
321 let size = processor.estimate_row_memory_size(&row);
322 assert!(size > 0);
323 }
324
325 #[test]
326 fn test_performance_stats() {
327 let mut processor = BatchProcessor::new(BatchStrategy::FixedSize(2));
328 let data = create_test_data(5);
329
330 let result = processor.process_batches(data, |_batch| {
331 std::thread::sleep(std::time::Duration::from_millis(10));
332 Ok(())
333 });
334
335 assert!(result.is_ok());
336
337 let stats = processor.get_performance_stats();
338 assert!(stats.total_batches > 0);
339 assert_eq!(stats.total_rows, 5);
340 assert!(stats.avg_rows_per_second > 0.0);
341 }
342}