1use core::sync::atomic::{AtomicU32, AtomicU64, Ordering};
22use core::time::Duration;
23
24use clock_lib::{Clock, Monotonic, SystemClock};
25use tokio::sync::Notify;
26
27#[non_exhaustive]
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum Outcome {
33 Success {
36 rtt: Duration,
38 },
39 Failure,
42}
43
44pub trait AdaptiveStrategy: Send + Sync {
50 fn adjust(&self, current: u32, in_flight: u32, outcome: Outcome) -> u32;
54}
55
56#[derive(Debug, Clone, Copy)]
75pub struct Aimd {
76 increase: u32,
77 decrease: f64,
78}
79
80impl Aimd {
81 #[must_use]
84 pub fn new(increase: u32, decrease: f64) -> Self {
85 Self {
86 increase: increase.max(1),
87 decrease: decrease.clamp(0.0, 1.0),
88 }
89 }
90}
91
92impl Default for Aimd {
93 fn default() -> Self {
95 Self::new(1, 0.5)
96 }
97}
98
99impl AdaptiveStrategy for Aimd {
100 fn adjust(&self, current: u32, in_flight: u32, outcome: Outcome) -> u32 {
101 match outcome {
102 Outcome::Success { .. } if in_flight >= current => {
105 current.saturating_add(self.increase)
106 }
107 Outcome::Success { .. } => current,
108 Outcome::Failure => {
109 let cut = (f64::from(current) * self.decrease) as u32;
110 cut.max(1)
111 }
112 }
113 }
114}
115
116#[derive(Debug)]
137pub struct Vegas {
138 alpha: u32,
139 beta: u32,
140 min_rtt_ns: AtomicU64,
142}
143
144impl Vegas {
145 #[must_use]
148 pub fn new(alpha: u32, beta: u32) -> Self {
149 Self {
150 alpha,
151 beta: beta.max(alpha),
152 min_rtt_ns: AtomicU64::new(u64::MAX),
153 }
154 }
155}
156
157impl Default for Vegas {
158 fn default() -> Self {
160 Self::new(3, 6)
161 }
162}
163
164impl AdaptiveStrategy for Vegas {
165 fn adjust(&self, current: u32, _in_flight: u32, outcome: Outcome) -> u32 {
166 let rtt = match outcome {
167 Outcome::Failure => return (current / 2).max(1),
168 Outcome::Success { rtt } => rtt,
169 };
170 let rtt_ns = u64::try_from(rtt.as_nanos()).unwrap_or(u64::MAX).max(1);
171 let min_ns = self
173 .min_rtt_ns
174 .fetch_min(rtt_ns, Ordering::AcqRel)
175 .min(rtt_ns);
176
177 let queue = u64::from(current).saturating_mul(rtt_ns.saturating_sub(min_ns)) / rtt_ns;
179 if queue < u64::from(self.alpha) {
180 current.saturating_add(1)
181 } else if queue > u64::from(self.beta) {
182 current.saturating_sub(1)
183 } else {
184 current
185 }
186 }
187}
188
189pub struct AdaptiveLimiter<S, C = SystemClock>
213where
214 C: Clock,
215{
216 strategy: S,
217 limit: AtomicU32,
218 in_flight: AtomicU32,
219 floor: u32,
220 ceiling: u32,
221 notify: Notify,
222 clock: C,
223}
224
225impl AdaptiveLimiter<core::convert::Infallible> {
226 #[must_use]
228 pub fn builder() -> AdaptiveLimiterBuilder {
229 AdaptiveLimiterBuilder::new()
230 }
231}
232
233impl<S, C> AdaptiveLimiter<S, C>
234where
235 S: AdaptiveStrategy,
236 C: Clock + Clone,
237{
238 fn new(strategy: S, floor: u32, ceiling: u32, initial: u32, clock: C) -> Self {
239 let floor = floor.max(1);
240 let ceiling = ceiling.max(floor);
241 Self {
242 strategy,
243 limit: AtomicU32::new(initial.clamp(floor, ceiling)),
244 in_flight: AtomicU32::new(0),
245 floor,
246 ceiling,
247 notify: Notify::new(),
248 clock,
249 }
250 }
251
252 #[must_use]
255 pub fn with_clock<C2>(self, clock: C2) -> AdaptiveLimiter<S, C2>
256 where
257 C2: Clock + Clone,
258 {
259 AdaptiveLimiter::new(
260 self.strategy,
261 self.floor,
262 self.ceiling,
263 self.limit.load(Ordering::Acquire),
264 clock,
265 )
266 }
267
268 #[must_use]
270 pub fn current_limit(&self) -> u32 {
271 self.limit.load(Ordering::Acquire)
272 }
273
274 #[must_use]
276 pub fn in_flight(&self) -> u32 {
277 self.in_flight.load(Ordering::Acquire)
278 }
279
280 #[must_use]
282 pub fn ceiling(&self) -> u32 {
283 self.ceiling
284 }
285
286 fn try_reserve(&self) -> bool {
288 loop {
289 let in_flight = self.in_flight.load(Ordering::Acquire);
290 if in_flight >= self.limit.load(Ordering::Acquire) {
291 return false;
292 }
293 if self
294 .in_flight
295 .compare_exchange_weak(
296 in_flight,
297 in_flight + 1,
298 Ordering::AcqRel,
299 Ordering::Acquire,
300 )
301 .is_ok()
302 {
303 return true;
304 }
305 }
306 }
307
308 #[must_use]
311 pub fn try_acquire(&self) -> Option<AdaptivePermit<'_, S, C>> {
312 self.try_reserve().then(|| AdaptivePermit::new(self))
313 }
314
315 fn settle(&self, outcome: Outcome) {
317 let in_flight = self.in_flight.load(Ordering::Acquire);
320 let current = self.limit.load(Ordering::Acquire);
321 let proposed = self.strategy.adjust(current, in_flight, outcome);
322 let new = proposed.clamp(self.floor, self.ceiling);
323 self.limit.store(new, Ordering::Release);
324 if new != current {
325 crate::obs::rate_change(current, new);
326 }
327 let _ = self.in_flight.fetch_sub(1, Ordering::AcqRel);
328 self.notify.notify_waiters();
330 }
331
332 fn rtt_since(&self, started: Monotonic) -> Duration {
334 self.clock.now().saturating_duration_since(started)
335 }
336
337 #[inline]
338 fn now(&self) -> Monotonic {
339 self.clock.now()
340 }
341}
342
343#[cfg(feature = "tokio")]
344#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
345impl<S, C> AdaptiveLimiter<S, C>
346where
347 S: AdaptiveStrategy,
348 C: Clock + Clone,
349{
350 pub async fn acquire(&self) -> AdaptivePermit<'_, S, C> {
356 loop {
357 let notified = self.notify.notified();
358 tokio::pin!(notified);
359 let _ = notified.as_mut().enable();
360 if self.try_reserve() {
361 return AdaptivePermit::new(self);
362 }
363 notified.await;
364 }
365 }
366}
367
368#[must_use = "settle the permit with `.success()` or `.failure()`; dropping it counts as a failure"]
372pub struct AdaptivePermit<'a, S, C>
373where
374 S: AdaptiveStrategy,
375 C: Clock + Clone,
376{
377 limiter: &'a AdaptiveLimiter<S, C>,
378 started: Monotonic,
379 settled: bool,
380}
381
382impl<'a, S, C> AdaptivePermit<'a, S, C>
383where
384 S: AdaptiveStrategy,
385 C: Clock + Clone,
386{
387 fn new(limiter: &'a AdaptiveLimiter<S, C>) -> Self {
388 Self {
389 started: limiter.now(),
390 limiter,
391 settled: false,
392 }
393 }
394
395 pub fn success(mut self) {
398 let rtt = self.limiter.rtt_since(self.started);
399 self.limiter.settle(Outcome::Success { rtt });
400 self.settled = true;
401 }
402
403 pub fn failure(mut self) {
405 self.limiter.settle(Outcome::Failure);
406 self.settled = true;
407 }
408}
409
410impl<S, C> Drop for AdaptivePermit<'_, S, C>
411where
412 S: AdaptiveStrategy,
413 C: Clock + Clone,
414{
415 fn drop(&mut self) {
416 if !self.settled {
417 self.limiter.settle(Outcome::Failure);
418 }
419 }
420}
421
422#[derive(Debug, Clone, Copy)]
424pub struct AdaptiveLimiterBuilder {
425 floor: u32,
426 ceiling: u32,
427 initial: Option<u32>,
428}
429
430impl Default for AdaptiveLimiterBuilder {
431 fn default() -> Self {
432 Self::new()
433 }
434}
435
436impl AdaptiveLimiterBuilder {
437 #[must_use]
439 pub fn new() -> Self {
440 Self {
441 floor: 1,
442 ceiling: 100,
443 initial: None,
444 }
445 }
446
447 #[must_use]
449 pub fn floor(mut self, floor: u32) -> Self {
450 self.floor = floor.max(1);
451 self
452 }
453
454 #[must_use]
457 pub fn ceiling(mut self, ceiling: u32) -> Self {
458 self.ceiling = ceiling;
459 self
460 }
461
462 #[must_use]
464 pub fn initial(mut self, initial: u32) -> Self {
465 self.initial = Some(initial);
466 self
467 }
468
469 #[must_use]
472 pub fn build<S>(self, strategy: S) -> AdaptiveLimiter<S, SystemClock>
473 where
474 S: AdaptiveStrategy,
475 {
476 let initial = self.initial.unwrap_or(self.floor);
477 AdaptiveLimiter::new(
478 strategy,
479 self.floor,
480 self.ceiling,
481 initial,
482 SystemClock::new(),
483 )
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 #![allow(clippy::unwrap_used, clippy::expect_used)]
490
491 use super::{AdaptiveLimiter, AdaptiveStrategy, Aimd, Outcome, Vegas};
492 use clock_lib::ManualClock;
493 use core::time::Duration;
494 use std::sync::Arc;
495
496 fn assert_send_sync<T: Send + Sync>() {}
497
498 #[test]
499 fn test_adaptive_is_send_sync() {
500 assert_send_sync::<AdaptiveLimiter<Aimd>>();
501 assert_send_sync::<AdaptiveLimiter<Vegas>>();
502 }
503
504 #[test]
505 fn test_aimd_adjust_rules() {
506 let aimd = Aimd::new(2, 0.5);
507 assert_eq!(
509 aimd.adjust(
510 10,
511 10,
512 Outcome::Success {
513 rtt: Duration::ZERO
514 }
515 ),
516 12
517 );
518 assert_eq!(
520 aimd.adjust(
521 10,
522 3,
523 Outcome::Success {
524 rtt: Duration::ZERO
525 }
526 ),
527 10
528 );
529 assert_eq!(aimd.adjust(10, 10, Outcome::Failure), 5);
531 }
532
533 #[test]
534 fn test_degradation_drives_limit_to_floor() {
535 let limiter = AdaptiveLimiter::builder()
536 .floor(4)
537 .ceiling(100)
538 .initial(64)
539 .build(Aimd::new(4, 0.5));
540
541 for _ in 0..10 {
543 let permit = limiter.try_acquire().expect("a slot under the limit");
544 permit.failure();
545 }
546 assert_eq!(limiter.current_limit(), 4);
547 }
548
549 #[test]
550 fn test_recovery_drives_limit_up_bounded_by_ceiling() {
551 let limiter = AdaptiveLimiter::builder()
552 .floor(1)
553 .ceiling(8)
554 .initial(1)
555 .build(Aimd::new(1, 0.5));
556
557 for _ in 0..50 {
560 let mut held = Vec::new();
561 while let Some(p) = limiter.try_acquire() {
562 held.push(p);
563 }
564 if let Some(p) = held.pop() {
566 p.success();
567 }
568 for p in held {
569 p.success();
570 }
571 }
572 assert_eq!(limiter.current_limit(), 8, "grows to the ceiling");
573 for _ in 0..20 {
575 let p = limiter.try_acquire().expect("slot");
576 p.success();
577 }
578 assert_eq!(limiter.current_limit(), 8, "never exceeds the ceiling");
579 }
580
581 #[test]
582 fn test_never_admits_more_than_the_limit() {
583 let limiter = AdaptiveLimiter::builder()
584 .floor(3)
585 .ceiling(3)
586 .initial(3)
587 .build(Aimd::default());
588
589 let p1 = limiter.try_acquire().expect("1");
590 let p2 = limiter.try_acquire().expect("2");
591 let p3 = limiter.try_acquire().expect("3");
592 assert_eq!(limiter.in_flight(), 3);
593 assert!(limiter.try_acquire().is_none());
595 drop((p1, p2, p3));
596 }
597
598 #[test]
599 fn test_dropping_permit_counts_as_failure() {
600 let limiter = AdaptiveLimiter::builder()
601 .floor(1)
602 .ceiling(100)
603 .initial(10)
604 .build(Aimd::new(1, 0.5));
605 drop(limiter.try_acquire().expect("slot")); assert_eq!(limiter.current_limit(), 5);
607 assert_eq!(limiter.in_flight(), 0, "the slot is released");
608 }
609
610 #[test]
611 fn test_vegas_grows_on_low_latency_shrinks_on_high() {
612 let clock = Arc::new(ManualClock::new());
613 let limiter = AdaptiveLimiter::builder()
614 .floor(1)
615 .ceiling(100)
616 .initial(20)
617 .build(Vegas::new(3, 6))
618 .with_clock(clock.clone());
619
620 let p = limiter.try_acquire().expect("slot");
622 clock.advance(Duration::from_millis(10));
623 p.success();
624 assert_eq!(limiter.current_limit(), 21);
625
626 let p = limiter.try_acquire().expect("slot");
628 clock.advance(Duration::from_millis(200));
629 p.success();
630 assert!(
631 limiter.current_limit() < 21,
632 "high latency shrinks the limit"
633 );
634 }
635
636 #[cfg(feature = "tokio")]
637 #[tokio::test]
638 async fn test_async_acquire_waits_for_a_freed_slot() {
639 let limiter = Arc::new(
640 AdaptiveLimiter::builder()
641 .floor(1)
642 .ceiling(1)
643 .initial(1)
644 .build(Aimd::default()),
645 );
646
647 let held = limiter.try_acquire().expect("the one slot");
648 assert!(limiter.try_acquire().is_none());
649
650 let l = Arc::clone(&limiter);
651 let waiter = tokio::spawn(async move { l.acquire().await.success() });
652 tokio::task::yield_now().await;
654 held.success();
655 waiter.await.unwrap();
656 }
657}