Skip to main content

oxigdal_ml/
batch_predict.rs

1//! Adaptive batch prediction for geospatial ML inference
2//!
3//! This module provides:
4//! - [`PredictionRequest`] / [`PredictionResult`] — typed request/response pairs
5//! - [`AdaptiveBatcher`] — adjusts batch size based on observed latency so that
6//!   throughput is maximised while keeping per-batch latency near a configurable
7//!   target.
8//!
9//! # Algorithm
10//!
11//! After each batch completes the batcher computes a rolling average over the
12//! last `N` observations and compares it to `target_latency_ms`:
13//!
14//! - **Too slow** (avg > target): shrink towards `min_batch_size`
15//! - **Too fast** (avg < target): grow towards `max_batch_size`
16//!
17//! The magnitude of each adjustment is controlled by `adaptation_rate`.
18
19use crate::error::MlError;
20
21/// Single prediction request carrying raw float tensors.
22#[derive(Debug, Clone)]
23pub struct PredictionRequest {
24    /// Caller-assigned identifier (used to correlate results)
25    pub id: u64,
26    /// Input tensors as flat float vectors
27    pub inputs: Vec<Vec<f32>>,
28    /// Shape of each input tensor (e.g. `[3, 256, 256]` for a CHW image)
29    pub input_shapes: Vec<Vec<usize>>,
30}
31
32/// Single prediction result produced by an inference run.
33#[derive(Debug, Clone)]
34pub struct PredictionResult {
35    /// Matches the `id` field of the originating [`PredictionRequest`]
36    pub id: u64,
37    /// Output tensors as flat float vectors
38    pub outputs: Vec<Vec<f32>>,
39    /// Shape of each output tensor
40    pub output_shapes: Vec<Vec<usize>>,
41    /// Wall-clock latency of the inference call in milliseconds
42    pub latency_ms: f64,
43}
44
45/// Adaptive batch sizing configuration
46#[derive(Debug, Clone)]
47pub struct AdaptiveBatchConfig {
48    /// Minimum allowed batch size (≥ 1)
49    pub min_batch_size: usize,
50    /// Maximum allowed batch size
51    pub max_batch_size: usize,
52    /// Target latency per batch in milliseconds
53    pub target_latency_ms: f64,
54    /// Learning rate for batch-size adaptation (0.0 – 1.0)
55    ///
56    /// Higher values cause larger adjustments; lower values yield smoother
57    /// adaptation.
58    pub adaptation_rate: f64,
59}
60
61impl Default for AdaptiveBatchConfig {
62    fn default() -> Self {
63        Self {
64            min_batch_size: 1,
65            max_batch_size: 64,
66            target_latency_ms: 50.0,
67            adaptation_rate: 0.1,
68        }
69    }
70}
71
72impl AdaptiveBatchConfig {
73    /// Validate the configuration.  Returns an error if any invariant is
74    /// violated.
75    pub fn validate(&self) -> Result<(), MlError> {
76        if self.min_batch_size == 0 {
77            return Err(MlError::InvalidConfig(
78                "min_batch_size must be at least 1".into(),
79            ));
80        }
81        if self.max_batch_size < self.min_batch_size {
82            return Err(MlError::InvalidConfig(
83                "max_batch_size must be >= min_batch_size".into(),
84            ));
85        }
86        if !(0.0..=1.0).contains(&self.adaptation_rate) {
87            return Err(MlError::InvalidConfig(
88                "adaptation_rate must be in [0.0, 1.0]".into(),
89            ));
90        }
91        if self.target_latency_ms <= 0.0 {
92            return Err(MlError::InvalidConfig(
93                "target_latency_ms must be positive".into(),
94            ));
95        }
96        Ok(())
97    }
98}
99
100/// Adaptive batch size controller
101///
102/// Tracks recent inference latencies and adjusts the recommended batch size
103/// to keep per-batch latency near the configured target.
104pub struct AdaptiveBatcher {
105    config: AdaptiveBatchConfig,
106    current_batch_size: usize,
107    /// Ring buffer of the most recent latency observations (milliseconds)
108    recent_latencies: Vec<f64>,
109    total_batches: u64,
110    total_items: u64,
111    /// Maximum number of latency samples to keep for the rolling average
112    window_size: usize,
113}
114
115impl AdaptiveBatcher {
116    /// Create a new `AdaptiveBatcher` starting at `min_batch_size`.
117    pub fn new(config: AdaptiveBatchConfig) -> Self {
118        let start = config.min_batch_size;
119        Self {
120            config,
121            current_batch_size: start,
122            recent_latencies: Vec::new(),
123            total_batches: 0,
124            total_items: 0,
125            window_size: 10,
126        }
127    }
128
129    /// Return the current recommended batch size.
130    pub fn recommended_batch_size(&self) -> usize {
131        self.current_batch_size
132    }
133
134    /// Update the batch-size estimate based on the observed latency for a
135    /// completed batch.
136    ///
137    /// # Parameters
138    /// - `latency_ms`: wall-clock time the batch took in milliseconds
139    /// - `batch_size`: number of items that were in the completed batch
140    pub fn update_latency(&mut self, latency_ms: f64, batch_size: usize) {
141        // Maintain a rolling window of latency observations
142        self.recent_latencies.push(latency_ms);
143        if self.recent_latencies.len() > self.window_size {
144            self.recent_latencies.remove(0);
145        }
146
147        self.total_batches += 1;
148        self.total_items += batch_size as u64;
149
150        let avg = self.average_latency_ms();
151        let target = self.config.target_latency_ms;
152        let rate = self.config.adaptation_rate;
153        let min_bs = self.config.min_batch_size as f64;
154        let max_bs = self.config.max_batch_size as f64;
155        let current = self.current_batch_size as f64;
156
157        let new_size = if avg > target {
158            // Too slow — reduce batch size; always move at least 1 step down
159            let reduction = (current * rate * (avg - target) / target).max(1.0);
160            (current - reduction).max(min_bs)
161        } else {
162            // Too fast — increase batch size; always move at least 1 step up
163            let gain = (current * rate * (target - avg) / target).max(1.0);
164            (current + gain).min(max_bs)
165        };
166
167        self.current_batch_size = (new_size.round() as usize)
168            .max(self.config.min_batch_size)
169            .min(self.config.max_batch_size);
170    }
171
172    /// Group a flat list of requests into batches, each of at most
173    /// `recommended_batch_size()` items.
174    pub fn create_batches(&self, requests: Vec<PredictionRequest>) -> Vec<Vec<PredictionRequest>> {
175        if requests.is_empty() {
176            return Vec::new();
177        }
178        let bs = self.current_batch_size.max(1);
179        requests.chunks(bs).map(|chunk| chunk.to_vec()).collect()
180    }
181
182    /// Rolling average latency over the recent observation window.
183    ///
184    /// Returns `0.0` when no observations have been recorded yet.
185    pub fn average_latency_ms(&self) -> f64 {
186        if self.recent_latencies.is_empty() {
187            return 0.0;
188        }
189        self.recent_latencies.iter().sum::<f64>() / self.recent_latencies.len() as f64
190    }
191
192    /// Total number of completed batches.
193    pub fn total_batches(&self) -> u64 {
194        self.total_batches
195    }
196
197    /// Total number of individual items processed across all batches.
198    pub fn total_items(&self) -> u64 {
199        self.total_items
200    }
201
202    /// Return a reference to the current configuration.
203    pub fn config(&self) -> &AdaptiveBatchConfig {
204        &self.config
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    fn default_batcher() -> AdaptiveBatcher {
213        AdaptiveBatcher::new(AdaptiveBatchConfig::default())
214    }
215
216    fn make_request(id: u64) -> PredictionRequest {
217        PredictionRequest {
218            id,
219            inputs: vec![vec![1.0, 2.0, 3.0]],
220            input_shapes: vec![vec![3]],
221        }
222    }
223
224    #[test]
225    fn test_construction_with_default_config() {
226        let batcher = default_batcher();
227        assert_eq!(
228            batcher.recommended_batch_size(),
229            AdaptiveBatchConfig::default().min_batch_size
230        );
231    }
232
233    #[test]
234    fn test_recommended_batch_size_starts_at_min() {
235        let config = AdaptiveBatchConfig {
236            min_batch_size: 4,
237            max_batch_size: 64,
238            ..Default::default()
239        };
240        let batcher = AdaptiveBatcher::new(config);
241        assert_eq!(batcher.recommended_batch_size(), 4);
242    }
243
244    #[test]
245    fn test_update_latency_adjusts_up_when_fast() {
246        let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
247            min_batch_size: 1,
248            max_batch_size: 128,
249            target_latency_ms: 100.0,
250            adaptation_rate: 0.5,
251        });
252        let initial = batcher.recommended_batch_size();
253        // Very fast batch — should grow
254        batcher.update_latency(10.0, initial);
255        assert!(
256            batcher.recommended_batch_size() > initial,
257            "batch size should grow when latency is well below target"
258        );
259    }
260
261    #[test]
262    fn test_update_latency_adjusts_down_when_slow() {
263        let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
264            min_batch_size: 1,
265            max_batch_size: 64,
266            target_latency_ms: 50.0,
267            adaptation_rate: 0.5,
268        });
269        // Force the current size up first
270        for _ in 0..10 {
271            let sz = batcher.recommended_batch_size();
272            batcher.update_latency(10.0, sz);
273        }
274        let high = batcher.recommended_batch_size();
275        // Now feed a very slow batch
276        batcher.update_latency(9999.0, high);
277        assert!(
278            batcher.recommended_batch_size() < high,
279            "batch size should shrink when latency exceeds target"
280        );
281    }
282
283    #[test]
284    fn test_batch_size_does_not_exceed_max() {
285        let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
286            min_batch_size: 1,
287            max_batch_size: 8,
288            target_latency_ms: 1000.0, // very long target → always growing
289            adaptation_rate: 1.0,
290        });
291        for _ in 0..100 {
292            let sz = batcher.recommended_batch_size();
293            batcher.update_latency(0.001, sz);
294        }
295        assert!(batcher.recommended_batch_size() <= 8);
296    }
297
298    #[test]
299    fn test_batch_size_does_not_go_below_min() {
300        let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
301            min_batch_size: 4,
302            max_batch_size: 64,
303            target_latency_ms: 1.0, // very short target → always shrinking
304            adaptation_rate: 1.0,
305        });
306        for _ in 0..100 {
307            let sz = batcher.recommended_batch_size();
308            batcher.update_latency(99999.0, sz);
309        }
310        assert!(batcher.recommended_batch_size() >= 4);
311    }
312
313    #[test]
314    fn test_create_batches_splits_correctly() {
315        let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
316            min_batch_size: 3,
317            max_batch_size: 3,
318            ..Default::default()
319        });
320        // Force batch size to 3
321        batcher.current_batch_size = 3;
322
323        let requests: Vec<PredictionRequest> = (0..7).map(make_request).collect();
324        let batches = batcher.create_batches(requests);
325
326        assert_eq!(batches.len(), 3, "7 items / 3 = 3 batches (3, 3, 1)");
327        assert_eq!(batches[0].len(), 3);
328        assert_eq!(batches[1].len(), 3);
329        assert_eq!(batches[2].len(), 1);
330    }
331
332    #[test]
333    fn test_create_batches_fewer_than_batch_size() {
334        let batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
335            min_batch_size: 16,
336            max_batch_size: 64,
337            ..Default::default()
338        });
339        let requests: Vec<PredictionRequest> = (0..5).map(make_request).collect();
340        let batches = batcher.create_batches(requests);
341        assert_eq!(batches.len(), 1);
342        assert_eq!(batches[0].len(), 5);
343    }
344
345    #[test]
346    fn test_create_batches_empty_input() {
347        let batcher = default_batcher();
348        let batches = batcher.create_batches(vec![]);
349        assert!(batches.is_empty());
350    }
351
352    #[test]
353    fn test_average_latency_ms_no_observations() {
354        let batcher = default_batcher();
355        assert_eq!(batcher.average_latency_ms(), 0.0);
356    }
357
358    #[test]
359    fn test_average_latency_ms_single_observation() {
360        let mut batcher = default_batcher();
361        batcher.update_latency(42.0, 1);
362        assert!((batcher.average_latency_ms() - 42.0).abs() < 1e-9);
363    }
364
365    #[test]
366    fn test_average_latency_ms_multiple_observations() {
367        let mut batcher = default_batcher();
368        batcher.update_latency(10.0, 1);
369        batcher.update_latency(20.0, 1);
370        batcher.update_latency(30.0, 1);
371        assert!((batcher.average_latency_ms() - 20.0).abs() < 1e-9);
372    }
373
374    #[test]
375    fn test_total_batches_and_items_tracking() {
376        let mut batcher = default_batcher();
377        batcher.update_latency(50.0, 8);
378        batcher.update_latency(50.0, 4);
379        assert_eq!(batcher.total_batches(), 2);
380        assert_eq!(batcher.total_items(), 12);
381    }
382
383    #[test]
384    fn test_config_validation_invalid_min_batch() {
385        let config = AdaptiveBatchConfig {
386            min_batch_size: 0,
387            ..Default::default()
388        };
389        assert!(config.validate().is_err());
390    }
391
392    #[test]
393    fn test_config_validation_max_less_than_min() {
394        let config = AdaptiveBatchConfig {
395            min_batch_size: 10,
396            max_batch_size: 5,
397            ..Default::default()
398        };
399        assert!(config.validate().is_err());
400    }
401
402    #[test]
403    fn test_config_validation_invalid_adaptation_rate() {
404        let config = AdaptiveBatchConfig {
405            adaptation_rate: 1.5,
406            ..Default::default()
407        };
408        assert!(config.validate().is_err());
409    }
410
411    #[test]
412    fn test_prediction_result_fields() {
413        let result = PredictionResult {
414            id: 42,
415            outputs: vec![vec![0.9, 0.1]],
416            output_shapes: vec![vec![2]],
417            latency_ms: 12.5,
418        };
419        assert_eq!(result.id, 42);
420        assert!((result.latency_ms - 12.5).abs() < 1e-9);
421    }
422}