warpdrive_proxy/middleware/
concurrency.rs

1//! Concurrency limiter middleware
2//!
3//! Limits the number of concurrent requests to prevent resource exhaustion.
4//! Uses tokio Semaphore for efficient async concurrency control.
5//!
6//! # Trusted Sources
7//!
8//! Requests from trusted IP ranges (set by trusted_ranges middleware) bypass
9//! concurrency limits entirely. This allows proxies/CDNs to send high volumes
10//! without being throttled.
11
12use async_trait::async_trait;
13use pingora::prelude::*;
14use std::sync::Arc;
15use tokio::sync::Semaphore;
16use tracing::{debug, warn};
17
18use super::{Middleware, MiddlewareContext};
19
20/// Concurrency limiter middleware
21///
22/// Enforces a maximum number of concurrent requests. When the limit is reached,
23/// additional requests are queued (awaiting a permit) or rejected with 503.
24pub struct ConcurrencyMiddleware {
25    /// Semaphore for concurrency control
26    semaphore: Arc<Semaphore>,
27    /// Whether concurrency limiting is enabled
28    enabled: bool,
29    /// Maximum concurrent requests (0 = unlimited)
30    max_concurrent: usize,
31}
32
33impl ConcurrencyMiddleware {
34    /// Create a new concurrency limiter middleware
35    ///
36    /// # Arguments
37    ///
38    /// * `enabled` - Whether to enable concurrency limiting
39    /// * `max_concurrent` - Maximum concurrent requests (0 = unlimited/disabled)
40    pub fn new(enabled: bool, max_concurrent: usize) -> Self {
41        let semaphore = if max_concurrent > 0 {
42            Arc::new(Semaphore::new(max_concurrent))
43        } else {
44            // Use a very large limit for "unlimited" (1 million concurrent requests)
45            // Note: Tokio's semaphore has a MAX_PERMITS limit
46            Arc::new(Semaphore::new(1_000_000))
47        };
48
49        debug!(
50            "Concurrency limiter initialized: enabled={}, max={}",
51            enabled, max_concurrent
52        );
53
54        Self {
55            semaphore,
56            enabled,
57            max_concurrent,
58        }
59    }
60}
61
62#[async_trait]
63impl Middleware for ConcurrencyMiddleware {
64    /// Check concurrency limits before processing request
65    async fn request_filter(
66        &self,
67        session: &mut Session,
68        ctx: &mut MiddlewareContext,
69    ) -> Result<()> {
70        if !self.enabled || self.max_concurrent == 0 {
71            return Ok(());
72        }
73
74        // Skip concurrency limiting for trusted sources (proxies/CDNs)
75        if ctx.trusted_source {
76            debug!("Skipping concurrency limit for trusted source");
77            return Ok(());
78        }
79
80        // Try to acquire an owned permit (can be stored in context)
81        match self.semaphore.clone().try_acquire_owned() {
82            Ok(permit) => {
83                debug!("Concurrency permit acquired");
84
85                // Store permit in context - it will be held for the request duration
86                // and automatically released when the context is dropped
87                ctx.concurrency_permit = Some(permit);
88                Ok(())
89            }
90            Err(_) => {
91                warn!(
92                    "Concurrency limit reached (max: {}), rejecting request",
93                    self.max_concurrent
94                );
95
96                // Return 503 Service Unavailable
97                session.respond_error(503).await?;
98
99                Err(Error::explain(
100                    ErrorType::HTTPStatus(503),
101                    format!("Concurrency limit reached: {}", self.max_concurrent),
102                ))
103            }
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_concurrency_middleware_creation() {
114        let mw = ConcurrencyMiddleware::new(true, 100);
115        assert!(mw.enabled);
116        assert_eq!(mw.max_concurrent, 100);
117    }
118
119    #[test]
120    fn test_concurrency_middleware_disabled() {
121        let mw = ConcurrencyMiddleware::new(false, 100);
122        assert!(!mw.enabled);
123    }
124
125    #[test]
126    fn test_concurrency_middleware_unlimited() {
127        let mw = ConcurrencyMiddleware::new(true, 0);
128        assert!(mw.enabled);
129        assert_eq!(mw.max_concurrent, 0);
130    }
131
132    #[tokio::test]
133    async fn test_concurrency_permit_acquire() {
134        let mw = ConcurrencyMiddleware::new(true, 2);
135
136        // Should acquire first permit
137        let _permit1 = mw.semaphore.clone().try_acquire_owned();
138        assert!(_permit1.is_ok());
139
140        // Should acquire second permit
141        let _permit2 = mw.semaphore.clone().try_acquire_owned();
142        assert!(_permit2.is_ok());
143
144        // Should fail to acquire third (limit reached)
145        let permit3 = mw.semaphore.clone().try_acquire_owned();
146        assert!(permit3.is_err());
147    }
148}