1use crate::error::{IoError, IoResult};
7use std::collections::VecDeque;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
12pub struct BatchConfig {
13 pub max_size: usize,
15 pub max_wait_time: Duration,
17 pub min_size: usize,
19}
20
21impl Default for BatchConfig {
22 fn default() -> Self {
23 Self {
24 max_size: 1024,
25 max_wait_time: Duration::from_millis(100),
26 min_size: 1,
27 }
28 }
29}
30
31impl BatchConfig {
32 pub fn low_latency() -> Self {
34 Self {
35 max_size: 256,
36 max_wait_time: Duration::from_millis(10),
37 min_size: 1,
38 }
39 }
40
41 pub fn high_throughput() -> Self {
43 Self {
44 max_size: 4096,
45 max_wait_time: Duration::from_millis(500),
46 min_size: 512,
47 }
48 }
49}
50
51pub struct BatchAccumulator {
53 config: BatchConfig,
54 buffer: VecDeque<f32>,
55 last_flush: Instant,
56}
57
58impl BatchAccumulator {
59 pub fn new(config: BatchConfig) -> Self {
61 let capacity = config.max_size;
62 Self {
63 config,
64 buffer: VecDeque::with_capacity(capacity),
65 last_flush: Instant::now(),
66 }
67 }
68
69 pub fn push(&mut self, sample: f32) {
71 self.buffer.push_back(sample);
72 }
73
74 pub fn push_slice(&mut self, samples: &[f32]) {
76 self.buffer.extend(samples.iter());
77 }
78
79 pub fn should_flush(&self) -> bool {
81 if self.buffer.len() >= self.config.max_size {
83 return true;
84 }
85
86 if self.buffer.len() >= self.config.min_size
88 && self.last_flush.elapsed() >= self.config.max_wait_time
89 {
90 return true;
91 }
92
93 false
94 }
95
96 pub fn flush(&mut self) -> Vec<f32> {
98 let samples: Vec<f32> = self.buffer.drain(..).collect();
99 self.last_flush = Instant::now();
100 samples
101 }
102
103 pub fn len(&self) -> usize {
105 self.buffer.len()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.buffer.is_empty()
111 }
112
113 pub fn clear(&mut self) {
115 self.buffer.clear();
116 self.last_flush = Instant::now();
117 }
118}
119
120pub struct BatchProcessor<F>
122where
123 F: FnMut(&[f32]) -> IoResult<Vec<f32>>,
124{
125 accumulator: BatchAccumulator,
126 processor: F,
127}
128
129impl<F> BatchProcessor<F>
130where
131 F: FnMut(&[f32]) -> IoResult<Vec<f32>>,
132{
133 pub fn new(config: BatchConfig, processor: F) -> Self {
135 Self {
136 accumulator: BatchAccumulator::new(config),
137 processor,
138 }
139 }
140
141 pub fn process(&mut self, sample: f32) -> IoResult<Option<Vec<f32>>> {
143 self.accumulator.push(sample);
144
145 if self.accumulator.should_flush() {
146 let batch = self.accumulator.flush();
147 let result = (self.processor)(&batch)?;
148 return Ok(Some(result));
149 }
150
151 Ok(None)
152 }
153
154 pub fn process_batch(&mut self, samples: &[f32]) -> IoResult<Vec<Vec<f32>>> {
156 let mut results = Vec::new();
157
158 for &sample in samples {
159 if let Some(result) = self.process(sample)? {
160 results.push(result);
161 }
162 }
163
164 Ok(results)
165 }
166
167 pub fn flush(&mut self) -> IoResult<Option<Vec<f32>>> {
169 if self.accumulator.is_empty() {
170 return Ok(None);
171 }
172
173 let batch = self.accumulator.flush();
174 let result = (self.processor)(&batch)?;
175 Ok(Some(result))
176 }
177
178 pub fn pending(&self) -> usize {
180 self.accumulator.len()
181 }
182}
183
184pub struct ParallelBatchProcessor<F>
186where
187 F: Fn(&[f32]) -> IoResult<Vec<f32>> + Send + Sync + Clone + 'static,
188{
189 processor: F,
190 num_threads: usize,
191}
192
193impl<F> ParallelBatchProcessor<F>
194where
195 F: Fn(&[f32]) -> IoResult<Vec<f32>> + Send + Sync + Clone + 'static,
196{
197 pub fn new(_config: BatchConfig, processor: F, num_threads: usize) -> Self {
199 Self {
200 processor,
201 num_threads: num_threads.max(1),
202 }
203 }
204
205 pub async fn process_parallel(&mut self, samples: &[f32]) -> IoResult<Vec<f32>> {
207 if samples.is_empty() {
208 return Ok(Vec::new());
209 }
210
211 let chunk_size = samples.len().div_ceil(self.num_threads);
212 let chunks: Vec<&[f32]> = samples.chunks(chunk_size).collect();
213
214 let mut tasks = Vec::new();
215 for chunk in chunks {
216 let chunk_vec = chunk.to_vec();
217 let processor = self.processor.clone();
218
219 let task = tokio::spawn(async move { processor(&chunk_vec) });
220
221 tasks.push(task);
222 }
223
224 let mut results = Vec::new();
226 for task in tasks {
227 let result = task
228 .await
229 .map_err(|e| IoError::SignalError(format!("Task failed: {}", e)))??;
230 results.extend(result);
231 }
232
233 Ok(results)
234 }
235}
236
237pub struct WindowedBatchProcessor {
239 window_size: usize,
240 overlap: usize,
241 buffer: VecDeque<f32>,
242}
243
244impl WindowedBatchProcessor {
245 pub fn new(window_size: usize, overlap: usize) -> IoResult<Self> {
247 if overlap >= window_size {
248 return Err(IoError::InvalidConfig(
249 "Overlap must be less than window size".to_string(),
250 ));
251 }
252
253 Ok(Self {
254 window_size,
255 overlap,
256 buffer: VecDeque::with_capacity(window_size),
257 })
258 }
259
260 pub fn with_half_overlap(window_size: usize) -> IoResult<Self> {
262 Self::new(window_size, window_size / 2)
263 }
264
265 pub fn push(&mut self, sample: f32) -> Option<Vec<f32>> {
267 self.buffer.push_back(sample);
268
269 if self.buffer.len() >= self.window_size {
270 let window: Vec<f32> = self.buffer.iter().copied().collect();
271
272 let to_remove = self.window_size - self.overlap;
274 for _ in 0..to_remove {
275 self.buffer.pop_front();
276 }
277
278 return Some(window);
279 }
280
281 None
282 }
283
284 pub fn push_batch(&mut self, samples: &[f32]) -> Vec<Vec<f32>> {
286 let mut windows = Vec::new();
287
288 for &sample in samples {
289 if let Some(window) = self.push(sample) {
290 windows.push(window);
291 }
292 }
293
294 windows
295 }
296
297 pub fn buffered(&self) -> usize {
299 self.buffer.len()
300 }
301
302 pub fn clear(&mut self) {
304 self.buffer.clear();
305 }
306}
307
308#[derive(Debug, Clone, Default)]
310pub struct BatchStats {
311 pub total_batches: u64,
312 pub total_samples: u64,
313 pub avg_batch_size: f32,
314 pub min_batch_size: usize,
315 pub max_batch_size: usize,
316}
317
318impl BatchStats {
319 pub fn new() -> Self {
321 Self {
322 total_batches: 0,
323 total_samples: 0,
324 avg_batch_size: 0.0,
325 min_batch_size: usize::MAX,
326 max_batch_size: 0,
327 }
328 }
329
330 pub fn update(&mut self, batch_size: usize) {
332 self.total_batches += 1;
333 self.total_samples += batch_size as u64;
334 self.min_batch_size = self.min_batch_size.min(batch_size);
335 self.max_batch_size = self.max_batch_size.max(batch_size);
336 self.avg_batch_size = self.total_samples as f32 / self.total_batches as f32;
337 }
338
339 pub fn reset(&mut self) {
341 *self = Self::new();
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_batch_accumulator() {
351 let config = BatchConfig {
352 max_size: 5,
353 max_wait_time: Duration::from_secs(1),
354 min_size: 1,
355 };
356
357 let mut acc = BatchAccumulator::new(config);
358
359 for i in 0..4 {
361 acc.push(i as f32);
362 assert!(!acc.should_flush());
363 }
364
365 acc.push(4.0);
367 assert!(acc.should_flush());
368
369 let batch = acc.flush();
370 assert_eq!(batch, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
371 assert!(acc.is_empty());
372 }
373
374 #[test]
375 fn test_batch_processor() {
376 let config = BatchConfig {
377 max_size: 3,
378 max_wait_time: Duration::from_secs(1),
379 min_size: 1,
380 };
381
382 let mut processor = BatchProcessor::new(config, |batch| {
383 Ok(vec![batch.iter().sum()])
385 });
386
387 assert!(processor.process(1.0).unwrap().is_none());
389 assert!(processor.process(2.0).unwrap().is_none());
390
391 let result = processor.process(3.0).unwrap();
393 assert!(result.is_some());
394 assert_eq!(result.unwrap(), vec![6.0]); }
396
397 #[test]
398 fn test_windowed_batch_processor() {
399 let mut processor = WindowedBatchProcessor::new(3, 1).unwrap();
400
401 assert!(processor.push(1.0).is_none());
403 assert!(processor.push(2.0).is_none());
404
405 let window = processor.push(3.0).unwrap();
407 assert_eq!(window, vec![1.0, 2.0, 3.0]);
408
409 assert!(processor.push(4.0).is_none());
411 let window = processor.push(5.0).unwrap();
412 assert_eq!(window, vec![3.0, 4.0, 5.0]);
413 }
414
415 #[test]
416 fn test_batch_stats() {
417 let mut stats = BatchStats::new();
418
419 stats.update(10);
420 stats.update(20);
421 stats.update(30);
422
423 assert_eq!(stats.total_batches, 3);
424 assert_eq!(stats.total_samples, 60);
425 assert_eq!(stats.min_batch_size, 10);
426 assert_eq!(stats.max_batch_size, 30);
427 assert_eq!(stats.avg_batch_size, 20.0);
428 }
429
430 #[tokio::test]
431 async fn test_parallel_batch_processor() {
432 let config = BatchConfig::default();
433 let processor =
434 |batch: &[f32]| -> IoResult<Vec<f32>> { Ok(batch.iter().map(|x| x * 2.0).collect()) };
435
436 let mut parallel = ParallelBatchProcessor::new(config, processor, 4);
437
438 let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0];
439 let result = parallel.process_parallel(&samples).await.unwrap();
440
441 assert_eq!(result, vec![2.0, 4.0, 6.0, 8.0, 10.0]);
442 }
443}