Skip to main content

oxigdal_ml/batch/
mod.rs

1//! Advanced batch inference for efficient model serving
2//!
3//! This module provides sophisticated batch processing capabilities including
4//! dynamic batching, batch scheduling, memory pooling, and parallel execution.
5//!
6//! # Dynamic Batching
7//!
8//! Dynamic batching optimizes model inference by grouping requests together:
9//! - **Priority Queuing**: Requests can have different priority levels (Critical, High, Normal, Low)
10//! - **Adaptive Batch Size**: Batch size adjusts based on request patterns and system load
11//! - **Variable Input Handling**: Inputs of different sizes are padded to batch together
12//! - **Latency-Throughput Balance**: Configurable timeout to balance latency vs throughput
13//!
14//! # Example
15//!
16//! ```ignore
17//! use oxigdal_ml::batch::{DynamicBatchConfig, DynamicBatchProcessor, PriorityLevel};
18//!
19//! let config = DynamicBatchConfig::builder()
20//!     .max_batch_size(32)
21//!     .min_batch_size(4)
22//!     .batch_timeout_ms(100)
23//!     .enable_adaptive_sizing(true)
24//!     .build();
25//!
26//! let processor = DynamicBatchProcessor::new(model, config);
27//!
28//! // Submit a high-priority request
29//! let result = processor.submit(input, PriorityLevel::High)?;
30//! ```
31
32mod dynamic;
33#[cfg(test)]
34mod tests;
35
36pub use dynamic::{
37    DynamicBatchConfig, DynamicBatchConfigBuilder, DynamicBatchProcessor, DynamicBatchStats,
38    PaddingStrategy, PriorityLevel,
39};
40
41use crate::error::{InferenceError, MlError, Result};
42use crate::models::Model;
43use indicatif::{ProgressBar, ProgressStyle};
44use oxigdal_core::buffer::RasterBuffer;
45use std::collections::VecDeque;
46use std::sync::{Arc, Mutex};
47use std::time::{Duration, Instant};
48use sysinfo::System;
49use tracing::{debug, info, warn};
50
51/// Batch inference configuration
52#[derive(Debug, Clone)]
53pub struct BatchConfig {
54    /// Maximum batch size
55    pub max_batch_size: usize,
56    /// Timeout for batch formation (milliseconds)
57    pub batch_timeout_ms: u64,
58    /// Enable dynamic batching
59    pub dynamic_batching: bool,
60    /// Number of parallel batches
61    pub parallel_batches: usize,
62    /// Enable memory pooling
63    pub memory_pooling: bool,
64}
65
66impl Default for BatchConfig {
67    fn default() -> Self {
68        Self {
69            max_batch_size: 32,
70            batch_timeout_ms: 100,
71            dynamic_batching: true,
72            parallel_batches: 4,
73            memory_pooling: true,
74        }
75    }
76}
77
78impl BatchConfig {
79    /// Creates a new batch configuration builder
80    #[must_use]
81    pub fn builder() -> BatchConfigBuilder {
82        BatchConfigBuilder::default()
83    }
84
85    /// Auto-tunes batch size based on available memory
86    ///
87    /// # Arguments
88    /// * `sample_size_bytes` - Estimated memory per sample in bytes
89    /// * `memory_fraction` - Fraction of available memory to use (0.0-1.0)
90    ///
91    /// # Returns
92    /// Recommended batch size based on system memory
93    #[must_use]
94    pub fn auto_tune_batch_size(sample_size_bytes: usize, memory_fraction: f32) -> usize {
95        let memory_fraction = memory_fraction.clamp(0.1, 0.9);
96
97        let mut system = System::new_all();
98        system.refresh_all();
99
100        let available_memory = system.available_memory() as usize;
101        let usable_memory = (available_memory as f32 * memory_fraction) as usize;
102
103        let batch_size = usable_memory
104            .checked_div(sample_size_bytes)
105            .map(|v| v.clamp(1, 256))
106            .unwrap_or(32);
107
108        info!(
109            "Auto-tuned batch size: {} (available memory: {} MB, sample size: {} MB)",
110            batch_size,
111            available_memory / (1024 * 1024),
112            sample_size_bytes / (1024 * 1024)
113        );
114
115        batch_size
116    }
117}
118
119/// Builder for batch configuration
120#[derive(Debug, Default)]
121pub struct BatchConfigBuilder {
122    max_batch_size: Option<usize>,
123    batch_timeout_ms: Option<u64>,
124    dynamic_batching: Option<bool>,
125    parallel_batches: Option<usize>,
126    memory_pooling: Option<bool>,
127}
128
129impl BatchConfigBuilder {
130    /// Sets the maximum batch size
131    #[must_use]
132    pub fn max_batch_size(mut self, size: usize) -> Self {
133        self.max_batch_size = Some(size);
134        self
135    }
136
137    /// Sets the batch timeout
138    #[must_use]
139    pub fn batch_timeout_ms(mut self, ms: u64) -> Self {
140        self.batch_timeout_ms = Some(ms);
141        self
142    }
143
144    /// Enables dynamic batching
145    #[must_use]
146    pub fn dynamic_batching(mut self, enable: bool) -> Self {
147        self.dynamic_batching = Some(enable);
148        self
149    }
150
151    /// Sets the number of parallel batches
152    #[must_use]
153    pub fn parallel_batches(mut self, count: usize) -> Self {
154        self.parallel_batches = Some(count);
155        self
156    }
157
158    /// Enables memory pooling
159    #[must_use]
160    pub fn memory_pooling(mut self, enable: bool) -> Self {
161        self.memory_pooling = Some(enable);
162        self
163    }
164
165    /// Builds the configuration
166    #[must_use]
167    pub fn build(self) -> BatchConfig {
168        BatchConfig {
169            max_batch_size: self.max_batch_size.unwrap_or(32),
170            batch_timeout_ms: self.batch_timeout_ms.unwrap_or(100),
171            dynamic_batching: self.dynamic_batching.unwrap_or(true),
172            parallel_batches: self.parallel_batches.unwrap_or(4),
173            memory_pooling: self.memory_pooling.unwrap_or(true),
174        }
175    }
176}
177
178/// Batch processor for efficient model serving
179pub struct BatchProcessor<M: Model> {
180    model: Arc<Mutex<M>>,
181    config: BatchConfig,
182    queue: Arc<Mutex<VecDeque<BatchRequest>>>,
183    stats: Arc<Mutex<BatchStats>>,
184}
185
186/// A single batch request
187struct BatchRequest {
188    input: RasterBuffer,
189    timestamp: Instant,
190}
191
192impl BatchRequest {
193    fn new(input: RasterBuffer) -> Self {
194        Self {
195            input,
196            timestamp: Instant::now(),
197        }
198    }
199
200    fn age(&self) -> Duration {
201        self.timestamp.elapsed()
202    }
203}
204
205impl<M: Model> BatchProcessor<M> {
206    /// Creates a new batch processor
207    #[must_use]
208    pub fn new(model: M, config: BatchConfig) -> Self {
209        info!(
210            "Creating batch processor with max_batch_size={}, timeout={}ms",
211            config.max_batch_size, config.batch_timeout_ms
212        );
213
214        Self {
215            model: Arc::new(Mutex::new(model)),
216            config,
217            queue: Arc::new(Mutex::new(VecDeque::new())),
218            stats: Arc::new(Mutex::new(BatchStats::default())),
219        }
220    }
221
222    /// Submits a request for batch inference
223    ///
224    /// When dynamic batching is enabled, requests are collected and processed
225    /// together to improve throughput. For single requests, use `DynamicBatchProcessor`
226    /// for more advanced batching with priority queuing and adaptive sizing.
227    ///
228    /// # Errors
229    /// Returns an error if inference fails
230    pub fn infer(&self, input: RasterBuffer) -> Result<RasterBuffer> {
231        // Capture start time before creating request for statistics
232        let start_time = Instant::now();
233        let request = BatchRequest::new(input);
234
235        // For synchronous single-request processing, run immediately
236        // For true dynamic batching with priority queuing and adaptive sizing,
237        // use DynamicBatchProcessor instead
238        let result = if self.config.dynamic_batching {
239            // Add to queue and check if we should form a batch
240            let mut queue = self
241                .queue
242                .lock()
243                .map_err(|e| MlError::InvalidConfig(format!("Failed to lock queue: {}", e)))?;
244            queue.push_back(request);
245
246            // Check if we should form a batch now
247            let timeout = Duration::from_millis(self.config.batch_timeout_ms);
248            let should_batch = queue.len() >= self.config.max_batch_size
249                || queue.front().map(|r| r.age() >= timeout).unwrap_or(false);
250
251            if should_batch {
252                // Form and process batch
253                let batch_size = queue.len().min(self.config.max_batch_size);
254                let batch: Vec<_> = queue.drain(..batch_size).map(|r| r.input).collect();
255                drop(queue); // Release lock before inference
256
257                let results = {
258                    let mut model = self.model.lock().map_err(|e| {
259                        MlError::InvalidConfig(format!("Failed to lock model: {}", e))
260                    })?;
261                    model.predict_batch(&batch)?
262                };
263
264                // Return the first result (our request)
265                results.into_iter().next().ok_or_else(|| {
266                    MlError::Inference(InferenceError::Failed {
267                        reason: "No results returned from batch".to_string(),
268                    })
269                })?
270            } else {
271                // Not enough requests yet - process just this one
272                let our_request = queue.pop_back().ok_or_else(|| {
273                    MlError::Inference(InferenceError::Failed {
274                        reason: "Request disappeared from queue".to_string(),
275                    })
276                })?;
277                drop(queue); // Release lock before inference
278
279                let mut model = self
280                    .model
281                    .lock()
282                    .map_err(|e| MlError::InvalidConfig(format!("Failed to lock model: {}", e)))?;
283                model.predict(&our_request.input)?
284            }
285        } else {
286            // Dynamic batching disabled - process immediately
287            let mut model = self
288                .model
289                .lock()
290                .map_err(|e| MlError::InvalidConfig(format!("Failed to lock model: {}", e)))?;
291            model.predict(&request.input)?
292        };
293
294        // Update statistics
295        if let Ok(mut stats) = self.stats.lock() {
296            stats.total_requests += 1;
297            stats.total_latency_ms += start_time.elapsed().as_millis() as u64;
298        }
299
300        Ok(result)
301    }
302
303    /// Processes a batch of inputs
304    ///
305    /// # Errors
306    /// Returns an error if inference fails
307    pub fn infer_batch(&self, inputs: Vec<RasterBuffer>) -> Result<Vec<RasterBuffer>> {
308        self.infer_batch_with_progress(inputs, false)
309    }
310
311    /// Processes a batch of inputs with optional progress tracking
312    ///
313    /// # Arguments
314    /// * `inputs` - Input raster buffers to process
315    /// * `show_progress` - Whether to display a progress bar
316    ///
317    /// # Errors
318    /// Returns an error if inference fails
319    pub fn infer_batch_with_progress(
320        &self,
321        inputs: Vec<RasterBuffer>,
322        show_progress: bool,
323    ) -> Result<Vec<RasterBuffer>> {
324        let batch_size = inputs.len();
325        debug!("Processing batch of size {}", batch_size);
326
327        let start = Instant::now();
328
329        // Create progress bar if requested
330        let progress = if show_progress {
331            let pb = ProgressBar::new(batch_size as u64);
332            pb.set_style(
333                ProgressStyle::default_bar()
334                    .template(
335                        "[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} ({per_sec}) {msg}",
336                    )
337                    .map_err(|e| crate::error::MlError::InvalidConfig(e.to_string()))?,
338            );
339            Some(pb)
340        } else {
341            None
342        };
343
344        // Use parallel processing if configured
345        let results = if self.config.parallel_batches > 1 && batch_size > 1 {
346            self.parallel_batch_inference_with_progress(inputs, progress.as_ref())?
347        } else {
348            let mut model = self.model.lock().map_err(|e| {
349                crate::error::MlError::InvalidConfig(format!("Failed to lock model: {}", e))
350            })?;
351            model.predict_batch(&inputs)?
352        };
353
354        if let Some(pb) = progress {
355            pb.finish_with_message("Batch inference complete");
356        }
357
358        // Update statistics
359        if let Ok(mut stats) = self.stats.lock() {
360            stats.total_requests += batch_size;
361            stats.total_batches += 1;
362            stats.total_latency_ms += start.elapsed().as_millis() as u64;
363
364            if batch_size > stats.max_batch_size {
365                stats.max_batch_size = batch_size;
366            }
367        }
368
369        Ok(results)
370    }
371
372    /// Performs parallel batch inference with progress tracking
373    fn parallel_batch_inference_with_progress(
374        &self,
375        inputs: Vec<RasterBuffer>,
376        progress: Option<&ProgressBar>,
377    ) -> Result<Vec<RasterBuffer>> {
378        use rayon::prelude::*;
379
380        let chunk_size =
381            (inputs.len() + self.config.parallel_batches - 1) / self.config.parallel_batches;
382
383        debug!(
384            "Splitting batch into {} chunks of ~{} items",
385            self.config.parallel_batches, chunk_size
386        );
387
388        let results: Result<Vec<_>> = inputs
389            .par_chunks(chunk_size)
390            .map(|chunk| {
391                let chunk_results: Result<Vec<_>> = chunk
392                    .iter()
393                    .map(|input| {
394                        let result = {
395                            let mut model = self.model.lock().map_err(|e| {
396                                crate::error::MlError::InvalidConfig(format!(
397                                    "Failed to lock model: {}",
398                                    e
399                                ))
400                            })?;
401                            model.predict(input)
402                        };
403                        if let Some(pb) = progress {
404                            pb.inc(1);
405                        }
406                        result
407                    })
408                    .collect();
409                chunk_results
410            })
411            .collect();
412
413        results.map(|chunks| chunks.into_iter().flatten().collect())
414    }
415
416    /// Returns the batch statistics
417    #[must_use]
418    pub fn stats(&self) -> BatchStats {
419        self.stats.lock().map(|s| s.clone()).unwrap_or_default()
420    }
421
422    /// Resets the statistics
423    pub fn reset_stats(&self) {
424        if let Ok(mut stats) = self.stats.lock() {
425            *stats = BatchStats::default();
426        }
427    }
428}
429
430/// Batch processing statistics
431#[derive(Debug, Clone, Default)]
432pub struct BatchStats {
433    /// Total number of requests processed
434    pub total_requests: usize,
435    /// Total number of batches processed
436    pub total_batches: usize,
437    /// Maximum batch size observed
438    pub max_batch_size: usize,
439    /// Total latency in milliseconds
440    pub total_latency_ms: u64,
441}
442
443impl BatchStats {
444    /// Returns the average batch size
445    #[must_use]
446    pub fn avg_batch_size(&self) -> f32 {
447        if self.total_batches > 0 {
448            self.total_requests as f32 / self.total_batches as f32
449        } else {
450            0.0
451        }
452    }
453
454    /// Returns the average latency per request
455    #[must_use]
456    pub fn avg_latency_ms(&self) -> f32 {
457        if self.total_requests > 0 {
458            self.total_latency_ms as f32 / self.total_requests as f32
459        } else {
460            0.0
461        }
462    }
463
464    /// Returns the throughput (requests per second)
465    #[must_use]
466    pub fn throughput(&self) -> f32 {
467        if self.total_latency_ms > 0 {
468            (self.total_requests as f32 * 1000.0) / self.total_latency_ms as f32
469        } else {
470            0.0
471        }
472    }
473}
474
475/// Dynamic batch scheduler
476pub struct BatchScheduler {
477    config: BatchConfig,
478    pending: VecDeque<BatchRequest>,
479    last_batch: Instant,
480}
481
482impl BatchScheduler {
483    /// Creates a new batch scheduler
484    #[must_use]
485    pub fn new(config: BatchConfig) -> Self {
486        Self {
487            config,
488            pending: VecDeque::new(),
489            last_batch: Instant::now(),
490        }
491    }
492
493    /// Adds a request to the pending queue
494    pub fn add_request(&mut self, input: RasterBuffer) {
495        self.pending.push_back(BatchRequest::new(input));
496    }
497
498    /// Checks if a batch should be formed
499    #[must_use]
500    pub fn should_form_batch(&self) -> bool {
501        // Form batch if max size reached
502        if self.pending.len() >= self.config.max_batch_size {
503            return true;
504        }
505
506        // Form batch if timeout elapsed and queue not empty
507        if !self.pending.is_empty() {
508            let timeout = Duration::from_millis(self.config.batch_timeout_ms);
509            if self.last_batch.elapsed() >= timeout {
510                return true;
511            }
512        }
513
514        false
515    }
516
517    /// Forms a batch from pending requests
518    #[must_use]
519    pub fn form_batch(&mut self) -> Vec<RasterBuffer> {
520        let batch_size = self.pending.len().min(self.config.max_batch_size);
521        let batch: Vec<_> = self
522            .pending
523            .drain(..batch_size)
524            .map(|req| {
525                let age = req.age();
526                if age.as_millis() > 500 {
527                    warn!("Request aged {}ms before batching", age.as_millis());
528                }
529                req.input
530            })
531            .collect();
532
533        self.last_batch = Instant::now();
534        batch
535    }
536
537    /// Returns the number of pending requests
538    #[must_use]
539    pub fn pending_count(&self) -> usize {
540        self.pending.len()
541    }
542}