Skip to main content

klauthed_data/rate_limit/
memory.rs

1//! The in-process [`InMemoryRateLimiter`].
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::Duration as StdDuration;
6
7use async_trait::async_trait;
8use klauthed_core::time::{Clock, Duration, SystemClock, Timestamp};
9
10use super::{RateLimitOutcome, RateLimiter};
11use crate::error::DataError;
12
13/// One key's counter within the current fixed window.
14#[derive(Debug, Clone, Copy)]
15struct Window {
16    started: Timestamp,
17    count: u32,
18}
19
20/// A per-process, fixed-window [`RateLimiter`] backed by a `Mutex<HashMap>`.
21///
22/// Counters live in this process only — each replica enforces its own budget, so
23/// for a multi-replica deployment with one global budget use a shared backend
24/// such as [`RedisRateLimiter`](super::RedisRateLimiter). "Now" comes from an
25/// injected [`Clock`], so expiry is deterministic under a `FixedClock` in tests.
26/// Cloned handles share the same backing map.
27#[derive(Clone)]
28pub struct InMemoryRateLimiter {
29    windows: Arc<Mutex<HashMap<String, Window>>>,
30    clock: Arc<dyn Clock>,
31}
32
33impl std::fmt::Debug for InMemoryRateLimiter {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        let len = self.windows.lock().map(|m| m.len()).unwrap_or(0);
36        f.debug_struct("InMemoryRateLimiter").field("keys", &len).finish_non_exhaustive()
37    }
38}
39
40impl InMemoryRateLimiter {
41    /// A limiter driven by `clock`.
42    #[must_use]
43    pub fn new(clock: Arc<dyn Clock>) -> Self {
44        Self { windows: Arc::new(Mutex::new(HashMap::new())), clock }
45    }
46
47    /// A limiter driven by the real system clock.
48    #[must_use]
49    pub fn system() -> Self {
50        Self::new(Arc::new(SystemClock))
51    }
52
53    /// Number of distinct keys currently tracked (including windows that have
54    /// elapsed but not yet been overwritten).
55    #[must_use]
56    pub fn len(&self) -> usize {
57        self.windows.lock().unwrap_or_else(std::sync::PoisonError::into_inner).len()
58    }
59
60    /// Whether no keys are currently tracked.
61    #[must_use]
62    pub fn is_empty(&self) -> bool {
63        self.windows.lock().unwrap_or_else(std::sync::PoisonError::into_inner).is_empty()
64    }
65}
66
67#[async_trait]
68impl RateLimiter for InMemoryRateLimiter {
69    async fn check(
70        &self,
71        key: &str,
72        max: u32,
73        window: StdDuration,
74    ) -> Result<RateLimitOutcome, DataError> {
75        let max = max.max(1);
76        let window_core = Duration::milliseconds(window.as_millis().min(i64::MAX as u128) as i64);
77        let now = self.clock.now();
78
79        let mut windows = self.windows.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
80        let entry = windows.entry(key.to_owned()).or_insert(Window { started: now, count: 0 });
81
82        // Reset the window if it has elapsed.
83        if now.duration_since(entry.started) >= window_core {
84            entry.started = now;
85            entry.count = 0;
86        }
87
88        if entry.count >= max {
89            let elapsed = now.duration_since(entry.started);
90            let remaining = window_core - elapsed;
91            let retry_after =
92                StdDuration::from_millis(remaining.whole_milliseconds().max(0) as u64);
93            Ok(RateLimitOutcome::Limited { retry_after })
94        } else {
95            entry.count += 1;
96            Ok(RateLimitOutcome::Allowed { remaining: max - entry.count })
97        }
98    }
99}
100
101/// One key's token-bucket state.
102#[derive(Debug, Clone, Copy)]
103struct Bucket {
104    /// Fractional tokens currently available.
105    tokens: f64,
106    /// When `tokens` was last refilled.
107    refilled_at: Timestamp,
108}
109
110/// A per-process **token-bucket** [`RateLimiter`].
111///
112/// Unlike the fixed-window [`InMemoryRateLimiter`] (which resets in hard steps),
113/// the bucket refills *continuously*: it holds up to `max` tokens (the burst
114/// size) and refills at `max / window` tokens per second, so traffic is smoothed
115/// rather than allowed in bursts at each window boundary. Each request spends one
116/// token; an empty bucket reports the time until the next token.
117///
118/// It implements the same [`RateLimiter`] trait with the same `(max, window)`
119/// parameters, so it is a drop-in alternative wherever a fixed-window limiter is
120/// used. Clock-injected for deterministic tests.
121#[derive(Clone)]
122pub struct InMemoryTokenBucket {
123    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
124    clock: Arc<dyn Clock>,
125}
126
127impl std::fmt::Debug for InMemoryTokenBucket {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        let len = self.buckets.lock().map(|m| m.len()).unwrap_or(0);
130        f.debug_struct("InMemoryTokenBucket").field("keys", &len).finish_non_exhaustive()
131    }
132}
133
134impl InMemoryTokenBucket {
135    /// A token bucket driven by `clock`.
136    #[must_use]
137    pub fn new(clock: Arc<dyn Clock>) -> Self {
138        Self { buckets: Arc::new(Mutex::new(HashMap::new())), clock }
139    }
140
141    /// A token bucket driven by the real system clock.
142    #[must_use]
143    pub fn system() -> Self {
144        Self::new(Arc::new(SystemClock))
145    }
146}
147
148#[async_trait]
149impl RateLimiter for InMemoryTokenBucket {
150    async fn check(
151        &self,
152        key: &str,
153        max: u32,
154        window: StdDuration,
155    ) -> Result<RateLimitOutcome, DataError> {
156        let capacity = f64::from(max.max(1));
157        // Tokens replenished per second so a full `capacity` refills over `window`.
158        let refill_per_sec = capacity / window.as_secs_f64().max(f64::MIN_POSITIVE);
159        let now = self.clock.now();
160
161        let mut buckets = self.buckets.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
162        // New keys start full, permitting an initial burst up to `capacity`.
163        let bucket =
164            buckets.entry(key.to_owned()).or_insert(Bucket { tokens: capacity, refilled_at: now });
165
166        let elapsed = now.duration_since(bucket.refilled_at).as_seconds_f64().max(0.0);
167        bucket.tokens = (bucket.tokens + elapsed * refill_per_sec).min(capacity);
168        bucket.refilled_at = now;
169
170        if bucket.tokens >= 1.0 {
171            bucket.tokens -= 1.0;
172            Ok(RateLimitOutcome::Allowed { remaining: bucket.tokens as u32 })
173        } else {
174            let secs_until_token = (1.0 - bucket.tokens) / refill_per_sec;
175            Ok(RateLimitOutcome::Limited {
176                retry_after: StdDuration::from_secs_f64(secs_until_token),
177            })
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use klauthed_core::time::FixedClock;
186
187    fn limiter_at(millis: i64) -> (Arc<FixedClock>, InMemoryRateLimiter) {
188        let clock = Arc::new(FixedClock::at_unix_millis(millis));
189        (clock.clone(), InMemoryRateLimiter::new(clock))
190    }
191
192    fn bucket_at(millis: i64) -> (Arc<FixedClock>, InMemoryTokenBucket) {
193        let clock = Arc::new(FixedClock::at_unix_millis(millis));
194        (clock.clone(), InMemoryTokenBucket::new(clock))
195    }
196
197    #[tokio::test]
198    async fn token_bucket_allows_initial_burst_up_to_capacity() {
199        let (_clock, tb) = bucket_at(0);
200        let window = StdDuration::from_secs(10); // 2 tokens / 10s
201        // Starts full: two requests succeed immediately, the third is limited.
202        assert!(tb.check("k", 2, window).await.unwrap().is_allowed());
203        assert!(tb.check("k", 2, window).await.unwrap().is_allowed());
204        assert!(!tb.check("k", 2, window).await.unwrap().is_allowed());
205    }
206
207    #[tokio::test]
208    async fn token_bucket_refills_continuously() {
209        let (clock, tb) = bucket_at(0);
210        let window = StdDuration::from_secs(10); // refill 0.2 tokens/s
211        // Drain the bucket (capacity 2).
212        tb.check("k", 2, window).await.unwrap();
213        tb.check("k", 2, window).await.unwrap();
214        assert!(!tb.check("k", 2, window).await.unwrap().is_allowed());
215
216        // After 5s, 0.2/s * 5s = 1 token refilled -> exactly one more allowed.
217        clock.advance(Duration::seconds(5));
218        assert!(tb.check("k", 2, window).await.unwrap().is_allowed());
219        assert!(!tb.check("k", 2, window).await.unwrap().is_allowed());
220    }
221
222    #[tokio::test]
223    async fn token_bucket_limited_reports_retry_after() {
224        let (_clock, tb) = bucket_at(0);
225        let window = StdDuration::from_secs(10); // 0.2 tokens/s -> ~5s for 1 token
226        tb.check("k", 1, window).await.unwrap(); // capacity 1, now empty
227        match tb.check("k", 1, window).await.unwrap() {
228            RateLimitOutcome::Limited { retry_after } => {
229                // 1 token at 0.1/s (capacity 1 over 10s) => ~10s.
230                assert_eq!(retry_after.as_secs(), 10);
231            }
232            other => panic!("expected Limited, got {other:?}"),
233        }
234    }
235
236    #[tokio::test]
237    async fn allows_up_to_max_then_limits_then_resets() {
238        let (clock, limiter) = limiter_at(0);
239        let window = StdDuration::from_secs(10);
240
241        assert_eq!(
242            limiter.check("k", 2, window).await.unwrap(),
243            RateLimitOutcome::Allowed { remaining: 1 }
244        );
245        assert_eq!(
246            limiter.check("k", 2, window).await.unwrap(),
247            RateLimitOutcome::Allowed { remaining: 0 }
248        );
249        assert!(!limiter.check("k", 2, window).await.unwrap().is_allowed());
250
251        // After the window elapses the budget refreshes.
252        clock.advance(Duration::seconds(10));
253        assert!(limiter.check("k", 2, window).await.unwrap().is_allowed());
254    }
255
256    #[tokio::test]
257    async fn keys_are_independent() {
258        let (_clock, limiter) = limiter_at(0);
259        let window = StdDuration::from_secs(10);
260        assert!(limiter.check("a", 1, window).await.unwrap().is_allowed());
261        assert!(!limiter.check("a", 1, window).await.unwrap().is_allowed());
262        // A different key has its own fresh budget.
263        assert!(limiter.check("b", 1, window).await.unwrap().is_allowed());
264        assert_eq!(limiter.len(), 2);
265    }
266
267    #[tokio::test]
268    async fn limited_reports_time_until_reset() {
269        let (clock, limiter) = limiter_at(0);
270        let window = StdDuration::from_secs(60);
271        limiter.check("k", 1, window).await.unwrap();
272        clock.advance(Duration::seconds(20));
273        match limiter.check("k", 1, window).await.unwrap() {
274            RateLimitOutcome::Limited { retry_after } => {
275                // 60s window, 20s elapsed -> ~40s remaining.
276                assert_eq!(retry_after, StdDuration::from_secs(40));
277            }
278            other => panic!("expected Limited, got {other:?}"),
279        }
280    }
281}