Skip to main content

do_over/
rate_limit.rs

1//! Rate limiter policy using a token bucket algorithm.
2//!
3//! The rate limiter controls the rate of operations by maintaining a bucket
4//! of tokens that refill at regular intervals.
5//!
6//! # How It Works
7//!
8//! - Each operation consumes one token
9//! - If no tokens are available, the operation is rejected
10//! - Tokens refill to capacity after each interval
11//!
12//! # Examples
13//!
14//! ```rust
15//! use do_over::{policy::Policy, rate_limit::RateLimiter, error::DoOverError};
16//! use std::time::Duration;
17//!
18//! # async fn example() -> Result<(), DoOverError<std::io::Error>> {
19//! // Allow 100 requests per second
20//! let limiter = RateLimiter::new(100, Duration::from_secs(1));
21//!
22//! match limiter.execute(|| async {
23//!     Ok::<_, DoOverError<std::io::Error>>("completed")
24//! }).await {
25//!     Ok(result) => println!("Success: {}", result),
26//!     Err(DoOverError::BulkheadFull) => println!("Rate limit exceeded"),
27//!     Err(e) => println!("Error: {:?}", e),
28//! }
29//! # Ok(())
30//! # }
31//! ```
32
33use std::time::{Duration, Instant};
34use std::sync::Arc;
35use tokio::sync::Mutex;
36use crate::{policy::Policy, error::DoOverError};
37
38/// A token bucket rate limiter.
39///
40/// The rate limiter maintains a bucket of tokens. Each operation consumes one
41/// token. When tokens are depleted, operations are rejected until the bucket
42/// refills.
43///
44/// # Examples
45///
46/// ```rust
47/// use do_over::{policy::Policy, rate_limit::RateLimiter, error::DoOverError};
48/// use std::time::Duration;
49///
50/// # async fn example() {
51/// // 10 requests per second
52/// let limiter = RateLimiter::new(10, Duration::from_secs(1));
53///
54/// // 1000 requests per minute
55/// let limiter = RateLimiter::new(1000, Duration::from_secs(60));
56/// # }
57/// ```
58pub struct RateLimiter {
59    capacity: u64,
60    interval: Duration,
61    state: Arc<Mutex<(u64, Instant)>>,
62}
63
64impl Clone for RateLimiter {
65    fn clone(&self) -> Self {
66        Self {
67            capacity: self.capacity,
68            interval: self.interval,
69            state: Arc::clone(&self.state),
70        }
71    }
72}
73
74impl RateLimiter {
75    /// Create a new rate limiter.
76    ///
77    /// # Arguments
78    ///
79    /// * `capacity` - Number of tokens (requests) allowed per interval
80    /// * `interval` - Duration after which tokens refill
81    ///
82    /// # Examples
83    ///
84    /// ```rust
85    /// use do_over::rate_limit::RateLimiter;
86    /// use std::time::Duration;
87    ///
88    /// // Allow 100 requests per second
89    /// let limiter = RateLimiter::new(100, Duration::from_secs(1));
90    ///
91    /// // Allow 5 requests per 100ms (burst limiting)
92    /// let limiter = RateLimiter::new(5, Duration::from_millis(100));
93    /// ```
94    pub fn new(capacity: u64, interval: Duration) -> Self {
95        Self {
96            capacity,
97            interval,
98            state: Arc::new(Mutex::new((capacity, Instant::now()))),
99        }
100    }
101}
102
103#[async_trait::async_trait]
104impl<E> Policy<DoOverError<E>> for RateLimiter
105where
106    E: Send + Sync,
107{
108    async fn execute<F, Fut, T>(&self, f: F) -> Result<T, DoOverError<E>>
109    where
110        F: Fn() -> Fut + Send + Sync,
111        Fut: std::future::Future<Output = Result<T, DoOverError<E>>> + Send,
112        T: Send,
113    {
114        let mut state = self.state.lock().await;
115        if state.1.elapsed() >= self.interval {
116            state.0 = self.capacity;
117            state.1 = Instant::now();
118        }
119
120        if state.0 == 0 {
121            return Err(DoOverError::BulkheadFull);
122        }
123
124        state.0 -= 1;
125        drop(state);
126        f().await
127    }
128}