Skip to main content

aa_storage_redis/
rate_limit.rs

1//! [`RateLimitCounter`] using an atomic Lua `INCRBY` + `EXPIRE` script.
2
3use aa_storage::{RateLimitCounter, Result};
4use async_trait::async_trait;
5use deadpool_redis::Pool;
6
7use crate::error::backend;
8
9/// Lua source executed atomically by Redis for each
10/// [`increment`](RateLimitCounter::increment).
11///
12/// `INCRBY` the counter, then arm `EXPIRE` only when this call created the key
13/// (the returned total equals the amount just added). Running both commands
14/// inside one script makes the read-modify-write atomic with respect to
15/// concurrent callers and starts a **fixed** window at the first increment.
16const INCREMENT_SCRIPT: &str = r"
17local current = redis.call('INCRBY', KEYS[1], ARGV[1])
18if tonumber(current) == tonumber(ARGV[1]) then
19    redis.call('EXPIRE', KEYS[1], ARGV[2])
20end
21return current
22";
23
24/// Redis-backed [`RateLimitCounter`].
25///
26/// Counters live at `aa:ratelimit:<key>`. Cheap to [`Clone`] — clones share
27/// the underlying [`Pool`].
28#[derive(Clone)]
29pub struct RedisRateLimitCounter {
30    pool: Pool,
31}
32
33impl RedisRateLimitCounter {
34    /// Create a counter over an existing connection pool.
35    pub fn new(pool: Pool) -> Self {
36        Self { pool }
37    }
38}
39
40fn counter_key(key: &str) -> String {
41    format!("aa:ratelimit:{key}")
42}
43
44#[async_trait]
45impl RateLimitCounter for RedisRateLimitCounter {
46    async fn increment(&self, key: &str, amount: u64, window_secs: u64) -> Result<u64> {
47        let mut conn = self.pool.get().await.map_err(backend)?;
48        let total: i64 = redis::Script::new(INCREMENT_SCRIPT)
49            .key(counter_key(key))
50            .arg(amount)
51            .arg(window_secs)
52            .invoke_async(&mut conn)
53            .await
54            .map_err(backend)?;
55        Ok(u64::try_from(total).unwrap_or(0))
56    }
57
58    async fn current(&self, key: &str) -> Result<u64> {
59        let mut conn = self.pool.get().await.map_err(backend)?;
60        let value: Option<u64> = redis::cmd("GET")
61            .arg(counter_key(key))
62            .query_async(&mut conn)
63            .await
64            .map_err(backend)?;
65        Ok(value.unwrap_or(0))
66    }
67
68    async fn reset(&self, key: &str) -> Result<()> {
69        let mut conn = self.pool.get().await.map_err(backend)?;
70        let _: () = redis::cmd("DEL")
71            .arg(counter_key(key))
72            .query_async(&mut conn)
73            .await
74            .map_err(backend)?;
75        Ok(())
76    }
77}