dscode_extension_host/
rate_limiter.rs1use governor::{
8 clock::DefaultClock, state::InMemoryState, Quota, RateLimiter as GovernorRateLimiter,
9};
10use std::collections::HashMap;
11use std::num::NonZeroU32;
12use std::sync::{Arc, Mutex};
13
14type GovernorLimiter =
15 Arc<GovernorRateLimiter<governor::state::direct::NotKeyed, InMemoryState, DefaultClock>>;
16
17pub struct RateLimiter {
18 limiters: Arc<Mutex<HashMap<String, GovernorLimiter>>>,
20
21 default_quota: Quota,
23}
24
25impl RateLimiter {
26 pub fn new() -> Self {
27 let quota = Quota::per_second(NonZeroU32::new(100).expect("100 is nonzero"));
29
30 Self { limiters: Arc::new(Mutex::new(HashMap::new())), default_quota: quota }
31 }
32
33 pub fn with_quota(requests_per_second: u32) -> Self {
34 let quota = Quota::per_second(
35 NonZeroU32::new(requests_per_second)
36 .unwrap_or(NonZeroU32::new(100).expect("100 is nonzero")),
37 );
38
39 Self { limiters: Arc::new(Mutex::new(HashMap::new())), default_quota: quota }
40 }
41
42 pub fn check_rate_limit(&self, extension_id: &str) -> Result<(), String> {
44 let limiter = self.get_or_create_limiter(extension_id);
45
46 match limiter.check() {
47 Ok(_) => Ok(()),
48 Err(_) => Err(format!(
49 "Rate limit exceeded for extension '{}'. Please slow down.",
50 extension_id
51 )),
52 }
53 }
54
55 fn get_or_create_limiter(&self, extension_id: &str) -> GovernorLimiter {
57 let mut limiters = self.limiters.lock().unwrap_or_else(|e| {
58 tracing::warn!("Rate limiter lock poisoned, recovering: {}", e);
59 e.into_inner()
60 });
61
62 limiters
63 .entry(extension_id.to_string())
64 .or_insert_with(|| Arc::new(GovernorRateLimiter::direct(self.default_quota)))
65 .clone()
66 }
67
68 pub fn remove_limiter(&self, extension_id: &str) {
70 let mut limiters = self.limiters.lock().unwrap_or_else(|e| {
71 tracing::warn!("Rate limiter lock poisoned, recovering: {}", e);
72 e.into_inner()
73 });
74 limiters.remove(extension_id);
75 }
76}
77
78impl Default for RateLimiter {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use std::thread;
88 use std::time::Duration;
89
90 #[test]
91 fn test_rate_limiting() {
92 let limiter = RateLimiter::with_quota(10); let ext_id = "test.extension";
94
95 for _ in 0..10 {
97 assert!(limiter.check_rate_limit(ext_id).is_ok());
98 }
99
100 assert!(limiter.check_rate_limit(ext_id).is_err());
102
103 thread::sleep(Duration::from_millis(1100));
105
106 assert!(limiter.check_rate_limit(ext_id).is_ok());
108 }
109
110 #[test]
111 fn test_per_extension_limits() {
112 let limiter = RateLimiter::with_quota(5);
113
114 for _ in 0..5 {
116 limiter.check_rate_limit("ext1").unwrap();
117 }
118 assert!(limiter.check_rate_limit("ext1").is_err());
119
120 assert!(limiter.check_rate_limit("ext2").is_ok());
122 }
123
124 #[test]
125 fn test_rate_limiter_allows_within_limit() {
126 let limiter = RateLimiter::with_quota(20);
127
128 for i in 0..20 {
130 let result = limiter.check_rate_limit("within-limit-ext");
131 assert!(result.is_ok(), "Request {} should have been allowed", i);
132 }
133 }
134
135 #[test]
136 fn test_rate_limiter_blocks_over_limit() {
137 let limiter = RateLimiter::with_quota(5);
138
139 for _ in 0..5 {
141 limiter.check_rate_limit("over-limit-ext").unwrap();
142 }
143
144 let result = limiter.check_rate_limit("over-limit-ext");
146 assert!(result.is_err());
147 assert!(result.unwrap_err().contains("Rate limit exceeded"));
148 }
149
150 #[test]
151 fn test_rate_limiter_default() {
152 let limiter = RateLimiter::default();
153 for _ in 0..10 {
155 assert!(limiter.check_rate_limit("default-ext").is_ok());
156 }
157 }
158
159 #[test]
160 fn test_rate_limiter_new() {
161 let limiter = RateLimiter::new();
162 assert!(limiter.check_rate_limit("new-ext").is_ok());
164 }
165
166 #[test]
167 fn test_rate_limiter_window_refill() {
168 let limiter = RateLimiter::with_quota(3);
170 for _ in 0..3 {
171 limiter.check_rate_limit("window-ext").unwrap();
172 }
173 assert!(limiter.check_rate_limit("window-ext").is_err());
174
175 thread::sleep(Duration::from_millis(1100));
177 assert!(limiter.check_rate_limit("window-ext").is_ok());
178 }
179
180 #[test]
181 fn test_rate_limiter_per_extension_isolation() {
182 let limiter = RateLimiter::with_quota(2);
183
184 limiter.check_rate_limit("ext-a").unwrap();
186 limiter.check_rate_limit("ext-a").unwrap();
187 assert!(limiter.check_rate_limit("ext-a").is_err());
188
189 assert!(limiter.check_rate_limit("ext-b").is_ok());
191 assert!(limiter.check_rate_limit("ext-b").is_ok());
192 assert!(limiter.check_rate_limit("ext-b").is_err());
193
194 assert!(limiter.check_rate_limit("ext-c").is_ok());
196 }
197
198 #[test]
199 fn test_rate_limiter_remove_limiter() {
200 let limiter = RateLimiter::with_quota(2);
201
202 limiter.check_rate_limit("ext-rm").unwrap();
204 limiter.check_rate_limit("ext-rm").unwrap();
205 assert!(limiter.check_rate_limit("ext-rm").is_err());
206
207 limiter.remove_limiter("ext-rm");
209 assert!(limiter.check_rate_limit("ext-rm").is_ok());
210 }
211
212 #[test]
213 fn test_rate_limiter_with_quota_zero_uses_default() {
214 let limiter = RateLimiter::with_quota(0);
216 assert!(limiter.check_rate_limit("zero-ext").is_ok());
218 }
219
220 #[test]
221 fn test_rate_limiter_error_contains_extension_id() {
222 let limiter = RateLimiter::with_quota(1);
223 limiter.check_rate_limit("error-ext").unwrap();
224 let err = limiter.check_rate_limit("error-ext").unwrap_err();
225 assert!(err.contains("error-ext"), "Error message should contain extension id");
226 }
227}