Skip to main content

dscode_extension_host/
rate_limiter.rs

1/**
2 * Rate Limiter
3 *
4 * Prevents DOS attacks by limiting the rate of requests from extensions.
5 * Uses token bucket algorithm via the governor crate.
6 */
7use governor::{
8    clock::DefaultClock, state::InMemoryState, Quota, RateLimiter as GovernorRateLimiter,
9};
10use std::collections::HashMap;
11use std::num::NonZeroU32;
12use std::sync::{Arc, Mutex};
13
14type GovernorLimiter =
15    Arc<GovernorRateLimiter<governor::state::direct::NotKeyed, InMemoryState, DefaultClock>>;
16
17pub struct RateLimiter {
18    /// Per-extension rate limiters
19    limiters: Arc<Mutex<HashMap<String, GovernorLimiter>>>,
20
21    /// Default quota: 100 requests per second per extension
22    default_quota: Quota,
23}
24
25impl RateLimiter {
26    pub fn new() -> Self {
27        // Default: 100 requests per second
28        let quota = Quota::per_second(NonZeroU32::new(100).expect("100 is nonzero"));
29
30        Self { limiters: Arc::new(Mutex::new(HashMap::new())), default_quota: quota }
31    }
32
33    pub fn with_quota(requests_per_second: u32) -> Self {
34        let quota = Quota::per_second(
35            NonZeroU32::new(requests_per_second)
36                .unwrap_or(NonZeroU32::new(100).expect("100 is nonzero")),
37        );
38
39        Self { limiters: Arc::new(Mutex::new(HashMap::new())), default_quota: quota }
40    }
41
42    /// Check if a request should be allowed
43    pub fn check_rate_limit(&self, extension_id: &str) -> Result<(), String> {
44        let limiter = self.get_or_create_limiter(extension_id);
45
46        match limiter.check() {
47            Ok(_) => Ok(()),
48            Err(_) => Err(format!(
49                "Rate limit exceeded for extension '{}'. Please slow down.",
50                extension_id
51            )),
52        }
53    }
54
55    /// Get or create a rate limiter for an extension
56    fn get_or_create_limiter(&self, extension_id: &str) -> GovernorLimiter {
57        let mut limiters = self.limiters.lock().unwrap_or_else(|e| {
58            tracing::warn!("Rate limiter lock poisoned, recovering: {}", e);
59            e.into_inner()
60        });
61
62        limiters
63            .entry(extension_id.to_string())
64            .or_insert_with(|| Arc::new(GovernorRateLimiter::direct(self.default_quota)))
65            .clone()
66    }
67
68    /// Remove rate limiter for an extension (called when extension unloads)
69    pub fn remove_limiter(&self, extension_id: &str) {
70        let mut limiters = self.limiters.lock().unwrap_or_else(|e| {
71            tracing::warn!("Rate limiter lock poisoned, recovering: {}", e);
72            e.into_inner()
73        });
74        limiters.remove(extension_id);
75    }
76}
77
78impl Default for RateLimiter {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use std::thread;
88    use std::time::Duration;
89
90    #[test]
91    fn test_rate_limiting() {
92        let limiter = RateLimiter::with_quota(10); // 10 requests per second
93        let ext_id = "test.extension";
94
95        // First 10 requests should succeed
96        for _ in 0..10 {
97            assert!(limiter.check_rate_limit(ext_id).is_ok());
98        }
99
100        // 11th request should fail
101        assert!(limiter.check_rate_limit(ext_id).is_err());
102
103        // Wait for quota to refill
104        thread::sleep(Duration::from_millis(1100));
105
106        // Should work again
107        assert!(limiter.check_rate_limit(ext_id).is_ok());
108    }
109
110    #[test]
111    fn test_per_extension_limits() {
112        let limiter = RateLimiter::with_quota(5);
113
114        // Exhaust quota for extension1
115        for _ in 0..5 {
116            limiter.check_rate_limit("ext1").unwrap();
117        }
118        assert!(limiter.check_rate_limit("ext1").is_err());
119
120        // Extension2 should still have quota
121        assert!(limiter.check_rate_limit("ext2").is_ok());
122    }
123
124    #[test]
125    fn test_rate_limiter_allows_within_limit() {
126        let limiter = RateLimiter::with_quota(20);
127
128        // All 20 requests within the rate limit should be allowed
129        for i in 0..20 {
130            let result = limiter.check_rate_limit("within-limit-ext");
131            assert!(result.is_ok(), "Request {} should have been allowed", i);
132        }
133    }
134
135    #[test]
136    fn test_rate_limiter_blocks_over_limit() {
137        let limiter = RateLimiter::with_quota(5);
138
139        // Exhaust quota
140        for _ in 0..5 {
141            limiter.check_rate_limit("over-limit-ext").unwrap();
142        }
143
144        // Next request should be blocked
145        let result = limiter.check_rate_limit("over-limit-ext");
146        assert!(result.is_err());
147        assert!(result.unwrap_err().contains("Rate limit exceeded"));
148    }
149
150    #[test]
151    fn test_rate_limiter_default() {
152        let limiter = RateLimiter::default();
153        // Default allows 100 requests per second, so the first few should pass
154        for _ in 0..10 {
155            assert!(limiter.check_rate_limit("default-ext").is_ok());
156        }
157    }
158
159    #[test]
160    fn test_rate_limiter_new() {
161        let limiter = RateLimiter::new();
162        // Same as default: 100 rps
163        assert!(limiter.check_rate_limit("new-ext").is_ok());
164    }
165
166    #[test]
167    fn test_rate_limiter_window_refill() {
168        // After the quota window passes, the limiter should allow requests again
169        let limiter = RateLimiter::with_quota(3);
170        for _ in 0..3 {
171            limiter.check_rate_limit("window-ext").unwrap();
172        }
173        assert!(limiter.check_rate_limit("window-ext").is_err());
174
175        // Wait for the bucket to refill (1 second + small buffer)
176        thread::sleep(Duration::from_millis(1100));
177        assert!(limiter.check_rate_limit("window-ext").is_ok());
178    }
179
180    #[test]
181    fn test_rate_limiter_per_extension_isolation() {
182        let limiter = RateLimiter::with_quota(2);
183
184        // Exhaust quota for ext-a
185        limiter.check_rate_limit("ext-a").unwrap();
186        limiter.check_rate_limit("ext-a").unwrap();
187        assert!(limiter.check_rate_limit("ext-a").is_err());
188
189        // ext-b should still have its own independent quota
190        assert!(limiter.check_rate_limit("ext-b").is_ok());
191        assert!(limiter.check_rate_limit("ext-b").is_ok());
192        assert!(limiter.check_rate_limit("ext-b").is_err());
193
194        // ext-c is also independent
195        assert!(limiter.check_rate_limit("ext-c").is_ok());
196    }
197
198    #[test]
199    fn test_rate_limiter_remove_limiter() {
200        let limiter = RateLimiter::with_quota(2);
201
202        // Use up quota for ext-rm
203        limiter.check_rate_limit("ext-rm").unwrap();
204        limiter.check_rate_limit("ext-rm").unwrap();
205        assert!(limiter.check_rate_limit("ext-rm").is_err());
206
207        // Remove the limiter; next check creates a fresh one
208        limiter.remove_limiter("ext-rm");
209        assert!(limiter.check_rate_limit("ext-rm").is_ok());
210    }
211
212    #[test]
213    fn test_rate_limiter_with_quota_zero_uses_default() {
214        // with_quota(0) should fall back to 100 since NonZeroU32::new(0) is None
215        let limiter = RateLimiter::with_quota(0);
216        // Should still work (default of 100)
217        assert!(limiter.check_rate_limit("zero-ext").is_ok());
218    }
219
220    #[test]
221    fn test_rate_limiter_error_contains_extension_id() {
222        let limiter = RateLimiter::with_quota(1);
223        limiter.check_rate_limit("error-ext").unwrap();
224        let err = limiter.check_rate_limit("error-ext").unwrap_err();
225        assert!(err.contains("error-ext"), "Error message should contain extension id");
226    }
227}