sentinel_common/
limits.rs

1//! Limits and rate limiting for Sentinel proxy
2//!
3//! This module implements bounded limits for all resources to ensure predictable
4//! behavior and prevent resource exhaustion - core to "sleepable ops".
5
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use crate::errors::{LimitType, SentinelError, SentinelResult};
13
14/// System-wide limits configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Limits {
17    // Header limits
18    pub max_header_size_bytes: usize,
19    pub max_header_count: usize,
20    pub max_header_name_bytes: usize,
21    pub max_header_value_bytes: usize,
22
23    // Body limits
24    pub max_body_size_bytes: usize,
25    pub max_body_buffer_bytes: usize,
26    pub max_body_inspection_bytes: usize,
27
28    // Decompression limits
29    pub max_decompression_ratio: f32,
30    pub max_decompressed_size_bytes: usize,
31
32    // Connection limits
33    pub max_connections_per_client: usize,
34    pub max_connections_per_route: usize,
35    pub max_total_connections: usize,
36    pub max_idle_connections_per_upstream: usize,
37
38    // Request limits
39    pub max_in_flight_requests: usize,
40    pub max_in_flight_requests_per_worker: usize,
41    pub max_queued_requests: usize,
42
43    // Agent limits
44    pub max_agent_queue_depth: usize,
45    pub max_agent_body_bytes: usize,
46    pub max_agent_response_bytes: usize,
47
48    // Rate limits
49    pub max_requests_per_second_global: Option<u32>,
50    pub max_requests_per_second_per_client: Option<u32>,
51    pub max_requests_per_second_per_route: Option<u32>,
52
53    // Memory limits
54    pub max_memory_bytes: Option<usize>,
55    pub max_memory_percent: Option<f32>,
56}
57
58impl Default for Limits {
59    fn default() -> Self {
60        Self {
61            // Conservative header limits
62            max_header_size_bytes: 8192,  // 8KB total headers
63            max_header_count: 100,        // Max 100 headers
64            max_header_name_bytes: 256,   // 256 bytes per header name
65            max_header_value_bytes: 4096, // 4KB per header value
66
67            // Body limits - 10MB default, 1MB buffer
68            max_body_size_bytes: 10 * 1024 * 1024,
69            max_body_buffer_bytes: 1024 * 1024,
70            max_body_inspection_bytes: 1024 * 1024,
71
72            // Decompression protection
73            max_decompression_ratio: 100.0,
74            max_decompressed_size_bytes: 100 * 1024 * 1024, // 100MB
75
76            // Connection limits
77            max_connections_per_client: 100,
78            max_connections_per_route: 1000,
79            max_total_connections: 10000,
80            max_idle_connections_per_upstream: 100,
81
82            // Request concurrency
83            max_in_flight_requests: 10000,
84            max_in_flight_requests_per_worker: 1000,
85            max_queued_requests: 1000,
86
87            // Agent communication
88            max_agent_queue_depth: 100,
89            max_agent_body_bytes: 1024 * 1024,   // 1MB to agents
90            max_agent_response_bytes: 10 * 1024, // 10KB from agents
91
92            // Rate limits (optional by default)
93            max_requests_per_second_global: None,
94            max_requests_per_second_per_client: None,
95            max_requests_per_second_per_route: None,
96
97            // Memory limits (optional by default)
98            max_memory_bytes: None,
99            max_memory_percent: None,
100        }
101    }
102}
103
104impl Limits {
105    /// Create limits suitable for testing (more permissive)
106    pub fn for_testing() -> Self {
107        Self {
108            max_header_size_bytes: 16384,
109            max_header_count: 200,
110            max_body_size_bytes: 100 * 1024 * 1024, // 100MB
111            max_in_flight_requests: 100000,
112            ..Default::default()
113        }
114    }
115
116    /// Create limits suitable for production (more restrictive)
117    pub fn for_production() -> Self {
118        Self {
119            max_header_size_bytes: 4096,
120            max_header_count: 50,
121            max_body_size_bytes: 1024 * 1024, // 1MB
122            max_in_flight_requests: 5000,
123            max_requests_per_second_global: Some(10000),
124            max_requests_per_second_per_client: Some(100),
125            max_memory_percent: Some(80.0),
126            ..Default::default()
127        }
128    }
129
130    /// Validate the limits configuration
131    pub fn validate(&self) -> SentinelResult<()> {
132        if self.max_header_size_bytes == 0 {
133            return Err(SentinelError::Config {
134                message: "max_header_size_bytes must be greater than 0".to_string(),
135                source: None,
136            });
137        }
138
139        if self.max_header_count == 0 {
140            return Err(SentinelError::Config {
141                message: "max_header_count must be greater than 0".to_string(),
142                source: None,
143            });
144        }
145
146        if self.max_body_buffer_bytes > self.max_body_size_bytes {
147            return Err(SentinelError::Config {
148                message: "max_body_buffer_bytes cannot exceed max_body_size_bytes".to_string(),
149                source: None,
150            });
151        }
152
153        if self.max_decompression_ratio <= 0.0 {
154            return Err(SentinelError::Config {
155                message: "max_decompression_ratio must be positive".to_string(),
156                source: None,
157            });
158        }
159
160        if let Some(pct) = self.max_memory_percent {
161            if pct <= 0.0 || pct > 100.0 {
162                return Err(SentinelError::Config {
163                    message: "max_memory_percent must be between 0 and 100".to_string(),
164                    source: None,
165                });
166            }
167        }
168
169        Ok(())
170    }
171
172    /// Check if a header size exceeds limits
173    pub fn check_header_size(&self, size: usize) -> SentinelResult<()> {
174        if size > self.max_header_size_bytes {
175            return Err(SentinelError::limit_exceeded(
176                LimitType::HeaderSize,
177                size,
178                self.max_header_size_bytes,
179            ));
180        }
181        Ok(())
182    }
183
184    /// Check if header count exceeds limits
185    pub fn check_header_count(&self, count: usize) -> SentinelResult<()> {
186        if count > self.max_header_count {
187            return Err(SentinelError::limit_exceeded(
188                LimitType::HeaderCount,
189                count,
190                self.max_header_count,
191            ));
192        }
193        Ok(())
194    }
195
196    /// Check if body size exceeds limits
197    pub fn check_body_size(&self, size: usize) -> SentinelResult<()> {
198        if size > self.max_body_size_bytes {
199            return Err(SentinelError::limit_exceeded(
200                LimitType::BodySize,
201                size,
202                self.max_body_size_bytes,
203            ));
204        }
205        Ok(())
206    }
207}
208
209/// Token bucket rate limiter implementation
210#[derive(Debug)]
211pub struct RateLimiter {
212    capacity: u32,
213    tokens: Arc<RwLock<f64>>,
214    refill_rate: f64,
215    last_refill: Arc<RwLock<Instant>>,
216}
217
218impl RateLimiter {
219    /// Create a new rate limiter with specified capacity and refill rate
220    pub fn new(capacity: u32, refill_per_second: u32) -> Self {
221        Self {
222            capacity,
223            tokens: Arc::new(RwLock::new(capacity as f64)),
224            refill_rate: refill_per_second as f64,
225            last_refill: Arc::new(RwLock::new(Instant::now())),
226        }
227    }
228
229    /// Try to acquire tokens, returns true if successful
230    pub fn try_acquire(&self, tokens: u32) -> bool {
231        self.refill();
232
233        let mut available_tokens = self.tokens.write();
234        if *available_tokens >= tokens as f64 {
235            *available_tokens -= tokens as f64;
236            true
237        } else {
238            false
239        }
240    }
241
242    /// Check if tokens are available without consuming
243    pub fn check(&self, tokens: u32) -> bool {
244        self.refill();
245        let available_tokens = self.tokens.read();
246        *available_tokens >= tokens as f64
247    }
248
249    /// Get current available tokens
250    pub fn available(&self) -> u32 {
251        self.refill();
252        let tokens = self.tokens.read();
253        *tokens as u32
254    }
255
256    /// Refill tokens based on elapsed time
257    fn refill(&self) {
258        let now = Instant::now();
259        let mut last_refill = self.last_refill.write();
260        let elapsed = now.duration_since(*last_refill).as_secs_f64();
261
262        if elapsed > 0.0 {
263            let mut tokens = self.tokens.write();
264            let tokens_to_add = elapsed * self.refill_rate;
265            *tokens = (*tokens + tokens_to_add).min(self.capacity as f64);
266            *last_refill = now;
267        }
268    }
269
270    /// Reset the rate limiter to full capacity
271    pub fn reset(&self) {
272        let mut tokens = self.tokens.write();
273        *tokens = self.capacity as f64;
274        let mut last_refill = self.last_refill.write();
275        *last_refill = Instant::now();
276    }
277}
278
279/// Multi-level rate limiter for different scopes
280pub struct MultiRateLimiter {
281    global: Option<RateLimiter>,
282    per_client: Arc<RwLock<HashMap<String, RateLimiter>>>,
283    per_route: Arc<RwLock<HashMap<String, RateLimiter>>>,
284    client_limit: Option<(u32, u32)>, // (capacity, refill_per_second)
285    route_limit: Option<(u32, u32)>,  // (capacity, refill_per_second)
286}
287
288impl MultiRateLimiter {
289    /// Create a new multi-level rate limiter
290    pub fn new(limits: &Limits) -> Self {
291        let global = limits
292            .max_requests_per_second_global
293            .map(|rps| RateLimiter::new(rps * 10, rps)); // 10 second burst
294
295        let client_limit = limits
296            .max_requests_per_second_per_client
297            .map(|rps| (rps * 10, rps));
298
299        let route_limit = limits
300            .max_requests_per_second_per_route
301            .map(|rps| (rps * 10, rps));
302
303        Self {
304            global,
305            per_client: Arc::new(RwLock::new(HashMap::new())),
306            per_route: Arc::new(RwLock::new(HashMap::new())),
307            client_limit,
308            route_limit,
309        }
310    }
311
312    /// Check if request is allowed for client and route
313    pub fn check_request(&self, client_id: &str, route: &str) -> SentinelResult<()> {
314        // Check global rate limit
315        if let Some(ref limiter) = self.global {
316            if !limiter.try_acquire(1) {
317                return Err(SentinelError::RateLimit {
318                    message: "Global rate limit exceeded".to_string(),
319                    limit: limiter.capacity,
320                    window_seconds: 10,
321                    retry_after_seconds: Some(1),
322                });
323            }
324        }
325
326        // Check per-client rate limit
327        if let Some((capacity, refill)) = self.client_limit {
328            let mut limiters = self.per_client.write();
329            let limiter = limiters
330                .entry(client_id.to_string())
331                .or_insert_with(|| RateLimiter::new(capacity, refill));
332
333            if !limiter.try_acquire(1) {
334                return Err(SentinelError::RateLimit {
335                    message: format!("Rate limit exceeded for client {}", client_id),
336                    limit: capacity,
337                    window_seconds: 10,
338                    retry_after_seconds: Some(1),
339                });
340            }
341        }
342
343        // Check per-route rate limit
344        if let Some((capacity, refill)) = self.route_limit {
345            let mut limiters = self.per_route.write();
346            let limiter = limiters
347                .entry(route.to_string())
348                .or_insert_with(|| RateLimiter::new(capacity, refill));
349
350            if !limiter.try_acquire(1) {
351                return Err(SentinelError::RateLimit {
352                    message: format!("Rate limit exceeded for route {}", route),
353                    limit: capacity,
354                    window_seconds: 10,
355                    retry_after_seconds: Some(1),
356                });
357            }
358        }
359
360        Ok(())
361    }
362
363    /// Clean up old rate limiters that haven't been used recently
364    pub fn cleanup(&self, _max_age: Duration) {
365        // TODO: Implement cleanup of unused rate limiters
366        // This would track last access time and remove old entries
367    }
368}
369
370/// Connection limiter for managing concurrent connections
371pub struct ConnectionLimiter {
372    per_client: Arc<RwLock<HashMap<String, usize>>>,
373    per_route: Arc<RwLock<HashMap<String, usize>>>,
374    total: Arc<RwLock<usize>>,
375    limits: Limits,
376}
377
378impl ConnectionLimiter {
379    pub fn new(limits: Limits) -> Self {
380        Self {
381            per_client: Arc::new(RwLock::new(HashMap::new())),
382            per_route: Arc::new(RwLock::new(HashMap::new())),
383            total: Arc::new(RwLock::new(0)),
384            limits,
385        }
386    }
387
388    /// Try to acquire a connection slot
389    pub fn try_acquire(&self, client_id: &str, route: &str) -> SentinelResult<ConnectionGuard<'_>> {
390        // Check total connections
391        {
392            let mut total = self.total.write();
393            if *total >= self.limits.max_total_connections {
394                return Err(SentinelError::limit_exceeded(
395                    LimitType::ConnectionCount,
396                    *total,
397                    self.limits.max_total_connections,
398                ));
399            }
400            *total += 1;
401        }
402
403        // Check per-client connections
404        {
405            let mut per_client = self.per_client.write();
406            let client_count = per_client.entry(client_id.to_string()).or_insert(0);
407            if *client_count >= self.limits.max_connections_per_client {
408                // Rollback total count
409                *self.total.write() -= 1;
410                return Err(SentinelError::limit_exceeded(
411                    LimitType::ConnectionCount,
412                    *client_count,
413                    self.limits.max_connections_per_client,
414                ));
415            }
416            *client_count += 1;
417        }
418
419        // Check per-route connections
420        {
421            let mut per_route = self.per_route.write();
422            let route_count = per_route.entry(route.to_string()).or_insert(0);
423            if *route_count >= self.limits.max_connections_per_route {
424                // Rollback counts
425                *self.total.write() -= 1;
426                *self.per_client.write().get_mut(client_id).unwrap() -= 1;
427                return Err(SentinelError::limit_exceeded(
428                    LimitType::ConnectionCount,
429                    *route_count,
430                    self.limits.max_connections_per_route,
431                ));
432            }
433            *route_count += 1;
434        }
435
436        Ok(ConnectionGuard {
437            limiter: self,
438            client_id: client_id.to_string(),
439            route: route.to_string(),
440        })
441    }
442
443    /// Release a connection slot
444    fn release(&self, client_id: &str, route: &str) {
445        *self.total.write() -= 1;
446
447        if let Some(count) = self.per_client.write().get_mut(client_id) {
448            *count = count.saturating_sub(1);
449        }
450
451        if let Some(count) = self.per_route.write().get_mut(route) {
452            *count = count.saturating_sub(1);
453        }
454    }
455
456    /// Get current connection statistics
457    pub fn stats(&self) -> ConnectionStats {
458        ConnectionStats {
459            total: *self.total.read(),
460            per_client_count: self.per_client.read().len(),
461            per_route_count: self.per_route.read().len(),
462        }
463    }
464}
465
466/// RAII guard for connection slots
467pub struct ConnectionGuard<'a> {
468    limiter: &'a ConnectionLimiter,
469    client_id: String,
470    route: String,
471}
472
473impl<'a> Drop for ConnectionGuard<'a> {
474    fn drop(&mut self) {
475        self.limiter.release(&self.client_id, &self.route);
476    }
477}
478
479/// Connection statistics
480#[derive(Debug, Clone, Serialize)]
481pub struct ConnectionStats {
482    pub total: usize,
483    pub per_client_count: usize,
484    pub per_route_count: usize,
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use std::thread;
491    use std::time::Duration;
492
493    #[test]
494    fn test_limits_validation() {
495        let mut limits = Limits::default();
496        assert!(limits.validate().is_ok());
497
498        limits.max_header_size_bytes = 0;
499        assert!(limits.validate().is_err());
500
501        limits = Limits::default();
502        limits.max_body_buffer_bytes = limits.max_body_size_bytes + 1;
503        assert!(limits.validate().is_err());
504    }
505
506    #[test]
507    fn test_rate_limiter() {
508        let limiter = RateLimiter::new(10, 10);
509
510        // Should allow initial burst
511        for _ in 0..10 {
512            assert!(limiter.try_acquire(1));
513        }
514
515        // Should be exhausted
516        assert!(!limiter.try_acquire(1));
517
518        // Wait for refill
519        thread::sleep(Duration::from_millis(200));
520
521        // Should have some tokens refilled (approximately 2)
522        assert!(limiter.try_acquire(1));
523        assert!(limiter.available() > 0);
524    }
525
526    #[test]
527    fn test_connection_limiter() {
528        let limits = Limits {
529            max_total_connections: 100,
530            max_connections_per_client: 10,
531            max_connections_per_route: 50,
532            ..Default::default()
533        };
534
535        let limiter = ConnectionLimiter::new(limits);
536
537        // Acquire connections
538        let _guard1 = limiter.try_acquire("client1", "route1").unwrap();
539        let _guard2 = limiter.try_acquire("client1", "route1").unwrap();
540
541        let stats = limiter.stats();
542        assert_eq!(stats.total, 2);
543
544        // Guards will release on drop
545    }
546}