Skip to main content

ai_lib_rust/resilience/
rate_limiter.rs

1use crate::Result;
2use std::time::{Duration, Instant};
3use tokio::sync::Mutex;
4
5#[derive(Debug, Clone)]
6pub struct RateLimiterSnapshot {
7    pub rps: f64,
8    pub burst: f64,
9    pub tokens: f64,
10    /// Estimated wait time until a token is available (ms), if currently empty.
11    pub estimated_wait_ms: Option<u64>,
12}
13
14#[derive(Debug, Clone)]
15pub struct RateLimiterConfig {
16    /// Tokens per second.
17    pub rps: f64,
18    /// Maximum burst size (tokens).
19    pub burst: f64,
20}
21
22impl RateLimiterConfig {
23    pub fn from_rps(rps: f64) -> Option<Self> {
24        if !rps.is_finite() || rps < 0.0 {
25            return None;
26        }
27        Some(Self {
28            rps,
29            burst: rps.max(1.0), // default burst: 1 second worth, at least 1
30        })
31    }
32}
33
34#[derive(Debug)]
35struct State {
36    tokens: f64,
37    last: Instant,
38    /// Absolute time when the budget is expected to reset (if blocked by provider)
39    blocked_until: Option<Instant>,
40    /// Last reported remaining budget from provider
41    remaining: Option<u64>,
42}
43
44/// Minimal token-bucket rate limiter (opt-in).
45///
46/// - Default disabled unless configured
47/// - Best-effort fairness for async tasks
48pub struct RateLimiter {
49    cfg: RateLimiterConfig,
50    state: Mutex<State>,
51}
52
53impl RateLimiter {
54    pub fn new(cfg: RateLimiterConfig) -> Self {
55        let burst = cfg.burst;
56        let state = Mutex::new(State {
57            tokens: burst,
58            last: Instant::now(),
59            blocked_until: None,
60            remaining: None,
61        });
62        Self { cfg, state }
63    }
64
65    fn refill_locked(cfg: &RateLimiterConfig, st: &mut State) {
66        let now = Instant::now();
67        let elapsed = now.duration_since(st.last).as_secs_f64();
68        if elapsed > 0.0 {
69            st.tokens = (st.tokens + elapsed * cfg.rps).min(cfg.burst);
70            st.last = now;
71        }
72    }
73
74    /// Acquire one token (may sleep).
75    pub async fn acquire(&self) -> Result<()> {
76        let cfg = &self.cfg;
77
78        loop {
79            let wait_duration = {
80                let mut st = self.state.lock().await;
81                let now = Instant::now();
82
83                // 1. Check if we are explicitly blocked by an external signal
84                if let Some(until) = st.blocked_until {
85                    if until > now {
86                        // Remain in loop and wait
87                        until.duration_since(now)
88                    } else {
89                        st.blocked_until = None;
90                        Duration::from_millis(0)
91                    }
92                } else {
93                    if cfg.rps <= 0.0 {
94                        return Ok(());
95                    }
96
97                    Self::refill_locked(cfg, &mut st);
98
99                    // 2. If we have local tokens and aren't hearing "remaining: 0" from provider, go.
100                    if st.tokens >= 1.0 && st.remaining.unwrap_or(1) > 0 {
101                        st.tokens -= 1.0;
102                        if let Some(rem) = st.remaining.as_mut() {
103                            *rem = rem.saturating_sub(1);
104                        }
105                        return Ok(());
106                    }
107
108                    // 3. Compute wait time until next token or reset
109                    let missing = 1.0 - st.tokens;
110                    Duration::from_secs_f64(missing / cfg.rps)
111                }
112            };
113
114            if wait_duration.as_millis() > 0 {
115                tokio::time::sleep(wait_duration).await;
116            }
117        }
118    }
119
120    /// Update rate limiter state based on external signals (e.g. HTTP headers)
121    pub async fn update_budget(
122        &self,
123        remaining: Option<u64>,
124        reset_after: Option<std::time::Duration>,
125    ) {
126        let mut st = self.state.lock().await;
127        if let Some(rem) = remaining {
128            st.remaining = Some(rem);
129            if rem == 0 {
130                // If 0 remaining, we must wait until reset or a default backoff
131                let after = reset_after.unwrap_or(std::time::Duration::from_secs(1));
132                st.blocked_until = Some(Instant::now() + after);
133            } else {
134                st.blocked_until = None;
135            }
136        }
137    }
138
139    pub async fn snapshot(&self) -> RateLimiterSnapshot {
140        let cfg = &self.cfg;
141        let mut st = self.state.lock().await;
142        let now = Instant::now();
143
144        // 1. Check external block first
145        let mut wait_ms = None;
146        if let Some(until) = st.blocked_until {
147            if until > now {
148                wait_ms = Some(until.duration_since(now).as_millis() as u64);
149            }
150        }
151
152        // 2. Then check local token bucket if no external block or if longer
153        if cfg.rps > 0.0 {
154            Self::refill_locked(cfg, &mut st);
155            if st.tokens < 1.0 {
156                let missing = 1.0 - st.tokens;
157                let local_wait_ms = (missing / cfg.rps * 1000.0) as u64;
158                wait_ms = Some(wait_ms.unwrap_or(0).max(local_wait_ms));
159            }
160        }
161
162        RateLimiterSnapshot {
163            rps: cfg.rps,
164            burst: cfg.burst,
165            tokens: st.tokens,
166            estimated_wait_ms: wait_ms,
167        }
168    }
169}