klauthed_data/rate_limit/
memory.rs1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::Duration as StdDuration;
6
7use async_trait::async_trait;
8use klauthed_core::time::{Clock, Duration, SystemClock, Timestamp};
9
10use super::{RateLimitOutcome, RateLimiter};
11use crate::error::DataError;
12
13#[derive(Debug, Clone, Copy)]
15struct Window {
16 started: Timestamp,
17 count: u32,
18}
19
20#[derive(Clone)]
28pub struct InMemoryRateLimiter {
29 windows: Arc<Mutex<HashMap<String, Window>>>,
30 clock: Arc<dyn Clock>,
31}
32
33impl std::fmt::Debug for InMemoryRateLimiter {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 let len = self.windows.lock().map(|m| m.len()).unwrap_or(0);
36 f.debug_struct("InMemoryRateLimiter").field("keys", &len).finish_non_exhaustive()
37 }
38}
39
40impl InMemoryRateLimiter {
41 #[must_use]
43 pub fn new(clock: Arc<dyn Clock>) -> Self {
44 Self { windows: Arc::new(Mutex::new(HashMap::new())), clock }
45 }
46
47 #[must_use]
49 pub fn system() -> Self {
50 Self::new(Arc::new(SystemClock))
51 }
52
53 #[must_use]
56 pub fn len(&self) -> usize {
57 self.windows.lock().unwrap_or_else(std::sync::PoisonError::into_inner).len()
58 }
59
60 #[must_use]
62 pub fn is_empty(&self) -> bool {
63 self.windows.lock().unwrap_or_else(std::sync::PoisonError::into_inner).is_empty()
64 }
65}
66
67#[async_trait]
68impl RateLimiter for InMemoryRateLimiter {
69 async fn check(
70 &self,
71 key: &str,
72 max: u32,
73 window: StdDuration,
74 ) -> Result<RateLimitOutcome, DataError> {
75 let max = max.max(1);
76 let window_core = Duration::milliseconds(window.as_millis().min(i64::MAX as u128) as i64);
77 let now = self.clock.now();
78
79 let mut windows = self.windows.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
80 let entry = windows.entry(key.to_owned()).or_insert(Window { started: now, count: 0 });
81
82 if now.duration_since(entry.started) >= window_core {
84 entry.started = now;
85 entry.count = 0;
86 }
87
88 if entry.count >= max {
89 let elapsed = now.duration_since(entry.started);
90 let remaining = window_core - elapsed;
91 let retry_after =
92 StdDuration::from_millis(remaining.whole_milliseconds().max(0) as u64);
93 Ok(RateLimitOutcome::Limited { retry_after })
94 } else {
95 entry.count += 1;
96 Ok(RateLimitOutcome::Allowed { remaining: max - entry.count })
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy)]
103struct Bucket {
104 tokens: f64,
106 refilled_at: Timestamp,
108}
109
110#[derive(Clone)]
122pub struct InMemoryTokenBucket {
123 buckets: Arc<Mutex<HashMap<String, Bucket>>>,
124 clock: Arc<dyn Clock>,
125}
126
127impl std::fmt::Debug for InMemoryTokenBucket {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 let len = self.buckets.lock().map(|m| m.len()).unwrap_or(0);
130 f.debug_struct("InMemoryTokenBucket").field("keys", &len).finish_non_exhaustive()
131 }
132}
133
134impl InMemoryTokenBucket {
135 #[must_use]
137 pub fn new(clock: Arc<dyn Clock>) -> Self {
138 Self { buckets: Arc::new(Mutex::new(HashMap::new())), clock }
139 }
140
141 #[must_use]
143 pub fn system() -> Self {
144 Self::new(Arc::new(SystemClock))
145 }
146}
147
148#[async_trait]
149impl RateLimiter for InMemoryTokenBucket {
150 async fn check(
151 &self,
152 key: &str,
153 max: u32,
154 window: StdDuration,
155 ) -> Result<RateLimitOutcome, DataError> {
156 let capacity = f64::from(max.max(1));
157 let refill_per_sec = capacity / window.as_secs_f64().max(f64::MIN_POSITIVE);
159 let now = self.clock.now();
160
161 let mut buckets = self.buckets.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
162 let bucket =
164 buckets.entry(key.to_owned()).or_insert(Bucket { tokens: capacity, refilled_at: now });
165
166 let elapsed = now.duration_since(bucket.refilled_at).as_seconds_f64().max(0.0);
167 bucket.tokens = (bucket.tokens + elapsed * refill_per_sec).min(capacity);
168 bucket.refilled_at = now;
169
170 if bucket.tokens >= 1.0 {
171 bucket.tokens -= 1.0;
172 Ok(RateLimitOutcome::Allowed { remaining: bucket.tokens as u32 })
173 } else {
174 let secs_until_token = (1.0 - bucket.tokens) / refill_per_sec;
175 Ok(RateLimitOutcome::Limited {
176 retry_after: StdDuration::from_secs_f64(secs_until_token),
177 })
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use klauthed_core::time::FixedClock;
186
187 fn limiter_at(millis: i64) -> (Arc<FixedClock>, InMemoryRateLimiter) {
188 let clock = Arc::new(FixedClock::at_unix_millis(millis));
189 (clock.clone(), InMemoryRateLimiter::new(clock))
190 }
191
192 fn bucket_at(millis: i64) -> (Arc<FixedClock>, InMemoryTokenBucket) {
193 let clock = Arc::new(FixedClock::at_unix_millis(millis));
194 (clock.clone(), InMemoryTokenBucket::new(clock))
195 }
196
197 #[tokio::test]
198 async fn token_bucket_allows_initial_burst_up_to_capacity() {
199 let (_clock, tb) = bucket_at(0);
200 let window = StdDuration::from_secs(10); assert!(tb.check("k", 2, window).await.unwrap().is_allowed());
203 assert!(tb.check("k", 2, window).await.unwrap().is_allowed());
204 assert!(!tb.check("k", 2, window).await.unwrap().is_allowed());
205 }
206
207 #[tokio::test]
208 async fn token_bucket_refills_continuously() {
209 let (clock, tb) = bucket_at(0);
210 let window = StdDuration::from_secs(10); tb.check("k", 2, window).await.unwrap();
213 tb.check("k", 2, window).await.unwrap();
214 assert!(!tb.check("k", 2, window).await.unwrap().is_allowed());
215
216 clock.advance(Duration::seconds(5));
218 assert!(tb.check("k", 2, window).await.unwrap().is_allowed());
219 assert!(!tb.check("k", 2, window).await.unwrap().is_allowed());
220 }
221
222 #[tokio::test]
223 async fn token_bucket_limited_reports_retry_after() {
224 let (_clock, tb) = bucket_at(0);
225 let window = StdDuration::from_secs(10); tb.check("k", 1, window).await.unwrap(); match tb.check("k", 1, window).await.unwrap() {
228 RateLimitOutcome::Limited { retry_after } => {
229 assert_eq!(retry_after.as_secs(), 10);
231 }
232 other => panic!("expected Limited, got {other:?}"),
233 }
234 }
235
236 #[tokio::test]
237 async fn allows_up_to_max_then_limits_then_resets() {
238 let (clock, limiter) = limiter_at(0);
239 let window = StdDuration::from_secs(10);
240
241 assert_eq!(
242 limiter.check("k", 2, window).await.unwrap(),
243 RateLimitOutcome::Allowed { remaining: 1 }
244 );
245 assert_eq!(
246 limiter.check("k", 2, window).await.unwrap(),
247 RateLimitOutcome::Allowed { remaining: 0 }
248 );
249 assert!(!limiter.check("k", 2, window).await.unwrap().is_allowed());
250
251 clock.advance(Duration::seconds(10));
253 assert!(limiter.check("k", 2, window).await.unwrap().is_allowed());
254 }
255
256 #[tokio::test]
257 async fn keys_are_independent() {
258 let (_clock, limiter) = limiter_at(0);
259 let window = StdDuration::from_secs(10);
260 assert!(limiter.check("a", 1, window).await.unwrap().is_allowed());
261 assert!(!limiter.check("a", 1, window).await.unwrap().is_allowed());
262 assert!(limiter.check("b", 1, window).await.unwrap().is_allowed());
264 assert_eq!(limiter.len(), 2);
265 }
266
267 #[tokio::test]
268 async fn limited_reports_time_until_reset() {
269 let (clock, limiter) = limiter_at(0);
270 let window = StdDuration::from_secs(60);
271 limiter.check("k", 1, window).await.unwrap();
272 clock.advance(Duration::seconds(20));
273 match limiter.check("k", 1, window).await.unwrap() {
274 RateLimitOutcome::Limited { retry_after } => {
275 assert_eq!(retry_after, StdDuration::from_secs(40));
277 }
278 other => panic!("expected Limited, got {other:?}"),
279 }
280 }
281}