warpdrive_proxy/middleware/
headers.rs1use async_trait::async_trait;
25use pingora::prelude::*;
26use std::sync::Arc;
27use tracing::debug;
28
29use super::{Middleware, MiddlewareContext};
30use crate::config::Config;
31
32pub struct HeadersMiddleware {
36 pub(crate) config: Arc<Config>,
37}
38
39impl HeadersMiddleware {
40 pub fn new(config: Arc<Config>) -> Self {
42 Self { config }
43 }
44
45 fn get_protocol(session: &Session) -> &'static str {
51 if let Some(scheme) = session.req_header().uri.scheme_str() {
53 if scheme == "https" {
54 return "https";
55 }
56 }
57
58 if let Some(digest) = session.digest() {
61 if digest.ssl_digest.is_some() {
62 return "https";
63 }
64 }
65
66 "http"
67 }
68}
69
70#[async_trait]
71impl Middleware for HeadersMiddleware {
72 async fn request_filter(
76 &self,
77 session: &mut Session,
78 ctx: &mut MiddlewareContext,
79 ) -> Result<()> {
80 let forward_headers = self.config.forward_headers;
81
82 let client_ip = ctx.real_client_ip.to_string();
84
85 let existing_xff = session
87 .req_header()
88 .headers
89 .get("x-forwarded-for")
90 .and_then(|v| v.to_str().ok())
91 .map(|s| s.to_string());
92
93 let existing_proto = session
94 .req_header()
95 .headers
96 .get("x-forwarded-proto")
97 .and_then(|v| v.to_str().ok())
98 .map(|s| s.to_string());
99
100 let existing_host = session
101 .req_header()
102 .headers
103 .get("x-forwarded-host")
104 .and_then(|v| v.to_str().ok())
105 .map(|s| s.to_string());
106
107 let current_host = session
108 .req_header()
109 .headers
110 .get("host")
111 .and_then(|v| v.to_str().ok())
112 .map(|s| s.to_string());
113
114 let protocol = Self::get_protocol(session);
115
116 let mut xff_value = if forward_headers { existing_xff } else { None };
120
121 xff_value = Some(match xff_value {
123 Some(existing) => format!("{}, {}", existing, client_ip),
124 None => client_ip.clone(),
125 });
126
127 if let Some(xff) = xff_value {
129 session
130 .req_header_mut()
131 .insert_header("X-Forwarded-For", &xff)?;
132 debug!("Set X-Forwarded-For: {}", xff);
133 }
134
135 let proto = if forward_headers {
137 existing_proto.as_deref().unwrap_or(protocol)
138 } else {
139 protocol
140 };
141
142 session
143 .req_header_mut()
144 .insert_header("X-Forwarded-Proto", proto)?;
145 debug!("Set X-Forwarded-Proto: {}", proto);
146
147 let forwarded_host = if forward_headers {
149 existing_host.or(current_host)
150 } else {
151 current_host
152 };
153
154 if let Some(host) = forwarded_host {
155 session
156 .req_header_mut()
157 .insert_header("X-Forwarded-Host", &host)?;
158 debug!("Set X-Forwarded-Host: {}", host);
159 }
160
161 session
163 .req_header_mut()
164 .insert_header("X-Real-IP", &client_ip)?;
165 debug!("Set X-Real-IP: {}", client_ip);
166
167 Ok(())
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_headers_middleware_creation() {
177 let config = Arc::new(Config::default());
178 let middleware = HeadersMiddleware::new(config.clone());
179
180 assert_eq!(middleware.config.forward_headers, config.forward_headers);
181 }
182
183 #[test]
184 fn test_xff_chain_building() {
185 let existing = Some("10.0.0.1, 10.0.0.2".to_string());
187 let new_ip = "192.168.1.1";
188
189 let result = match existing {
190 Some(existing) => format!("{}, {}", existing, new_ip),
191 None => new_ip.to_string(),
192 };
193
194 assert_eq!(result, "10.0.0.1, 10.0.0.2, 192.168.1.1");
195
196 let existing: Option<String> = None;
198 let result = match existing {
199 Some(existing) => format!("{}, {}", existing, new_ip),
200 None => new_ip.to_string(),
201 };
202
203 assert_eq!(result, "192.168.1.1");
204 }
205}