Skip to main content

pylon_plugin/builtin/
rate_limit.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5use crate::{Plugin, PluginError, RequestMeta};
6use pylon_auth::AuthContext;
7
8/// Rate limiting plugin. Limits requests per IP/user within a time window.
9pub struct RateLimitPlugin {
10    max_requests: u32,
11    window: Duration,
12    counters: Mutex<HashMap<String, (u32, Instant)>>,
13}
14
15impl RateLimitPlugin {
16    pub fn new(max_requests: u32, window: Duration) -> Self {
17        Self {
18            max_requests,
19            window,
20            counters: Mutex::new(HashMap::new()),
21        }
22    }
23
24    fn check(&self, key: &str) -> Result<(), PluginError> {
25        let mut counters = self.counters.lock().unwrap();
26        let now = Instant::now();
27
28        let entry = counters.entry(key.to_string()).or_insert((0, now));
29
30        // Reset if window expired.
31        if now.duration_since(entry.1) > self.window {
32            *entry = (0, now);
33        }
34
35        entry.0 += 1;
36
37        if entry.0 > self.max_requests {
38            Err(PluginError {
39                code: "RATE_LIMITED".into(),
40                message: format!(
41                    "Too many requests. Limit: {} per {:?}",
42                    self.max_requests, self.window
43                ),
44                status: 429,
45            })
46        } else {
47            Ok(())
48        }
49    }
50
51    /// Rate-limit by user id when present, otherwise by peer IP. Prefer
52    /// this over the Plugin trait's `on_request` hook: that hook has no
53    /// access to peer IP and collapses every unauthenticated caller into
54    /// a single `__anon__` bucket, which means one attacker can DoS the
55    /// entire anonymous client population.
56    ///
57    /// Call from the HTTP layer where peer IP is available. Pass `""` for
58    /// `peer_ip` if unknown — the fallback is the same shared `__anon__`
59    /// bucket as before (not worse than the old behavior).
60    pub fn check_request(&self, user_id: Option<&str>, peer_ip: &str) -> Result<(), PluginError> {
61        let key = match user_id {
62            Some(u) if !u.is_empty() => format!("user:{u}"),
63            _ if !peer_ip.is_empty() => format!("ip:{peer_ip}"),
64            _ => "__anon__".to_string(),
65        };
66        self.check(&key)
67    }
68}
69
70impl Plugin for RateLimitPlugin {
71    fn name(&self) -> &str {
72        "rate-limit"
73    }
74
75    fn on_request(
76        &self,
77        _method: &str,
78        _path: &str,
79        auth: &AuthContext,
80    ) -> Result<(), PluginError> {
81        // Legacy path (no peer_ip available). Keys by user_id or a shared
82        // `__anon__` bucket — kept for callers that still invoke
83        // `on_request` directly. New callers should prefer
84        // `on_request_with_meta` so the IP dimension takes effect.
85        let key = auth.user_id.as_deref().unwrap_or("__anon__").to_string();
86        self.check(&key)
87    }
88
89    fn on_request_with_meta(
90        &self,
91        _method: &str,
92        _path: &str,
93        auth: &AuthContext,
94        meta: &RequestMeta<'_>,
95    ) -> Result<(), PluginError> {
96        // Per-IP bucket for anonymous traffic fixes the "one attacker
97        // DoSes every anon user" collapse we used to have when all anon
98        // callers shared a single `__anon__` bucket.
99        self.check_request(auth.user_id.as_deref(), meta.peer_ip)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn allows_under_limit() {
109        let plugin = RateLimitPlugin::new(3, Duration::from_secs(60));
110        let auth = AuthContext::anonymous();
111        assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
112        assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
113        assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
114    }
115
116    #[test]
117    fn different_ips_use_different_buckets() {
118        let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
119        // Two anonymous clients from different IPs should each get their
120        // own bucket under check_request — previously both collapsed into
121        // `__anon__` and one could burn the other's quota.
122        assert!(plugin.check_request(None, "1.1.1.1").is_ok());
123        assert!(plugin.check_request(None, "1.1.1.1").is_ok());
124        assert!(plugin.check_request(None, "1.1.1.1").is_err());
125        // Second IP is untouched.
126        assert!(plugin.check_request(None, "2.2.2.2").is_ok());
127        assert!(plugin.check_request(None, "2.2.2.2").is_ok());
128    }
129
130    #[test]
131    fn user_id_preferred_over_ip() {
132        let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
133        // Same user id from different IPs uses one bucket.
134        assert!(plugin.check_request(Some("alice"), "1.1.1.1").is_ok());
135        assert!(plugin.check_request(Some("alice"), "2.2.2.2").is_ok());
136        assert!(plugin.check_request(Some("alice"), "3.3.3.3").is_err());
137    }
138
139    #[test]
140    fn blocks_over_limit() {
141        let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
142        let auth = AuthContext::anonymous();
143        assert!(plugin.on_request("GET", "/", &auth).is_ok());
144        assert!(plugin.on_request("GET", "/", &auth).is_ok());
145        let result = plugin.on_request("GET", "/", &auth);
146        assert!(result.is_err());
147        assert_eq!(result.unwrap_err().code, "RATE_LIMITED");
148    }
149
150    #[test]
151    fn separate_users_separate_limits() {
152        let plugin = RateLimitPlugin::new(1, Duration::from_secs(60));
153        let alice = AuthContext::authenticated("alice".into());
154        let bob = AuthContext::authenticated("bob".into());
155        assert!(plugin.on_request("GET", "/", &alice).is_ok());
156        assert!(plugin.on_request("GET", "/", &bob).is_ok());
157        // Alice is now rate limited, Bob is not.
158        assert!(plugin.on_request("GET", "/", &alice).is_err());
159        assert!(plugin.on_request("GET", "/", &bob).is_err());
160    }
161}