lazy_limit/
lib.rs

1/* src/lib.rs */
2
3use std::sync::Arc;
4use tokio::sync::{OnceCell, RwLock};
5
6mod config;
7mod gc;
8mod limiter;
9mod types;
10
11pub use config::*;
12use limiter::RateLimiter;
13pub use types::*;
14
15// Global rate limiter instance, initialized once.
16static GLOBAL_LIMITER: OnceCell<Arc<RwLock<RateLimiter>>> = OnceCell::const_new();
17
18/// Initialize the rate limiter with default and optional route-specific rules.
19/// This must be called once, typically at application startup, before any calls to `limit!`.
20///
21/// # Panics
22///
23/// Panics if called more than once.
24///
25/// # Examples
26///
27/// ```rust,ignore
28/// use lazy_limit::*;
29///
30/// #[tokio::main]
31/// async fn main() {
32///     init_rate_limiter!(
33///         default: RuleConfig::new(Duration::seconds(1), 5),
34///         max_memory: Some(64 * 1024 * 1024), // 64MB
35///         routes: [
36///             ("/api/login", RuleConfig::new(Duration::minutes(1), 3)),
37///             ("/api/public", RuleConfig::new(Duration::seconds(1), 10)),
38///         ]
39///     ).await;
40/// }
41/// ```
42#[macro_export]
43macro_rules! init_rate_limiter {
44    (
45        default: $default_rule:expr
46        $(, max_memory: $max_memory:expr)?
47        $(, routes: [ $(($route:expr, $rule:expr)),* $(,)? ])?
48    ) => {
49        {
50            let mut config = $crate::LimiterConfig::new($default_rule);
51
52            $(
53                if let Some(mem) = $max_memory {
54                    config = config.with_max_memory(mem);
55                }
56            )?
57
58            $(
59                $(
60                    config = config.add_route_rule($route, $rule);
61                )*
62            )?
63
64            $crate::initialize_limiter(config)
65        }
66    };
67}
68
69/// Check if a request should be allowed based on rate limiting rules.
70///
71/// # Panics
72///
73/// Panics if the rate limiter has not been initialized.
74#[macro_export]
75macro_rules! limit {
76    ($who:expr, $route:expr) => {
77        $crate::check_limit($who, $route)
78    };
79}
80
81/// Check rate limit with override mode (only applies route-specific rules).
82///
83/// # Panics
84///
85/// Panics if the rate limiter has not been initialized.
86#[macro_export]
87macro_rules! limit_override {
88    ($who:expr, $route:expr) => {
89        $crate::check_limit_override($who, $route)
90    };
91}
92
93/// Initialize the global rate limiter. Should be called only once.
94pub async fn initialize_limiter(config: LimiterConfig) {
95    let limiter = RateLimiter::new(config).await;
96    if GLOBAL_LIMITER.set(Arc::new(RwLock::new(limiter))).is_err() {
97        panic!("Rate limiter has already been initialized.");
98    }
99}
100
101/// Check if a request should be allowed.
102pub async fn check_limit(who: &str, route: &str) -> bool {
103    if let Some(limiter) = GLOBAL_LIMITER.get() {
104        let mut limiter = limiter.write().await;
105        limiter.check_limit(who, route, false).await
106    } else {
107        panic!("Rate limiter not initialized! Call init_rate_limiter! first.");
108    }
109}
110
111/// Check rate limit with override mode.
112pub async fn check_limit_override(who: &str, route: &str) -> bool {
113    if let Some(limiter) = GLOBAL_LIMITER.get() {
114        let mut limiter = limiter.write().await;
115        limiter.check_limit(who, route, true).await
116    } else {
117        panic!("Rate limiter not initialized! Call init_rate_limiter! first.");
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::types::Duration;
125    use std::time::Duration as StdDuration;
126
127    #[tokio::test]
128    async fn test_basic_rate_limiting() {
129        // We re-create the limiter for each test, which isn't ideal with a global static.
130        // For a simple test suite, this works by overwriting.
131        let config = LimiterConfig::new(RuleConfig::new(Duration::seconds(1), 2));
132        let limiter = RateLimiter::new(config).await;
133        let _ = GLOBAL_LIMITER.set(Arc::new(RwLock::new(limiter)));
134
135        let who = "test_ip";
136        let route = "/test";
137
138        assert!(check_limit(who, route).await);
139        assert!(check_limit(who, route).await);
140
141        assert!(!check_limit(who, route).await);
142
143        tokio::time::sleep(StdDuration::from_secs(1)).await;
144        assert!(check_limit(who, route).await);
145    }
146}