axum_reverse_proxy/
rfc9110.rs

1//! RFC9110 (HTTP Semantics) Compliance Layer
2//!
3//! This module implements middleware for RFC9110 compliance, focusing on:
4//! - [RFC9110 Section 7: Message Routing](https://www.rfc-editor.org/rfc/rfc9110.html#section-7)
5//! - [RFC9110 Section 5.7: Message Forwarding](https://www.rfc-editor.org/rfc/rfc9110.html#section-5.7)
6//!
7//! Key compliance points:
8//! 1. Connection header handling (Section 7.6.1)
9//!    - Remove Connection header and all headers listed within it
10//!    - Remove standard hop-by-hop headers
11//!
12//! 2. Via header handling (Section 7.6.3)
13//!    - Add Via header entries for request/response
14//!    - Support protocol version, pseudonym, and comments
15//!    - Optional combining of multiple entries
16//!
17//! 3. Max-Forwards handling (Section 7.6.2)
18//!    - Process for TRACE and OPTIONS methods
19//!    - Decrement value or respond directly if zero
20//!
21//! 4. Loop detection (Section 7.3)
22//!    - Detect loops using Via headers
23//!    - Check server names/aliases
24//!    - Return 508 Loop Detected status
25//!
26//! 5. End-to-end and Hop-by-hop Headers (Section 7.6.1)
27//!    - Preserve end-to-end headers
28//!    - Remove hop-by-hop headers
29
30use axum::{
31    body::Body,
32    http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35    collections::HashSet,
36    future::Future,
37    pin::Pin,
38    str::FromStr,
39    task::{Context, Poll},
40};
41
42/// Standard hop-by-hop headers defined by RFC 9110
43static HOP_BY_HOP_HEADERS: &[&str] = &[
44    "connection",
45    "keep-alive",
46    "proxy-connection",
47    "transfer-encoding",
48    "te",
49    "trailer",
50    "upgrade",
51];
52use tower::{Layer, Service};
53
54/// Represents a single Via header entry
55#[derive(Debug, Clone)]
56#[allow(dead_code)] // Fields are used for future extensibility
57struct ViaEntry {
58    protocol: String,        // e.g., "1.1"
59    pseudonym: String,       // e.g., "proxy1"
60    port: Option<String>,    // e.g., "8080"
61    comment: Option<String>, // e.g., "(Proxy Software 1.0)"
62}
63
64impl ViaEntry {
65    fn parse(entry: &str) -> Option<Self> {
66        let mut parts = entry.split_whitespace();
67
68        // Get protocol version
69        let protocol = parts.next()?.to_string();
70
71        // Get pseudonym and optional port
72        let pseudonym_part = parts.next()?;
73        let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74            let (name, port) = pseudonym_part.split_at(colon_idx);
75            (name.to_string(), Some(port[1..].to_string()))
76        } else {
77            (pseudonym_part.to_string(), None)
78        };
79
80        // Get optional comment (everything between parentheses)
81        let comment = entry
82            .find('(')
83            .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85        Some(ViaEntry {
86            protocol,
87            pseudonym,
88            port,
89            comment,
90        })
91    }
92}
93
94impl std::fmt::Display for ViaEntry {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(f, "{} {}", self.protocol, self.pseudonym)?;
97        if let Some(port) = &self.port {
98            write!(f, ":{port}")?;
99        }
100        if let Some(comment) = &self.comment {
101            write!(f, " {comment}")?;
102        }
103        Ok(())
104    }
105}
106
107/// Parse a Via header value into a vector of ViaEntry structs
108#[allow(dead_code)] // Kept for future extensibility
109fn parse_via_header(header: &str) -> Vec<ViaEntry> {
110    header
111        .split(',')
112        .filter_map(|entry| ViaEntry::parse(entry.trim()))
113        .collect()
114}
115
116/// Configuration for RFC9110 middleware
117#[derive(Clone, Debug)]
118pub struct Rfc9110Config {
119    /// Server names to check for loop detection
120    pub server_names: Option<HashSet<String>>,
121    /// Pseudonym to use in Via headers
122    pub pseudonym: Option<String>,
123    /// Whether to combine Via headers with the same protocol version
124    pub combine_via: bool,
125    /// Whether to preserve WebSocket upgrade headers (Connection, Upgrade)
126    /// When true, the layer will detect WebSocket upgrade requests and preserve
127    /// the hop-by-hop headers needed for the upgrade to work.
128    /// Default: true
129    pub preserve_websocket_headers: bool,
130}
131
132impl Default for Rfc9110Config {
133    fn default() -> Self {
134        Self {
135            server_names: None,
136            pseudonym: None,
137            combine_via: true,
138            preserve_websocket_headers: true,
139        }
140    }
141}
142
143/// Layer that applies RFC9110 middleware
144#[derive(Clone)]
145pub struct Rfc9110Layer {
146    config: Rfc9110Config,
147}
148
149impl Default for Rfc9110Layer {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl Rfc9110Layer {
156    /// Create a new RFC9110 layer with default configuration
157    pub fn new() -> Self {
158        Self {
159            config: Rfc9110Config::default(),
160        }
161    }
162
163    /// Create a new RFC9110 layer with custom configuration
164    pub fn with_config(config: Rfc9110Config) -> Self {
165        Self { config }
166    }
167}
168
169impl<S> Layer<S> for Rfc9110Layer {
170    type Service = Rfc9110<S>;
171
172    fn layer(&self, inner: S) -> Self::Service {
173        Rfc9110 {
174            inner,
175            config: self.config.clone(),
176        }
177    }
178}
179
180/// RFC9110 middleware service
181#[derive(Clone)]
182pub struct Rfc9110<S> {
183    inner: S,
184    config: Rfc9110Config,
185}
186
187impl<S> Service<Request<Body>> for Rfc9110<S>
188where
189    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
190    S::Future: Send + 'static,
191{
192    type Response = S::Response;
193    type Error = S::Error;
194    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
195
196    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197        self.inner.poll_ready(cx)
198    }
199
200    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
201        let mut inner = self.inner.clone();
202        let config = self.config.clone();
203
204        Box::pin(async move {
205            // 1. Check for loops
206            if let Some(response) = detect_loop(&request, &config) {
207                return Ok(response);
208            }
209
210            // Save original Max-Forwards value for non-TRACE/OPTIONS methods
211            let original_max_forwards =
212                if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
213                    request.headers().get(http::header::MAX_FORWARDS).cloned()
214                } else {
215                    None
216                };
217
218            // 2. Process Max-Forwards
219            if let Some(response) = process_max_forwards(&mut request) {
220                return Ok(response);
221            }
222
223            // Save the Max-Forwards value after processing
224            let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
225
226            // Detect WebSocket upgrade request to preserve necessary headers
227            let is_websocket =
228                config.preserve_websocket_headers && is_websocket_upgrade_request(&request);
229
230            // 3. Process Connection header and remove hop-by-hop headers
231            process_connection_header(&mut request, is_websocket);
232
233            // Save end-to-end headers after processing Connection header
234            let preserved_headers = request.headers().clone();
235
236            // 4. Add Via header and save it for the response
237            let via_header = add_via_header(&mut request, &config);
238
239            // 5. Forward the request
240            let mut response = inner.call(request).await?;
241
242            // 6. Process response headers (preserve WebSocket headers for 101 responses too)
243            let is_websocket_response =
244                is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
245            process_response_headers(&mut response, is_websocket_response);
246
247            // 7. Add Via header to response (use the same one we set in the request)
248            if let Some(via) = via_header {
249                // In firewall mode, always use "1.1 firewall"
250                if config.pseudonym.is_some() && !config.combine_via {
251                    response
252                        .headers_mut()
253                        .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
254                } else {
255                    response.headers_mut().insert(http::header::VIA, via);
256                }
257            }
258
259            // 8. Restore Max-Forwards header
260            if let Some(max_forwards) = original_max_forwards {
261                // For non-TRACE/OPTIONS methods, restore original value
262                response
263                    .headers_mut()
264                    .insert(http::header::MAX_FORWARDS, max_forwards);
265            } else if let Some(max_forwards) = max_forwards {
266                // For TRACE/OPTIONS, copy the decremented value to the response
267                response
268                    .headers_mut()
269                    .insert(http::header::MAX_FORWARDS, max_forwards);
270            }
271
272            // 9. Restore preserved end-to-end headers
273            for (name, value) in preserved_headers.iter() {
274                if !is_hop_by_hop_header(name) {
275                    response.headers_mut().insert(name, value.clone());
276                }
277            }
278
279            Ok(response)
280        })
281    }
282}
283
284/// Detect request loops based on Via headers and server names
285fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
286    // 1. Check if the target host matches any of our server names
287    if let Some(server_names) = &config.server_names
288        && let Some(host) = request.uri().host()
289        && server_names.contains(host)
290    {
291        let mut response = Response::new(Body::empty());
292        *response.status_mut() = StatusCode::LOOP_DETECTED;
293        return Some(response);
294    }
295
296    // 2. Check for loops in Via headers
297    if let Some(via) = request.headers().get(http::header::VIA)
298        && let Ok(via_str) = via.to_str()
299    {
300        let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
301        let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
302
303        // Check if our pseudonym appears in any Via header
304        for entry in via_entries {
305            let parts: Vec<&str> = entry.split_whitespace().collect();
306            if parts.len() >= 2 && parts[1] == pseudonym {
307                let mut response = Response::new(Body::empty());
308                *response.status_mut() = StatusCode::LOOP_DETECTED;
309                return Some(response);
310            }
311        }
312    }
313
314    None
315}
316
317/// Process Max-Forwards header for TRACE and OPTIONS methods
318fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
319    let method = request.method();
320
321    // Only process Max-Forwards for TRACE and OPTIONS
322    if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
323        if *method != Method::TRACE && *method != Method::OPTIONS {
324            // For other methods, just preserve the header
325            return None;
326        }
327
328        if let Ok(value_str) = max_forwards.to_str() {
329            if let Ok(value) = value_str.parse::<u32>() {
330                if value == 0 {
331                    let mut response = Response::new(Body::empty());
332                    if *method == Method::TRACE {
333                        *response.body_mut() = Body::from(format!("{request:?}"));
334                    } else {
335                        // For OPTIONS, return 200 OK with Allow header
336                        response.headers_mut().insert(
337                            http::header::ALLOW,
338                            HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
339                        );
340                    }
341                    *response.status_mut() = StatusCode::OK;
342                    Some(response)
343                } else {
344                    // Decrement Max-Forwards
345                    let new_value = value - 1;
346                    request.headers_mut().insert(
347                        http::header::MAX_FORWARDS,
348                        HeaderValue::from_str(&new_value.to_string()).unwrap(),
349                    );
350                    None
351                }
352            } else {
353                None // Invalid number format
354            }
355        } else {
356            None // Invalid header value format
357        }
358    } else {
359        None // No Max-Forwards header
360    }
361}
362
363/// Headers that should be preserved for WebSocket upgrades
364static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];
365
366/// Process Connection header and remove hop-by-hop headers
367fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
368    let mut headers_to_remove = HashSet::new();
369
370    // Add standard hop-by-hop headers
371    for &name in HOP_BY_HOP_HEADERS {
372        // Skip connection and upgrade headers if this is a WebSocket upgrade
373        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
374            continue;
375        }
376        headers_to_remove.insert(HeaderName::from_static(name));
377    }
378
379    // Get headers listed in Connection header
380    if let Some(connection) = request
381        .headers()
382        .get_all(http::header::CONNECTION)
383        .iter()
384        .next()
385        && let Ok(connection_str) = connection.to_str()
386    {
387        for header in connection_str.split(',') {
388            let header = header.trim();
389            // Skip connection and upgrade headers if this is a WebSocket upgrade
390            if preserve_websocket
391                && WEBSOCKET_HEADERS
392                    .iter()
393                    .any(|h| header.eq_ignore_ascii_case(h))
394            {
395                continue;
396            }
397            if let Ok(header_name) = HeaderName::from_str(header)
398                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
399            {
400                headers_to_remove.insert(header_name);
401            }
402        }
403    }
404
405    // Remove all identified headers (case-insensitive)
406    let headers_to_remove = headers_to_remove; // Make immutable
407    let headers_to_remove: Vec<_> = request
408        .headers()
409        .iter()
410        .filter(|(k, _)| {
411            headers_to_remove
412                .iter()
413                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
414        })
415        .map(|(k, _)| k.clone())
416        .collect();
417
418    for header in headers_to_remove {
419        request.headers_mut().remove(&header);
420    }
421}
422
423/// Add Via header to the request
424fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
425    // Get the protocol version from the request
426    let protocol_version = match request.version() {
427        http::Version::HTTP_09 => "0.9",
428        http::Version::HTTP_10 => "1.0",
429        http::Version::HTTP_11 => "1.1",
430        http::Version::HTTP_2 => "2.0",
431        http::Version::HTTP_3 => "3.0",
432        _ => "1.1", // Default to HTTP/1.1 for unknown versions
433    };
434
435    // Get the pseudonym from the config or use the default
436    let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
437
438    // If we're in firewall mode, always use "1.1 firewall"
439    if config.pseudonym.is_some() && !config.combine_via {
440        let via = HeaderValue::from_static("1.1 firewall");
441        request.headers_mut().insert(http::header::VIA, via.clone());
442        return Some(via);
443    }
444
445    // Get any existing Via headers
446    let mut via_values = Vec::new();
447    if let Some(existing_via) = request.headers().get(http::header::VIA)
448        && let Ok(existing_via_str) = existing_via.to_str()
449    {
450        // If we're combining Via headers and have a pseudonym, replace all entries with our protocol version
451        if config.combine_via && config.pseudonym.is_some() {
452            let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
453            let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
454            if all_same_protocol {
455                let via = HeaderValue::from_str(&format!(
456                    "{} {}",
457                    protocol_version,
458                    config.pseudonym.as_ref().unwrap()
459                ))
460                .ok()?;
461                request.headers_mut().insert(http::header::VIA, via.clone());
462                return Some(via);
463            }
464        }
465        via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
466    }
467
468    // Add our new Via header value
469    let new_value = format!("{protocol_version} {pseudonym}");
470    via_values.push(new_value);
471
472    // Create the combined Via header value
473    let combined_via = via_values.join(", ");
474    let via = HeaderValue::from_str(&combined_via).ok()?;
475    request.headers_mut().insert(http::header::VIA, via.clone());
476    Some(via)
477}
478
479/// Process response headers according to RFC9110
480fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
481    let mut headers_to_remove = HashSet::new();
482
483    // Add standard hop-by-hop headers
484    for &name in HOP_BY_HOP_HEADERS {
485        // Skip connection and upgrade headers if this is a WebSocket upgrade response
486        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
487            continue;
488        }
489        headers_to_remove.insert(HeaderName::from_static(name));
490    }
491
492    // Get headers listed in Connection header
493    if let Some(connection) = response
494        .headers()
495        .get_all(http::header::CONNECTION)
496        .iter()
497        .next()
498        && let Ok(connection_str) = connection.to_str()
499    {
500        for header in connection_str.split(',') {
501            let header = header.trim();
502            // Skip connection and upgrade headers if this is a WebSocket upgrade response
503            if preserve_websocket
504                && WEBSOCKET_HEADERS
505                    .iter()
506                    .any(|h| header.eq_ignore_ascii_case(h))
507            {
508                continue;
509            }
510            if let Ok(header_name) = HeaderName::from_str(header)
511                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
512            {
513                headers_to_remove.insert(header_name);
514            }
515        }
516    }
517
518    // Remove all identified headers (case-insensitive)
519    let headers_to_remove = headers_to_remove; // Make immutable
520    let headers_to_remove: Vec<_> = response
521        .headers()
522        .iter()
523        .filter(|(k, _)| {
524            headers_to_remove
525                .iter()
526                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
527        })
528        .map(|(k, _)| k.clone())
529        .collect();
530
531    for header in headers_to_remove {
532        response.headers_mut().remove(&header);
533    }
534
535    // Handle Via header in response - if in firewall mode, replace all entries with "1.1 firewall"
536    if let Some(via) = response.headers().get(http::header::VIA)
537        && let Ok(via_str) = via.to_str()
538        && via_str.contains("firewall")
539    {
540        response
541            .headers_mut()
542            .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
543    }
544}
545
546/// Check if a header is a hop-by-hop header
547fn is_hop_by_hop_header(name: &HeaderName) -> bool {
548    HOP_BY_HOP_HEADERS
549        .iter()
550        .any(|h| name.as_str().eq_ignore_ascii_case(h))
551        || name.as_str().eq_ignore_ascii_case("via")
552}
553
554/// Check if a header is a known end-to-end header
555fn is_end_to_end_header(name: &HeaderName) -> bool {
556    matches!(
557        name.as_str(),
558        "cache-control"
559            | "authorization"
560            | "content-length"
561            | "content-type"
562            | "content-encoding"
563            | "accept"
564            | "accept-encoding"
565            | "accept-language"
566            | "range"
567            | "cookie"
568            | "set-cookie"
569            | "etag"
570    )
571}
572
573/// Check if a request appears to be a WebSocket upgrade request.
574/// This is detected by the presence of sec-websocket-key and sec-websocket-version headers,
575/// which are required for all WebSocket handshakes per RFC 6455.
576fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
577    request.headers().contains_key("sec-websocket-key")
578        && request.headers().contains_key("sec-websocket-version")
579}