1use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tracing::{debug, trace, warn};
12
13use crate::errors::{LimitType, SentinelError, SentinelResult};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Limits {
18 pub max_header_size_bytes: usize,
20 pub max_header_count: usize,
21 pub max_header_name_bytes: usize,
22 pub max_header_value_bytes: usize,
23
24 pub max_body_size_bytes: usize,
26 pub max_body_buffer_bytes: usize,
27 pub max_body_inspection_bytes: usize,
28
29 pub max_decompression_ratio: f32,
31 pub max_decompressed_size_bytes: usize,
32
33 pub max_connections_per_client: usize,
35 pub max_connections_per_route: usize,
36 pub max_total_connections: usize,
37 pub max_idle_connections_per_upstream: usize,
38
39 pub max_in_flight_requests: usize,
41 pub max_in_flight_requests_per_worker: usize,
42 pub max_queued_requests: usize,
43
44 pub max_agent_queue_depth: usize,
46 pub max_agent_body_bytes: usize,
47 pub max_agent_response_bytes: usize,
48
49 pub max_requests_per_second_global: Option<u32>,
51 pub max_requests_per_second_per_client: Option<u32>,
52 pub max_requests_per_second_per_route: Option<u32>,
53
54 pub max_memory_bytes: Option<usize>,
56 pub max_memory_percent: Option<f32>,
57}
58
59impl Default for Limits {
60 fn default() -> Self {
61 Self {
62 max_header_size_bytes: 8192, max_header_count: 100, max_header_name_bytes: 256, max_header_value_bytes: 4096, max_body_size_bytes: 10 * 1024 * 1024,
70 max_body_buffer_bytes: 1024 * 1024,
71 max_body_inspection_bytes: 1024 * 1024,
72
73 max_decompression_ratio: 100.0,
75 max_decompressed_size_bytes: 100 * 1024 * 1024, max_connections_per_client: 100,
79 max_connections_per_route: 1000,
80 max_total_connections: 10000,
81 max_idle_connections_per_upstream: 100,
82
83 max_in_flight_requests: 10000,
85 max_in_flight_requests_per_worker: 1000,
86 max_queued_requests: 1000,
87
88 max_agent_queue_depth: 100,
90 max_agent_body_bytes: 1024 * 1024, max_agent_response_bytes: 10 * 1024, max_requests_per_second_global: None,
95 max_requests_per_second_per_client: None,
96 max_requests_per_second_per_route: None,
97
98 max_memory_bytes: None,
100 max_memory_percent: None,
101 }
102 }
103}
104
105impl Limits {
106 pub fn for_testing() -> Self {
108 Self {
109 max_header_size_bytes: 16384,
110 max_header_count: 200,
111 max_body_size_bytes: 100 * 1024 * 1024, max_in_flight_requests: 100000,
113 ..Default::default()
114 }
115 }
116
117 pub fn for_production() -> Self {
119 Self {
120 max_header_size_bytes: 4096,
121 max_header_count: 50,
122 max_body_size_bytes: 1024 * 1024, max_in_flight_requests: 5000,
124 max_requests_per_second_global: Some(10000),
125 max_requests_per_second_per_client: Some(100),
126 max_memory_percent: Some(80.0),
127 ..Default::default()
128 }
129 }
130
131 pub fn validate(&self) -> SentinelResult<()> {
133 if self.max_header_size_bytes == 0 {
134 return Err(SentinelError::Config {
135 message: "max_header_size_bytes must be greater than 0".to_string(),
136 source: None,
137 });
138 }
139
140 if self.max_header_count == 0 {
141 return Err(SentinelError::Config {
142 message: "max_header_count must be greater than 0".to_string(),
143 source: None,
144 });
145 }
146
147 if self.max_body_buffer_bytes > self.max_body_size_bytes {
148 return Err(SentinelError::Config {
149 message: "max_body_buffer_bytes cannot exceed max_body_size_bytes".to_string(),
150 source: None,
151 });
152 }
153
154 if self.max_decompression_ratio <= 0.0 {
155 return Err(SentinelError::Config {
156 message: "max_decompression_ratio must be positive".to_string(),
157 source: None,
158 });
159 }
160
161 if let Some(pct) = self.max_memory_percent {
162 if pct <= 0.0 || pct > 100.0 {
163 return Err(SentinelError::Config {
164 message: "max_memory_percent must be between 0 and 100".to_string(),
165 source: None,
166 });
167 }
168 }
169
170 Ok(())
171 }
172
173 pub fn check_header_size(&self, size: usize) -> SentinelResult<()> {
175 if size > self.max_header_size_bytes {
176 return Err(SentinelError::limit_exceeded(
177 LimitType::HeaderSize,
178 size,
179 self.max_header_size_bytes,
180 ));
181 }
182 Ok(())
183 }
184
185 pub fn check_header_count(&self, count: usize) -> SentinelResult<()> {
187 if count > self.max_header_count {
188 return Err(SentinelError::limit_exceeded(
189 LimitType::HeaderCount,
190 count,
191 self.max_header_count,
192 ));
193 }
194 Ok(())
195 }
196
197 pub fn check_body_size(&self, size: usize) -> SentinelResult<()> {
199 if size > self.max_body_size_bytes {
200 return Err(SentinelError::limit_exceeded(
201 LimitType::BodySize,
202 size,
203 self.max_body_size_bytes,
204 ));
205 }
206 Ok(())
207 }
208}
209
210#[derive(Debug)]
212pub struct RateLimiter {
213 capacity: u32,
214 tokens: Arc<RwLock<f64>>,
215 refill_rate: f64,
216 last_refill: Arc<RwLock<Instant>>,
217}
218
219impl RateLimiter {
220 pub fn new(capacity: u32, refill_per_second: u32) -> Self {
222 trace!(
223 capacity = capacity,
224 refill_per_second = refill_per_second,
225 "Creating rate limiter"
226 );
227 Self {
228 capacity,
229 tokens: Arc::new(RwLock::new(capacity as f64)),
230 refill_rate: refill_per_second as f64,
231 last_refill: Arc::new(RwLock::new(Instant::now())),
232 }
233 }
234
235 pub fn try_acquire(&self, tokens: u32) -> bool {
237 self.refill();
238
239 let mut available_tokens = self.tokens.write();
240 if *available_tokens >= tokens as f64 {
241 *available_tokens -= tokens as f64;
242 trace!(
243 tokens_requested = tokens,
244 tokens_remaining = *available_tokens as u32,
245 "Rate limiter: tokens acquired"
246 );
247 true
248 } else {
249 trace!(
250 tokens_requested = tokens,
251 tokens_available = *available_tokens as u32,
252 "Rate limiter: insufficient tokens"
253 );
254 false
255 }
256 }
257
258 pub fn check(&self, tokens: u32) -> bool {
260 self.refill();
261 let available_tokens = self.tokens.read();
262 *available_tokens >= tokens as f64
263 }
264
265 pub fn available(&self) -> u32 {
267 self.refill();
268 let tokens = self.tokens.read();
269 *tokens as u32
270 }
271
272 fn refill(&self) {
274 let now = Instant::now();
275 let mut last_refill = self.last_refill.write();
276 let elapsed = now.duration_since(*last_refill).as_secs_f64();
277
278 if elapsed > 0.0 {
279 let mut tokens = self.tokens.write();
280 let tokens_to_add = elapsed * self.refill_rate;
281 *tokens = (*tokens + tokens_to_add).min(self.capacity as f64);
282 *last_refill = now;
283 }
284 }
285
286 pub fn reset(&self) {
288 let mut tokens = self.tokens.write();
289 *tokens = self.capacity as f64;
290 let mut last_refill = self.last_refill.write();
291 *last_refill = Instant::now();
292 }
293
294 pub fn last_accessed(&self) -> Instant {
296 *self.last_refill.read()
297 }
298}
299
300pub struct MultiRateLimiter {
302 global: Option<RateLimiter>,
303 per_client: Arc<RwLock<HashMap<String, RateLimiter>>>,
304 per_route: Arc<RwLock<HashMap<String, RateLimiter>>>,
305 client_limit: Option<(u32, u32)>, route_limit: Option<(u32, u32)>, }
308
309impl MultiRateLimiter {
310 pub fn new(limits: &Limits) -> Self {
312 let global = limits
313 .max_requests_per_second_global
314 .map(|rps| RateLimiter::new(rps * 10, rps)); let client_limit = limits
317 .max_requests_per_second_per_client
318 .map(|rps| (rps * 10, rps));
319
320 let route_limit = limits
321 .max_requests_per_second_per_route
322 .map(|rps| (rps * 10, rps));
323
324 Self {
325 global,
326 per_client: Arc::new(RwLock::new(HashMap::new())),
327 per_route: Arc::new(RwLock::new(HashMap::new())),
328 client_limit,
329 route_limit,
330 }
331 }
332
333 pub fn check_request(&self, client_id: &str, route: &str) -> SentinelResult<()> {
335 trace!(
336 client_id = %client_id,
337 route = %route,
338 "Checking rate limits"
339 );
340
341 if let Some(ref limiter) = self.global {
343 if !limiter.try_acquire(1) {
344 warn!(
345 client_id = %client_id,
346 route = %route,
347 "Global rate limit exceeded"
348 );
349 return Err(SentinelError::RateLimit {
350 message: "Global rate limit exceeded".to_string(),
351 limit: limiter.capacity,
352 window_seconds: 10,
353 retry_after_seconds: Some(1),
354 });
355 }
356 }
357
358 if let Some((capacity, refill)) = self.client_limit {
360 let mut limiters = self.per_client.write();
361 let limiter = limiters
362 .entry(client_id.to_string())
363 .or_insert_with(|| RateLimiter::new(capacity, refill));
364
365 if !limiter.try_acquire(1) {
366 warn!(
367 client_id = %client_id,
368 route = %route,
369 "Per-client rate limit exceeded"
370 );
371 return Err(SentinelError::RateLimit {
372 message: format!("Rate limit exceeded for client {}", client_id),
373 limit: capacity,
374 window_seconds: 10,
375 retry_after_seconds: Some(1),
376 });
377 }
378 }
379
380 if let Some((capacity, refill)) = self.route_limit {
382 let mut limiters = self.per_route.write();
383 let limiter = limiters
384 .entry(route.to_string())
385 .or_insert_with(|| RateLimiter::new(capacity, refill));
386
387 if !limiter.try_acquire(1) {
388 warn!(
389 client_id = %client_id,
390 route = %route,
391 "Per-route rate limit exceeded"
392 );
393 return Err(SentinelError::RateLimit {
394 message: format!("Rate limit exceeded for route {}", route),
395 limit: capacity,
396 window_seconds: 10,
397 retry_after_seconds: Some(1),
398 });
399 }
400 }
401
402 trace!(
403 client_id = %client_id,
404 route = %route,
405 "Rate limits check passed"
406 );
407 Ok(())
408 }
409
410 pub fn cleanup(&self, max_age: Duration) -> (usize, usize) {
414 let now = Instant::now();
415
416 let clients_before = self.per_client.read().len();
418 self.per_client.write().retain(|client_id, limiter| {
419 let age = now.duration_since(limiter.last_accessed());
420 let keep = age < max_age;
421 if !keep {
422 trace!(
423 client_id = %client_id,
424 age_secs = age.as_secs(),
425 "Removing idle client rate limiter"
426 );
427 }
428 keep
429 });
430 let clients_removed = clients_before - self.per_client.read().len();
431
432 let routes_before = self.per_route.read().len();
434 self.per_route.write().retain(|route, limiter| {
435 let age = now.duration_since(limiter.last_accessed());
436 let keep = age < max_age;
437 if !keep {
438 trace!(
439 route = %route,
440 age_secs = age.as_secs(),
441 "Removing idle route rate limiter"
442 );
443 }
444 keep
445 });
446 let routes_removed = routes_before - self.per_route.read().len();
447
448 if clients_removed > 0 || routes_removed > 0 {
449 debug!(
450 clients_removed = clients_removed,
451 routes_removed = routes_removed,
452 clients_remaining = self.per_client.read().len(),
453 routes_remaining = self.per_route.read().len(),
454 "Rate limiter cleanup completed"
455 );
456 }
457
458 (clients_removed, routes_removed)
459 }
460
461 pub fn entry_counts(&self) -> (usize, usize) {
463 (self.per_client.read().len(), self.per_route.read().len())
464 }
465}
466
467pub struct ConnectionLimiter {
469 per_client: Arc<RwLock<HashMap<String, usize>>>,
470 per_route: Arc<RwLock<HashMap<String, usize>>>,
471 total: Arc<RwLock<usize>>,
472 limits: Limits,
473}
474
475impl ConnectionLimiter {
476 pub fn new(limits: Limits) -> Self {
477 debug!(
478 max_total = limits.max_total_connections,
479 max_per_client = limits.max_connections_per_client,
480 max_per_route = limits.max_connections_per_route,
481 "Creating connection limiter"
482 );
483 Self {
484 per_client: Arc::new(RwLock::new(HashMap::new())),
485 per_route: Arc::new(RwLock::new(HashMap::new())),
486 total: Arc::new(RwLock::new(0)),
487 limits,
488 }
489 }
490
491 pub fn try_acquire(&self, client_id: &str, route: &str) -> SentinelResult<ConnectionGuard<'_>> {
493 trace!(
494 client_id = %client_id,
495 route = %route,
496 "Attempting to acquire connection slot"
497 );
498
499 {
501 let mut total = self.total.write();
502 if *total >= self.limits.max_total_connections {
503 warn!(
504 current = *total,
505 max = self.limits.max_total_connections,
506 "Total connection limit exceeded"
507 );
508 return Err(SentinelError::limit_exceeded(
509 LimitType::ConnectionCount,
510 *total,
511 self.limits.max_total_connections,
512 ));
513 }
514 *total += 1;
515 }
516
517 {
519 let mut per_client = self.per_client.write();
520 let client_count = per_client.entry(client_id.to_string()).or_insert(0);
521 if *client_count >= self.limits.max_connections_per_client {
522 *self.total.write() -= 1;
524 warn!(
525 client_id = %client_id,
526 current = *client_count,
527 max = self.limits.max_connections_per_client,
528 "Per-client connection limit exceeded"
529 );
530 return Err(SentinelError::limit_exceeded(
531 LimitType::ConnectionCount,
532 *client_count,
533 self.limits.max_connections_per_client,
534 ));
535 }
536 *client_count += 1;
537 }
538
539 {
541 let mut per_route = self.per_route.write();
542 let route_count = per_route.entry(route.to_string()).or_insert(0);
543 if *route_count >= self.limits.max_connections_per_route {
544 *self.total.write() -= 1;
546 *self.per_client.write().get_mut(client_id).unwrap() -= 1;
547 warn!(
548 route = %route,
549 current = *route_count,
550 max = self.limits.max_connections_per_route,
551 "Per-route connection limit exceeded"
552 );
553 return Err(SentinelError::limit_exceeded(
554 LimitType::ConnectionCount,
555 *route_count,
556 self.limits.max_connections_per_route,
557 ));
558 }
559 *route_count += 1;
560 }
561
562 trace!(
563 client_id = %client_id,
564 route = %route,
565 "Connection slot acquired"
566 );
567
568 Ok(ConnectionGuard {
569 limiter: self,
570 client_id: client_id.to_string(),
571 route: route.to_string(),
572 })
573 }
574
575 fn release(&self, client_id: &str, route: &str) {
577 trace!(
578 client_id = %client_id,
579 route = %route,
580 "Releasing connection slot"
581 );
582
583 *self.total.write() -= 1;
584
585 if let Some(count) = self.per_client.write().get_mut(client_id) {
586 *count = count.saturating_sub(1);
587 }
588
589 if let Some(count) = self.per_route.write().get_mut(route) {
590 *count = count.saturating_sub(1);
591 }
592 }
593
594 pub fn stats(&self) -> ConnectionStats {
596 ConnectionStats {
597 total: *self.total.read(),
598 per_client_count: self.per_client.read().len(),
599 per_route_count: self.per_route.read().len(),
600 }
601 }
602}
603
604pub struct ConnectionGuard<'a> {
606 limiter: &'a ConnectionLimiter,
607 client_id: String,
608 route: String,
609}
610
611impl Drop for ConnectionGuard<'_> {
612 fn drop(&mut self) {
613 self.limiter.release(&self.client_id, &self.route);
614 }
615}
616
617#[derive(Debug, Clone, Serialize)]
619pub struct ConnectionStats {
620 pub total: usize,
621 pub per_client_count: usize,
622 pub per_route_count: usize,
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 use std::thread;
629 use std::time::Duration;
630
631 #[test]
632 fn test_limits_validation() {
633 let mut limits = Limits::default();
634 assert!(limits.validate().is_ok());
635
636 limits.max_header_size_bytes = 0;
637 assert!(limits.validate().is_err());
638
639 limits = Limits::default();
640 limits.max_body_buffer_bytes = limits.max_body_size_bytes + 1;
641 assert!(limits.validate().is_err());
642 }
643
644 #[test]
645 fn test_rate_limiter() {
646 let limiter = RateLimiter::new(10, 10);
647
648 for _ in 0..10 {
650 assert!(limiter.try_acquire(1));
651 }
652
653 assert!(!limiter.try_acquire(1));
655
656 thread::sleep(Duration::from_millis(200));
658
659 assert!(limiter.try_acquire(1));
661 assert!(limiter.available() > 0);
662 }
663
664 #[test]
665 fn test_connection_limiter() {
666 let limits = Limits {
667 max_total_connections: 100,
668 max_connections_per_client: 10,
669 max_connections_per_route: 50,
670 ..Default::default()
671 };
672
673 let limiter = ConnectionLimiter::new(limits);
674
675 let _guard1 = limiter.try_acquire("client1", "route1").unwrap();
677 let _guard2 = limiter.try_acquire("client1", "route1").unwrap();
678
679 let stats = limiter.stats();
680 assert_eq!(stats.total, 2);
681
682 }
684
685 #[test]
686 fn test_rate_limiter_last_accessed() {
687 let limiter = RateLimiter::new(10, 10);
688 let before = Instant::now();
689
690 limiter.try_acquire(1);
692
693 let last_accessed = limiter.last_accessed();
694 assert!(last_accessed >= before);
695 assert!(last_accessed <= Instant::now());
696 }
697
698 #[test]
699 fn test_multi_rate_limiter_entry_counts() {
700 let limits = Limits {
701 max_requests_per_second_per_client: Some(100),
702 max_requests_per_second_per_route: Some(1000),
703 ..Default::default()
704 };
705
706 let limiter = MultiRateLimiter::new(&limits);
707
708 assert_eq!(limiter.entry_counts(), (0, 0));
710
711 let _ = limiter.check_request("client1", "route1");
713 let _ = limiter.check_request("client2", "route1");
714 let _ = limiter.check_request("client1", "route2");
715
716 assert_eq!(limiter.entry_counts(), (2, 2));
718 }
719
720 #[test]
721 fn test_multi_rate_limiter_cleanup() {
722 let limits = Limits {
723 max_requests_per_second_per_client: Some(100),
724 max_requests_per_second_per_route: Some(1000),
725 ..Default::default()
726 };
727
728 let limiter = MultiRateLimiter::new(&limits);
729
730 let _ = limiter.check_request("client1", "route1");
732 let _ = limiter.check_request("client2", "route2");
733
734 assert_eq!(limiter.entry_counts(), (2, 2));
735
736 let (clients_removed, routes_removed) = limiter.cleanup(Duration::from_secs(3600));
738 assert_eq!(clients_removed, 0);
739 assert_eq!(routes_removed, 0);
740 assert_eq!(limiter.entry_counts(), (2, 2));
741
742 thread::sleep(Duration::from_millis(50));
744
745 let (clients_removed, routes_removed) = limiter.cleanup(Duration::from_millis(10));
747 assert_eq!(clients_removed, 2);
748 assert_eq!(routes_removed, 2);
749 assert_eq!(limiter.entry_counts(), (0, 0));
750 }
751
752 #[test]
753 fn test_multi_rate_limiter_cleanup_partial() {
754 let limits = Limits {
755 max_requests_per_second_per_client: Some(100),
756 max_requests_per_second_per_route: Some(1000),
757 ..Default::default()
758 };
759
760 let limiter = MultiRateLimiter::new(&limits);
761
762 let _ = limiter.check_request("old_client", "old_route");
764
765 thread::sleep(Duration::from_millis(60));
767
768 let _ = limiter.check_request("new_client", "new_route");
770
771 assert_eq!(limiter.entry_counts(), (2, 2));
772
773 let (clients_removed, routes_removed) = limiter.cleanup(Duration::from_millis(30));
775 assert_eq!(clients_removed, 1);
776 assert_eq!(routes_removed, 1);
777 assert_eq!(limiter.entry_counts(), (1, 1));
778
779 }
782}