Skip to main content

aspect_std/
ratelimit.rs

1//! Rate limiting aspect using token bucket algorithm.
2
3use aspect_core::{Aspect, AspectError, ProceedingJoinPoint};
4use parking_lot::Mutex;
5use std::any::Any;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10/// Rate limiting aspect with token bucket algorithm.
11///
12/// Limits the rate at which functions can be called, useful for API throttling,
13/// resource protection, and preventing abuse.
14///
15/// # Example
16///
17/// ```rust,ignore
18/// use aspect_std::RateLimitAspect;
19/// use aspect_macros::aspect;
20/// use std::time::Duration;
21///
22/// // Allow 10 calls per second
23/// let limiter = RateLimitAspect::new(10, Duration::from_secs(1));
24///
25/// #[aspect(limiter.clone())]
26/// fn api_call(data: String) -> Result<(), String> {
27///     // This function is rate-limited
28///     Ok(())
29/// }
30/// ```
31#[derive(Clone)]
32pub struct RateLimitAspect {
33    state: Arc<Mutex<RateLimitState>>,
34}
35
36struct RateLimitState {
37    tokens: f64,
38    max_tokens: f64,
39    refill_rate: f64, // tokens per second
40    last_refill: Instant,
41    per_function: bool,
42    function_states: HashMap<String, FunctionRateLimit>,
43}
44
45struct FunctionRateLimit {
46    tokens: f64,
47    last_refill: Instant,
48}
49
50impl RateLimitAspect {
51    /// Create a new rate limiter.
52    ///
53    /// # Arguments
54    /// * `max_requests` - Maximum number of requests allowed
55    /// * `window` - Time window for the limit
56    ///
57    /// # Example
58    /// ```rust,ignore
59    /// // 100 requests per minute
60    /// let limiter = RateLimitAspect::new(100, Duration::from_secs(60));
61    /// ```
62    pub fn new(max_requests: u64, window: Duration) -> Self {
63        let refill_rate = max_requests as f64 / window.as_secs_f64();
64
65        Self {
66            state: Arc::new(Mutex::new(RateLimitState {
67                tokens: max_requests as f64,
68                max_tokens: max_requests as f64,
69                refill_rate,
70                last_refill: Instant::now(),
71                per_function: false,
72                function_states: HashMap::new(),
73            })),
74        }
75    }
76
77    /// Enable per-function rate limiting.
78    ///
79    /// When enabled, each function gets its own token bucket.
80    pub fn per_function(self) -> Self {
81        self.state.lock().per_function = true;
82        self
83    }
84
85    /// Check if a request is allowed (consumes a token if available).
86    fn try_acquire(&self, function_name: Option<&str>) -> bool {
87        let mut state = self.state.lock();
88        let now = Instant::now();
89
90        if state.per_function {
91            if let Some(name) = function_name {
92                // Per-function rate limiting
93                // Capture values before borrowing function_states
94                let max_tokens = state.max_tokens;
95                let refill_rate = state.refill_rate;
96
97                let func_state = state
98                    .function_states
99                    .entry(name.to_string())
100                    .or_insert_with(|| FunctionRateLimit {
101                        tokens: max_tokens,
102                        last_refill: now,
103                    });
104
105                // Refill tokens
106                let elapsed = now.duration_since(func_state.last_refill).as_secs_f64();
107                func_state.tokens = (func_state.tokens + elapsed * refill_rate).min(max_tokens);
108                func_state.last_refill = now;
109
110                if func_state.tokens >= 1.0 {
111                    func_state.tokens -= 1.0;
112                    true
113                } else {
114                    false
115                }
116            } else {
117                false
118            }
119        } else {
120            // Global rate limiting
121            let elapsed = now.duration_since(state.last_refill).as_secs_f64();
122            state.tokens = (state.tokens + elapsed * state.refill_rate).min(state.max_tokens);
123            state.last_refill = now;
124
125            if state.tokens >= 1.0 {
126                state.tokens -= 1.0;
127                true
128            } else {
129                false
130            }
131        }
132    }
133
134    /// Get current token count.
135    pub fn available_tokens(&self) -> f64 {
136        let mut state = self.state.lock();
137        let now = Instant::now();
138        let elapsed = now.duration_since(state.last_refill).as_secs_f64();
139        state.tokens = (state.tokens + elapsed * state.refill_rate).min(state.max_tokens);
140        state.last_refill = now;
141        state.tokens
142    }
143}
144
145impl Aspect for RateLimitAspect {
146    fn around(&self, pjp: ProceedingJoinPoint) -> Result<Box<dyn Any>, AspectError> {
147        let function_name = pjp.context().function_name;
148
149        if self.try_acquire(Some(function_name)) {
150            pjp.proceed()
151        } else {
152            Err(AspectError::execution(format!(
153                "Rate limit exceeded for {}",
154                function_name
155            )))
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_rate_limit_basic() {
166        let limiter = RateLimitAspect::new(5, Duration::from_secs(1));
167
168        // Should allow 5 calls
169        for _ in 0..5 {
170            assert!(limiter.try_acquire(Some("test")));
171        }
172
173        // 6th call should be denied
174        assert!(!limiter.try_acquire(Some("test")));
175    }
176
177    #[test]
178    fn test_rate_limit_refill() {
179        let limiter = RateLimitAspect::new(2, Duration::from_millis(100));
180
181        // Consume both tokens
182        assert!(limiter.try_acquire(Some("test")));
183        assert!(limiter.try_acquire(Some("test")));
184        assert!(!limiter.try_acquire(Some("test")));
185
186        // Wait for refill
187        std::thread::sleep(Duration::from_millis(150));
188
189        // Should have at least 1 token now
190        assert!(limiter.try_acquire(Some("test")));
191    }
192
193    #[test]
194    fn test_per_function_limiting() {
195        let limiter = RateLimitAspect::new(2, Duration::from_secs(1)).per_function();
196
197        // Function A consumes its quota
198        assert!(limiter.try_acquire(Some("func_a")));
199        assert!(limiter.try_acquire(Some("func_a")));
200        assert!(!limiter.try_acquire(Some("func_a")));
201
202        // Function B should still have its quota
203        assert!(limiter.try_acquire(Some("func_b")));
204        assert!(limiter.try_acquire(Some("func_b")));
205        assert!(!limiter.try_acquire(Some("func_b")));
206    }
207
208    #[test]
209    fn test_available_tokens() {
210        let limiter = RateLimitAspect::new(10, Duration::from_secs(1));
211
212        let initial = limiter.available_tokens();
213        assert!((initial - 10.0).abs() < 0.01);
214
215        limiter.try_acquire(Some("test"));
216
217        let after = limiter.available_tokens();
218        assert!((after - 9.0).abs() < 0.01);
219    }
220}