Skip to main content

axum_reverse_proxy/
rfc9110.rs

1//! RFC9110 (HTTP Semantics) Compliance Layer
2//!
3//! This module implements middleware for RFC9110 compliance, focusing on:
4//! - [RFC9110 Section 7: Message Routing](https://www.rfc-editor.org/rfc/rfc9110.html#section-7)
5//! - [RFC9110 Section 5.7: Message Forwarding](https://www.rfc-editor.org/rfc/rfc9110.html#section-5.7)
6//!
7//! Key compliance points:
8//! 1. Connection header handling (Section 7.6.1)
9//!    - Remove Connection header and all headers listed within it
10//!    - Remove standard hop-by-hop headers
11//!
12//! 2. Via header handling (Section 7.6.3)
13//!    - Add Via header entries for request/response
14//!    - Support protocol version, pseudonym, and comments
15//!    - Optional combining of multiple entries
16//!
17//! 3. Max-Forwards handling (Section 7.6.2)
18//!    - Process for TRACE and OPTIONS methods
19//!    - Decrement value or respond directly if zero
20//!
21//! 4. Loop detection (Section 7.3)
22//!    - Detect loops using Via headers
23//!    - Check server names/aliases
24//!    - Return 508 Loop Detected status
25//!
26//! 5. End-to-end and Hop-by-hop Headers (Section 7.6.1)
27//!    - Preserve end-to-end headers
28//!    - Remove hop-by-hop headers
29
30use axum::{
31    body::Body,
32    http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35    collections::HashSet,
36    future::Future,
37    pin::Pin,
38    str::FromStr,
39    task::{Context, Poll},
40};
41
42/// Standard hop-by-hop headers defined by RFC 9110
43static HOP_BY_HOP_HEADERS: &[&str] = &[
44    "connection",
45    "keep-alive",
46    "proxy-connection",
47    "transfer-encoding",
48    "te",
49    "trailer",
50    "upgrade",
51];
52use tower::{Layer, Service};
53
54/// Represents a single Via header entry
55#[derive(Debug, Clone)]
56#[allow(dead_code)] // Fields are used for future extensibility
57struct ViaEntry {
58    protocol: String,        // e.g., "1.1"
59    pseudonym: String,       // e.g., "proxy1"
60    port: Option<String>,    // e.g., "8080"
61    comment: Option<String>, // e.g., "(Proxy Software 1.0)"
62}
63
64impl ViaEntry {
65    fn parse(entry: &str) -> Option<Self> {
66        let mut parts = entry.split_whitespace();
67
68        // Get protocol version
69        let protocol = parts.next()?.to_string();
70
71        // Get pseudonym and optional port
72        let pseudonym_part = parts.next()?;
73        let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74            let (name, port) = pseudonym_part.split_at(colon_idx);
75            (name.to_string(), Some(port[1..].to_string()))
76        } else {
77            (pseudonym_part.to_string(), None)
78        };
79
80        // Get optional comment (everything between parentheses)
81        let comment = entry
82            .find('(')
83            .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85        Some(ViaEntry {
86            protocol,
87            pseudonym,
88            port,
89            comment,
90        })
91    }
92}
93
94impl std::fmt::Display for ViaEntry {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(f, "{} {}", self.protocol, self.pseudonym)?;
97        if let Some(port) = &self.port {
98            write!(f, ":{port}")?;
99        }
100        if let Some(comment) = &self.comment {
101            write!(f, " {comment}")?;
102        }
103        Ok(())
104    }
105}
106
107/// Parse a Via header value into a vector of ViaEntry structs
108#[allow(dead_code)] // Kept for future extensibility
109fn parse_via_header(header: &str) -> Vec<ViaEntry> {
110    header
111        .split(',')
112        .filter_map(|entry| ViaEntry::parse(entry.trim()))
113        .collect()
114}
115
116/// Configuration for RFC9110 middleware
117#[derive(Clone, Debug)]
118pub struct Rfc9110Config {
119    /// Server names to check for loop detection
120    pub server_names: Option<HashSet<String>>,
121    /// Pseudonym to use in Via headers
122    pub pseudonym: Option<String>,
123    /// Whether to combine Via headers with the same protocol version
124    pub combine_via: bool,
125    /// Whether to preserve WebSocket upgrade headers (Connection, Upgrade)
126    /// When true, the layer will detect WebSocket upgrade requests and preserve
127    /// the hop-by-hop headers needed for the upgrade to work.
128    /// Default: true
129    pub preserve_websocket_headers: bool,
130}
131
132impl Default for Rfc9110Config {
133    fn default() -> Self {
134        Self {
135            server_names: None,
136            pseudonym: None,
137            combine_via: true,
138            preserve_websocket_headers: true,
139        }
140    }
141}
142
143/// Layer that applies RFC9110 middleware
144#[derive(Clone)]
145pub struct Rfc9110Layer {
146    config: Rfc9110Config,
147}
148
149impl Default for Rfc9110Layer {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl Rfc9110Layer {
156    /// Create a new RFC9110 layer with default configuration
157    pub fn new() -> Self {
158        Self {
159            config: Rfc9110Config::default(),
160        }
161    }
162
163    /// Create a new RFC9110 layer with custom configuration
164    pub fn with_config(config: Rfc9110Config) -> Self {
165        Self { config }
166    }
167}
168
169impl<S> Layer<S> for Rfc9110Layer {
170    type Service = Rfc9110<S>;
171
172    fn layer(&self, inner: S) -> Self::Service {
173        Rfc9110 {
174            inner,
175            config: self.config.clone(),
176        }
177    }
178}
179
180/// RFC9110 middleware service
181#[derive(Clone)]
182pub struct Rfc9110<S> {
183    inner: S,
184    config: Rfc9110Config,
185}
186
187impl<S> Service<Request<Body>> for Rfc9110<S>
188where
189    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
190    S::Future: Send + 'static,
191{
192    type Response = S::Response;
193    type Error = S::Error;
194    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
195
196    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197        self.inner.poll_ready(cx)
198    }
199
200    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
201        let clone = self.inner.clone();
202        let mut inner = std::mem::replace(&mut self.inner, clone);
203        let config = self.config.clone();
204
205        Box::pin(async move {
206            // 1. Check for loops
207            if let Some(response) = detect_loop(&request, &config) {
208                return Ok(response);
209            }
210
211            // Save original Max-Forwards value for non-TRACE/OPTIONS methods
212            let original_max_forwards =
213                if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
214                    request.headers().get(http::header::MAX_FORWARDS).cloned()
215                } else {
216                    None
217                };
218
219            // 2. Process Max-Forwards
220            if let Some(response) = process_max_forwards(&mut request) {
221                return Ok(response);
222            }
223
224            // Save the Max-Forwards value after processing
225            let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
226
227            // Detect WebSocket upgrade request to preserve necessary headers
228            let is_websocket =
229                config.preserve_websocket_headers && is_websocket_upgrade_request(&request);
230
231            // 3. Process Connection header and remove hop-by-hop headers
232            process_connection_header(&mut request, is_websocket);
233
234            // 4. Add Via header and save it for the response
235            let via_header = add_via_header(&mut request, &config);
236
237            // 5. Forward the request
238            let mut response = inner.call(request).await?;
239
240            // 6. Process response headers (preserve WebSocket headers for 101 responses too)
241            let is_websocket_response =
242                is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
243            process_response_headers(&mut response, is_websocket_response);
244
245            // 7. Add Via header to response (use the same one we set in the request)
246            if let Some(via) = via_header {
247                // In firewall mode, always use "1.1 firewall"
248                if config.pseudonym.is_some() && !config.combine_via {
249                    response
250                        .headers_mut()
251                        .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
252                } else {
253                    response.headers_mut().insert(http::header::VIA, via);
254                }
255            }
256
257            // 8. Restore Max-Forwards header
258            if let Some(max_forwards) = original_max_forwards {
259                // For non-TRACE/OPTIONS methods, restore original value
260                response
261                    .headers_mut()
262                    .insert(http::header::MAX_FORWARDS, max_forwards);
263            } else if let Some(max_forwards) = max_forwards {
264                // For TRACE/OPTIONS, copy the decremented value to the response
265                response
266                    .headers_mut()
267                    .insert(http::header::MAX_FORWARDS, max_forwards);
268            }
269
270            Ok(response)
271        })
272    }
273}
274
275/// Detect request loops based on Via headers and server names
276fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
277    // 1. Check if the target host matches any of our server names
278    if let Some(server_names) = &config.server_names
279        && let Some(host) = request.uri().host()
280        && server_names.contains(host)
281    {
282        let mut response = Response::new(Body::empty());
283        *response.status_mut() = StatusCode::LOOP_DETECTED;
284        return Some(response);
285    }
286
287    // 2. Check for loops in Via headers
288    if let Some(via) = request.headers().get(http::header::VIA)
289        && let Ok(via_str) = via.to_str()
290    {
291        let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
292        let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
293
294        // Check if our pseudonym appears in any Via header
295        for entry in via_entries {
296            let parts: Vec<&str> = entry.split_whitespace().collect();
297            if parts.len() >= 2 && parts[1] == pseudonym {
298                let mut response = Response::new(Body::empty());
299                *response.status_mut() = StatusCode::LOOP_DETECTED;
300                return Some(response);
301            }
302        }
303    }
304
305    None
306}
307
308/// Process Max-Forwards header for TRACE and OPTIONS methods
309fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
310    let method = request.method();
311
312    // Only process Max-Forwards for TRACE and OPTIONS
313    if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
314        if *method != Method::TRACE && *method != Method::OPTIONS {
315            // For other methods, just preserve the header
316            return None;
317        }
318
319        if let Ok(value_str) = max_forwards.to_str() {
320            if let Ok(value) = value_str.parse::<u32>() {
321                if value == 0 {
322                    let mut response = Response::new(Body::empty());
323                    if *method == Method::TRACE {
324                        *response.body_mut() = Body::from(format!("{request:?}"));
325                    } else {
326                        // For OPTIONS, return 200 OK with Allow header
327                        response.headers_mut().insert(
328                            http::header::ALLOW,
329                            HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
330                        );
331                    }
332                    *response.status_mut() = StatusCode::OK;
333                    Some(response)
334                } else {
335                    // Decrement Max-Forwards
336                    let new_value = value - 1;
337                    request.headers_mut().insert(
338                        http::header::MAX_FORWARDS,
339                        HeaderValue::from_str(&new_value.to_string()).unwrap(),
340                    );
341                    None
342                }
343            } else {
344                None // Invalid number format
345            }
346        } else {
347            None // Invalid header value format
348        }
349    } else {
350        None // No Max-Forwards header
351    }
352}
353
354/// Headers that should be preserved for WebSocket upgrades
355static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];
356
357/// Process Connection header and remove hop-by-hop headers
358fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
359    let mut headers_to_remove = HashSet::new();
360
361    // Add standard hop-by-hop headers
362    for &name in HOP_BY_HOP_HEADERS {
363        // Skip connection and upgrade headers if this is a WebSocket upgrade
364        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
365            continue;
366        }
367        headers_to_remove.insert(HeaderName::from_static(name));
368    }
369
370    // Get headers listed in Connection header
371    if let Some(connection) = request
372        .headers()
373        .get_all(http::header::CONNECTION)
374        .iter()
375        .next()
376        && let Ok(connection_str) = connection.to_str()
377    {
378        for header in connection_str.split(',') {
379            let header = header.trim();
380            // Skip connection and upgrade headers if this is a WebSocket upgrade
381            if preserve_websocket
382                && WEBSOCKET_HEADERS
383                    .iter()
384                    .any(|h| header.eq_ignore_ascii_case(h))
385            {
386                continue;
387            }
388            if let Ok(header_name) = HeaderName::from_str(header)
389                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
390            {
391                headers_to_remove.insert(header_name);
392            }
393        }
394    }
395
396    // Remove all identified headers (case-insensitive)
397    let headers_to_remove = headers_to_remove; // Make immutable
398    let headers_to_remove: Vec<_> = request
399        .headers()
400        .iter()
401        .filter(|(k, _)| {
402            headers_to_remove
403                .iter()
404                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
405        })
406        .map(|(k, _)| k.clone())
407        .collect();
408
409    for header in headers_to_remove {
410        request.headers_mut().remove(&header);
411    }
412}
413
414/// Add Via header to the request
415fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
416    // Get the protocol version from the request
417    let protocol_version = match request.version() {
418        http::Version::HTTP_09 => "0.9",
419        http::Version::HTTP_10 => "1.0",
420        http::Version::HTTP_11 => "1.1",
421        http::Version::HTTP_2 => "2.0",
422        http::Version::HTTP_3 => "3.0",
423        _ => "1.1", // Default to HTTP/1.1 for unknown versions
424    };
425
426    // Get the pseudonym from the config or use the default
427    let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
428
429    // If we're in firewall mode, always use "1.1 firewall"
430    if config.pseudonym.is_some() && !config.combine_via {
431        let via = HeaderValue::from_static("1.1 firewall");
432        request.headers_mut().insert(http::header::VIA, via.clone());
433        return Some(via);
434    }
435
436    // Get any existing Via headers
437    let mut via_values = Vec::new();
438    if let Some(existing_via) = request.headers().get(http::header::VIA)
439        && let Ok(existing_via_str) = existing_via.to_str()
440    {
441        // If we're combining Via headers and have a pseudonym, replace all entries with our protocol version
442        if config.combine_via && config.pseudonym.is_some() {
443            let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
444            let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
445            if all_same_protocol {
446                let via = HeaderValue::from_str(&format!(
447                    "{} {}",
448                    protocol_version,
449                    config.pseudonym.as_ref().unwrap()
450                ))
451                .ok()?;
452                request.headers_mut().insert(http::header::VIA, via.clone());
453                return Some(via);
454            }
455        }
456        via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
457    }
458
459    // Add our new Via header value
460    let new_value = format!("{protocol_version} {pseudonym}");
461    via_values.push(new_value);
462
463    // Create the combined Via header value
464    let combined_via = via_values.join(", ");
465    let via = HeaderValue::from_str(&combined_via).ok()?;
466    request.headers_mut().insert(http::header::VIA, via.clone());
467    Some(via)
468}
469
470/// Process response headers according to RFC9110
471fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
472    let mut headers_to_remove = HashSet::new();
473
474    // Add standard hop-by-hop headers
475    for &name in HOP_BY_HOP_HEADERS {
476        // Skip connection and upgrade headers if this is a WebSocket upgrade response
477        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
478            continue;
479        }
480        headers_to_remove.insert(HeaderName::from_static(name));
481    }
482
483    // Get headers listed in Connection header
484    if let Some(connection) = response
485        .headers()
486        .get_all(http::header::CONNECTION)
487        .iter()
488        .next()
489        && let Ok(connection_str) = connection.to_str()
490    {
491        for header in connection_str.split(',') {
492            let header = header.trim();
493            // Skip connection and upgrade headers if this is a WebSocket upgrade response
494            if preserve_websocket
495                && WEBSOCKET_HEADERS
496                    .iter()
497                    .any(|h| header.eq_ignore_ascii_case(h))
498            {
499                continue;
500            }
501            if let Ok(header_name) = HeaderName::from_str(header)
502                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
503            {
504                headers_to_remove.insert(header_name);
505            }
506        }
507    }
508
509    // Remove all identified headers (case-insensitive)
510    let headers_to_remove = headers_to_remove; // Make immutable
511    let headers_to_remove: Vec<_> = response
512        .headers()
513        .iter()
514        .filter(|(k, _)| {
515            headers_to_remove
516                .iter()
517                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
518        })
519        .map(|(k, _)| k.clone())
520        .collect();
521
522    for header in headers_to_remove {
523        response.headers_mut().remove(&header);
524    }
525
526    // Handle Via header in response - if in firewall mode, replace all entries with "1.1 firewall"
527    if let Some(via) = response.headers().get(http::header::VIA)
528        && let Ok(via_str) = via.to_str()
529        && via_str.contains("firewall")
530    {
531        response
532            .headers_mut()
533            .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
534    }
535}
536
537/// Check if a header is a hop-by-hop header
538fn is_hop_by_hop_header(name: &HeaderName) -> bool {
539    HOP_BY_HOP_HEADERS
540        .iter()
541        .any(|h| name.as_str().eq_ignore_ascii_case(h))
542        || name.as_str().eq_ignore_ascii_case("via")
543}
544
545/// Check if a header is a known end-to-end header
546fn is_end_to_end_header(name: &HeaderName) -> bool {
547    matches!(
548        name.as_str(),
549        "cache-control"
550            | "authorization"
551            | "content-length"
552            | "content-type"
553            | "content-encoding"
554            | "accept"
555            | "accept-encoding"
556            | "accept-language"
557            | "range"
558            | "cookie"
559            | "set-cookie"
560            | "etag"
561    )
562}
563
564/// Check if a request appears to be a WebSocket upgrade request.
565/// This is detected by the presence of sec-websocket-key and sec-websocket-version headers,
566/// which are required for all WebSocket handshakes per RFC 6455.
567fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
568    request.headers().contains_key("sec-websocket-key")
569        && request.headers().contains_key("sec-websocket-version")
570}