bws_web_server/server/
middleware.rs

1use async_trait::async_trait;
2use pingora::prelude::*;
3use std::collections::HashMap;
4use std::time::Instant;
5
6#[async_trait]
7pub trait Middleware: Send + Sync {
8    async fn before_request(&self, session: &mut Session) -> Result<bool>;
9    async fn after_response(&self, session: &mut Session) -> Result<()>;
10}
11
12pub struct MiddlewareStack {
13    middlewares: Vec<Box<dyn Middleware>>,
14}
15
16impl MiddlewareStack {
17    pub fn new() -> Self {
18        Self {
19            middlewares: Vec::new(),
20        }
21    }
22
23    pub fn add_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
24        self.middlewares.push(Box::new(middleware));
25        self
26    }
27
28    pub async fn before_request(&self, session: &mut Session) -> Result<bool> {
29        for middleware in &self.middlewares {
30            if !middleware.before_request(session).await? {
31                return Ok(false);
32            }
33        }
34        Ok(true)
35    }
36
37    pub async fn after_response(&self, session: &mut Session) -> Result<()> {
38        // Execute in reverse order
39        for middleware in self.middlewares.iter().rev() {
40            middleware.after_response(session).await?;
41        }
42        Ok(())
43    }
44}
45
46// Logging middleware
47pub struct LoggingMiddleware {
48    log_requests: bool,
49}
50
51impl LoggingMiddleware {
52    pub fn new(log_requests: bool) -> Self {
53        Self { log_requests }
54    }
55}
56
57#[async_trait]
58impl Middleware for LoggingMiddleware {
59    async fn before_request(&self, session: &mut Session) -> Result<bool> {
60        if self.log_requests {
61            let _start_time = Instant::now();
62            // TODO: Store start time when session variables are available
63            // session.set_var("request_start_time", start_time);
64
65            log::info!(
66                "Request started: {} {} from {}",
67                session.req_header().method,
68                session.req_header().uri,
69                session
70                    .client_addr()
71                    .map(|addr| addr.to_string())
72                    .unwrap_or_else(|| "unknown".to_string())
73            );
74        }
75        Ok(true)
76    }
77
78    async fn after_response(&self, session: &mut Session) -> Result<()> {
79        if self.log_requests {
80            // TODO: Implement proper request timing when session variables are available
81            let status = session
82                .response_written()
83                .map(|r| r.status.as_u16())
84                .unwrap_or(0);
85
86            log::info!(
87                "Request completed: {} {} (status: {})",
88                session.req_header().method,
89                session.req_header().uri,
90                status
91            );
92        }
93        Ok(())
94    }
95}
96
97// Rate limiting middleware
98#[allow(dead_code)]
99pub struct RateLimitMiddleware {
100    requests_per_minute: u32,
101    burst_size: u32,
102    clients: HashMap<String, ClientInfo>,
103}
104
105#[derive(Debug, Clone)]
106#[allow(dead_code)]
107struct ClientInfo {
108    request_count: u32,
109    last_reset: Instant,
110    tokens: u32,
111}
112
113impl RateLimitMiddleware {
114    pub fn new(requests_per_minute: u32, burst_size: u32) -> Self {
115        Self {
116            requests_per_minute,
117            burst_size,
118            clients: HashMap::new(),
119        }
120    }
121
122    fn get_client_ip(&self, session: &Session) -> String {
123        // Try to get real IP from headers (for reverse proxy setups)
124        if let Some(forwarded_for) = session.req_header().headers.get("X-Forwarded-For") {
125            if let Ok(forwarded_str) = forwarded_for.to_str() {
126                if let Some(ip) = forwarded_str.split(',').next() {
127                    return ip.trim().to_string();
128                }
129            }
130        }
131
132        if let Some(real_ip) = session.req_header().headers.get("X-Real-IP") {
133            if let Ok(ip_str) = real_ip.to_str() {
134                return ip_str.to_string();
135            }
136        }
137
138        session
139            .client_addr()
140            .map(|addr| addr.to_string())
141            .unwrap_or_else(|| "unknown".to_string())
142    }
143
144    #[allow(dead_code)]
145    fn is_allowed(&mut self, client_ip: &str) -> bool {
146        let now = Instant::now();
147        let client_info = self
148            .clients
149            .entry(client_ip.to_string())
150            .or_insert(ClientInfo {
151                request_count: 0,
152                last_reset: now,
153                tokens: self.burst_size,
154            });
155
156        // Reset counters if a minute has passed
157        if now.duration_since(client_info.last_reset).as_secs() >= 60 {
158            client_info.request_count = 0;
159            client_info.last_reset = now;
160            client_info.tokens = self.burst_size;
161        }
162
163        // Add tokens based on time elapsed
164        let seconds_elapsed = now.duration_since(client_info.last_reset).as_secs();
165        let tokens_to_add = (seconds_elapsed * self.requests_per_minute as u64) / 60;
166        client_info.tokens = (client_info.tokens + tokens_to_add as u32).min(self.burst_size);
167
168        // Check if request is allowed
169        if client_info.tokens > 0 {
170            client_info.tokens -= 1;
171            client_info.request_count += 1;
172            true
173        } else {
174            false
175        }
176    }
177}
178
179#[async_trait]
180impl Middleware for RateLimitMiddleware {
181    async fn before_request(&self, session: &mut Session) -> Result<bool> {
182        let client_ip = self.get_client_ip(session);
183
184        // Note: This implementation has a concurrency issue with the mutable reference.
185        // In a real implementation, you'd use Arc<Mutex<>> or similar
186        // For now, we'll allow all requests
187        log::debug!("Rate limit check for client: {}", client_ip);
188        Ok(true)
189    }
190
191    async fn after_response(&self, _session: &mut Session) -> Result<()> {
192        Ok(())
193    }
194}
195
196// Security headers middleware
197pub struct SecurityHeadersMiddleware {
198    headers: HashMap<String, String>,
199}
200
201impl SecurityHeadersMiddleware {
202    pub fn new() -> Self {
203        let mut headers = HashMap::new();
204        headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
205        headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
206        headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
207        headers.insert(
208            "Referrer-Policy".to_string(),
209            "strict-origin-when-cross-origin".to_string(),
210        );
211
212        Self { headers }
213    }
214
215    pub fn with_header(mut self, name: String, value: String) -> Self {
216        self.headers.insert(name, value);
217        self
218    }
219
220    pub fn with_hsts(mut self, max_age: u32, include_subdomains: bool) -> Self {
221        let value = if include_subdomains {
222            format!("max-age={}; includeSubDomains", max_age)
223        } else {
224            format!("max-age={}", max_age)
225        };
226        self.headers
227            .insert("Strict-Transport-Security".to_string(), value);
228        self
229    }
230}
231
232#[async_trait]
233impl Middleware for SecurityHeadersMiddleware {
234    async fn before_request(&self, _session: &mut Session) -> Result<bool> {
235        Ok(true)
236    }
237
238    async fn after_response(&self, _session: &mut Session) -> Result<()> {
239        // Note: Adding headers after response is sent is not possible in this context
240        // This middleware would need to be integrated differently in the actual response handling
241        log::debug!("Security headers middleware executed (headers would be added to response)");
242        for (name, value) in &self.headers {
243            log::debug!("Would add header: {}: {}", name, value);
244        }
245        Ok(())
246    }
247}
248
249// CORS middleware
250pub struct CorsMiddleware {
251    allow_origins: Vec<String>,
252    allow_methods: Vec<String>,
253    allow_headers: Vec<String>,
254    allow_credentials: bool,
255    max_age: u32,
256}
257
258impl CorsMiddleware {
259    pub fn new() -> Self {
260        Self {
261            allow_origins: vec!["*".to_string()],
262            allow_methods: vec!["GET".to_string(), "HEAD".to_string(), "OPTIONS".to_string()],
263            allow_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
264            allow_credentials: false,
265            max_age: 86400,
266        }
267    }
268
269    pub fn allow_origins(mut self, origins: Vec<String>) -> Self {
270        self.allow_origins = origins;
271        self
272    }
273
274    pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
275        self.allow_methods = methods;
276        self
277    }
278
279    pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
280        self.allow_headers = headers;
281        self
282    }
283
284    pub fn allow_credentials(mut self, allow: bool) -> Self {
285        self.allow_credentials = allow;
286        self
287    }
288
289    pub fn max_age(mut self, age: u32) -> Self {
290        self.max_age = age;
291        self
292    }
293}
294
295#[async_trait]
296impl Middleware for CorsMiddleware {
297    async fn before_request(&self, session: &mut Session) -> Result<bool> {
298        // Handle preflight requests
299        if session.req_header().method == "OPTIONS" {
300            log::debug!("Handling CORS preflight request");
301            // In a real implementation, we'd send the CORS response here
302            // For now, just log that we would handle it
303            return Ok(true);
304        }
305        Ok(true)
306    }
307
308    async fn after_response(&self, _session: &mut Session) -> Result<()> {
309        // Add CORS headers to response
310        log::debug!("CORS middleware executed (headers would be added to response)");
311        log::debug!(
312            "Access-Control-Allow-Origin: {}",
313            self.allow_origins.join(", ")
314        );
315        log::debug!(
316            "Access-Control-Allow-Methods: {}",
317            self.allow_methods.join(", ")
318        );
319        log::debug!(
320            "Access-Control-Allow-Headers: {}",
321            self.allow_headers.join(", ")
322        );
323        Ok(())
324    }
325}
326
327impl Default for MiddlewareStack {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333impl Default for SecurityHeadersMiddleware {
334    fn default() -> Self {
335        Self::new()
336    }
337}
338
339impl Default for CorsMiddleware {
340    fn default() -> Self {
341        Self::new()
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_middleware_stack_creation() {
351        let stack = MiddlewareStack::new();
352        assert_eq!(stack.middlewares.len(), 0);
353    }
354
355    #[test]
356    fn test_middleware_stack_with_middleware() {
357        let stack = MiddlewareStack::new()
358            .add_middleware(LoggingMiddleware::new(true))
359            .add_middleware(SecurityHeadersMiddleware::new());
360
361        assert_eq!(stack.middlewares.len(), 2);
362    }
363
364    #[test]
365    fn test_rate_limit_middleware_creation() {
366        let middleware = RateLimitMiddleware::new(60, 10);
367        assert_eq!(middleware.requests_per_minute, 60);
368        assert_eq!(middleware.burst_size, 10);
369    }
370
371    #[test]
372    fn test_security_headers_middleware() {
373        let middleware = SecurityHeadersMiddleware::new()
374            .with_header("Custom-Header".to_string(), "Custom-Value".to_string())
375            .with_hsts(31536000, true);
376
377        assert!(middleware.headers.contains_key("Custom-Header"));
378        assert!(middleware.headers.contains_key("Strict-Transport-Security"));
379    }
380
381    #[test]
382    fn test_cors_middleware_configuration() {
383        let middleware = CorsMiddleware::new()
384            .allow_origins(vec!["https://example.com".to_string()])
385            .allow_methods(vec!["GET".to_string(), "POST".to_string()])
386            .allow_credentials(true)
387            .max_age(7200);
388
389        assert_eq!(middleware.allow_origins, vec!["https://example.com"]);
390        assert_eq!(middleware.allow_methods, vec!["GET", "POST"]);
391        assert!(middleware.allow_credentials);
392        assert_eq!(middleware.max_age, 7200);
393    }
394}