Skip to main content

adk_anthropic/
backoff.rs

1//! A perfect backoff algorithm.
2//!
3//! This algorithm is based upon the following insight:  The integral of system headroom across the
4//! recovery window must be at least as large as the integral of the system downtime during an
5//! outage.
6//!
7//! It looks like this:
8//! ```text
9//! │
10//! │                            HHHHHHHHHHHHHHHHHHHHH
11//! │                            HHHHHHHHHHHHHHHHHHHHH
12//! ├────────────┐              ┌─────────────────────
13//! │            │DDDDDDDDDDDDDD│          
14//! │            │DDDDDDDDDDDDDD│          
15//! │            │DDDDDDDDDDDDDD│          
16//! │            └──────────────┘          
17//! └────────────────────────────────────────────────
18//! ```
19//!
20//! The area of downtime, D, must be less than or equal to the area of headroom, H, for the system
21//! to be able to absorb the downtime.  If t_D is the duration of downtime, t_R is the duration
22//! of recovery, T_N the nominal throughput of the system and T_H the throughput kept in reserve as
23//! headroom, we can say t_D * T_N = t_R * T_H, or t_R = t_D * T_N / T_H.
24//!
25//! This module provides an `ExponentialBackoff` struct that implements an exponential backoff
26//! algorithm based on this insight.
27//!
28//! Here is an example that shows how to use this struct:
29//!
30//! ```ignore
31//! let exp_backoff = ExponentialBackoff::new(1_000.0, 100.0);
32//! loop {
33//!     let result = match try_some_operation().await {
34//!         Ok(result) => break result,
35//!         Err(e) => {
36//!             if e.is_recoverable() {
37//!                 tokio::time::sleep(exp_backoff.next()).await;
38//!             } else {
39//!                 return Err(e);
40//!             }
41//!         }
42//!     };
43//!     // process the result
44//! }
45//! ```
46
47use std::collections::hash_map::RandomState;
48use std::hash::BuildHasher;
49use std::time::{Duration, Instant};
50
51//////////////////////////////////////// ExponentialBackoff ////////////////////////////////////////
52
53pub struct ExponentialBackoff {
54    throughput_ops_sec: f64,
55    reserve_capacity: f64,
56    start: Instant,
57}
58
59impl ExponentialBackoff {
60    pub fn new(throughput_ops_sec: impl Into<f64>, reserve_capacity: impl Into<f64>) -> Self {
61        let throughput_ops_sec = throughput_ops_sec.into();
62        let reserve_capacity = reserve_capacity.into();
63        Self { throughput_ops_sec, reserve_capacity, start: Instant::now() }
64    }
65
66    pub fn next(&self) -> Duration {
67        // Figure out the recovery window
68        let elapsed = self.start.elapsed();
69        // The units on throughput_ops_sec and reserve_capacity cancel out, so we simply scale
70        // elapsed.as_micros() by the ratio of the two.
71        let recovery_window = Duration::from_micros(
72            (elapsed.as_micros() as f64 * self.throughput_ops_sec / self.reserve_capacity) as u64,
73        );
74        // Use the hash table's random state to hash the current time to get a random number.
75        let s = RandomState::new();
76        let random = s.hash_one(Instant::now());
77        // Scale the random number to be between 0 and 1.
78        let ratio = (random & 0x1fffffffffffffu64) as f64 / (1u64 << f64::MANTISSA_DIGITS) as f64;
79        // Scale the recovery window by the random number.
80        Duration::from_micros((recovery_window.as_micros() as f64 * ratio) as u64)
81    }
82}
83
84/////////////////////////////////////////////// tests //////////////////////////////////////////////
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn with_exponential_backoff() {
92        let exp_backoff = ExponentialBackoff::new(1_000.0, 100.0);
93        assert!(exp_backoff.next() < Duration::from_secs(1));
94        assert!(exp_backoff.next() < Duration::from_secs(1));
95        assert!(exp_backoff.next() < Duration::from_secs(1));
96        std::thread::sleep(Duration::from_secs(10));
97        let mut durations = (0..100).map(|_| exp_backoff.next()).collect::<Vec<_>>();
98        durations.sort();
99        assert!(
100            durations.iter().sum::<Duration>() / durations.len() as u32 > Duration::from_secs(10)
101        );
102    }
103}