reliability_toolkit/
rate_limiter.rs1use std::sync::Arc;
24use std::time::Duration;
25
26use tokio::sync::Mutex;
27use tokio::time::{sleep, Instant};
28
29#[derive(Clone, Debug)]
31pub struct RateLimiter {
32 inner: Arc<Inner>,
33}
34
35#[derive(Debug)]
36struct Inner {
37 rate_per_second: f64,
38 burst: f64,
39 state: Mutex<State>,
40}
41
42#[derive(Debug)]
43struct State {
44 tokens: f64,
45 last_refill: Instant,
46}
47
48impl RateLimiter {
49 pub fn new(rate_per_second: f64, burst: u32) -> Self {
55 assert!(rate_per_second > 0.0, "rate_per_second must be positive");
56 assert!(burst > 0, "burst must be non-zero");
57 let burst_f = f64::from(burst);
58 Self {
59 inner: Arc::new(Inner {
60 rate_per_second,
61 burst: burst_f,
62 state: Mutex::new(State {
63 tokens: burst_f,
64 last_refill: Instant::now(),
65 }),
66 }),
67 }
68 }
69
70 pub async fn acquire(&self) {
72 self.acquire_n(1).await;
73 }
74
75 pub async fn acquire_n(&self, n: u32) {
81 let needed = f64::from(n);
82 assert!(
83 needed <= self.inner.burst,
84 "requested {n} tokens but burst is {}",
85 self.inner.burst
86 );
87
88 loop {
89 let wait = {
90 let mut state = self.inner.state.lock().await;
91 self.refill(&mut state);
92 if state.tokens >= needed {
93 state.tokens -= needed;
94 return;
95 }
96 let deficit = needed - state.tokens;
97 let seconds = deficit / self.inner.rate_per_second;
98 Duration::from_secs_f64(seconds)
99 };
100 sleep(wait).await;
101 }
102 }
103
104 pub async fn try_acquire(&self) -> bool {
106 self.try_acquire_n(1).await
107 }
108
109 pub async fn try_acquire_n(&self, n: u32) -> bool {
111 let needed = f64::from(n);
112 let mut state = self.inner.state.lock().await;
113 self.refill(&mut state);
114 if state.tokens >= needed {
115 state.tokens -= needed;
116 true
117 } else {
118 false
119 }
120 }
121
122 pub async fn tokens(&self) -> f64 {
124 let mut state = self.inner.state.lock().await;
125 self.refill(&mut state);
126 state.tokens
127 }
128
129 fn refill(&self, state: &mut State) {
130 let now = Instant::now();
131 let elapsed = now.duration_since(state.last_refill).as_secs_f64();
132 if elapsed > 0.0 {
133 state.tokens =
134 (state.tokens + elapsed * self.inner.rate_per_second).min(self.inner.burst);
135 state.last_refill = now;
136 }
137 }
138}