sentinel_common/
limits.rs

1//! Limits and rate limiting for Sentinel proxy
2//!
3//! This module implements bounded limits for all resources to ensure predictable
4//! behavior and prevent resource exhaustion - core to "sleepable ops".
5
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tracing::{debug, trace, warn};
12
13use crate::errors::{LimitType, SentinelError, SentinelResult};
14
15/// System-wide limits configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Limits {
18    // Header limits
19    pub max_header_size_bytes: usize,
20    pub max_header_count: usize,
21    pub max_header_name_bytes: usize,
22    pub max_header_value_bytes: usize,
23
24    // Body limits
25    pub max_body_size_bytes: usize,
26    pub max_body_buffer_bytes: usize,
27    pub max_body_inspection_bytes: usize,
28
29    // Decompression limits
30    pub max_decompression_ratio: f32,
31    pub max_decompressed_size_bytes: usize,
32
33    // Connection limits
34    pub max_connections_per_client: usize,
35    pub max_connections_per_route: usize,
36    pub max_total_connections: usize,
37    pub max_idle_connections_per_upstream: usize,
38
39    // Request limits
40    pub max_in_flight_requests: usize,
41    pub max_in_flight_requests_per_worker: usize,
42    pub max_queued_requests: usize,
43
44    // Agent limits
45    pub max_agent_queue_depth: usize,
46    pub max_agent_body_bytes: usize,
47    pub max_agent_response_bytes: usize,
48
49    // Rate limits
50    pub max_requests_per_second_global: Option<u32>,
51    pub max_requests_per_second_per_client: Option<u32>,
52    pub max_requests_per_second_per_route: Option<u32>,
53
54    // Memory limits
55    pub max_memory_bytes: Option<usize>,
56    pub max_memory_percent: Option<f32>,
57}
58
59impl Default for Limits {
60    fn default() -> Self {
61        Self {
62            // Conservative header limits
63            max_header_size_bytes: 8192,  // 8KB total headers
64            max_header_count: 100,        // Max 100 headers
65            max_header_name_bytes: 256,   // 256 bytes per header name
66            max_header_value_bytes: 4096, // 4KB per header value
67
68            // Body limits - 10MB default, 1MB buffer
69            max_body_size_bytes: 10 * 1024 * 1024,
70            max_body_buffer_bytes: 1024 * 1024,
71            max_body_inspection_bytes: 1024 * 1024,
72
73            // Decompression protection
74            max_decompression_ratio: 100.0,
75            max_decompressed_size_bytes: 100 * 1024 * 1024, // 100MB
76
77            // Connection limits
78            max_connections_per_client: 100,
79            max_connections_per_route: 1000,
80            max_total_connections: 10000,
81            max_idle_connections_per_upstream: 100,
82
83            // Request concurrency
84            max_in_flight_requests: 10000,
85            max_in_flight_requests_per_worker: 1000,
86            max_queued_requests: 1000,
87
88            // Agent communication
89            max_agent_queue_depth: 100,
90            max_agent_body_bytes: 1024 * 1024,   // 1MB to agents
91            max_agent_response_bytes: 10 * 1024, // 10KB from agents
92
93            // Rate limits (optional by default)
94            max_requests_per_second_global: None,
95            max_requests_per_second_per_client: None,
96            max_requests_per_second_per_route: None,
97
98            // Memory limits (optional by default)
99            max_memory_bytes: None,
100            max_memory_percent: None,
101        }
102    }
103}
104
105impl Limits {
106    /// Create limits suitable for testing (more permissive)
107    pub fn for_testing() -> Self {
108        Self {
109            max_header_size_bytes: 16384,
110            max_header_count: 200,
111            max_body_size_bytes: 100 * 1024 * 1024, // 100MB
112            max_in_flight_requests: 100000,
113            ..Default::default()
114        }
115    }
116
117    /// Create limits suitable for production (more restrictive)
118    pub fn for_production() -> Self {
119        Self {
120            max_header_size_bytes: 4096,
121            max_header_count: 50,
122            max_body_size_bytes: 1024 * 1024, // 1MB
123            max_in_flight_requests: 5000,
124            max_requests_per_second_global: Some(10000),
125            max_requests_per_second_per_client: Some(100),
126            max_memory_percent: Some(80.0),
127            ..Default::default()
128        }
129    }
130
131    /// Validate the limits configuration
132    pub fn validate(&self) -> SentinelResult<()> {
133        if self.max_header_size_bytes == 0 {
134            return Err(SentinelError::Config {
135                message: "max_header_size_bytes must be greater than 0".to_string(),
136                source: None,
137            });
138        }
139
140        if self.max_header_count == 0 {
141            return Err(SentinelError::Config {
142                message: "max_header_count must be greater than 0".to_string(),
143                source: None,
144            });
145        }
146
147        if self.max_body_buffer_bytes > self.max_body_size_bytes {
148            return Err(SentinelError::Config {
149                message: "max_body_buffer_bytes cannot exceed max_body_size_bytes".to_string(),
150                source: None,
151            });
152        }
153
154        if self.max_decompression_ratio <= 0.0 {
155            return Err(SentinelError::Config {
156                message: "max_decompression_ratio must be positive".to_string(),
157                source: None,
158            });
159        }
160
161        if let Some(pct) = self.max_memory_percent {
162            if pct <= 0.0 || pct > 100.0 {
163                return Err(SentinelError::Config {
164                    message: "max_memory_percent must be between 0 and 100".to_string(),
165                    source: None,
166                });
167            }
168        }
169
170        Ok(())
171    }
172
173    /// Check if a header size exceeds limits
174    pub fn check_header_size(&self, size: usize) -> SentinelResult<()> {
175        if size > self.max_header_size_bytes {
176            return Err(SentinelError::limit_exceeded(
177                LimitType::HeaderSize,
178                size,
179                self.max_header_size_bytes,
180            ));
181        }
182        Ok(())
183    }
184
185    /// Check if header count exceeds limits
186    pub fn check_header_count(&self, count: usize) -> SentinelResult<()> {
187        if count > self.max_header_count {
188            return Err(SentinelError::limit_exceeded(
189                LimitType::HeaderCount,
190                count,
191                self.max_header_count,
192            ));
193        }
194        Ok(())
195    }
196
197    /// Check if body size exceeds limits
198    pub fn check_body_size(&self, size: usize) -> SentinelResult<()> {
199        if size > self.max_body_size_bytes {
200            return Err(SentinelError::limit_exceeded(
201                LimitType::BodySize,
202                size,
203                self.max_body_size_bytes,
204            ));
205        }
206        Ok(())
207    }
208}
209
210/// Token bucket rate limiter implementation
211#[derive(Debug)]
212pub struct RateLimiter {
213    capacity: u32,
214    tokens: Arc<RwLock<f64>>,
215    refill_rate: f64,
216    last_refill: Arc<RwLock<Instant>>,
217}
218
219impl RateLimiter {
220    /// Create a new rate limiter with specified capacity and refill rate
221    pub fn new(capacity: u32, refill_per_second: u32) -> Self {
222        trace!(
223            capacity = capacity,
224            refill_per_second = refill_per_second,
225            "Creating rate limiter"
226        );
227        Self {
228            capacity,
229            tokens: Arc::new(RwLock::new(capacity as f64)),
230            refill_rate: refill_per_second as f64,
231            last_refill: Arc::new(RwLock::new(Instant::now())),
232        }
233    }
234
235    /// Try to acquire tokens, returns true if successful
236    pub fn try_acquire(&self, tokens: u32) -> bool {
237        self.refill();
238
239        let mut available_tokens = self.tokens.write();
240        if *available_tokens >= tokens as f64 {
241            *available_tokens -= tokens as f64;
242            trace!(
243                tokens_requested = tokens,
244                tokens_remaining = *available_tokens as u32,
245                "Rate limiter: tokens acquired"
246            );
247            true
248        } else {
249            trace!(
250                tokens_requested = tokens,
251                tokens_available = *available_tokens as u32,
252                "Rate limiter: insufficient tokens"
253            );
254            false
255        }
256    }
257
258    /// Check if tokens are available without consuming
259    pub fn check(&self, tokens: u32) -> bool {
260        self.refill();
261        let available_tokens = self.tokens.read();
262        *available_tokens >= tokens as f64
263    }
264
265    /// Get current available tokens
266    pub fn available(&self) -> u32 {
267        self.refill();
268        let tokens = self.tokens.read();
269        *tokens as u32
270    }
271
272    /// Refill tokens based on elapsed time
273    fn refill(&self) {
274        let now = Instant::now();
275        let mut last_refill = self.last_refill.write();
276        let elapsed = now.duration_since(*last_refill).as_secs_f64();
277
278        if elapsed > 0.0 {
279            let mut tokens = self.tokens.write();
280            let tokens_to_add = elapsed * self.refill_rate;
281            *tokens = (*tokens + tokens_to_add).min(self.capacity as f64);
282            *last_refill = now;
283        }
284    }
285
286    /// Reset the rate limiter to full capacity
287    pub fn reset(&self) {
288        let mut tokens = self.tokens.write();
289        *tokens = self.capacity as f64;
290        let mut last_refill = self.last_refill.write();
291        *last_refill = Instant::now();
292    }
293
294    /// Get the time of last activity (used for cleanup of idle limiters)
295    pub fn last_accessed(&self) -> Instant {
296        *self.last_refill.read()
297    }
298}
299
300/// Multi-level rate limiter for different scopes
301pub struct MultiRateLimiter {
302    global: Option<RateLimiter>,
303    per_client: Arc<RwLock<HashMap<String, RateLimiter>>>,
304    per_route: Arc<RwLock<HashMap<String, RateLimiter>>>,
305    client_limit: Option<(u32, u32)>, // (capacity, refill_per_second)
306    route_limit: Option<(u32, u32)>,  // (capacity, refill_per_second)
307}
308
309impl MultiRateLimiter {
310    /// Create a new multi-level rate limiter
311    pub fn new(limits: &Limits) -> Self {
312        let global = limits
313            .max_requests_per_second_global
314            .map(|rps| RateLimiter::new(rps * 10, rps)); // 10 second burst
315
316        let client_limit = limits
317            .max_requests_per_second_per_client
318            .map(|rps| (rps * 10, rps));
319
320        let route_limit = limits
321            .max_requests_per_second_per_route
322            .map(|rps| (rps * 10, rps));
323
324        Self {
325            global,
326            per_client: Arc::new(RwLock::new(HashMap::new())),
327            per_route: Arc::new(RwLock::new(HashMap::new())),
328            client_limit,
329            route_limit,
330        }
331    }
332
333    /// Check if request is allowed for client and route
334    pub fn check_request(&self, client_id: &str, route: &str) -> SentinelResult<()> {
335        trace!(
336            client_id = %client_id,
337            route = %route,
338            "Checking rate limits"
339        );
340
341        // Check global rate limit
342        if let Some(ref limiter) = self.global {
343            if !limiter.try_acquire(1) {
344                warn!(
345                    client_id = %client_id,
346                    route = %route,
347                    "Global rate limit exceeded"
348                );
349                return Err(SentinelError::RateLimit {
350                    message: "Global rate limit exceeded".to_string(),
351                    limit: limiter.capacity,
352                    window_seconds: 10,
353                    retry_after_seconds: Some(1),
354                });
355            }
356        }
357
358        // Check per-client rate limit
359        if let Some((capacity, refill)) = self.client_limit {
360            let mut limiters = self.per_client.write();
361            let limiter = limiters
362                .entry(client_id.to_string())
363                .or_insert_with(|| RateLimiter::new(capacity, refill));
364
365            if !limiter.try_acquire(1) {
366                warn!(
367                    client_id = %client_id,
368                    route = %route,
369                    "Per-client rate limit exceeded"
370                );
371                return Err(SentinelError::RateLimit {
372                    message: format!("Rate limit exceeded for client {}", client_id),
373                    limit: capacity,
374                    window_seconds: 10,
375                    retry_after_seconds: Some(1),
376                });
377            }
378        }
379
380        // Check per-route rate limit
381        if let Some((capacity, refill)) = self.route_limit {
382            let mut limiters = self.per_route.write();
383            let limiter = limiters
384                .entry(route.to_string())
385                .or_insert_with(|| RateLimiter::new(capacity, refill));
386
387            if !limiter.try_acquire(1) {
388                warn!(
389                    client_id = %client_id,
390                    route = %route,
391                    "Per-route rate limit exceeded"
392                );
393                return Err(SentinelError::RateLimit {
394                    message: format!("Rate limit exceeded for route {}", route),
395                    limit: capacity,
396                    window_seconds: 10,
397                    retry_after_seconds: Some(1),
398                });
399            }
400        }
401
402        trace!(
403            client_id = %client_id,
404            route = %route,
405            "Rate limits check passed"
406        );
407        Ok(())
408    }
409
410    /// Clean up old rate limiters that haven't been used recently
411    ///
412    /// Returns the number of entries removed (clients, routes).
413    pub fn cleanup(&self, max_age: Duration) -> (usize, usize) {
414        let now = Instant::now();
415
416        // Clean up per-client limiters
417        let clients_before = self.per_client.read().len();
418        self.per_client.write().retain(|client_id, limiter| {
419            let age = now.duration_since(limiter.last_accessed());
420            let keep = age < max_age;
421            if !keep {
422                trace!(
423                    client_id = %client_id,
424                    age_secs = age.as_secs(),
425                    "Removing idle client rate limiter"
426                );
427            }
428            keep
429        });
430        let clients_removed = clients_before - self.per_client.read().len();
431
432        // Clean up per-route limiters
433        let routes_before = self.per_route.read().len();
434        self.per_route.write().retain(|route, limiter| {
435            let age = now.duration_since(limiter.last_accessed());
436            let keep = age < max_age;
437            if !keep {
438                trace!(
439                    route = %route,
440                    age_secs = age.as_secs(),
441                    "Removing idle route rate limiter"
442                );
443            }
444            keep
445        });
446        let routes_removed = routes_before - self.per_route.read().len();
447
448        if clients_removed > 0 || routes_removed > 0 {
449            debug!(
450                clients_removed = clients_removed,
451                routes_removed = routes_removed,
452                clients_remaining = self.per_client.read().len(),
453                routes_remaining = self.per_route.read().len(),
454                "Rate limiter cleanup completed"
455            );
456        }
457
458        (clients_removed, routes_removed)
459    }
460
461    /// Get the current number of tracked clients and routes
462    pub fn entry_counts(&self) -> (usize, usize) {
463        (self.per_client.read().len(), self.per_route.read().len())
464    }
465}
466
467/// Connection limiter for managing concurrent connections
468pub struct ConnectionLimiter {
469    per_client: Arc<RwLock<HashMap<String, usize>>>,
470    per_route: Arc<RwLock<HashMap<String, usize>>>,
471    total: Arc<RwLock<usize>>,
472    limits: Limits,
473}
474
475impl ConnectionLimiter {
476    pub fn new(limits: Limits) -> Self {
477        debug!(
478            max_total = limits.max_total_connections,
479            max_per_client = limits.max_connections_per_client,
480            max_per_route = limits.max_connections_per_route,
481            "Creating connection limiter"
482        );
483        Self {
484            per_client: Arc::new(RwLock::new(HashMap::new())),
485            per_route: Arc::new(RwLock::new(HashMap::new())),
486            total: Arc::new(RwLock::new(0)),
487            limits,
488        }
489    }
490
491    /// Try to acquire a connection slot
492    pub fn try_acquire(&self, client_id: &str, route: &str) -> SentinelResult<ConnectionGuard<'_>> {
493        trace!(
494            client_id = %client_id,
495            route = %route,
496            "Attempting to acquire connection slot"
497        );
498
499        // Check total connections
500        {
501            let mut total = self.total.write();
502            if *total >= self.limits.max_total_connections {
503                warn!(
504                    current = *total,
505                    max = self.limits.max_total_connections,
506                    "Total connection limit exceeded"
507                );
508                return Err(SentinelError::limit_exceeded(
509                    LimitType::ConnectionCount,
510                    *total,
511                    self.limits.max_total_connections,
512                ));
513            }
514            *total += 1;
515        }
516
517        // Check per-client connections
518        {
519            let mut per_client = self.per_client.write();
520            let client_count = per_client.entry(client_id.to_string()).or_insert(0);
521            if *client_count >= self.limits.max_connections_per_client {
522                // Rollback total count
523                *self.total.write() -= 1;
524                warn!(
525                    client_id = %client_id,
526                    current = *client_count,
527                    max = self.limits.max_connections_per_client,
528                    "Per-client connection limit exceeded"
529                );
530                return Err(SentinelError::limit_exceeded(
531                    LimitType::ConnectionCount,
532                    *client_count,
533                    self.limits.max_connections_per_client,
534                ));
535            }
536            *client_count += 1;
537        }
538
539        // Check per-route connections
540        {
541            let mut per_route = self.per_route.write();
542            let route_count = per_route.entry(route.to_string()).or_insert(0);
543            if *route_count >= self.limits.max_connections_per_route {
544                // Rollback counts
545                *self.total.write() -= 1;
546                *self.per_client.write().get_mut(client_id).unwrap() -= 1;
547                warn!(
548                    route = %route,
549                    current = *route_count,
550                    max = self.limits.max_connections_per_route,
551                    "Per-route connection limit exceeded"
552                );
553                return Err(SentinelError::limit_exceeded(
554                    LimitType::ConnectionCount,
555                    *route_count,
556                    self.limits.max_connections_per_route,
557                ));
558            }
559            *route_count += 1;
560        }
561
562        trace!(
563            client_id = %client_id,
564            route = %route,
565            "Connection slot acquired"
566        );
567
568        Ok(ConnectionGuard {
569            limiter: self,
570            client_id: client_id.to_string(),
571            route: route.to_string(),
572        })
573    }
574
575    /// Release a connection slot
576    fn release(&self, client_id: &str, route: &str) {
577        trace!(
578            client_id = %client_id,
579            route = %route,
580            "Releasing connection slot"
581        );
582
583        *self.total.write() -= 1;
584
585        if let Some(count) = self.per_client.write().get_mut(client_id) {
586            *count = count.saturating_sub(1);
587        }
588
589        if let Some(count) = self.per_route.write().get_mut(route) {
590            *count = count.saturating_sub(1);
591        }
592    }
593
594    /// Get current connection statistics
595    pub fn stats(&self) -> ConnectionStats {
596        ConnectionStats {
597            total: *self.total.read(),
598            per_client_count: self.per_client.read().len(),
599            per_route_count: self.per_route.read().len(),
600        }
601    }
602}
603
604/// RAII guard for connection slots
605pub struct ConnectionGuard<'a> {
606    limiter: &'a ConnectionLimiter,
607    client_id: String,
608    route: String,
609}
610
611impl Drop for ConnectionGuard<'_> {
612    fn drop(&mut self) {
613        self.limiter.release(&self.client_id, &self.route);
614    }
615}
616
617/// Connection statistics
618#[derive(Debug, Clone, Serialize)]
619pub struct ConnectionStats {
620    pub total: usize,
621    pub per_client_count: usize,
622    pub per_route_count: usize,
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use std::thread;
629    use std::time::Duration;
630
631    #[test]
632    fn test_limits_validation() {
633        let mut limits = Limits::default();
634        assert!(limits.validate().is_ok());
635
636        limits.max_header_size_bytes = 0;
637        assert!(limits.validate().is_err());
638
639        limits = Limits::default();
640        limits.max_body_buffer_bytes = limits.max_body_size_bytes + 1;
641        assert!(limits.validate().is_err());
642    }
643
644    #[test]
645    fn test_rate_limiter() {
646        let limiter = RateLimiter::new(10, 10);
647
648        // Should allow initial burst
649        for _ in 0..10 {
650            assert!(limiter.try_acquire(1));
651        }
652
653        // Should be exhausted
654        assert!(!limiter.try_acquire(1));
655
656        // Wait for refill
657        thread::sleep(Duration::from_millis(200));
658
659        // Should have some tokens refilled (approximately 2)
660        assert!(limiter.try_acquire(1));
661        assert!(limiter.available() > 0);
662    }
663
664    #[test]
665    fn test_connection_limiter() {
666        let limits = Limits {
667            max_total_connections: 100,
668            max_connections_per_client: 10,
669            max_connections_per_route: 50,
670            ..Default::default()
671        };
672
673        let limiter = ConnectionLimiter::new(limits);
674
675        // Acquire connections
676        let _guard1 = limiter.try_acquire("client1", "route1").unwrap();
677        let _guard2 = limiter.try_acquire("client1", "route1").unwrap();
678
679        let stats = limiter.stats();
680        assert_eq!(stats.total, 2);
681
682        // Guards will release on drop
683    }
684
685    #[test]
686    fn test_rate_limiter_last_accessed() {
687        let limiter = RateLimiter::new(10, 10);
688        let before = Instant::now();
689
690        // Access the limiter
691        limiter.try_acquire(1);
692
693        let last_accessed = limiter.last_accessed();
694        assert!(last_accessed >= before);
695        assert!(last_accessed <= Instant::now());
696    }
697
698    #[test]
699    fn test_multi_rate_limiter_entry_counts() {
700        let limits = Limits {
701            max_requests_per_second_per_client: Some(100),
702            max_requests_per_second_per_route: Some(1000),
703            ..Default::default()
704        };
705
706        let limiter = MultiRateLimiter::new(&limits);
707
708        // Initially empty
709        assert_eq!(limiter.entry_counts(), (0, 0));
710
711        // Make requests from different clients/routes
712        let _ = limiter.check_request("client1", "route1");
713        let _ = limiter.check_request("client2", "route1");
714        let _ = limiter.check_request("client1", "route2");
715
716        // Should have 2 clients and 2 routes
717        assert_eq!(limiter.entry_counts(), (2, 2));
718    }
719
720    #[test]
721    fn test_multi_rate_limiter_cleanup() {
722        let limits = Limits {
723            max_requests_per_second_per_client: Some(100),
724            max_requests_per_second_per_route: Some(1000),
725            ..Default::default()
726        };
727
728        let limiter = MultiRateLimiter::new(&limits);
729
730        // Make requests to create entries
731        let _ = limiter.check_request("client1", "route1");
732        let _ = limiter.check_request("client2", "route2");
733
734        assert_eq!(limiter.entry_counts(), (2, 2));
735
736        // Cleanup with very long max_age should remove nothing
737        let (clients_removed, routes_removed) = limiter.cleanup(Duration::from_secs(3600));
738        assert_eq!(clients_removed, 0);
739        assert_eq!(routes_removed, 0);
740        assert_eq!(limiter.entry_counts(), (2, 2));
741
742        // Wait a bit
743        thread::sleep(Duration::from_millis(50));
744
745        // Cleanup with very short max_age should remove all
746        let (clients_removed, routes_removed) = limiter.cleanup(Duration::from_millis(10));
747        assert_eq!(clients_removed, 2);
748        assert_eq!(routes_removed, 2);
749        assert_eq!(limiter.entry_counts(), (0, 0));
750    }
751
752    #[test]
753    fn test_multi_rate_limiter_cleanup_partial() {
754        let limits = Limits {
755            max_requests_per_second_per_client: Some(100),
756            max_requests_per_second_per_route: Some(1000),
757            ..Default::default()
758        };
759
760        let limiter = MultiRateLimiter::new(&limits);
761
762        // Create old entry
763        let _ = limiter.check_request("old_client", "old_route");
764
765        // Wait
766        thread::sleep(Duration::from_millis(60));
767
768        // Create new entry
769        let _ = limiter.check_request("new_client", "new_route");
770
771        assert_eq!(limiter.entry_counts(), (2, 2));
772
773        // Cleanup with age that only removes old entries
774        let (clients_removed, routes_removed) = limiter.cleanup(Duration::from_millis(30));
775        assert_eq!(clients_removed, 1);
776        assert_eq!(routes_removed, 1);
777        assert_eq!(limiter.entry_counts(), (1, 1));
778
779        // Verify the new entries remain
780        // (they were accessed recently so should still exist)
781    }
782}