bws_web_server/middleware/
mod.rs

1pub mod compression;
2
3use async_trait::async_trait;
4use pingora::prelude::*;
5use std::collections::HashMap;
6use std::time::Instant;
7
8#[async_trait]
9pub trait Middleware: Send + Sync {
10    async fn before_request(&self, session: &mut Session) -> Result<bool>;
11    async fn after_response(&self, session: &mut Session) -> Result<()>;
12}
13
14pub struct MiddlewareStack {
15    middlewares: Vec<Box<dyn Middleware>>,
16}
17
18impl MiddlewareStack {
19    pub fn new() -> Self {
20        Self {
21            middlewares: Vec::new(),
22        }
23    }
24
25    pub fn add_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
26        self.middlewares.push(Box::new(middleware));
27        self
28    }
29
30    pub async fn before_request(&self, session: &mut Session) -> Result<bool> {
31        for middleware in &self.middlewares {
32            if !middleware.before_request(session).await? {
33                return Ok(false);
34            }
35        }
36        Ok(true)
37    }
38
39    pub async fn after_response(&self, session: &mut Session) -> Result<()> {
40        // Execute in reverse order
41        for middleware in self.middlewares.iter().rev() {
42            middleware.after_response(session).await?;
43        }
44        Ok(())
45    }
46}
47
48// Logging middleware
49pub struct LoggingMiddleware {
50    log_requests: bool,
51}
52
53impl LoggingMiddleware {
54    pub fn new(log_requests: bool) -> Self {
55        Self { log_requests }
56    }
57}
58
59#[async_trait]
60impl Middleware for LoggingMiddleware {
61    async fn before_request(&self, session: &mut Session) -> Result<bool> {
62        if self.log_requests {
63            let _start_time = Instant::now();
64            // TODO: Store start time when session variables are available
65            // session.set_var("request_start_time", start_time);
66
67            log::info!(
68                "Request started: {} {} from {}",
69                session.req_header().method,
70                session.req_header().uri,
71                session
72                    .client_addr()
73                    .map(|addr| addr.to_string())
74                    .unwrap_or_else(|| "unknown".to_string())
75            );
76        }
77        Ok(true)
78    }
79
80    async fn after_response(&self, session: &mut Session) -> Result<()> {
81        if self.log_requests {
82            // TODO: Implement proper request timing when session variables are available
83            let status = session
84                .response_written()
85                .map(|r| r.status.as_u16())
86                .unwrap_or(0);
87
88            log::info!(
89                "Request completed: {} {} (status: {})",
90                session.req_header().method,
91                session.req_header().uri,
92                status
93            );
94        }
95        Ok(())
96    }
97}
98
99// Rate limiting middleware (currently unused)
100#[allow(dead_code)]
101pub struct RateLimitMiddleware {
102    requests_per_minute: u32,
103    burst_size: u32,
104    clients: HashMap<String, ClientInfo>,
105}
106
107#[derive(Debug, Clone)]
108#[allow(dead_code)]
109struct ClientInfo {
110    request_count: u32,
111    last_reset: Instant,
112    tokens: u32,
113}
114
115#[allow(dead_code)]
116impl RateLimitMiddleware {
117    pub fn new(requests_per_minute: u32, burst_size: u32) -> Self {
118        Self {
119            requests_per_minute,
120            burst_size,
121            clients: HashMap::new(),
122        }
123    }
124
125    fn get_client_ip(&self, session: &Session) -> String {
126        // Try to get real IP from headers (for reverse proxy setups)
127        if let Some(forwarded_for) = session.req_header().headers.get("X-Forwarded-For") {
128            if let Ok(forwarded_str) = forwarded_for.to_str() {
129                if let Some(ip) = forwarded_str.split(',').next() {
130                    return ip.trim().to_string();
131                }
132            }
133        }
134
135        if let Some(real_ip) = session.req_header().headers.get("X-Real-IP") {
136            if let Ok(ip_str) = real_ip.to_str() {
137                return ip_str.to_string();
138            }
139        }
140
141        session
142            .client_addr()
143            .map(|addr| addr.to_string())
144            .unwrap_or_else(|| "unknown".to_string())
145    }
146
147    fn is_allowed(&mut self, client_ip: &str) -> bool {
148        let now = Instant::now();
149        let client_info = self
150            .clients
151            .entry(client_ip.to_string())
152            .or_insert(ClientInfo {
153                request_count: 0,
154                last_reset: now,
155                tokens: self.burst_size,
156            });
157
158        // Reset counters if a minute has passed
159        if now.duration_since(client_info.last_reset).as_secs() >= 60 {
160            client_info.request_count = 0;
161            client_info.last_reset = now;
162            client_info.tokens = self.burst_size;
163        }
164
165        // Add tokens based on time elapsed
166        let seconds_elapsed = now.duration_since(client_info.last_reset).as_secs();
167        let tokens_to_add = (seconds_elapsed * self.requests_per_minute as u64) / 60;
168        client_info.tokens = (client_info.tokens + tokens_to_add as u32).min(self.burst_size);
169
170        // Check if request is allowed
171        if client_info.tokens > 0 {
172            client_info.tokens -= 1;
173            client_info.request_count += 1;
174            true
175        } else {
176            false
177        }
178    }
179}
180
181#[async_trait]
182impl Middleware for RateLimitMiddleware {
183    async fn before_request(&self, session: &mut Session) -> Result<bool> {
184        let client_ip = self.get_client_ip(session);
185
186        // Note: This implementation has a concurrency issue with the mutable reference.
187        // In a real implementation, you'd use Arc<Mutex<>> or similar
188        // For now, we'll allow all requests
189        log::debug!("Rate limit check for client: {}", client_ip);
190        Ok(true)
191    }
192
193    async fn after_response(&self, _session: &mut Session) -> Result<()> {
194        Ok(())
195    }
196}
197
198// Security headers middleware
199pub struct SecurityHeadersMiddleware {
200    headers: HashMap<String, String>,
201}
202
203impl SecurityHeadersMiddleware {
204    pub fn new() -> Self {
205        let mut headers = HashMap::new();
206        headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
207        headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
208        headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
209        headers.insert(
210            "Referrer-Policy".to_string(),
211            "strict-origin-when-cross-origin".to_string(),
212        );
213
214        Self { headers }
215    }
216
217    pub fn with_header(mut self, name: String, value: String) -> Self {
218        self.headers.insert(name, value);
219        self
220    }
221
222    pub fn with_hsts(mut self, max_age: u32, include_subdomains: bool) -> Self {
223        let value = if include_subdomains {
224            format!("max-age={}; includeSubDomains", max_age)
225        } else {
226            format!("max-age={}", max_age)
227        };
228        self.headers
229            .insert("Strict-Transport-Security".to_string(), value);
230        self
231    }
232}
233
234#[async_trait]
235impl Middleware for SecurityHeadersMiddleware {
236    async fn before_request(&self, _session: &mut Session) -> Result<bool> {
237        Ok(true)
238    }
239
240    async fn after_response(&self, _session: &mut Session) -> Result<()> {
241        // Note: Adding headers after response is sent is not possible in this context
242        // This middleware would need to be integrated differently in the actual response handling
243        log::debug!("Security headers middleware executed (headers would be added to response)");
244        for (name, value) in &self.headers {
245            log::debug!("Would add header: {}: {}", name, value);
246        }
247        Ok(())
248    }
249}
250
251// CORS middleware
252pub struct CorsMiddleware {
253    allow_origins: Vec<String>,
254    allow_methods: Vec<String>,
255    allow_headers: Vec<String>,
256    allow_credentials: bool,
257    max_age: u32,
258}
259
260impl CorsMiddleware {
261    pub fn new() -> Self {
262        Self {
263            allow_origins: vec!["*".to_string()],
264            allow_methods: vec!["GET".to_string(), "HEAD".to_string(), "OPTIONS".to_string()],
265            allow_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
266            allow_credentials: false,
267            max_age: 86400,
268        }
269    }
270
271    pub fn allow_origins(mut self, origins: Vec<String>) -> Self {
272        self.allow_origins = origins;
273        self
274    }
275
276    pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
277        self.allow_methods = methods;
278        self
279    }
280
281    pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
282        self.allow_headers = headers;
283        self
284    }
285
286    pub fn allow_credentials(mut self, allow: bool) -> Self {
287        self.allow_credentials = allow;
288        self
289    }
290
291    pub fn max_age(mut self, age: u32) -> Self {
292        self.max_age = age;
293        self
294    }
295}
296
297#[async_trait]
298impl Middleware for CorsMiddleware {
299    async fn before_request(&self, session: &mut Session) -> Result<bool> {
300        // Handle preflight requests
301        if session.req_header().method == "OPTIONS" {
302            log::debug!("Handling CORS preflight request");
303            // In a real implementation, we'd send the CORS response here
304            // For now, just log that we would handle it
305            return Ok(true);
306        }
307        Ok(true)
308    }
309
310    async fn after_response(&self, _session: &mut Session) -> Result<()> {
311        // Add CORS headers to response
312        log::debug!("CORS middleware executed (headers would be added to response)");
313        log::debug!(
314            "Access-Control-Allow-Origin: {}",
315            self.allow_origins.join(", ")
316        );
317        log::debug!(
318            "Access-Control-Allow-Methods: {}",
319            self.allow_methods.join(", ")
320        );
321        log::debug!(
322            "Access-Control-Allow-Headers: {}",
323            self.allow_headers.join(", ")
324        );
325        Ok(())
326    }
327}
328
329impl Default for MiddlewareStack {
330    fn default() -> Self {
331        Self::new()
332    }
333}
334
335impl Default for SecurityHeadersMiddleware {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl Default for CorsMiddleware {
342    fn default() -> Self {
343        Self::new()
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_middleware_stack_creation() {
353        let stack = MiddlewareStack::new();
354        assert_eq!(stack.middlewares.len(), 0);
355    }
356
357    #[test]
358    fn test_middleware_stack_with_middleware() {
359        let stack = MiddlewareStack::new()
360            .add_middleware(LoggingMiddleware::new(true))
361            .add_middleware(SecurityHeadersMiddleware::new());
362
363        assert_eq!(stack.middlewares.len(), 2);
364    }
365
366    #[test]
367    fn test_rate_limit_middleware_creation() {
368        let middleware = RateLimitMiddleware::new(60, 10);
369        assert_eq!(middleware.requests_per_minute, 60);
370        assert_eq!(middleware.burst_size, 10);
371    }
372
373    #[test]
374    fn test_security_headers_middleware() {
375        let middleware = SecurityHeadersMiddleware::new()
376            .with_header("Custom-Header".to_string(), "Custom-Value".to_string())
377            .with_hsts(31536000, true);
378
379        assert!(middleware.headers.contains_key("Custom-Header"));
380        assert!(middleware.headers.contains_key("Strict-Transport-Security"));
381    }
382
383    #[test]
384    fn test_cors_middleware_configuration() {
385        let middleware = CorsMiddleware::new()
386            .allow_origins(vec!["https://example.com".to_string()])
387            .allow_methods(vec!["GET".to_string(), "POST".to_string()])
388            .allow_credentials(true)
389            .max_age(7200);
390
391        assert_eq!(middleware.allow_origins, vec!["https://example.com"]);
392        assert_eq!(middleware.allow_methods, vec!["GET", "POST"]);
393        assert!(middleware.allow_credentials);
394        assert_eq!(middleware.max_age, 7200);
395    }
396}