1mod dynamic;
33#[cfg(test)]
34mod tests;
35
36pub use dynamic::{
37 DynamicBatchConfig, DynamicBatchConfigBuilder, DynamicBatchProcessor, DynamicBatchStats,
38 PaddingStrategy, PriorityLevel,
39};
40
41use crate::error::{InferenceError, MlError, Result};
42use crate::models::Model;
43use indicatif::{ProgressBar, ProgressStyle};
44use oxigdal_core::buffer::RasterBuffer;
45use std::collections::VecDeque;
46use std::sync::{Arc, Mutex};
47use std::time::{Duration, Instant};
48use sysinfo::System;
49use tracing::{debug, info, warn};
50
51#[derive(Debug, Clone)]
53pub struct BatchConfig {
54 pub max_batch_size: usize,
56 pub batch_timeout_ms: u64,
58 pub dynamic_batching: bool,
60 pub parallel_batches: usize,
62 pub memory_pooling: bool,
64}
65
66impl Default for BatchConfig {
67 fn default() -> Self {
68 Self {
69 max_batch_size: 32,
70 batch_timeout_ms: 100,
71 dynamic_batching: true,
72 parallel_batches: 4,
73 memory_pooling: true,
74 }
75 }
76}
77
78impl BatchConfig {
79 #[must_use]
81 pub fn builder() -> BatchConfigBuilder {
82 BatchConfigBuilder::default()
83 }
84
85 #[must_use]
94 pub fn auto_tune_batch_size(sample_size_bytes: usize, memory_fraction: f32) -> usize {
95 let memory_fraction = memory_fraction.clamp(0.1, 0.9);
96
97 let mut system = System::new_all();
98 system.refresh_all();
99
100 let available_memory = system.available_memory() as usize;
101 let usable_memory = (available_memory as f32 * memory_fraction) as usize;
102
103 let batch_size = usable_memory
104 .checked_div(sample_size_bytes)
105 .map(|v| v.clamp(1, 256))
106 .unwrap_or(32);
107
108 info!(
109 "Auto-tuned batch size: {} (available memory: {} MB, sample size: {} MB)",
110 batch_size,
111 available_memory / (1024 * 1024),
112 sample_size_bytes / (1024 * 1024)
113 );
114
115 batch_size
116 }
117}
118
119#[derive(Debug, Default)]
121pub struct BatchConfigBuilder {
122 max_batch_size: Option<usize>,
123 batch_timeout_ms: Option<u64>,
124 dynamic_batching: Option<bool>,
125 parallel_batches: Option<usize>,
126 memory_pooling: Option<bool>,
127}
128
129impl BatchConfigBuilder {
130 #[must_use]
132 pub fn max_batch_size(mut self, size: usize) -> Self {
133 self.max_batch_size = Some(size);
134 self
135 }
136
137 #[must_use]
139 pub fn batch_timeout_ms(mut self, ms: u64) -> Self {
140 self.batch_timeout_ms = Some(ms);
141 self
142 }
143
144 #[must_use]
146 pub fn dynamic_batching(mut self, enable: bool) -> Self {
147 self.dynamic_batching = Some(enable);
148 self
149 }
150
151 #[must_use]
153 pub fn parallel_batches(mut self, count: usize) -> Self {
154 self.parallel_batches = Some(count);
155 self
156 }
157
158 #[must_use]
160 pub fn memory_pooling(mut self, enable: bool) -> Self {
161 self.memory_pooling = Some(enable);
162 self
163 }
164
165 #[must_use]
167 pub fn build(self) -> BatchConfig {
168 BatchConfig {
169 max_batch_size: self.max_batch_size.unwrap_or(32),
170 batch_timeout_ms: self.batch_timeout_ms.unwrap_or(100),
171 dynamic_batching: self.dynamic_batching.unwrap_or(true),
172 parallel_batches: self.parallel_batches.unwrap_or(4),
173 memory_pooling: self.memory_pooling.unwrap_or(true),
174 }
175 }
176}
177
178pub struct BatchProcessor<M: Model> {
180 model: Arc<Mutex<M>>,
181 config: BatchConfig,
182 queue: Arc<Mutex<VecDeque<BatchRequest>>>,
183 stats: Arc<Mutex<BatchStats>>,
184}
185
186struct BatchRequest {
188 input: RasterBuffer,
189 timestamp: Instant,
190}
191
192impl BatchRequest {
193 fn new(input: RasterBuffer) -> Self {
194 Self {
195 input,
196 timestamp: Instant::now(),
197 }
198 }
199
200 fn age(&self) -> Duration {
201 self.timestamp.elapsed()
202 }
203}
204
205impl<M: Model> BatchProcessor<M> {
206 #[must_use]
208 pub fn new(model: M, config: BatchConfig) -> Self {
209 info!(
210 "Creating batch processor with max_batch_size={}, timeout={}ms",
211 config.max_batch_size, config.batch_timeout_ms
212 );
213
214 Self {
215 model: Arc::new(Mutex::new(model)),
216 config,
217 queue: Arc::new(Mutex::new(VecDeque::new())),
218 stats: Arc::new(Mutex::new(BatchStats::default())),
219 }
220 }
221
222 pub fn infer(&self, input: RasterBuffer) -> Result<RasterBuffer> {
231 let start_time = Instant::now();
233 let request = BatchRequest::new(input);
234
235 let result = if self.config.dynamic_batching {
239 let mut queue = self
241 .queue
242 .lock()
243 .map_err(|e| MlError::InvalidConfig(format!("Failed to lock queue: {}", e)))?;
244 queue.push_back(request);
245
246 let timeout = Duration::from_millis(self.config.batch_timeout_ms);
248 let should_batch = queue.len() >= self.config.max_batch_size
249 || queue.front().map(|r| r.age() >= timeout).unwrap_or(false);
250
251 if should_batch {
252 let batch_size = queue.len().min(self.config.max_batch_size);
254 let batch: Vec<_> = queue.drain(..batch_size).map(|r| r.input).collect();
255 drop(queue); let results = {
258 let mut model = self.model.lock().map_err(|e| {
259 MlError::InvalidConfig(format!("Failed to lock model: {}", e))
260 })?;
261 model.predict_batch(&batch)?
262 };
263
264 results.into_iter().next().ok_or_else(|| {
266 MlError::Inference(InferenceError::Failed {
267 reason: "No results returned from batch".to_string(),
268 })
269 })?
270 } else {
271 let our_request = queue.pop_back().ok_or_else(|| {
273 MlError::Inference(InferenceError::Failed {
274 reason: "Request disappeared from queue".to_string(),
275 })
276 })?;
277 drop(queue); let mut model = self
280 .model
281 .lock()
282 .map_err(|e| MlError::InvalidConfig(format!("Failed to lock model: {}", e)))?;
283 model.predict(&our_request.input)?
284 }
285 } else {
286 let mut model = self
288 .model
289 .lock()
290 .map_err(|e| MlError::InvalidConfig(format!("Failed to lock model: {}", e)))?;
291 model.predict(&request.input)?
292 };
293
294 if let Ok(mut stats) = self.stats.lock() {
296 stats.total_requests += 1;
297 stats.total_latency_ms += start_time.elapsed().as_millis() as u64;
298 }
299
300 Ok(result)
301 }
302
303 pub fn infer_batch(&self, inputs: Vec<RasterBuffer>) -> Result<Vec<RasterBuffer>> {
308 self.infer_batch_with_progress(inputs, false)
309 }
310
311 pub fn infer_batch_with_progress(
320 &self,
321 inputs: Vec<RasterBuffer>,
322 show_progress: bool,
323 ) -> Result<Vec<RasterBuffer>> {
324 let batch_size = inputs.len();
325 debug!("Processing batch of size {}", batch_size);
326
327 let start = Instant::now();
328
329 let progress = if show_progress {
331 let pb = ProgressBar::new(batch_size as u64);
332 pb.set_style(
333 ProgressStyle::default_bar()
334 .template(
335 "[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} ({per_sec}) {msg}",
336 )
337 .map_err(|e| crate::error::MlError::InvalidConfig(e.to_string()))?,
338 );
339 Some(pb)
340 } else {
341 None
342 };
343
344 let results = if self.config.parallel_batches > 1 && batch_size > 1 {
346 self.parallel_batch_inference_with_progress(inputs, progress.as_ref())?
347 } else {
348 let mut model = self.model.lock().map_err(|e| {
349 crate::error::MlError::InvalidConfig(format!("Failed to lock model: {}", e))
350 })?;
351 model.predict_batch(&inputs)?
352 };
353
354 if let Some(pb) = progress {
355 pb.finish_with_message("Batch inference complete");
356 }
357
358 if let Ok(mut stats) = self.stats.lock() {
360 stats.total_requests += batch_size;
361 stats.total_batches += 1;
362 stats.total_latency_ms += start.elapsed().as_millis() as u64;
363
364 if batch_size > stats.max_batch_size {
365 stats.max_batch_size = batch_size;
366 }
367 }
368
369 Ok(results)
370 }
371
372 fn parallel_batch_inference_with_progress(
374 &self,
375 inputs: Vec<RasterBuffer>,
376 progress: Option<&ProgressBar>,
377 ) -> Result<Vec<RasterBuffer>> {
378 use rayon::prelude::*;
379
380 let chunk_size =
381 (inputs.len() + self.config.parallel_batches - 1) / self.config.parallel_batches;
382
383 debug!(
384 "Splitting batch into {} chunks of ~{} items",
385 self.config.parallel_batches, chunk_size
386 );
387
388 let results: Result<Vec<_>> = inputs
389 .par_chunks(chunk_size)
390 .map(|chunk| {
391 let chunk_results: Result<Vec<_>> = chunk
392 .iter()
393 .map(|input| {
394 let result = {
395 let mut model = self.model.lock().map_err(|e| {
396 crate::error::MlError::InvalidConfig(format!(
397 "Failed to lock model: {}",
398 e
399 ))
400 })?;
401 model.predict(input)
402 };
403 if let Some(pb) = progress {
404 pb.inc(1);
405 }
406 result
407 })
408 .collect();
409 chunk_results
410 })
411 .collect();
412
413 results.map(|chunks| chunks.into_iter().flatten().collect())
414 }
415
416 #[must_use]
418 pub fn stats(&self) -> BatchStats {
419 self.stats.lock().map(|s| s.clone()).unwrap_or_default()
420 }
421
422 pub fn reset_stats(&self) {
424 if let Ok(mut stats) = self.stats.lock() {
425 *stats = BatchStats::default();
426 }
427 }
428}
429
430#[derive(Debug, Clone, Default)]
432pub struct BatchStats {
433 pub total_requests: usize,
435 pub total_batches: usize,
437 pub max_batch_size: usize,
439 pub total_latency_ms: u64,
441}
442
443impl BatchStats {
444 #[must_use]
446 pub fn avg_batch_size(&self) -> f32 {
447 if self.total_batches > 0 {
448 self.total_requests as f32 / self.total_batches as f32
449 } else {
450 0.0
451 }
452 }
453
454 #[must_use]
456 pub fn avg_latency_ms(&self) -> f32 {
457 if self.total_requests > 0 {
458 self.total_latency_ms as f32 / self.total_requests as f32
459 } else {
460 0.0
461 }
462 }
463
464 #[must_use]
466 pub fn throughput(&self) -> f32 {
467 if self.total_latency_ms > 0 {
468 (self.total_requests as f32 * 1000.0) / self.total_latency_ms as f32
469 } else {
470 0.0
471 }
472 }
473}
474
475pub struct BatchScheduler {
477 config: BatchConfig,
478 pending: VecDeque<BatchRequest>,
479 last_batch: Instant,
480}
481
482impl BatchScheduler {
483 #[must_use]
485 pub fn new(config: BatchConfig) -> Self {
486 Self {
487 config,
488 pending: VecDeque::new(),
489 last_batch: Instant::now(),
490 }
491 }
492
493 pub fn add_request(&mut self, input: RasterBuffer) {
495 self.pending.push_back(BatchRequest::new(input));
496 }
497
498 #[must_use]
500 pub fn should_form_batch(&self) -> bool {
501 if self.pending.len() >= self.config.max_batch_size {
503 return true;
504 }
505
506 if !self.pending.is_empty() {
508 let timeout = Duration::from_millis(self.config.batch_timeout_ms);
509 if self.last_batch.elapsed() >= timeout {
510 return true;
511 }
512 }
513
514 false
515 }
516
517 #[must_use]
519 pub fn form_batch(&mut self) -> Vec<RasterBuffer> {
520 let batch_size = self.pending.len().min(self.config.max_batch_size);
521 let batch: Vec<_> = self
522 .pending
523 .drain(..batch_size)
524 .map(|req| {
525 let age = req.age();
526 if age.as_millis() > 500 {
527 warn!("Request aged {}ms before batching", age.as_millis());
528 }
529 req.input
530 })
531 .collect();
532
533 self.last_batch = Instant::now();
534 batch
535 }
536
537 #[must_use]
539 pub fn pending_count(&self) -> usize {
540 self.pending.len()
541 }
542}