1use aspect_core::{Aspect, AspectError, ProceedingJoinPoint};
4use parking_lot::Mutex;
5use std::any::Any;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10#[derive(Clone)]
32pub struct RateLimitAspect {
33 state: Arc<Mutex<RateLimitState>>,
34}
35
36struct RateLimitState {
37 tokens: f64,
38 max_tokens: f64,
39 refill_rate: f64, last_refill: Instant,
41 per_function: bool,
42 function_states: HashMap<String, FunctionRateLimit>,
43}
44
45struct FunctionRateLimit {
46 tokens: f64,
47 last_refill: Instant,
48}
49
50impl RateLimitAspect {
51 pub fn new(max_requests: u64, window: Duration) -> Self {
63 let refill_rate = max_requests as f64 / window.as_secs_f64();
64
65 Self {
66 state: Arc::new(Mutex::new(RateLimitState {
67 tokens: max_requests as f64,
68 max_tokens: max_requests as f64,
69 refill_rate,
70 last_refill: Instant::now(),
71 per_function: false,
72 function_states: HashMap::new(),
73 })),
74 }
75 }
76
77 pub fn per_function(self) -> Self {
81 self.state.lock().per_function = true;
82 self
83 }
84
85 fn try_acquire(&self, function_name: Option<&str>) -> bool {
87 let mut state = self.state.lock();
88 let now = Instant::now();
89
90 if state.per_function {
91 if let Some(name) = function_name {
92 let max_tokens = state.max_tokens;
95 let refill_rate = state.refill_rate;
96
97 let func_state = state
98 .function_states
99 .entry(name.to_string())
100 .or_insert_with(|| FunctionRateLimit {
101 tokens: max_tokens,
102 last_refill: now,
103 });
104
105 let elapsed = now.duration_since(func_state.last_refill).as_secs_f64();
107 func_state.tokens = (func_state.tokens + elapsed * refill_rate).min(max_tokens);
108 func_state.last_refill = now;
109
110 if func_state.tokens >= 1.0 {
111 func_state.tokens -= 1.0;
112 true
113 } else {
114 false
115 }
116 } else {
117 false
118 }
119 } else {
120 let elapsed = now.duration_since(state.last_refill).as_secs_f64();
122 state.tokens = (state.tokens + elapsed * state.refill_rate).min(state.max_tokens);
123 state.last_refill = now;
124
125 if state.tokens >= 1.0 {
126 state.tokens -= 1.0;
127 true
128 } else {
129 false
130 }
131 }
132 }
133
134 pub fn available_tokens(&self) -> f64 {
136 let mut state = self.state.lock();
137 let now = Instant::now();
138 let elapsed = now.duration_since(state.last_refill).as_secs_f64();
139 state.tokens = (state.tokens + elapsed * state.refill_rate).min(state.max_tokens);
140 state.last_refill = now;
141 state.tokens
142 }
143}
144
145impl Aspect for RateLimitAspect {
146 fn around(&self, pjp: ProceedingJoinPoint) -> Result<Box<dyn Any>, AspectError> {
147 let function_name = pjp.context().function_name;
148
149 if self.try_acquire(Some(function_name)) {
150 pjp.proceed()
151 } else {
152 Err(AspectError::execution(format!(
153 "Rate limit exceeded for {}",
154 function_name
155 )))
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_rate_limit_basic() {
166 let limiter = RateLimitAspect::new(5, Duration::from_secs(1));
167
168 for _ in 0..5 {
170 assert!(limiter.try_acquire(Some("test")));
171 }
172
173 assert!(!limiter.try_acquire(Some("test")));
175 }
176
177 #[test]
178 fn test_rate_limit_refill() {
179 let limiter = RateLimitAspect::new(2, Duration::from_millis(100));
180
181 assert!(limiter.try_acquire(Some("test")));
183 assert!(limiter.try_acquire(Some("test")));
184 assert!(!limiter.try_acquire(Some("test")));
185
186 std::thread::sleep(Duration::from_millis(150));
188
189 assert!(limiter.try_acquire(Some("test")));
191 }
192
193 #[test]
194 fn test_per_function_limiting() {
195 let limiter = RateLimitAspect::new(2, Duration::from_secs(1)).per_function();
196
197 assert!(limiter.try_acquire(Some("func_a")));
199 assert!(limiter.try_acquire(Some("func_a")));
200 assert!(!limiter.try_acquire(Some("func_a")));
201
202 assert!(limiter.try_acquire(Some("func_b")));
204 assert!(limiter.try_acquire(Some("func_b")));
205 assert!(!limiter.try_acquire(Some("func_b")));
206 }
207
208 #[test]
209 fn test_available_tokens() {
210 let limiter = RateLimitAspect::new(10, Duration::from_secs(1));
211
212 let initial = limiter.available_tokens();
213 assert!((initial - 10.0).abs() < 0.01);
214
215 limiter.try_acquire(Some("test"));
216
217 let after = limiter.available_tokens();
218 assert!((after - 9.0).abs() < 0.01);
219 }
220}