1use std::sync::Arc;
4use tokio::sync::{OnceCell, RwLock};
5
6mod config;
7mod gc;
8mod limiter;
9mod types;
10
11pub use config::*;
12use limiter::RateLimiter;
13pub use types::*;
14
15static GLOBAL_LIMITER: OnceCell<Arc<RwLock<RateLimiter>>> = OnceCell::const_new();
17
18#[macro_export]
43macro_rules! init_rate_limiter {
44 (
45 default: $default_rule:expr
46 $(, max_memory: $max_memory:expr)?
47 $(, routes: [ $(($route:expr, $rule:expr)),* $(,)? ])?
48 ) => {
49 {
50 let mut config = $crate::LimiterConfig::new($default_rule);
51
52 $(
53 if let Some(mem) = $max_memory {
54 config = config.with_max_memory(mem);
55 }
56 )?
57
58 $(
59 $(
60 config = config.add_route_rule($route, $rule);
61 )*
62 )?
63
64 $crate::initialize_limiter(config)
65 }
66 };
67}
68
69#[macro_export]
75macro_rules! limit {
76 ($who:expr, $route:expr) => {
77 $crate::check_limit($who, $route)
78 };
79}
80
81#[macro_export]
87macro_rules! limit_override {
88 ($who:expr, $route:expr) => {
89 $crate::check_limit_override($who, $route)
90 };
91}
92
93pub async fn initialize_limiter(config: LimiterConfig) {
95 let limiter = RateLimiter::new(config).await;
96 if GLOBAL_LIMITER.set(Arc::new(RwLock::new(limiter))).is_err() {
97 panic!("Rate limiter has already been initialized.");
98 }
99}
100
101pub async fn check_limit(who: &str, route: &str) -> bool {
103 if let Some(limiter) = GLOBAL_LIMITER.get() {
104 let mut limiter = limiter.write().await;
105 limiter.check_limit(who, route, false).await
106 } else {
107 panic!("Rate limiter not initialized! Call init_rate_limiter! first.");
108 }
109}
110
111pub async fn check_limit_override(who: &str, route: &str) -> bool {
113 if let Some(limiter) = GLOBAL_LIMITER.get() {
114 let mut limiter = limiter.write().await;
115 limiter.check_limit(who, route, true).await
116 } else {
117 panic!("Rate limiter not initialized! Call init_rate_limiter! first.");
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::types::Duration;
125 use std::time::Duration as StdDuration;
126
127 #[tokio::test]
128 async fn test_basic_rate_limiting() {
129 let config = LimiterConfig::new(RuleConfig::new(Duration::seconds(1), 2));
132 let limiter = RateLimiter::new(config).await;
133 let _ = GLOBAL_LIMITER.set(Arc::new(RwLock::new(limiter)));
134
135 let who = "test_ip";
136 let route = "/test";
137
138 assert!(check_limit(who, route).await);
139 assert!(check_limit(who, route).await);
140
141 assert!(!check_limit(who, route).await);
142
143 tokio::time::sleep(StdDuration::from_secs(1)).await;
144 assert!(check_limit(who, route).await);
145 }
146}