ai_lib_rust/resilience/
rate_limiter.rs1use crate::Result;
2use std::time::{Duration, Instant};
3use tokio::sync::Mutex;
4
5#[derive(Debug, Clone)]
6pub struct RateLimiterSnapshot {
7 pub rps: f64,
8 pub burst: f64,
9 pub tokens: f64,
10 pub estimated_wait_ms: Option<u64>,
12}
13
14#[derive(Debug, Clone)]
15pub struct RateLimiterConfig {
16 pub rps: f64,
18 pub burst: f64,
20}
21
22impl RateLimiterConfig {
23 pub fn from_rps(rps: f64) -> Option<Self> {
24 if !rps.is_finite() || rps < 0.0 {
25 return None;
26 }
27 Some(Self {
28 rps,
29 burst: rps.max(1.0), })
31 }
32}
33
34#[derive(Debug)]
35struct State {
36 tokens: f64,
37 last: Instant,
38 blocked_until: Option<Instant>,
40 remaining: Option<u64>,
42}
43
44pub struct RateLimiter {
49 cfg: RateLimiterConfig,
50 state: Mutex<State>,
51}
52
53impl RateLimiter {
54 pub fn new(cfg: RateLimiterConfig) -> Self {
55 let burst = cfg.burst;
56 let state = Mutex::new(State {
57 tokens: burst,
58 last: Instant::now(),
59 blocked_until: None,
60 remaining: None,
61 });
62 Self { cfg, state }
63 }
64
65 fn refill_locked(cfg: &RateLimiterConfig, st: &mut State) {
66 let now = Instant::now();
67 let elapsed = now.duration_since(st.last).as_secs_f64();
68 if elapsed > 0.0 {
69 st.tokens = (st.tokens + elapsed * cfg.rps).min(cfg.burst);
70 st.last = now;
71 }
72 }
73
74 pub async fn acquire(&self) -> Result<()> {
76 let cfg = &self.cfg;
77
78 loop {
79 let wait_duration = {
80 let mut st = self.state.lock().await;
81 let now = Instant::now();
82
83 if let Some(until) = st.blocked_until {
85 if until > now {
86 until.duration_since(now)
88 } else {
89 st.blocked_until = None;
90 Duration::from_millis(0)
91 }
92 } else {
93 if cfg.rps <= 0.0 {
94 return Ok(());
95 }
96
97 Self::refill_locked(cfg, &mut st);
98
99 if st.tokens >= 1.0 && st.remaining.unwrap_or(1) > 0 {
101 st.tokens -= 1.0;
102 if let Some(rem) = st.remaining.as_mut() {
103 *rem = rem.saturating_sub(1);
104 }
105 return Ok(());
106 }
107
108 let missing = 1.0 - st.tokens;
110 Duration::from_secs_f64(missing / cfg.rps)
111 }
112 };
113
114 if wait_duration.as_millis() > 0 {
115 tokio::time::sleep(wait_duration).await;
116 }
117 }
118 }
119
120 pub async fn update_budget(
122 &self,
123 remaining: Option<u64>,
124 reset_after: Option<std::time::Duration>,
125 ) {
126 let mut st = self.state.lock().await;
127 if let Some(rem) = remaining {
128 st.remaining = Some(rem);
129 if rem == 0 {
130 let after = reset_after.unwrap_or(std::time::Duration::from_secs(1));
132 st.blocked_until = Some(Instant::now() + after);
133 } else {
134 st.blocked_until = None;
135 }
136 }
137 }
138
139 pub async fn snapshot(&self) -> RateLimiterSnapshot {
140 let cfg = &self.cfg;
141 let mut st = self.state.lock().await;
142 let now = Instant::now();
143
144 let mut wait_ms = None;
146 if let Some(until) = st.blocked_until {
147 if until > now {
148 wait_ms = Some(until.duration_since(now).as_millis() as u64);
149 }
150 }
151
152 if cfg.rps > 0.0 {
154 Self::refill_locked(cfg, &mut st);
155 if st.tokens < 1.0 {
156 let missing = 1.0 - st.tokens;
157 let local_wait_ms = (missing / cfg.rps * 1000.0) as u64;
158 wait_ms = Some(wait_ms.unwrap_or(0).max(local_wait_ms));
159 }
160 }
161
162 RateLimiterSnapshot {
163 rps: cfg.rps,
164 burst: cfg.burst,
165 tokens: st.tokens,
166 estimated_wait_ms: wait_ms,
167 }
168 }
169}