1use crate::error::MlError;
20
21#[derive(Debug, Clone)]
23pub struct PredictionRequest {
24 pub id: u64,
26 pub inputs: Vec<Vec<f32>>,
28 pub input_shapes: Vec<Vec<usize>>,
30}
31
32#[derive(Debug, Clone)]
34pub struct PredictionResult {
35 pub id: u64,
37 pub outputs: Vec<Vec<f32>>,
39 pub output_shapes: Vec<Vec<usize>>,
41 pub latency_ms: f64,
43}
44
45#[derive(Debug, Clone)]
47pub struct AdaptiveBatchConfig {
48 pub min_batch_size: usize,
50 pub max_batch_size: usize,
52 pub target_latency_ms: f64,
54 pub adaptation_rate: f64,
59}
60
61impl Default for AdaptiveBatchConfig {
62 fn default() -> Self {
63 Self {
64 min_batch_size: 1,
65 max_batch_size: 64,
66 target_latency_ms: 50.0,
67 adaptation_rate: 0.1,
68 }
69 }
70}
71
72impl AdaptiveBatchConfig {
73 pub fn validate(&self) -> Result<(), MlError> {
76 if self.min_batch_size == 0 {
77 return Err(MlError::InvalidConfig(
78 "min_batch_size must be at least 1".into(),
79 ));
80 }
81 if self.max_batch_size < self.min_batch_size {
82 return Err(MlError::InvalidConfig(
83 "max_batch_size must be >= min_batch_size".into(),
84 ));
85 }
86 if !(0.0..=1.0).contains(&self.adaptation_rate) {
87 return Err(MlError::InvalidConfig(
88 "adaptation_rate must be in [0.0, 1.0]".into(),
89 ));
90 }
91 if self.target_latency_ms <= 0.0 {
92 return Err(MlError::InvalidConfig(
93 "target_latency_ms must be positive".into(),
94 ));
95 }
96 Ok(())
97 }
98}
99
100pub struct AdaptiveBatcher {
105 config: AdaptiveBatchConfig,
106 current_batch_size: usize,
107 recent_latencies: Vec<f64>,
109 total_batches: u64,
110 total_items: u64,
111 window_size: usize,
113}
114
115impl AdaptiveBatcher {
116 pub fn new(config: AdaptiveBatchConfig) -> Self {
118 let start = config.min_batch_size;
119 Self {
120 config,
121 current_batch_size: start,
122 recent_latencies: Vec::new(),
123 total_batches: 0,
124 total_items: 0,
125 window_size: 10,
126 }
127 }
128
129 pub fn recommended_batch_size(&self) -> usize {
131 self.current_batch_size
132 }
133
134 pub fn update_latency(&mut self, latency_ms: f64, batch_size: usize) {
141 self.recent_latencies.push(latency_ms);
143 if self.recent_latencies.len() > self.window_size {
144 self.recent_latencies.remove(0);
145 }
146
147 self.total_batches += 1;
148 self.total_items += batch_size as u64;
149
150 let avg = self.average_latency_ms();
151 let target = self.config.target_latency_ms;
152 let rate = self.config.adaptation_rate;
153 let min_bs = self.config.min_batch_size as f64;
154 let max_bs = self.config.max_batch_size as f64;
155 let current = self.current_batch_size as f64;
156
157 let new_size = if avg > target {
158 let reduction = (current * rate * (avg - target) / target).max(1.0);
160 (current - reduction).max(min_bs)
161 } else {
162 let gain = (current * rate * (target - avg) / target).max(1.0);
164 (current + gain).min(max_bs)
165 };
166
167 self.current_batch_size = (new_size.round() as usize)
168 .max(self.config.min_batch_size)
169 .min(self.config.max_batch_size);
170 }
171
172 pub fn create_batches(&self, requests: Vec<PredictionRequest>) -> Vec<Vec<PredictionRequest>> {
175 if requests.is_empty() {
176 return Vec::new();
177 }
178 let bs = self.current_batch_size.max(1);
179 requests.chunks(bs).map(|chunk| chunk.to_vec()).collect()
180 }
181
182 pub fn average_latency_ms(&self) -> f64 {
186 if self.recent_latencies.is_empty() {
187 return 0.0;
188 }
189 self.recent_latencies.iter().sum::<f64>() / self.recent_latencies.len() as f64
190 }
191
192 pub fn total_batches(&self) -> u64 {
194 self.total_batches
195 }
196
197 pub fn total_items(&self) -> u64 {
199 self.total_items
200 }
201
202 pub fn config(&self) -> &AdaptiveBatchConfig {
204 &self.config
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 fn default_batcher() -> AdaptiveBatcher {
213 AdaptiveBatcher::new(AdaptiveBatchConfig::default())
214 }
215
216 fn make_request(id: u64) -> PredictionRequest {
217 PredictionRequest {
218 id,
219 inputs: vec![vec![1.0, 2.0, 3.0]],
220 input_shapes: vec![vec![3]],
221 }
222 }
223
224 #[test]
225 fn test_construction_with_default_config() {
226 let batcher = default_batcher();
227 assert_eq!(
228 batcher.recommended_batch_size(),
229 AdaptiveBatchConfig::default().min_batch_size
230 );
231 }
232
233 #[test]
234 fn test_recommended_batch_size_starts_at_min() {
235 let config = AdaptiveBatchConfig {
236 min_batch_size: 4,
237 max_batch_size: 64,
238 ..Default::default()
239 };
240 let batcher = AdaptiveBatcher::new(config);
241 assert_eq!(batcher.recommended_batch_size(), 4);
242 }
243
244 #[test]
245 fn test_update_latency_adjusts_up_when_fast() {
246 let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
247 min_batch_size: 1,
248 max_batch_size: 128,
249 target_latency_ms: 100.0,
250 adaptation_rate: 0.5,
251 });
252 let initial = batcher.recommended_batch_size();
253 batcher.update_latency(10.0, initial);
255 assert!(
256 batcher.recommended_batch_size() > initial,
257 "batch size should grow when latency is well below target"
258 );
259 }
260
261 #[test]
262 fn test_update_latency_adjusts_down_when_slow() {
263 let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
264 min_batch_size: 1,
265 max_batch_size: 64,
266 target_latency_ms: 50.0,
267 adaptation_rate: 0.5,
268 });
269 for _ in 0..10 {
271 let sz = batcher.recommended_batch_size();
272 batcher.update_latency(10.0, sz);
273 }
274 let high = batcher.recommended_batch_size();
275 batcher.update_latency(9999.0, high);
277 assert!(
278 batcher.recommended_batch_size() < high,
279 "batch size should shrink when latency exceeds target"
280 );
281 }
282
283 #[test]
284 fn test_batch_size_does_not_exceed_max() {
285 let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
286 min_batch_size: 1,
287 max_batch_size: 8,
288 target_latency_ms: 1000.0, adaptation_rate: 1.0,
290 });
291 for _ in 0..100 {
292 let sz = batcher.recommended_batch_size();
293 batcher.update_latency(0.001, sz);
294 }
295 assert!(batcher.recommended_batch_size() <= 8);
296 }
297
298 #[test]
299 fn test_batch_size_does_not_go_below_min() {
300 let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
301 min_batch_size: 4,
302 max_batch_size: 64,
303 target_latency_ms: 1.0, adaptation_rate: 1.0,
305 });
306 for _ in 0..100 {
307 let sz = batcher.recommended_batch_size();
308 batcher.update_latency(99999.0, sz);
309 }
310 assert!(batcher.recommended_batch_size() >= 4);
311 }
312
313 #[test]
314 fn test_create_batches_splits_correctly() {
315 let mut batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
316 min_batch_size: 3,
317 max_batch_size: 3,
318 ..Default::default()
319 });
320 batcher.current_batch_size = 3;
322
323 let requests: Vec<PredictionRequest> = (0..7).map(make_request).collect();
324 let batches = batcher.create_batches(requests);
325
326 assert_eq!(batches.len(), 3, "7 items / 3 = 3 batches (3, 3, 1)");
327 assert_eq!(batches[0].len(), 3);
328 assert_eq!(batches[1].len(), 3);
329 assert_eq!(batches[2].len(), 1);
330 }
331
332 #[test]
333 fn test_create_batches_fewer_than_batch_size() {
334 let batcher = AdaptiveBatcher::new(AdaptiveBatchConfig {
335 min_batch_size: 16,
336 max_batch_size: 64,
337 ..Default::default()
338 });
339 let requests: Vec<PredictionRequest> = (0..5).map(make_request).collect();
340 let batches = batcher.create_batches(requests);
341 assert_eq!(batches.len(), 1);
342 assert_eq!(batches[0].len(), 5);
343 }
344
345 #[test]
346 fn test_create_batches_empty_input() {
347 let batcher = default_batcher();
348 let batches = batcher.create_batches(vec![]);
349 assert!(batches.is_empty());
350 }
351
352 #[test]
353 fn test_average_latency_ms_no_observations() {
354 let batcher = default_batcher();
355 assert_eq!(batcher.average_latency_ms(), 0.0);
356 }
357
358 #[test]
359 fn test_average_latency_ms_single_observation() {
360 let mut batcher = default_batcher();
361 batcher.update_latency(42.0, 1);
362 assert!((batcher.average_latency_ms() - 42.0).abs() < 1e-9);
363 }
364
365 #[test]
366 fn test_average_latency_ms_multiple_observations() {
367 let mut batcher = default_batcher();
368 batcher.update_latency(10.0, 1);
369 batcher.update_latency(20.0, 1);
370 batcher.update_latency(30.0, 1);
371 assert!((batcher.average_latency_ms() - 20.0).abs() < 1e-9);
372 }
373
374 #[test]
375 fn test_total_batches_and_items_tracking() {
376 let mut batcher = default_batcher();
377 batcher.update_latency(50.0, 8);
378 batcher.update_latency(50.0, 4);
379 assert_eq!(batcher.total_batches(), 2);
380 assert_eq!(batcher.total_items(), 12);
381 }
382
383 #[test]
384 fn test_config_validation_invalid_min_batch() {
385 let config = AdaptiveBatchConfig {
386 min_batch_size: 0,
387 ..Default::default()
388 };
389 assert!(config.validate().is_err());
390 }
391
392 #[test]
393 fn test_config_validation_max_less_than_min() {
394 let config = AdaptiveBatchConfig {
395 min_batch_size: 10,
396 max_batch_size: 5,
397 ..Default::default()
398 };
399 assert!(config.validate().is_err());
400 }
401
402 #[test]
403 fn test_config_validation_invalid_adaptation_rate() {
404 let config = AdaptiveBatchConfig {
405 adaptation_rate: 1.5,
406 ..Default::default()
407 };
408 assert!(config.validate().is_err());
409 }
410
411 #[test]
412 fn test_prediction_result_fields() {
413 let result = PredictionResult {
414 id: 42,
415 outputs: vec![vec![0.9, 0.1]],
416 output_shapes: vec![vec![2]],
417 latency_ms: 12.5,
418 };
419 assert_eq!(result.id, 42);
420 assert!((result.latency_ms - 12.5).abs() < 1e-9);
421 }
422}