Skip to main content

aa_storage_memory/
rate_limit_counter.rs

1//! In-memory [`RateLimitCounter`] backed by a `DashMap` of windowed counters.
2
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use aa_storage::{RateLimitCounter, Result};
7use async_trait::async_trait;
8use dashmap::DashMap;
9
10/// A single key's counter and the window it is bucketed into.
11struct Window {
12    count: u64,
13    start: Instant,
14    window: Duration,
15}
16
17/// A `DashMap`-backed [`RateLimitCounter`].
18///
19/// Each key tracks a count and the wall-clock start of its current window;
20/// once the window elapses the count rolls over to zero on the next access.
21/// `DashMap`'s per-key entry lock makes the read-modify-write in
22/// [`increment`](RateLimitCounter::increment) atomic across concurrent callers.
23/// Cloning shares the same underlying map.
24#[derive(Clone, Default)]
25pub struct MemoryRateLimitCounter {
26    counters: Arc<DashMap<String, Window>>,
27}
28
29impl MemoryRateLimitCounter {
30    /// Create an empty counter set.
31    pub fn new() -> Self {
32        Self::default()
33    }
34}
35
36#[async_trait]
37impl RateLimitCounter for MemoryRateLimitCounter {
38    async fn increment(&self, key: &str, amount: u64, window_secs: u64) -> Result<u64> {
39        let window = Duration::from_secs(window_secs);
40        let now = Instant::now();
41        let mut entry = self.counters.entry(key.to_owned()).or_insert_with(|| Window {
42            count: 0,
43            start: now,
44            window,
45        });
46        if now.duration_since(entry.start) >= entry.window {
47            entry.count = 0;
48            entry.start = now;
49            entry.window = window;
50        }
51        entry.count = entry.count.saturating_add(amount);
52        Ok(entry.count)
53    }
54
55    async fn current(&self, key: &str) -> Result<u64> {
56        match self.counters.get(key) {
57            Some(entry) if Instant::now().duration_since(entry.start) < entry.window => Ok(entry.count),
58            _ => Ok(0),
59        }
60    }
61
62    async fn reset(&self, key: &str) -> Result<()> {
63        self.counters.remove(key);
64        Ok(())
65    }
66}