1use crate::error::FaucetError;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10use std::time::Duration;
11
12fn default_controller() -> String {
13 "aimd".to_string()
14}
15fn default_min() -> usize {
16 100
17}
18fn default_max() -> usize {
19 50_000
20}
21fn default_increase_step() -> usize {
22 250
23}
24fn default_decrease_factor() -> f64 {
25 0.5
26}
27fn default_cooldown_batches() -> usize {
28 5
29}
30fn default_latency_window() -> usize {
31 10
32}
33fn default_error_threshold() -> f64 {
34 0.01
35}
36fn default_true() -> bool {
37 true
38}
39fn default_log_every() -> usize {
40 50
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
47pub struct AdaptiveBatchConfig {
48 #[serde(default)]
50 pub enabled: bool,
51 #[serde(default = "default_controller")]
53 pub controller: String,
54 #[serde(default = "default_min")]
56 pub min: usize,
57 #[serde(default = "default_max")]
60 pub max: usize,
61 #[serde(default = "default_increase_step")]
63 pub increase_step: usize,
64 #[serde(default = "default_decrease_factor")]
66 pub decrease_factor: f64,
67 #[serde(default = "default_cooldown_batches")]
69 pub cooldown_batches: usize,
70 #[serde(default)]
72 pub target_latency_ms: Option<u64>,
73 #[serde(default = "default_latency_window")]
75 pub latency_window: usize,
76 #[serde(default = "default_error_threshold")]
78 pub error_threshold: f64,
79 #[serde(default = "default_true")]
86 pub respect_source_max: bool,
87 #[serde(default = "default_log_every")]
89 pub log_every: usize,
90}
91
92impl AdaptiveBatchConfig {
93 pub fn validate(&self) -> Result<(), FaucetError> {
95 if self.controller != "aimd" {
96 return Err(FaucetError::Config(format!(
97 "adaptive_batch_size.controller '{}' is not supported (only 'aimd')",
98 self.controller
99 )));
100 }
101 if self.min < 1 {
102 return Err(FaucetError::Config(
103 "adaptive_batch_size.min must be >= 1".into(),
104 ));
105 }
106 if self.min > self.max {
107 return Err(FaucetError::Config(format!(
108 "adaptive_batch_size.min ({}) must be <= max ({})",
109 self.min, self.max
110 )));
111 }
112 if self.max > crate::MAX_BATCH_SIZE {
113 return Err(FaucetError::Config(format!(
114 "adaptive_batch_size.max ({}) must be <= {} (MAX_BATCH_SIZE)",
115 self.max,
116 crate::MAX_BATCH_SIZE
117 )));
118 }
119 if self.increase_step > crate::MAX_BATCH_SIZE {
120 return Err(FaucetError::Config(format!(
121 "adaptive_batch_size.increase_step ({}) must be <= {} (MAX_BATCH_SIZE)",
122 self.increase_step,
123 crate::MAX_BATCH_SIZE
124 )));
125 }
126 if !(self.decrease_factor > 0.0 && self.decrease_factor < 1.0) {
127 return Err(FaucetError::Config(
128 "adaptive_batch_size.decrease_factor must be in (0, 1)".into(),
129 ));
130 }
131 if self.increase_step < 1 {
132 return Err(FaucetError::Config(
133 "adaptive_batch_size.increase_step must be >= 1".into(),
134 ));
135 }
136 if !(0.0..=1.0).contains(&self.error_threshold) {
137 return Err(FaucetError::Config(
138 "adaptive_batch_size.error_threshold must be in [0, 1]".into(),
139 ));
140 }
141 if self.latency_window < 1 {
142 return Err(FaucetError::Config(
143 "adaptive_batch_size.latency_window must be >= 1".into(),
144 ));
145 }
146 if let Some(t) = self.target_latency_ms
147 && t == 0
148 {
149 return Err(FaucetError::Config(
150 "adaptive_batch_size.target_latency_ms must be > 0 when set".into(),
151 ));
152 }
153 if !self.respect_source_max {
154 return Err(FaucetError::Config(
155 "adaptive_batch_size.respect_source_max=false is not supported \
156 (cross-page buffering would violate the O(batch_size) memory \
157 guarantee); remove the field or set it to true"
158 .into(),
159 ));
160 }
161 Ok(())
162 }
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum AdjustDirection {
168 Up,
169 Down,
170}
171impl AdjustDirection {
172 pub fn as_str(&self) -> &'static str {
173 match self {
174 AdjustDirection::Up => "up",
175 AdjustDirection::Down => "down",
176 }
177 }
178}
179
180#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum AdjustReason {
183 Success,
184 Error,
185 Latency,
186}
187impl AdjustReason {
188 pub fn as_str(&self) -> &'static str {
189 match self {
190 AdjustReason::Success => "success",
191 AdjustReason::Error => "error",
192 AdjustReason::Latency => "latency",
193 }
194 }
195}
196
197#[derive(Debug, Clone, Copy)]
199pub struct Observation {
200 pub batch_len: usize,
201 pub errors: usize,
202 pub latency: Duration,
203}
204
205#[derive(Debug, Clone, Copy)]
207pub struct Adjustment {
208 pub new_size: usize,
209 pub direction: AdjustDirection,
210 pub reason: AdjustReason,
211}
212
213pub struct AimdController {
215 min: usize,
216 max: usize,
217 increase_step: usize,
218 decrease_factor: f64,
219 cooldown_batches: usize,
220 target_latency_ms: Option<u64>,
221 latency_window: usize,
222 error_threshold: f64,
223 log_every: usize,
224
225 current: usize,
226 cooldown: usize,
227 latencies: VecDeque<u64>,
228 floor_warned: bool,
229 adjustments: u64,
230}
231
232impl AimdController {
233 pub fn new(cfg: &AdaptiveBatchConfig, initial: usize) -> Self {
236 Self {
237 min: cfg.min,
238 max: cfg.max,
239 increase_step: cfg.increase_step,
240 decrease_factor: cfg.decrease_factor,
241 cooldown_batches: cfg.cooldown_batches,
242 target_latency_ms: cfg.target_latency_ms,
243 latency_window: cfg.latency_window.max(1),
244 error_threshold: cfg.error_threshold,
245 log_every: cfg.log_every,
246 current: initial.clamp(cfg.min, cfg.max),
247 cooldown: 0,
248 latencies: VecDeque::new(),
249 floor_warned: false,
250 adjustments: 0,
251 }
252 }
253
254 pub fn current(&self) -> usize {
255 self.current
256 }
257 pub fn cooldown_active(&self) -> bool {
258 self.cooldown > 0
259 }
260
261 pub fn p50_latency_ms(&self) -> Option<u64> {
263 if self.latencies.is_empty() {
264 return None;
265 }
266 let mut v: Vec<u64> = self.latencies.iter().copied().collect();
267 v.sort_unstable();
268 Some(v[v.len() / 2])
269 }
270
271 pub fn observe(&mut self, obs: Observation) -> Option<Adjustment> {
275 if obs.batch_len > 0 {
277 let rate = obs.errors as f64 / obs.batch_len as f64;
278 if rate > self.error_threshold {
279 return self.shrink(AdjustReason::Error);
280 }
281 }
282 if self.cooldown > 0 {
284 self.cooldown -= 1;
285 return None;
286 }
287 if let Some(target) = self.target_latency_ms {
289 self.latencies.push_back(obs.latency.as_millis() as u64);
290 while self.latencies.len() > self.latency_window {
291 self.latencies.pop_front();
292 }
293 let p50 = self.p50_latency_ms().unwrap_or(0) as f64;
294 let t = target as f64;
295 if p50 > t * 1.2 {
296 return self.shrink(AdjustReason::Latency);
297 } else if p50 < t * 0.5 {
298 return self.grow(AdjustReason::Latency);
299 }
300 return None;
301 }
302 self.grow(AdjustReason::Success)
304 }
305
306 fn shrink(&mut self, reason: AdjustReason) -> Option<Adjustment> {
307 let new = ((self.current as f64 * self.decrease_factor).floor() as usize).max(self.min);
308 self.cooldown = self.cooldown_batches;
309 if new == self.current {
310 if reason == AdjustReason::Error && self.current == self.min && !self.floor_warned {
311 tracing::warn!(
312 batch_size = self.current,
313 "adaptive batch size at floor (min) and still seeing errors; \
314 consider lowering `min` or investigating the sink"
315 );
316 self.floor_warned = true;
317 }
318 return None;
319 }
320 self.current = new;
321 self.bump_log(AdjustDirection::Down, reason);
322 Some(Adjustment {
323 new_size: new,
324 direction: AdjustDirection::Down,
325 reason,
326 })
327 }
328
329 fn grow(&mut self, reason: AdjustReason) -> Option<Adjustment> {
330 let new = self
333 .current
334 .saturating_add(self.increase_step)
335 .min(self.max);
336 if new == self.current {
337 return None;
338 }
339 self.current = new;
340 self.bump_log(AdjustDirection::Up, reason);
341 Some(Adjustment {
342 new_size: new,
343 direction: AdjustDirection::Up,
344 reason,
345 })
346 }
347
348 fn bump_log(&mut self, direction: AdjustDirection, reason: AdjustReason) {
349 self.adjustments += 1;
350 if self.log_every > 0 && self.adjustments.is_multiple_of(self.log_every as u64) {
351 tracing::info!(
352 current = self.current,
353 direction = direction.as_str(),
354 reason = reason.as_str(),
355 adjustments = self.adjustments,
356 "adaptive batch size adjusted"
357 );
358 }
359 }
360}
361
362#[cfg(test)]
363mod config_tests {
364 use super::*;
365
366 fn valid() -> AdaptiveBatchConfig {
367 serde_json::from_value(serde_json::json!({"enabled": true})).unwrap()
368 }
369
370 #[test]
371 fn defaults_are_sane_and_valid() {
372 let c = valid();
373 assert_eq!(c.controller, "aimd");
374 assert_eq!(c.min, 100);
375 assert_eq!(c.max, 50_000);
376 assert!(c.respect_source_max);
377 assert!(c.target_latency_ms.is_none());
378 c.validate().unwrap();
379 }
380
381 #[test]
382 fn rejects_respect_source_max_false() {
383 let mut c = valid();
387 c.respect_source_max = false;
388 assert!(c.validate().is_err());
389 }
390
391 #[test]
392 fn rejects_unknown_controller() {
393 let mut c = valid();
394 c.controller = "pid".into();
395 assert!(c.validate().is_err());
396 }
397
398 #[test]
399 fn rejects_min_gt_max_and_zero_min() {
400 let mut c = valid();
401 c.min = 10;
402 c.max = 5;
403 assert!(c.validate().is_err());
404 let mut c = valid();
405 c.min = 0;
406 assert!(c.validate().is_err());
407 }
408
409 #[test]
410 fn rejects_max_and_increase_step_above_max_batch_size() {
411 let mut c = valid();
412 c.max = crate::MAX_BATCH_SIZE + 1;
413 assert!(c.validate().is_err());
414 let mut c = valid();
415 c.increase_step = crate::MAX_BATCH_SIZE + 1;
416 assert!(c.validate().is_err());
417 let mut c = valid();
419 c.max = crate::MAX_BATCH_SIZE;
420 c.validate().unwrap();
421 }
422
423 #[test]
424 fn rejects_out_of_range_factors() {
425 let mut c = valid();
426 c.decrease_factor = 1.5;
427 assert!(c.validate().is_err());
428 let mut c = valid();
429 c.error_threshold = 2.0;
430 assert!(c.validate().is_err());
431 let mut c = valid();
432 c.increase_step = 0;
433 assert!(c.validate().is_err());
434 let mut c = valid();
435 c.target_latency_ms = Some(0);
436 assert!(c.validate().is_err());
437 let mut c = valid();
439 c.decrease_factor = 0.0;
440 assert!(c.validate().is_err());
441 let mut c = valid();
442 c.decrease_factor = 1.0;
443 assert!(c.validate().is_err());
444 let mut c = valid();
446 c.latency_window = 0;
447 assert!(c.validate().is_err());
448 }
449}
450
451#[cfg(test)]
452mod controller_tests {
453 use super::*;
454 use std::time::Duration;
455
456 fn cfg() -> AdaptiveBatchConfig {
457 serde_json::from_value(serde_json::json!({
458 "enabled": true, "min": 100, "max": 1000,
459 "increase_step": 100, "decrease_factor": 0.5,
460 "cooldown_batches": 2, "error_threshold": 0.1
461 }))
462 .unwrap()
463 }
464
465 fn ok(len: usize) -> Observation {
466 Observation {
467 batch_len: len,
468 errors: 0,
469 latency: Duration::from_millis(1),
470 }
471 }
472
473 #[test]
474 fn cold_start_clamps_initial_to_bounds() {
475 let c = AimdController::new(&cfg(), 50); assert_eq!(c.current(), 100);
477 let c = AimdController::new(&cfg(), 99_999); assert_eq!(c.current(), 1000);
479 let c = AimdController::new(&cfg(), 500);
480 assert_eq!(c.current(), 500);
481 }
482
483 #[test]
484 fn grow_saturates_instead_of_overflowing_usize() {
485 let cfg: AdaptiveBatchConfig = serde_json::from_value(serde_json::json!({
489 "enabled": true, "min": 1, "max": usize::MAX,
490 "increase_step": usize::MAX, "decrease_factor": 0.5
491 }))
492 .unwrap();
493 let mut c = AimdController::new(&cfg, 1);
494 let adj = c.observe(ok(1)).expect("growth should occur");
495 assert_eq!(adj.new_size, usize::MAX);
496 assert_eq!(c.current(), usize::MAX);
497 }
498
499 #[test]
500 fn grows_additively_on_success_up_to_max() {
501 let mut c = AimdController::new(&cfg(), 800);
502 let a = c.observe(ok(800)).unwrap();
503 assert_eq!(a.new_size, 900);
504 assert_eq!(a.direction, AdjustDirection::Up);
505 assert_eq!(a.reason, AdjustReason::Success);
506 c.observe(ok(900)); assert_eq!(c.current(), 1000);
508 assert!(c.observe(ok(1000)).is_none());
510 assert_eq!(c.current(), 1000);
511 }
512
513 #[test]
514 fn shrinks_multiplicatively_on_error_and_arms_cooldown() {
515 let mut c = AimdController::new(&cfg(), 800);
516 let a = c
517 .observe(Observation {
518 batch_len: 100,
519 errors: 20,
520 latency: Duration::from_millis(1),
521 })
522 .unwrap();
523 assert_eq!(a.new_size, 400); assert_eq!(a.direction, AdjustDirection::Down);
525 assert_eq!(a.reason, AdjustReason::Error);
526 assert!(c.cooldown_active());
527 assert!(c.observe(ok(400)).is_none());
529 assert!(c.observe(ok(400)).is_none());
530 let a = c.observe(ok(400)).unwrap();
532 assert_eq!(a.new_size, 500);
533 }
534
535 #[test]
536 fn does_not_shrink_below_min_and_warns_once() {
537 let mut c = AimdController::new(&cfg(), 100); let bad = Observation {
539 batch_len: 100,
540 errors: 100,
541 latency: Duration::from_millis(1),
542 };
543 assert!(c.observe(bad).is_none());
545 assert_eq!(c.current(), 100);
546 }
547
548 #[test]
549 fn latency_target_shrinks_when_slow_grows_when_fast() {
550 let mut c: AimdController = AimdController::new(
551 &serde_json::from_value(serde_json::json!({
552 "enabled": true, "min": 100, "max": 1000, "increase_step": 100,
553 "decrease_factor": 0.5, "cooldown_batches": 0,
554 "target_latency_ms": 500, "latency_window": 1
555 }))
556 .unwrap(),
557 800,
558 );
559 let a = c
561 .observe(Observation {
562 batch_len: 800,
563 errors: 0,
564 latency: Duration::from_millis(700),
565 })
566 .unwrap();
567 assert_eq!(a.reason, AdjustReason::Latency);
568 assert_eq!(a.direction, AdjustDirection::Down);
569 assert_eq!(c.current(), 400);
570 let a = c
572 .observe(Observation {
573 batch_len: 400,
574 errors: 0,
575 latency: Duration::from_millis(100),
576 })
577 .unwrap();
578 assert_eq!(a.direction, AdjustDirection::Up);
579 assert_eq!(a.reason, AdjustReason::Latency);
580 assert_eq!(c.current(), 500);
581 assert!(
583 c.observe(Observation {
584 batch_len: 500,
585 errors: 0,
586 latency: Duration::from_millis(500)
587 })
588 .is_none()
589 );
590 }
591
592 #[test]
593 fn error_during_cooldown_reshrinks_and_rearms() {
594 let mut c = AimdController::new(&cfg(), 800);
597 let bad = Observation {
598 batch_len: 100,
599 errors: 50,
600 latency: Duration::from_millis(1),
601 };
602 let a = c.observe(bad).unwrap(); assert_eq!(a.new_size, 400);
604 assert!(c.cooldown_active());
605 let a = c.observe(bad).unwrap();
607 assert_eq!(a.new_size, 200);
608 assert_eq!(a.reason, AdjustReason::Error);
609 assert!(c.cooldown_active());
610 }
611
612 #[test]
613 fn p50_uses_median_of_multi_sample_window() {
614 let mut c = AimdController::new(
615 &serde_json::from_value(serde_json::json!({
616 "enabled": true, "min": 100, "max": 1000, "increase_step": 100,
617 "decrease_factor": 0.5, "cooldown_batches": 0,
618 "target_latency_ms": 500, "latency_window": 5
619 }))
620 .unwrap(),
621 800,
622 );
623 for _ in 0..5 {
625 c.observe(Observation {
626 batch_len: 800,
627 errors: 0,
628 latency: Duration::from_millis(10),
629 });
630 }
631 assert_eq!(c.p50_latency_ms(), Some(10));
632 c.observe(Observation {
635 batch_len: 800,
636 errors: 0,
637 latency: Duration::from_millis(900),
638 });
639 assert_eq!(c.p50_latency_ms(), Some(10));
640 }
641}