1use std::{borrow::Cow, fmt, sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}};
2
3use actix_web::dev::ServiceRequest;
4use deadpool_redis::{Pool};
5
6mod builder;
7mod errors;
8mod middleware;
9mod status;
10
11pub use self::{builder::Builder, errors::Error, middleware::RateLimiter, status::Status};
12
13const LUA: &str = r#"
14local key = KEYS[1]
15local limit = tonumber(ARGV[1])
16local win = tonumber(ARGV[2])
17local now = tonumber(ARGV[3])
18
19local cnt = redis.call("INCR", key)
20if cnt == 1 then
21 redis.call("EXPIRE", key, win)
22end
23
24local ttl = redis.call("TTL", key)
25if ttl < 0 then ttl = win end
26
27local limited = cnt > limit and 1 or 0
28local remaining = limited == 1 and 0 or (limit - cnt)
29return {limited, remaining, now + ttl}
30"#;
31
32pub const DEFAULT_REQUEST_LIMIT: usize = 5000;
34
35pub const DEFAULT_PERIOD_SECS: u64 = 3600;
37
38pub const DEFAULT_COOKIE_NAME: &str = "sid";
40
41#[cfg(feature = "session")]
43pub const DEFAULT_SESSION_KEY: &str = "rate-api-id";
44
45trait GetKeyFnT: Fn(&ServiceRequest) -> Option<String> {}
47
48impl<T> GetKeyFnT for T where T: Fn(&ServiceRequest) -> Option<String> {}
49
50type GetKeyFn = dyn GetKeyFnT + Send + Sync;
52
53impl fmt::Debug for GetKeyFn {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 write!(f, "GetKeyFn")
57 }
58}
59
60type GetArcBoxKeyFn = Arc<GetKeyFn>;
62
63#[derive(Debug, Clone)]
65pub struct Limiter {
66 pool: Arc<Pool>,
67 limit: usize,
68 period: Duration,
69 get_key_fn: GetArcBoxKeyFn,
70}
71
72impl Limiter {
73 #[must_use]
78 pub fn builder(r: Arc<Pool>) -> Builder {
79 Builder {
80 redis: r,
81 limit: DEFAULT_REQUEST_LIMIT,
82 period: Duration::from_secs(DEFAULT_PERIOD_SECS),
83 get_key_fn: None,
84 cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
85 #[cfg(feature = "session")]
86 session_key: Cow::Borrowed(DEFAULT_SESSION_KEY),
87 }
88 }
89
90 pub async fn count(&self, key: impl Into<String>) -> Result<(bool, usize, usize), Error> {
92 let key = key.into();
93 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as usize;
94 let win = self.period.as_secs() as usize;
95
96 let mut conn = self.pool.get().await?;
97 let res: Vec<i64> = redis::cmd("EVAL")
98 .arg(LUA)
99 .arg(1) .arg(&key) .arg(self.limit as i64) .arg(win as i64) .arg(now as i64) .query_async(&mut *conn)
105 .await?;
106
107 let limited = res[0] == 1;
108 let remaining = res[1] as usize;
109 let reset = res[2] as usize;
110 Ok((limited, remaining, reset))
111 }
112}