actix_limiter/
lib.rs

1use std::{borrow::Cow, fmt, sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}};
2
3use actix_web::dev::ServiceRequest;
4use deadpool_redis::{Pool};
5
6mod builder;
7mod errors;
8mod middleware;
9mod status;
10
11pub use self::{builder::Builder, errors::Error, middleware::RateLimiter, status::Status};
12
13const LUA: &str = r#"
14local key   = KEYS[1]
15local limit = tonumber(ARGV[1])
16local win   = tonumber(ARGV[2])
17local now   = tonumber(ARGV[3])
18
19local cnt = redis.call("INCR", key)
20if cnt == 1 then
21    redis.call("EXPIRE", key, win)
22end
23
24local ttl = redis.call("TTL", key)
25if ttl < 0 then ttl = win end
26
27local limited = cnt > limit and 1 or 0
28local remaining = limited == 1 and 0 or (limit - cnt)
29return {limited, remaining, now + ttl}
30"#;
31
32/// Default request limit.
33pub const DEFAULT_REQUEST_LIMIT: usize = 5000;
34
35/// Default period (in seconds).
36pub const DEFAULT_PERIOD_SECS: u64 = 3600;
37
38/// Default cookie name.
39pub const DEFAULT_COOKIE_NAME: &str = "sid";
40
41/// Default session key.
42#[cfg(feature = "session")]
43pub const DEFAULT_SESSION_KEY: &str = "rate-api-id";
44
45/// Helper trait to impl Debug on GetKeyFn type
46trait GetKeyFnT: Fn(&ServiceRequest) -> Option<String> {}
47
48impl<T> GetKeyFnT for T where T: Fn(&ServiceRequest) -> Option<String> {}
49
50/// Get key function type with auto traits
51type GetKeyFn = dyn GetKeyFnT + Send + Sync;
52
53/// Get key resolver function type
54impl fmt::Debug for GetKeyFn {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        write!(f, "GetKeyFn")
57    }
58}
59
60/// Wrapped Get key function Trait
61type GetArcBoxKeyFn = Arc<GetKeyFn>;
62
63/// Rate limiter.
64#[derive(Debug, Clone)]
65pub struct Limiter {
66    pool: Arc<Pool>,
67    limit: usize,
68    period: Duration,
69    get_key_fn: GetArcBoxKeyFn,
70}
71
72impl Limiter {
73    /// Construct rate limiter builder with defaults.
74    ///
75    /// See [`redis-rs` docs](https://docs.rs/redis/0.21/redis/#connection-parameters) on connection
76    /// parameters for how to set the Redis URL.
77    #[must_use]
78    pub fn builder(r: Arc<Pool>) -> Builder {
79        Builder {
80            redis: r,
81            limit: DEFAULT_REQUEST_LIMIT,
82            period: Duration::from_secs(DEFAULT_PERIOD_SECS),
83            get_key_fn: None,
84            cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
85            #[cfg(feature = "session")]
86            session_key: Cow::Borrowed(DEFAULT_SESSION_KEY),
87        }
88    }
89
90    /// Consumes one rate limit unit, returning the status.
91    pub async fn count(&self, key: impl Into<String>) -> Result<(bool, usize, usize), Error> {
92        let key = key.into();
93        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize;
94        let win = self.period.as_secs() as usize;
95
96        let mut conn = self.pool.get().await?;
97        let res: Vec<i64> = redis::cmd("EVAL")
98            .arg(LUA)
99            .arg(1)                       // number of keys
100            .arg(&key)                    // KEYS[1]
101            .arg(self.limit as i64)       // ARGV[1]
102            .arg(win as i64)              // ARGV[2]
103            .arg(now as i64)              // ARGV[3]
104            .query_async(&mut *conn)   
105            .await?;
106
107        let limited = res[0] == 1;
108        let remaining = res[1] as usize;
109        let reset = res[2] as usize;
110        Ok((limited, remaining, reset))
111    }
112}