aa_storage_memory/
rate_limit_counter.rs1use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use aa_storage::{RateLimitCounter, Result};
7use async_trait::async_trait;
8use dashmap::DashMap;
9
10struct Window {
12 count: u64,
13 start: Instant,
14 window: Duration,
15}
16
17#[derive(Clone, Default)]
25pub struct MemoryRateLimitCounter {
26 counters: Arc<DashMap<String, Window>>,
27}
28
29impl MemoryRateLimitCounter {
30 pub fn new() -> Self {
32 Self::default()
33 }
34}
35
36#[async_trait]
37impl RateLimitCounter for MemoryRateLimitCounter {
38 async fn increment(&self, key: &str, amount: u64, window_secs: u64) -> Result<u64> {
39 let window = Duration::from_secs(window_secs);
40 let now = Instant::now();
41 let mut entry = self.counters.entry(key.to_owned()).or_insert_with(|| Window {
42 count: 0,
43 start: now,
44 window,
45 });
46 if now.duration_since(entry.start) >= entry.window {
47 entry.count = 0;
48 entry.start = now;
49 entry.window = window;
50 }
51 entry.count = entry.count.saturating_add(amount);
52 Ok(entry.count)
53 }
54
55 async fn current(&self, key: &str) -> Result<u64> {
56 match self.counters.get(key) {
57 Some(entry) if Instant::now().duration_since(entry.start) < entry.window => Ok(entry.count),
58 _ => Ok(0),
59 }
60 }
61
62 async fn reset(&self, key: &str) -> Result<()> {
63 self.counters.remove(key);
64 Ok(())
65 }
66}