aa_storage_redis/
rate_limit.rs1use aa_storage::{RateLimitCounter, Result};
4use async_trait::async_trait;
5use deadpool_redis::Pool;
6
7use crate::error::backend;
8
9const INCREMENT_SCRIPT: &str = r"
17local current = redis.call('INCRBY', KEYS[1], ARGV[1])
18if tonumber(current) == tonumber(ARGV[1]) then
19 redis.call('EXPIRE', KEYS[1], ARGV[2])
20end
21return current
22";
23
24#[derive(Clone)]
29pub struct RedisRateLimitCounter {
30 pool: Pool,
31}
32
33impl RedisRateLimitCounter {
34 pub fn new(pool: Pool) -> Self {
36 Self { pool }
37 }
38}
39
40fn counter_key(key: &str) -> String {
41 format!("aa:ratelimit:{key}")
42}
43
44#[async_trait]
45impl RateLimitCounter for RedisRateLimitCounter {
46 async fn increment(&self, key: &str, amount: u64, window_secs: u64) -> Result<u64> {
47 let mut conn = self.pool.get().await.map_err(backend)?;
48 let total: i64 = redis::Script::new(INCREMENT_SCRIPT)
49 .key(counter_key(key))
50 .arg(amount)
51 .arg(window_secs)
52 .invoke_async(&mut conn)
53 .await
54 .map_err(backend)?;
55 Ok(u64::try_from(total).unwrap_or(0))
56 }
57
58 async fn current(&self, key: &str) -> Result<u64> {
59 let mut conn = self.pool.get().await.map_err(backend)?;
60 let value: Option<u64> = redis::cmd("GET")
61 .arg(counter_key(key))
62 .query_async(&mut conn)
63 .await
64 .map_err(backend)?;
65 Ok(value.unwrap_or(0))
66 }
67
68 async fn reset(&self, key: &str) -> Result<()> {
69 let mut conn = self.pool.get().await.map_err(backend)?;
70 let _: () = redis::cmd("DEL")
71 .arg(counter_key(key))
72 .query_async(&mut conn)
73 .await
74 .map_err(backend)?;
75 Ok(())
76 }
77}