cognate_core/
ratelimit.rs1use std::sync::Arc;
7use tokio::sync::Mutex;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
12pub struct TokenBucket {
13 state: Arc<Mutex<BucketState>>,
14}
15
16#[derive(Debug)]
17struct BucketState {
18 last_refill: Instant,
19 tokens: f64,
20 capacity: f64,
21 fill_rate: f64,
22}
23
24impl TokenBucket {
25 pub fn new(capacity: f64, fill_rate: f64) -> Self {
31 Self {
32 state: Arc::new(Mutex::new(BucketState {
33 last_refill: Instant::now(),
34 tokens: capacity,
35 capacity,
36 fill_rate,
37 })),
38 }
39 }
40
41 pub async fn try_acquire(&self, amount: f64) -> bool {
45 let mut state = self.state.lock().await;
46 state.refill();
47
48 if state.tokens >= amount {
49 state.tokens -= amount;
50 true
51 } else {
52 false
53 }
54 }
55
56 pub async fn acquire(&self, amount: f64) -> Duration {
58 loop {
59 let mut state = self.state.lock().await;
60 state.refill();
61
62 if state.tokens >= amount {
63 state.tokens -= amount;
64 return Duration::from_secs(0);
65 }
66
67 let tokens_needed = amount - state.tokens;
68 let wait_time = Duration::from_secs_f64(tokens_needed / state.fill_rate);
69
70 drop(state);
72 tokio::time::sleep(wait_time).await;
73 }
74 }
75}
76
77impl BucketState {
78 fn refill(&mut self) {
79 let now = Instant::now();
80 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
81
82 self.tokens = (self.tokens + elapsed * self.fill_rate).min(self.capacity);
83 self.last_refill = now;
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use std::time::Duration;
91
92 #[tokio::test]
93 async fn test_token_bucket() {
94 let bucket = TokenBucket::new(10.0, 1.0);
95
96 assert!(bucket.try_acquire(5.0).await);
98
99 assert!(bucket.try_acquire(5.0).await);
101
102 assert!(!bucket.try_acquire(1.0).await);
104
105 tokio::time::sleep(Duration::from_millis(1100)).await;
107 assert!(bucket.try_acquire(1.0).await);
108 }
109}