warpdrive_proxy/middleware/
headers.rs

1//! Headers middleware
2//!
3//! Manages X-Forwarded-* headers to preserve client information through the proxy.
4//! This is critical for applications that need to know the original client IP,
5//! protocol, and host when behind a reverse proxy.
6//!
7//! # Headers Managed
8//!
9//! - **X-Forwarded-For**: Client IP address chain
10//! - **X-Forwarded-Proto**: Original protocol (http/https)
11//! - **X-Forwarded-Host**: Original host header
12//! - **X-Real-IP**: Direct client IP (non-standard but widely used)
13//!
14//! # Behavior
15//!
16//! When `forward_headers` is true:
17//! - Preserves existing X-Forwarded-* headers from client
18//! - Appends to X-Forwarded-For chain
19//!
20//! When `forward_headers` is false:
21//! - Replaces X-Forwarded-* headers with current request values
22//! - Starts new X-Forwarded-For chain with client IP
23
24use async_trait::async_trait;
25use pingora::prelude::*;
26use std::sync::Arc;
27use tracing::debug;
28
29use super::{Middleware, MiddlewareContext};
30use crate::config::Config;
31
32/// Headers middleware
33///
34/// Injects and manages X-Forwarded-* headers for proxy transparency.
35pub struct HeadersMiddleware {
36    pub(crate) config: Arc<Config>,
37}
38
39impl HeadersMiddleware {
40    /// Create new headers middleware
41    pub fn new(config: Arc<Config>) -> Self {
42        Self { config }
43    }
44
45    /// Get request protocol (http/https)
46    ///
47    /// Detects if the downstream connection is using TLS by checking the session digest.
48    /// This is critical for X-Forwarded-Proto header accuracy, which apps use for
49    /// building absolute URLs (redirects, asset URLs, etc.).
50    fn get_protocol(session: &Session) -> &'static str {
51        // Check URI scheme first (for explicit https:// requests)
52        if let Some(scheme) = session.req_header().uri.scheme_str() {
53            if scheme == "https" {
54                return "https";
55            }
56        }
57
58        // Check if downstream connection is TLS via session digest
59        // See: https://github.com/cloudflare/pingora/issues/403
60        if let Some(digest) = session.digest() {
61            if digest.ssl_digest.is_some() {
62                return "https";
63            }
64        }
65
66        "http"
67    }
68}
69
70#[async_trait]
71impl Middleware for HeadersMiddleware {
72    /// Add X-Forwarded-* headers to upstream request
73    ///
74    /// This preserves client information when proxying requests.
75    async fn request_filter(
76        &self,
77        session: &mut Session,
78        ctx: &mut MiddlewareContext,
79    ) -> Result<()> {
80        let forward_headers = self.config.forward_headers;
81
82        // Use real client IP from context (normalized by trusted_ranges middleware)
83        let client_ip = ctx.real_client_ip.to_string();
84
85        // Extract all header values we need before mutating
86        let existing_xff = session
87            .req_header()
88            .headers
89            .get("x-forwarded-for")
90            .and_then(|v| v.to_str().ok())
91            .map(|s| s.to_string());
92
93        let existing_proto = session
94            .req_header()
95            .headers
96            .get("x-forwarded-proto")
97            .and_then(|v| v.to_str().ok())
98            .map(|s| s.to_string());
99
100        let existing_host = session
101            .req_header()
102            .headers
103            .get("x-forwarded-host")
104            .and_then(|v| v.to_str().ok())
105            .map(|s| s.to_string());
106
107        let current_host = session
108            .req_header()
109            .headers
110            .get("host")
111            .and_then(|v| v.to_str().ok())
112            .map(|s| s.to_string());
113
114        let protocol = Self::get_protocol(session);
115
116        // Now we can mutate the session safely
117
118        // Handle X-Forwarded-For
119        let mut xff_value = if forward_headers { existing_xff } else { None };
120
121        // Append client IP to X-Forwarded-For chain
122        xff_value = Some(match xff_value {
123            Some(existing) => format!("{}, {}", existing, client_ip),
124            None => client_ip.clone(),
125        });
126
127        // Set X-Forwarded-For header
128        if let Some(xff) = xff_value {
129            session
130                .req_header_mut()
131                .insert_header("X-Forwarded-For", &xff)?;
132            debug!("Set X-Forwarded-For: {}", xff);
133        }
134
135        // Handle X-Forwarded-Proto
136        let proto = if forward_headers {
137            existing_proto.as_deref().unwrap_or(protocol)
138        } else {
139            protocol
140        };
141
142        session
143            .req_header_mut()
144            .insert_header("X-Forwarded-Proto", proto)?;
145        debug!("Set X-Forwarded-Proto: {}", proto);
146
147        // Handle X-Forwarded-Host
148        let forwarded_host = if forward_headers {
149            existing_host.or(current_host)
150        } else {
151            current_host
152        };
153
154        if let Some(host) = forwarded_host {
155            session
156                .req_header_mut()
157                .insert_header("X-Forwarded-Host", &host)?;
158            debug!("Set X-Forwarded-Host: {}", host);
159        }
160
161        // Set X-Real-IP (non-standard but widely used)
162        session
163            .req_header_mut()
164            .insert_header("X-Real-IP", &client_ip)?;
165        debug!("Set X-Real-IP: {}", client_ip);
166
167        Ok(())
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_headers_middleware_creation() {
177        let config = Arc::new(Config::default());
178        let middleware = HeadersMiddleware::new(config.clone());
179
180        assert_eq!(middleware.config.forward_headers, config.forward_headers);
181    }
182
183    #[test]
184    fn test_xff_chain_building() {
185        // Test X-Forwarded-For chain building logic
186        let existing = Some("10.0.0.1, 10.0.0.2".to_string());
187        let new_ip = "192.168.1.1";
188
189        let result = match existing {
190            Some(existing) => format!("{}, {}", existing, new_ip),
191            None => new_ip.to_string(),
192        };
193
194        assert_eq!(result, "10.0.0.1, 10.0.0.2, 192.168.1.1");
195
196        // Test with no existing
197        let existing: Option<String> = None;
198        let result = match existing {
199            Some(existing) => format!("{}, {}", existing, new_ip),
200            None => new_ip.to_string(),
201        };
202
203        assert_eq!(result, "192.168.1.1");
204    }
205}