1use async_trait::async_trait;
2use pingora::prelude::*;
3use std::collections::HashMap;
4use std::time::Instant;
5
6#[async_trait]
7pub trait Middleware: Send + Sync {
8 async fn before_request(&self, session: &mut Session) -> Result<bool>;
9 async fn after_response(&self, session: &mut Session) -> Result<()>;
10}
11
12pub struct MiddlewareStack {
13 middlewares: Vec<Box<dyn Middleware>>,
14}
15
16impl MiddlewareStack {
17 pub fn new() -> Self {
18 Self {
19 middlewares: Vec::new(),
20 }
21 }
22
23 pub fn add_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
24 self.middlewares.push(Box::new(middleware));
25 self
26 }
27
28 pub async fn before_request(&self, session: &mut Session) -> Result<bool> {
29 for middleware in &self.middlewares {
30 if !middleware.before_request(session).await? {
31 return Ok(false);
32 }
33 }
34 Ok(true)
35 }
36
37 pub async fn after_response(&self, session: &mut Session) -> Result<()> {
38 for middleware in self.middlewares.iter().rev() {
40 middleware.after_response(session).await?;
41 }
42 Ok(())
43 }
44}
45
46pub struct LoggingMiddleware {
48 log_requests: bool,
49}
50
51impl LoggingMiddleware {
52 pub fn new(log_requests: bool) -> Self {
53 Self { log_requests }
54 }
55}
56
57#[async_trait]
58impl Middleware for LoggingMiddleware {
59 async fn before_request(&self, session: &mut Session) -> Result<bool> {
60 if self.log_requests {
61 let _start_time = Instant::now();
62 log::info!(
66 "Request started: {} {} from {}",
67 session.req_header().method,
68 session.req_header().uri,
69 session
70 .client_addr()
71 .map(|addr| addr.to_string())
72 .unwrap_or_else(|| "unknown".to_string())
73 );
74 }
75 Ok(true)
76 }
77
78 async fn after_response(&self, session: &mut Session) -> Result<()> {
79 if self.log_requests {
80 let status = session
82 .response_written()
83 .map(|r| r.status.as_u16())
84 .unwrap_or(0);
85
86 log::info!(
87 "Request completed: {} {} (status: {})",
88 session.req_header().method,
89 session.req_header().uri,
90 status
91 );
92 }
93 Ok(())
94 }
95}
96
97#[allow(dead_code)]
99pub struct RateLimitMiddleware {
100 requests_per_minute: u32,
101 burst_size: u32,
102 clients: HashMap<String, ClientInfo>,
103}
104
105#[derive(Debug, Clone)]
106#[allow(dead_code)]
107struct ClientInfo {
108 request_count: u32,
109 last_reset: Instant,
110 tokens: u32,
111}
112
113impl RateLimitMiddleware {
114 pub fn new(requests_per_minute: u32, burst_size: u32) -> Self {
115 Self {
116 requests_per_minute,
117 burst_size,
118 clients: HashMap::new(),
119 }
120 }
121
122 fn get_client_ip(&self, session: &Session) -> String {
123 if let Some(forwarded_for) = session.req_header().headers.get("X-Forwarded-For") {
125 if let Ok(forwarded_str) = forwarded_for.to_str() {
126 if let Some(ip) = forwarded_str.split(',').next() {
127 return ip.trim().to_string();
128 }
129 }
130 }
131
132 if let Some(real_ip) = session.req_header().headers.get("X-Real-IP") {
133 if let Ok(ip_str) = real_ip.to_str() {
134 return ip_str.to_string();
135 }
136 }
137
138 session
139 .client_addr()
140 .map(|addr| addr.to_string())
141 .unwrap_or_else(|| "unknown".to_string())
142 }
143
144 #[allow(dead_code)]
145 fn is_allowed(&mut self, client_ip: &str) -> bool {
146 let now = Instant::now();
147 let client_info = self
148 .clients
149 .entry(client_ip.to_string())
150 .or_insert(ClientInfo {
151 request_count: 0,
152 last_reset: now,
153 tokens: self.burst_size,
154 });
155
156 if now.duration_since(client_info.last_reset).as_secs() >= 60 {
158 client_info.request_count = 0;
159 client_info.last_reset = now;
160 client_info.tokens = self.burst_size;
161 }
162
163 let seconds_elapsed = now.duration_since(client_info.last_reset).as_secs();
165 let tokens_to_add = (seconds_elapsed * self.requests_per_minute as u64) / 60;
166 client_info.tokens = (client_info.tokens + tokens_to_add as u32).min(self.burst_size);
167
168 if client_info.tokens > 0 {
170 client_info.tokens -= 1;
171 client_info.request_count += 1;
172 true
173 } else {
174 false
175 }
176 }
177}
178
179#[async_trait]
180impl Middleware for RateLimitMiddleware {
181 async fn before_request(&self, session: &mut Session) -> Result<bool> {
182 let client_ip = self.get_client_ip(session);
183
184 log::debug!("Rate limit check for client: {}", client_ip);
188 Ok(true)
189 }
190
191 async fn after_response(&self, _session: &mut Session) -> Result<()> {
192 Ok(())
193 }
194}
195
196pub struct SecurityHeadersMiddleware {
198 headers: HashMap<String, String>,
199}
200
201impl SecurityHeadersMiddleware {
202 pub fn new() -> Self {
203 let mut headers = HashMap::new();
204 headers.insert("X-Frame-Options".to_string(), "DENY".to_string());
205 headers.insert("X-Content-Type-Options".to_string(), "nosniff".to_string());
206 headers.insert("X-XSS-Protection".to_string(), "1; mode=block".to_string());
207 headers.insert(
208 "Referrer-Policy".to_string(),
209 "strict-origin-when-cross-origin".to_string(),
210 );
211
212 Self { headers }
213 }
214
215 pub fn with_header(mut self, name: String, value: String) -> Self {
216 self.headers.insert(name, value);
217 self
218 }
219
220 pub fn with_hsts(mut self, max_age: u32, include_subdomains: bool) -> Self {
221 let value = if include_subdomains {
222 format!("max-age={}; includeSubDomains", max_age)
223 } else {
224 format!("max-age={}", max_age)
225 };
226 self.headers
227 .insert("Strict-Transport-Security".to_string(), value);
228 self
229 }
230}
231
232#[async_trait]
233impl Middleware for SecurityHeadersMiddleware {
234 async fn before_request(&self, _session: &mut Session) -> Result<bool> {
235 Ok(true)
236 }
237
238 async fn after_response(&self, _session: &mut Session) -> Result<()> {
239 log::debug!("Security headers middleware executed (headers would be added to response)");
242 for (name, value) in &self.headers {
243 log::debug!("Would add header: {}: {}", name, value);
244 }
245 Ok(())
246 }
247}
248
249pub struct CorsMiddleware {
251 allow_origins: Vec<String>,
252 allow_methods: Vec<String>,
253 allow_headers: Vec<String>,
254 allow_credentials: bool,
255 max_age: u32,
256}
257
258impl CorsMiddleware {
259 pub fn new() -> Self {
260 Self {
261 allow_origins: vec!["*".to_string()],
262 allow_methods: vec!["GET".to_string(), "HEAD".to_string(), "OPTIONS".to_string()],
263 allow_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
264 allow_credentials: false,
265 max_age: 86400,
266 }
267 }
268
269 pub fn allow_origins(mut self, origins: Vec<String>) -> Self {
270 self.allow_origins = origins;
271 self
272 }
273
274 pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
275 self.allow_methods = methods;
276 self
277 }
278
279 pub fn allow_headers(mut self, headers: Vec<String>) -> Self {
280 self.allow_headers = headers;
281 self
282 }
283
284 pub fn allow_credentials(mut self, allow: bool) -> Self {
285 self.allow_credentials = allow;
286 self
287 }
288
289 pub fn max_age(mut self, age: u32) -> Self {
290 self.max_age = age;
291 self
292 }
293}
294
295#[async_trait]
296impl Middleware for CorsMiddleware {
297 async fn before_request(&self, session: &mut Session) -> Result<bool> {
298 if session.req_header().method == "OPTIONS" {
300 log::debug!("Handling CORS preflight request");
301 return Ok(true);
304 }
305 Ok(true)
306 }
307
308 async fn after_response(&self, _session: &mut Session) -> Result<()> {
309 log::debug!("CORS middleware executed (headers would be added to response)");
311 log::debug!(
312 "Access-Control-Allow-Origin: {}",
313 self.allow_origins.join(", ")
314 );
315 log::debug!(
316 "Access-Control-Allow-Methods: {}",
317 self.allow_methods.join(", ")
318 );
319 log::debug!(
320 "Access-Control-Allow-Headers: {}",
321 self.allow_headers.join(", ")
322 );
323 Ok(())
324 }
325}
326
327impl Default for MiddlewareStack {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333impl Default for SecurityHeadersMiddleware {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339impl Default for CorsMiddleware {
340 fn default() -> Self {
341 Self::new()
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_middleware_stack_creation() {
351 let stack = MiddlewareStack::new();
352 assert_eq!(stack.middlewares.len(), 0);
353 }
354
355 #[test]
356 fn test_middleware_stack_with_middleware() {
357 let stack = MiddlewareStack::new()
358 .add_middleware(LoggingMiddleware::new(true))
359 .add_middleware(SecurityHeadersMiddleware::new());
360
361 assert_eq!(stack.middlewares.len(), 2);
362 }
363
364 #[test]
365 fn test_rate_limit_middleware_creation() {
366 let middleware = RateLimitMiddleware::new(60, 10);
367 assert_eq!(middleware.requests_per_minute, 60);
368 assert_eq!(middleware.burst_size, 10);
369 }
370
371 #[test]
372 fn test_security_headers_middleware() {
373 let middleware = SecurityHeadersMiddleware::new()
374 .with_header("Custom-Header".to_string(), "Custom-Value".to_string())
375 .with_hsts(31536000, true);
376
377 assert!(middleware.headers.contains_key("Custom-Header"));
378 assert!(middleware.headers.contains_key("Strict-Transport-Security"));
379 }
380
381 #[test]
382 fn test_cors_middleware_configuration() {
383 let middleware = CorsMiddleware::new()
384 .allow_origins(vec!["https://example.com".to_string()])
385 .allow_methods(vec!["GET".to_string(), "POST".to_string()])
386 .allow_credentials(true)
387 .max_age(7200);
388
389 assert_eq!(middleware.allow_origins, vec!["https://example.com"]);
390 assert_eq!(middleware.allow_methods, vec!["GET", "POST"]);
391 assert!(middleware.allow_credentials);
392 assert_eq!(middleware.max_age, 7200);
393 }
394}