dataforge/filling/
batch.rs

1//! 批量处理优化模块
2
3use std::collections::HashMap;
4use std::time::Instant;
5use serde_json::Value;
6use crate::error::Result;
7
8/// 批量处理策略
9#[derive(Debug, Clone, PartialEq)]
10pub enum BatchStrategy {
11    /// 固定大小批次
12    FixedSize(usize),
13    /// 基于内存大小的批次
14    MemoryBased(usize), // 字节数
15    /// 自适应批次大小
16    Adaptive { initial_size: usize, max_size: usize },
17}
18
19impl Default for BatchStrategy {
20    fn default() -> Self {
21        BatchStrategy::FixedSize(1000)
22    }
23}
24
25/// 批量处理器
26pub struct BatchProcessor {
27    strategy: BatchStrategy,
28    current_batch_size: usize,
29    performance_history: Vec<BatchPerformance>,
30}
31
32/// 批次性能记录
33#[derive(Debug, Clone)]
34struct BatchPerformance {
35    batch_size: usize,
36    processing_time_ms: u64,
37    rows_per_second: f64,
38}
39
40impl BatchProcessor {
41    /// 创建新的批量处理器
42    pub fn new(strategy: BatchStrategy) -> Self {
43        let initial_size = match &strategy {
44            BatchStrategy::FixedSize(size) => *size,
45            BatchStrategy::MemoryBased(_) => 1000,
46            BatchStrategy::Adaptive { initial_size, .. } => *initial_size,
47        };
48
49        Self {
50            strategy,
51            current_batch_size: initial_size,
52            performance_history: Vec::new(),
53        }
54    }
55
56    /// 处理数据批次
57    pub fn process_batches<F>(&mut self, data: Vec<HashMap<String, Value>>, mut processor: F) -> Result<()>
58    where
59        F: FnMut(&[HashMap<String, Value>]) -> Result<()>,
60    {
61        if data.is_empty() {
62            return Ok(());
63        }
64
65        let batches = self.create_batches(&data)?;
66        
67        for batch in batches {
68            let start_time = Instant::now();
69            
70            processor(&batch)?;
71            
72            let processing_time = start_time.elapsed().as_millis() as u64;
73            self.record_performance(batch.len(), processing_time);
74            
75            // 如果使用自适应策略,调整批次大小
76            if matches!(self.strategy, BatchStrategy::Adaptive { .. }) {
77                self.adjust_batch_size();
78            }
79        }
80
81        Ok(())
82    }
83
84    /// 创建数据批次
85    fn create_batches(&self, data: &[HashMap<String, Value>]) -> Result<Vec<Vec<HashMap<String, Value>>>> {
86        let batch_size = match &self.strategy {
87            BatchStrategy::FixedSize(size) => *size,
88            BatchStrategy::MemoryBased(max_bytes) => {
89                self.calculate_memory_based_batch_size(data, *max_bytes)?
90            },
91            BatchStrategy::Adaptive { .. } => self.current_batch_size,
92        };
93
94        let mut batches = Vec::new();
95        let mut current_batch = Vec::new();
96
97        for row in data {
98            current_batch.push(row.clone());
99            
100            if current_batch.len() >= batch_size {
101                batches.push(current_batch);
102                current_batch = Vec::new();
103            }
104        }
105
106        // 添加最后一个批次(如果有剩余数据)
107        if !current_batch.is_empty() {
108            batches.push(current_batch);
109        }
110
111        Ok(batches)
112    }
113
114    /// 计算基于内存的批次大小
115    fn calculate_memory_based_batch_size(&self, data: &[HashMap<String, Value>], max_bytes: usize) -> Result<usize> {
116        if data.is_empty() {
117            return Ok(1000);
118        }
119
120        // 估算单行数据的内存使用量
121        let sample_row = &data[0];
122        let estimated_row_size = self.estimate_row_memory_size(sample_row);
123        
124        if estimated_row_size == 0 {
125            return Ok(1000);
126        }
127
128        let batch_size = max_bytes / estimated_row_size;
129        Ok(batch_size.max(1).min(10000)) // 限制在1-10000之间
130    }
131
132    /// 估算行数据的内存使用量
133    fn estimate_row_memory_size(&self, row: &HashMap<String, Value>) -> usize {
134        let mut size = 0;
135        
136        for (key, value) in row {
137            size += key.len(); // 键的大小
138            size += match value {
139                Value::String(s) => s.len(),
140                Value::Number(_) => 8, // 假设数字占8字节
141                Value::Bool(_) => 1,
142                Value::Array(arr) => arr.len() * 50, // 粗略估算
143                Value::Object(obj) => obj.len() * 100, // 粗略估算
144                Value::Null => 0,
145            };
146        }
147        
148        size
149    }
150
151    /// 记录性能数据
152    fn record_performance(&mut self, batch_size: usize, processing_time_ms: u64) {
153        let rows_per_second = if processing_time_ms > 0 {
154            (batch_size as f64 * 1000.0) / processing_time_ms as f64
155        } else {
156            0.0
157        };
158
159        let performance = BatchPerformance {
160            batch_size,
161            processing_time_ms,
162            rows_per_second,
163        };
164
165        self.performance_history.push(performance);
166        
167        // 保持历史记录在合理范围内
168        if self.performance_history.len() > 100 {
169            self.performance_history.remove(0);
170        }
171    }
172
173    /// 调整批次大小(自适应策略)
174    fn adjust_batch_size(&mut self) {
175        if let BatchStrategy::Adaptive { initial_size: _, max_size } = &self.strategy {
176            if self.performance_history.len() < 3 {
177                return;
178            }
179
180            // 获取最近的性能数据
181            let recent_performances = &self.performance_history[self.performance_history.len() - 3..];
182            let _avg_performance = recent_performances.iter()
183                .map(|p| p.rows_per_second)
184                .sum::<f64>() / recent_performances.len() as f64;
185
186            // 如果性能在提升,增加批次大小
187            if recent_performances.windows(2).all(|w| w[1].rows_per_second >= w[0].rows_per_second) {
188                self.current_batch_size = (self.current_batch_size * 110 / 100).min(*max_size);
189            }
190            // 如果性能在下降,减少批次大小
191            else if recent_performances.windows(2).all(|w| w[1].rows_per_second < w[0].rows_per_second) {
192                self.current_batch_size = (self.current_batch_size * 90 / 100).max(100);
193            }
194        }
195    }
196
197    /// 获取当前批次大小
198    pub fn current_batch_size(&self) -> usize {
199        self.current_batch_size
200    }
201
202    /// 获取性能统计
203    pub fn get_performance_stats(&self) -> BatchPerformanceStats {
204        if self.performance_history.is_empty() {
205            return BatchPerformanceStats::default();
206        }
207
208        let total_rows: usize = self.performance_history.iter().map(|p| p.batch_size).sum();
209        let total_time: u64 = self.performance_history.iter().map(|p| p.processing_time_ms).sum();
210        let avg_rows_per_second = self.performance_history.iter()
211            .map(|p| p.rows_per_second)
212            .sum::<f64>() / self.performance_history.len() as f64;
213
214        let max_rows_per_second = self.performance_history.iter()
215            .map(|p| p.rows_per_second)
216            .fold(0.0, f64::max);
217
218        let min_rows_per_second = self.performance_history.iter()
219            .map(|p| p.rows_per_second)
220            .fold(f64::INFINITY, f64::min);
221
222        BatchPerformanceStats {
223            total_batches: self.performance_history.len(),
224            total_rows,
225            total_processing_time_ms: total_time,
226            avg_rows_per_second,
227            max_rows_per_second,
228            min_rows_per_second,
229            current_batch_size: self.current_batch_size,
230        }
231    }
232}
233
234/// 批量处理性能统计
235#[derive(Debug, Clone, Default)]
236pub struct BatchPerformanceStats {
237    pub total_batches: usize,
238    pub total_rows: usize,
239    pub total_processing_time_ms: u64,
240    pub avg_rows_per_second: f64,
241    pub max_rows_per_second: f64,
242    pub min_rows_per_second: f64,
243    pub current_batch_size: usize,
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    fn create_test_data(count: usize) -> Vec<HashMap<String, Value>> {
251        (0..count).map(|i| {
252            let mut row = HashMap::new();
253            row.insert("id".to_string(), Value::Number(serde_json::Number::from(i)));
254            row.insert("name".to_string(), Value::String(format!("User{}", i)));
255            row
256        }).collect()
257    }
258
259    #[test]
260    fn test_fixed_size_batch_processor() {
261        let mut processor = BatchProcessor::new(BatchStrategy::FixedSize(3));
262        let data = create_test_data(10);
263        
264        let mut processed_batches = Vec::new();
265        let result = processor.process_batches(data, |batch| {
266            processed_batches.push(batch.len());
267            Ok(())
268        });
269        
270        assert!(result.is_ok());
271        assert_eq!(processed_batches, vec![3, 3, 3, 1]); // 3个完整批次 + 1个剩余
272    }
273
274    #[test]
275    fn test_memory_based_batch_processor() {
276        let mut processor = BatchProcessor::new(BatchStrategy::MemoryBased(1024));
277        let data = create_test_data(5);
278        
279        let mut batch_count = 0;
280        let result = processor.process_batches(data, |_batch| {
281            batch_count += 1;
282            Ok(())
283        });
284        
285        assert!(result.is_ok());
286        assert!(batch_count > 0);
287    }
288
289    #[test]
290    fn test_adaptive_batch_processor() {
291        let mut processor = BatchProcessor::new(BatchStrategy::Adaptive {
292            initial_size: 2,
293            max_size: 10,
294        });
295        let data = create_test_data(20);
296        
297        let initial_size = processor.current_batch_size();
298        
299        let result = processor.process_batches(data, |_batch| {
300            // 模拟处理时间
301            std::thread::sleep(std::time::Duration::from_millis(1));
302            Ok(())
303        });
304        
305        assert!(result.is_ok());
306        assert_eq!(initial_size, 2);
307        
308        let stats = processor.get_performance_stats();
309        assert!(stats.total_batches > 0);
310        assert!(stats.total_rows > 0);
311    }
312
313    #[test]
314    fn test_estimate_row_memory_size() {
315        let processor = BatchProcessor::new(BatchStrategy::default());
316        
317        let mut row = HashMap::new();
318        row.insert("id".to_string(), Value::Number(serde_json::Number::from(1)));
319        row.insert("name".to_string(), Value::String("test".to_string()));
320        
321        let size = processor.estimate_row_memory_size(&row);
322        assert!(size > 0);
323    }
324
325    #[test]
326    fn test_performance_stats() {
327        let mut processor = BatchProcessor::new(BatchStrategy::FixedSize(2));
328        let data = create_test_data(5);
329        
330        let result = processor.process_batches(data, |_batch| {
331            std::thread::sleep(std::time::Duration::from_millis(10));
332            Ok(())
333        });
334        
335        assert!(result.is_ok());
336        
337        let stats = processor.get_performance_stats();
338        assert!(stats.total_batches > 0);
339        assert_eq!(stats.total_rows, 5);
340        assert!(stats.avg_rows_per_second > 0.0);
341    }
342}