bws_web_server/middleware/
mod.rs1pub mod compression;
2
3use async_trait::async_trait;
4use pingora::prelude::*;
5use std::collections::HashMap;
6use std::time::Instant;
7
8#[async_trait]
9pub trait Middleware: Send + Sync {
10 async fn before_request(&self, session: &mut Session) -> Result<bool>;
11 async fn after_response(&self, session: &mut Session) -> Result<()>;
12}
13
14pub struct MiddlewareStack {
15 middlewares: Vec<Box<dyn Middleware>>,
16}
17
18impl MiddlewareStack {
19 pub fn new() -> Self {
20 Self {
21 middlewares: Vec::new(),
22 }
23 }
24
25 pub fn add_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
26 self.middlewares.push(Box::new(middleware));
27 self
28 }
29
30 pub async fn before_request(&self, session: &mut Session) -> Result<bool> {
31 for middleware in &self.middlewares {
32 if !middleware.before_request(session).await? {
33 return Ok(false);
34 }
35 }
36 Ok(true)
37 }
38
39 pub async fn after_response(&self, session: &mut Session) -> Result<()> {
40 for middleware in self.middlewares.iter().rev() {
42 middleware.after_response(session).await?;
43 }
44 Ok(())
45 }
46}
47
48pub struct LoggingMiddleware {
50 log_requests: bool,
51}
52
53impl LoggingMiddleware {
54 pub fn new(log_requests: bool) -> Self {
55 Self { log_requests }
56 }
57}
58
59#[async_trait]
60impl Middleware for LoggingMiddleware {
61 async fn before_request(&self, session: &mut Session) -> Result<bool> {
62 if self.log_requests {
63 let _start_time = Instant::now();
64 log::info!(
68 "Request started: {} {} from {}",
69 session.req_header().method,
70 session.req_header().uri,
71 session
72 .client_addr()
73 .map(|addr| addr.to_string())
74 .unwrap_or_else(|| "unknown".to_string())
75 );
76 }
77 Ok(true)
78 }
79
80 async fn after_response(&self, session: &mut Session) -> Result<()> {
81 if self.log_requests {
82 let status = session
84 .response_written()
85 .map(|r| r.status.as_u16())
86 .unwrap_or(0);
87
88 log::info!(
89 "Request completed: {} {} (status: {})",
90 session.req_header().method,
91 session.req_header().uri,
92 status
93 );
94 }
95 Ok(())
96 }
97}
98
99#[allow(dead_code)]
101pub struct RateLimitMiddleware {
102 requests_per_minute: u32,
103 burst_size: u32,
104 clients: HashMap<String, ClientInfo>,
105}
106
107#[derive(Debug, Clone)]
108#[allow(dead_code)]
109struct ClientInfo {
110 request_count: u32,
111 last_reset: Instant,
112 tokens: u32,
113}
114
115#[allow(dead_code)]
116impl RateLimitMiddleware {
117 pub fn new(requests_per_minute: u32, burst_size: u32) -> Self {
118 Self {
119 requests_per_minute,
120 burst_size,
121 clients: HashMap::new(),
122 }
123 }
124
125 fn get_client_ip(&self, session: &Session) -> String {
126 if let Some(forwarded_for) = session.req_header().headers.get("X-Forwarded-For") {
128 if let Ok(forwarded_str) = forwarded_for.to_str() {
129 if let Some(ip) = forwarded_str.split(',').next() {
130 return ip.trim().to_string();
131 }
132 }
133 }
134
135 if let Some(real_ip) = session.req_header().headers.get("X-Real-IP") {
136 if let Ok(ip_str) = real_ip.to_str() {
137 return ip_str.to_string();
138 }
139 }
140
141 session
142 .client_addr()
143 .map(|addr| addr.to_string())
144 .unwrap_or_else(|| "unknown".to_string())
145 }
146
147 fn is_allowed(&mut self, client_ip: &str) -> bool {
148 let now = Instant::now();
149 let client_info = self
150 .clients
151 .entry(client_ip.to_string())
152 .or_insert(ClientInfo {
153 request_count: 0,
154 last_reset: now,
155 tokens: self.burst_size,
156 });
157
158 if now.duration_since(client_info.last_reset).as_secs() >= 60 {
160 client_info.request_count = 0;
161 client_info.last_reset = now;
162 client_info.tokens = self.burst_size;
163 }
164
165 let seconds_elapsed = now.duration_since(client_info.last_reset).as_secs();
167 let tokens_to_add = (seconds_elapsed * self.requests_per_minute as u64) / 60;
168 client_info.tokens = (client_info.tokens + tokens_to_add as u32).min(self.burst_size);
169
170 if client_info.tokens > 0 {
172 client_info.tokens -= 1;
173 client_info.request_count += 1;
174 true
175 } else {
176 false
177 }
178 }
179}
180
181#[async_trait]
182impl Middleware for RateLimitMiddleware {
183 async fn before_request(&self, session: &mut Session) -> Result<bool> {
184 let client_ip = self.get_client_ip(session);
185
186 log::debug!("Rate limit check for client: {}", client_ip);
190 Ok(true)
191 }
192
193 async fn after_response(&self, _session: &mut Session) -> Result<()> {
194 Ok(())
195 }
196}
197
198pub struct SecurityHeadersMiddleware {
200 headers: HashMap<String, String>,
201}
202
203impl SecurityHeadersMiddleware {
204 pub fn new() -> Self {
205 let mut headers = HashMap::new();
206 headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
207 headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
208 headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
209 headers.insert(
210 "Referrer-Policy".to_string(),
211 "strict-origin-when-cross-origin".to_string(),
212 );
213
214 Self { headers }
215 }
216
217 pub fn with_header(mut self, name: String, value: String) -> Self {
218 self.headers.insert(name, value);
219 self
220 }
221
222 pub fn with_hsts(mut self, max_age: u32, include_subdomains: bool) -> Self {
223 let value = if include_subdomains {
224 format!("max-age={}; includeSubDomains", max_age)
225 } else {
226 format!("max-age={}", max_age)
227 };
228 self.headers
229 .insert("Strict-Transport-Security".to_string(), value);
230 self
231 }
232}
233
234#[async_trait]
235impl Middleware for SecurityHeadersMiddleware {
236 async fn before_request(&self, _session: &mut Session) -> Result<bool> {
237 Ok(true)
238 }
239
240 async fn after_response(&self, _session: &mut Session) -> Result<()> {
241 log::debug!("Security headers middleware executed (headers would be added to response)");
244 for (name, value) in &self.headers {
245 log::debug!("Would add header: {}: {}", name, value);
246 }
247 Ok(())
248 }
249}
250
251pub struct CorsMiddleware {
253 allow_origins: Vec<String>,
254 allow_methods: Vec<String>,
255 allow_headers: Vec<String>,
256 allow_credentials: bool,
257 max_age: u32,
258}
259
260impl CorsMiddleware {
261 pub fn new() -> Self {
262 Self {
263 allow_origins: vec!["*".to_string()],
264 allow_methods: vec!["GET".to_string(), "HEAD".to_string(), "OPTIONS".to_string()],
265 allow_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
266 allow_credentials: false,
267 max_age: 86400,
268 }
269 }
270
271 pub fn allow_origins(mut self, origins: Vec<String>) -> Self {
272 self.allow_origins = origins;
273 self
274 }
275
276 pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
277 self.allow_methods = methods;
278 self
279 }
280
281 pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
282 self.allow_headers = headers;
283 self
284 }
285
286 pub fn allow_credentials(mut self, allow: bool) -> Self {
287 self.allow_credentials = allow;
288 self
289 }
290
291 pub fn max_age(mut self, age: u32) -> Self {
292 self.max_age = age;
293 self
294 }
295}
296
297#[async_trait]
298impl Middleware for CorsMiddleware {
299 async fn before_request(&self, session: &mut Session) -> Result<bool> {
300 if session.req_header().method == "OPTIONS" {
302 log::debug!("Handling CORS preflight request");
303 return Ok(true);
306 }
307 Ok(true)
308 }
309
310 async fn after_response(&self, _session: &mut Session) -> Result<()> {
311 log::debug!("CORS middleware executed (headers would be added to response)");
313 log::debug!(
314 "Access-Control-Allow-Origin: {}",
315 self.allow_origins.join(", ")
316 );
317 log::debug!(
318 "Access-Control-Allow-Methods: {}",
319 self.allow_methods.join(", ")
320 );
321 log::debug!(
322 "Access-Control-Allow-Headers: {}",
323 self.allow_headers.join(", ")
324 );
325 Ok(())
326 }
327}
328
329impl Default for MiddlewareStack {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl Default for SecurityHeadersMiddleware {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl Default for CorsMiddleware {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_middleware_stack_creation() {
353 let stack = MiddlewareStack::new();
354 assert_eq!(stack.middlewares.len(), 0);
355 }
356
357 #[test]
358 fn test_middleware_stack_with_middleware() {
359 let stack = MiddlewareStack::new()
360 .add_middleware(LoggingMiddleware::new(true))
361 .add_middleware(SecurityHeadersMiddleware::new());
362
363 assert_eq!(stack.middlewares.len(), 2);
364 }
365
366 #[test]
367 fn test_rate_limit_middleware_creation() {
368 let middleware = RateLimitMiddleware::new(60, 10);
369 assert_eq!(middleware.requests_per_minute, 60);
370 assert_eq!(middleware.burst_size, 10);
371 }
372
373 #[test]
374 fn test_security_headers_middleware() {
375 let middleware = SecurityHeadersMiddleware::new()
376 .with_header("Custom-Header".to_string(), "Custom-Value".to_string())
377 .with_hsts(31536000, true);
378
379 assert!(middleware.headers.contains_key("Custom-Header"));
380 assert!(middleware.headers.contains_key("Strict-Transport-Security"));
381 }
382
383 #[test]
384 fn test_cors_middleware_configuration() {
385 let middleware = CorsMiddleware::new()
386 .allow_origins(vec!["https://example.com".to_string()])
387 .allow_methods(vec!["GET".to_string(), "POST".to_string()])
388 .allow_credentials(true)
389 .max_age(7200);
390
391 assert_eq!(middleware.allow_origins, vec!["https://example.com"]);
392 assert_eq!(middleware.allow_methods, vec!["GET", "POST"]);
393 assert!(middleware.allow_credentials);
394 assert_eq!(middleware.max_age, 7200);
395 }
396}