Skip to main content

cognate_core/
ratelimit.rs

1//! Token bucket rate limiting implementation
2//!
3//! This module provides a simple, async-safe token bucket for rate limiting
4//! LLM provider requests.
5
6use std::sync::Arc;
7use tokio::sync::Mutex;
8use std::time::{Duration, Instant};
9
10/// A token bucket rate limiter
11#[derive(Debug, Clone)]
12pub struct TokenBucket {
13    state: Arc<Mutex<BucketState>>,
14}
15
16#[derive(Debug)]
17struct BucketState {
18    last_refill: Instant,
19    tokens: f64,
20    capacity: f64,
21    fill_rate: f64,
22}
23
24impl TokenBucket {
25    /// Create a new token bucket
26    ///
27    /// # Arguments
28    /// * `capacity` - Maximum number of tokens the bucket can hold
29    /// * `fill_rate` - Number of tokens added to the bucket per second
30    pub fn new(capacity: f64, fill_rate: f64) -> Self {
31        Self {
32            state: Arc::new(Mutex::new(BucketState {
33                last_refill: Instant::now(),
34                tokens: capacity,
35                capacity,
36                fill_rate,
37            })),
38        }
39    }
40
41    /// Try to acquire tokens from the bucket
42    ///
43    /// Returns true if tokens were acquired, false otherwise.
44    pub async fn try_acquire(&self, amount: f64) -> bool {
45        let mut state = self.state.lock().await;
46        state.refill();
47
48        if state.tokens >= amount {
49            state.tokens -= amount;
50            true
51        } else {
52            false
53        }
54    }
55
56    /// Wait until tokens are available and acquire them
57    pub async fn acquire(&self, amount: f64) -> Duration {
58        loop {
59            let mut state = self.state.lock().await;
60            state.refill();
61
62            if state.tokens >= amount {
63                state.tokens -= amount;
64                return Duration::from_secs(0);
65            }
66
67            let tokens_needed = amount - state.tokens;
68            let wait_time = Duration::from_secs_f64(tokens_needed / state.fill_rate);
69            
70            // Drop lock before sleeping
71            drop(state);
72            tokio::time::sleep(wait_time).await;
73        }
74    }
75}
76
77impl BucketState {
78    fn refill(&mut self) {
79        let now = Instant::now();
80        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
81        
82        self.tokens = (self.tokens + elapsed * self.fill_rate).min(self.capacity);
83        self.last_refill = now;
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use std::time::Duration;
91
92    #[tokio::test]
93    async fn test_token_bucket() {
94        let bucket = TokenBucket::new(10.0, 1.0);
95        
96        // Should be able to acquire 5 tokens immediately
97        assert!(bucket.try_acquire(5.0).await);
98        
99        // Should be able to acquire another 5 tokens immediately
100        assert!(bucket.try_acquire(5.0).await);
101        
102        // Should fail to acquire more tokens
103        assert!(!bucket.try_acquire(1.0).await);
104        
105        // Wait for 1.1s to get 1 token
106        tokio::time::sleep(Duration::from_millis(1100)).await;
107        assert!(bucket.try_acquire(1.0).await);
108    }
109}