pylon_plugin/builtin/
rate_limit.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5use crate::{Plugin, PluginError, RequestMeta};
6use pylon_auth::AuthContext;
7
8pub struct RateLimitPlugin {
10 max_requests: u32,
11 window: Duration,
12 counters: Mutex<HashMap<String, (u32, Instant)>>,
13}
14
15impl RateLimitPlugin {
16 pub fn new(max_requests: u32, window: Duration) -> Self {
17 Self {
18 max_requests,
19 window,
20 counters: Mutex::new(HashMap::new()),
21 }
22 }
23
24 fn check(&self, key: &str) -> Result<(), PluginError> {
25 let mut counters = self.counters.lock().unwrap();
26 let now = Instant::now();
27
28 let entry = counters.entry(key.to_string()).or_insert((0, now));
29
30 if now.duration_since(entry.1) > self.window {
32 *entry = (0, now);
33 }
34
35 entry.0 += 1;
36
37 if entry.0 > self.max_requests {
38 Err(PluginError {
39 code: "RATE_LIMITED".into(),
40 message: format!(
41 "Too many requests. Limit: {} per {:?}",
42 self.max_requests, self.window
43 ),
44 status: 429,
45 })
46 } else {
47 Ok(())
48 }
49 }
50
51 pub fn check_request(&self, user_id: Option<&str>, peer_ip: &str) -> Result<(), PluginError> {
61 let key = match user_id {
62 Some(u) if !u.is_empty() => format!("user:{u}"),
63 _ if !peer_ip.is_empty() => format!("ip:{peer_ip}"),
64 _ => "__anon__".to_string(),
65 };
66 self.check(&key)
67 }
68}
69
70impl Plugin for RateLimitPlugin {
71 fn name(&self) -> &str {
72 "rate-limit"
73 }
74
75 fn on_request(
76 &self,
77 _method: &str,
78 _path: &str,
79 auth: &AuthContext,
80 ) -> Result<(), PluginError> {
81 let key = auth.user_id.as_deref().unwrap_or("__anon__").to_string();
86 self.check(&key)
87 }
88
89 fn on_request_with_meta(
90 &self,
91 _method: &str,
92 _path: &str,
93 auth: &AuthContext,
94 meta: &RequestMeta<'_>,
95 ) -> Result<(), PluginError> {
96 self.check_request(auth.user_id.as_deref(), meta.peer_ip)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn allows_under_limit() {
109 let plugin = RateLimitPlugin::new(3, Duration::from_secs(60));
110 let auth = AuthContext::anonymous();
111 assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
112 assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
113 assert!(plugin.on_request("GET", "/api/test", &auth).is_ok());
114 }
115
116 #[test]
117 fn different_ips_use_different_buckets() {
118 let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
119 assert!(plugin.check_request(None, "1.1.1.1").is_ok());
123 assert!(plugin.check_request(None, "1.1.1.1").is_ok());
124 assert!(plugin.check_request(None, "1.1.1.1").is_err());
125 assert!(plugin.check_request(None, "2.2.2.2").is_ok());
127 assert!(plugin.check_request(None, "2.2.2.2").is_ok());
128 }
129
130 #[test]
131 fn user_id_preferred_over_ip() {
132 let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
133 assert!(plugin.check_request(Some("alice"), "1.1.1.1").is_ok());
135 assert!(plugin.check_request(Some("alice"), "2.2.2.2").is_ok());
136 assert!(plugin.check_request(Some("alice"), "3.3.3.3").is_err());
137 }
138
139 #[test]
140 fn blocks_over_limit() {
141 let plugin = RateLimitPlugin::new(2, Duration::from_secs(60));
142 let auth = AuthContext::anonymous();
143 assert!(plugin.on_request("GET", "/", &auth).is_ok());
144 assert!(plugin.on_request("GET", "/", &auth).is_ok());
145 let result = plugin.on_request("GET", "/", &auth);
146 assert!(result.is_err());
147 assert_eq!(result.unwrap_err().code, "RATE_LIMITED");
148 }
149
150 #[test]
151 fn separate_users_separate_limits() {
152 let plugin = RateLimitPlugin::new(1, Duration::from_secs(60));
153 let alice = AuthContext::authenticated("alice".into());
154 let bob = AuthContext::authenticated("bob".into());
155 assert!(plugin.on_request("GET", "/", &alice).is_ok());
156 assert!(plugin.on_request("GET", "/", &bob).is_ok());
157 assert!(plugin.on_request("GET", "/", &alice).is_err());
159 assert!(plugin.on_request("GET", "/", &bob).is_err());
160 }
161}