Skip to main content

axum_reverse_proxy/
rfc9110.rs

1//! RFC9110 (HTTP Semantics) Compliance Layer
2//!
3//! This module implements middleware for RFC9110 compliance, focusing on:
4//! - [RFC9110 Section 7: Message Routing](https://www.rfc-editor.org/rfc/rfc9110.html#section-7)
5//! - [RFC9110 Section 5.7: Message Forwarding](https://www.rfc-editor.org/rfc/rfc9110.html#section-5.7)
6//!
7//! Key compliance points:
8//! 1. Connection header handling (Section 7.6.1)
9//!    - Remove Connection header and all headers listed within it
10//!    - Remove standard hop-by-hop headers
11//!
12//! 2. Via header handling (Section 7.6.3)
13//!    - Add Via header entries for request/response
14//!    - Support protocol version, pseudonym, and comments
15//!    - Optional combining of multiple entries
16//!
17//! 3. Max-Forwards handling (Section 7.6.2)
18//!    - Process for TRACE and OPTIONS methods
19//!    - Decrement value or respond directly if zero
20//!
21//! 4. Loop detection (Section 7.3)
22//!    - Detect loops using Via headers
23//!    - Check server names/aliases
24//!    - Return 508 Loop Detected status
25//!
26//! 5. End-to-end and Hop-by-hop Headers (Section 7.6.1)
27//!    - Preserve end-to-end headers
28//!    - Remove hop-by-hop headers
29
30use axum::{
31    body::Body,
32    http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35    collections::HashSet,
36    future::Future,
37    pin::Pin,
38    str::FromStr,
39    task::{Context, Poll},
40};
41
42/// Standard hop-by-hop headers defined by RFC 9110
43static HOP_BY_HOP_HEADERS: &[&str] = &[
44    "connection",
45    "keep-alive",
46    "proxy-connection",
47    "transfer-encoding",
48    "te",
49    "trailer",
50    "upgrade",
51];
52use tower::{Layer, Service};
53
54/// Represents a single Via header entry
55#[derive(Debug, Clone)]
56#[allow(dead_code)] // Fields are used for future extensibility
57struct ViaEntry {
58    protocol: String,        // e.g., "1.1"
59    pseudonym: String,       // e.g., "proxy1"
60    port: Option<String>,    // e.g., "8080"
61    comment: Option<String>, // e.g., "(Proxy Software 1.0)"
62}
63
64impl ViaEntry {
65    fn parse(entry: &str) -> Option<Self> {
66        let mut parts = entry.split_whitespace();
67
68        // Get protocol version
69        let protocol = parts.next()?.to_string();
70
71        // Get pseudonym and optional port
72        let pseudonym_part = parts.next()?;
73        let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74            let (name, port) = pseudonym_part.split_at(colon_idx);
75            (name.to_string(), Some(port[1..].to_string()))
76        } else {
77            (pseudonym_part.to_string(), None)
78        };
79
80        // Get optional comment (everything between parentheses)
81        let comment = entry
82            .find('(')
83            .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85        Some(ViaEntry {
86            protocol,
87            pseudonym,
88            port,
89            comment,
90        })
91    }
92}
93
94impl std::fmt::Display for ViaEntry {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(f, "{} {}", self.protocol, self.pseudonym)?;
97        if let Some(port) = &self.port {
98            write!(f, ":{port}")?;
99        }
100        if let Some(comment) = &self.comment {
101            write!(f, " {comment}")?;
102        }
103        Ok(())
104    }
105}
106
107/// Parse a Via header value into a vector of ViaEntry structs
108#[allow(dead_code)] // Kept for future extensibility
109fn parse_via_header(header: &str) -> Vec<ViaEntry> {
110    header
111        .split(',')
112        .filter_map(|entry| ViaEntry::parse(entry.trim()))
113        .collect()
114}
115
116/// Configuration for RFC9110 middleware
117#[derive(Clone, Debug)]
118pub struct Rfc9110Config {
119    /// Server names to check for loop detection
120    pub server_names: Option<HashSet<String>>,
121    /// Pseudonym to use in Via headers
122    pub pseudonym: Option<String>,
123    /// Whether to combine Via headers with the same protocol version
124    pub combine_via: bool,
125    /// Whether to preserve WebSocket upgrade headers (Connection, Upgrade)
126    /// When true, the layer will detect WebSocket upgrade requests and preserve
127    /// the hop-by-hop headers needed for the upgrade to work.
128    /// Default: true
129    pub preserve_websocket_headers: bool,
130}
131
132impl Default for Rfc9110Config {
133    fn default() -> Self {
134        Self {
135            server_names: None,
136            pseudonym: None,
137            combine_via: true,
138            preserve_websocket_headers: true,
139        }
140    }
141}
142
143/// Layer that applies RFC9110 middleware
144#[derive(Clone)]
145pub struct Rfc9110Layer {
146    config: Rfc9110Config,
147}
148
149impl Default for Rfc9110Layer {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl Rfc9110Layer {
156    /// Create a new RFC9110 layer with default configuration
157    pub fn new() -> Self {
158        Self {
159            config: Rfc9110Config::default(),
160        }
161    }
162
163    /// Create a new RFC9110 layer with custom configuration
164    pub fn with_config(config: Rfc9110Config) -> Self {
165        Self { config }
166    }
167}
168
169impl<S> Layer<S> for Rfc9110Layer {
170    type Service = Rfc9110<S>;
171
172    fn layer(&self, inner: S) -> Self::Service {
173        Rfc9110 {
174            inner,
175            config: self.config.clone(),
176        }
177    }
178}
179
180/// RFC9110 middleware service
181#[derive(Clone)]
182pub struct Rfc9110<S> {
183    inner: S,
184    config: Rfc9110Config,
185}
186
187impl<S> Service<Request<Body>> for Rfc9110<S>
188where
189    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
190    S::Future: Send + 'static,
191{
192    type Response = S::Response;
193    type Error = S::Error;
194    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
195
196    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197        self.inner.poll_ready(cx)
198    }
199
200    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
201        let mut inner = self.inner.clone();
202        let config = self.config.clone();
203
204        Box::pin(async move {
205            // 1. Check for loops
206            if let Some(response) = detect_loop(&request, &config) {
207                return Ok(response);
208            }
209
210            // Save original Max-Forwards value for non-TRACE/OPTIONS methods
211            let original_max_forwards =
212                if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
213                    request.headers().get(http::header::MAX_FORWARDS).cloned()
214                } else {
215                    None
216                };
217
218            // 2. Process Max-Forwards
219            if let Some(response) = process_max_forwards(&mut request) {
220                return Ok(response);
221            }
222
223            // Save the Max-Forwards value after processing
224            let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
225
226            // Detect WebSocket upgrade request to preserve necessary headers
227            let is_websocket =
228                config.preserve_websocket_headers && is_websocket_upgrade_request(&request);
229
230            // 3. Process Connection header and remove hop-by-hop headers
231            process_connection_header(&mut request, is_websocket);
232
233            // 4. Add Via header and save it for the response
234            let via_header = add_via_header(&mut request, &config);
235
236            // 5. Forward the request
237            let mut response = inner.call(request).await?;
238
239            // 6. Process response headers (preserve WebSocket headers for 101 responses too)
240            let is_websocket_response =
241                is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
242            process_response_headers(&mut response, is_websocket_response);
243
244            // 7. Add Via header to response (use the same one we set in the request)
245            if let Some(via) = via_header {
246                // In firewall mode, always use "1.1 firewall"
247                if config.pseudonym.is_some() && !config.combine_via {
248                    response
249                        .headers_mut()
250                        .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
251                } else {
252                    response.headers_mut().insert(http::header::VIA, via);
253                }
254            }
255
256            // 8. Restore Max-Forwards header
257            if let Some(max_forwards) = original_max_forwards {
258                // For non-TRACE/OPTIONS methods, restore original value
259                response
260                    .headers_mut()
261                    .insert(http::header::MAX_FORWARDS, max_forwards);
262            } else if let Some(max_forwards) = max_forwards {
263                // For TRACE/OPTIONS, copy the decremented value to the response
264                response
265                    .headers_mut()
266                    .insert(http::header::MAX_FORWARDS, max_forwards);
267            }
268
269            Ok(response)
270        })
271    }
272}
273
274/// Detect request loops based on Via headers and server names
275fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
276    // 1. Check if the target host matches any of our server names
277    if let Some(server_names) = &config.server_names
278        && let Some(host) = request.uri().host()
279        && server_names.contains(host)
280    {
281        let mut response = Response::new(Body::empty());
282        *response.status_mut() = StatusCode::LOOP_DETECTED;
283        return Some(response);
284    }
285
286    // 2. Check for loops in Via headers
287    if let Some(via) = request.headers().get(http::header::VIA)
288        && let Ok(via_str) = via.to_str()
289    {
290        let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
291        let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
292
293        // Check if our pseudonym appears in any Via header
294        for entry in via_entries {
295            let parts: Vec<&str> = entry.split_whitespace().collect();
296            if parts.len() >= 2 && parts[1] == pseudonym {
297                let mut response = Response::new(Body::empty());
298                *response.status_mut() = StatusCode::LOOP_DETECTED;
299                return Some(response);
300            }
301        }
302    }
303
304    None
305}
306
307/// Process Max-Forwards header for TRACE and OPTIONS methods
308fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
309    let method = request.method();
310
311    // Only process Max-Forwards for TRACE and OPTIONS
312    if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
313        if *method != Method::TRACE && *method != Method::OPTIONS {
314            // For other methods, just preserve the header
315            return None;
316        }
317
318        if let Ok(value_str) = max_forwards.to_str() {
319            if let Ok(value) = value_str.parse::<u32>() {
320                if value == 0 {
321                    let mut response = Response::new(Body::empty());
322                    if *method == Method::TRACE {
323                        *response.body_mut() = Body::from(format!("{request:?}"));
324                    } else {
325                        // For OPTIONS, return 200 OK with Allow header
326                        response.headers_mut().insert(
327                            http::header::ALLOW,
328                            HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
329                        );
330                    }
331                    *response.status_mut() = StatusCode::OK;
332                    Some(response)
333                } else {
334                    // Decrement Max-Forwards
335                    let new_value = value - 1;
336                    request.headers_mut().insert(
337                        http::header::MAX_FORWARDS,
338                        HeaderValue::from_str(&new_value.to_string()).unwrap(),
339                    );
340                    None
341                }
342            } else {
343                None // Invalid number format
344            }
345        } else {
346            None // Invalid header value format
347        }
348    } else {
349        None // No Max-Forwards header
350    }
351}
352
353/// Headers that should be preserved for WebSocket upgrades
354static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];
355
356/// Process Connection header and remove hop-by-hop headers
357fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
358    let mut headers_to_remove = HashSet::new();
359
360    // Add standard hop-by-hop headers
361    for &name in HOP_BY_HOP_HEADERS {
362        // Skip connection and upgrade headers if this is a WebSocket upgrade
363        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
364            continue;
365        }
366        headers_to_remove.insert(HeaderName::from_static(name));
367    }
368
369    // Get headers listed in Connection header
370    if let Some(connection) = request
371        .headers()
372        .get_all(http::header::CONNECTION)
373        .iter()
374        .next()
375        && let Ok(connection_str) = connection.to_str()
376    {
377        for header in connection_str.split(',') {
378            let header = header.trim();
379            // Skip connection and upgrade headers if this is a WebSocket upgrade
380            if preserve_websocket
381                && WEBSOCKET_HEADERS
382                    .iter()
383                    .any(|h| header.eq_ignore_ascii_case(h))
384            {
385                continue;
386            }
387            if let Ok(header_name) = HeaderName::from_str(header)
388                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
389            {
390                headers_to_remove.insert(header_name);
391            }
392        }
393    }
394
395    // Remove all identified headers (case-insensitive)
396    let headers_to_remove = headers_to_remove; // Make immutable
397    let headers_to_remove: Vec<_> = request
398        .headers()
399        .iter()
400        .filter(|(k, _)| {
401            headers_to_remove
402                .iter()
403                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
404        })
405        .map(|(k, _)| k.clone())
406        .collect();
407
408    for header in headers_to_remove {
409        request.headers_mut().remove(&header);
410    }
411}
412
413/// Add Via header to the request
414fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
415    // Get the protocol version from the request
416    let protocol_version = match request.version() {
417        http::Version::HTTP_09 => "0.9",
418        http::Version::HTTP_10 => "1.0",
419        http::Version::HTTP_11 => "1.1",
420        http::Version::HTTP_2 => "2.0",
421        http::Version::HTTP_3 => "3.0",
422        _ => "1.1", // Default to HTTP/1.1 for unknown versions
423    };
424
425    // Get the pseudonym from the config or use the default
426    let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
427
428    // If we're in firewall mode, always use "1.1 firewall"
429    if config.pseudonym.is_some() && !config.combine_via {
430        let via = HeaderValue::from_static("1.1 firewall");
431        request.headers_mut().insert(http::header::VIA, via.clone());
432        return Some(via);
433    }
434
435    // Get any existing Via headers
436    let mut via_values = Vec::new();
437    if let Some(existing_via) = request.headers().get(http::header::VIA)
438        && let Ok(existing_via_str) = existing_via.to_str()
439    {
440        // If we're combining Via headers and have a pseudonym, replace all entries with our protocol version
441        if config.combine_via && config.pseudonym.is_some() {
442            let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
443            let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
444            if all_same_protocol {
445                let via = HeaderValue::from_str(&format!(
446                    "{} {}",
447                    protocol_version,
448                    config.pseudonym.as_ref().unwrap()
449                ))
450                .ok()?;
451                request.headers_mut().insert(http::header::VIA, via.clone());
452                return Some(via);
453            }
454        }
455        via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
456    }
457
458    // Add our new Via header value
459    let new_value = format!("{protocol_version} {pseudonym}");
460    via_values.push(new_value);
461
462    // Create the combined Via header value
463    let combined_via = via_values.join(", ");
464    let via = HeaderValue::from_str(&combined_via).ok()?;
465    request.headers_mut().insert(http::header::VIA, via.clone());
466    Some(via)
467}
468
469/// Process response headers according to RFC9110
470fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
471    let mut headers_to_remove = HashSet::new();
472
473    // Add standard hop-by-hop headers
474    for &name in HOP_BY_HOP_HEADERS {
475        // Skip connection and upgrade headers if this is a WebSocket upgrade response
476        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
477            continue;
478        }
479        headers_to_remove.insert(HeaderName::from_static(name));
480    }
481
482    // Get headers listed in Connection header
483    if let Some(connection) = response
484        .headers()
485        .get_all(http::header::CONNECTION)
486        .iter()
487        .next()
488        && let Ok(connection_str) = connection.to_str()
489    {
490        for header in connection_str.split(',') {
491            let header = header.trim();
492            // Skip connection and upgrade headers if this is a WebSocket upgrade response
493            if preserve_websocket
494                && WEBSOCKET_HEADERS
495                    .iter()
496                    .any(|h| header.eq_ignore_ascii_case(h))
497            {
498                continue;
499            }
500            if let Ok(header_name) = HeaderName::from_str(header)
501                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
502            {
503                headers_to_remove.insert(header_name);
504            }
505        }
506    }
507
508    // Remove all identified headers (case-insensitive)
509    let headers_to_remove = headers_to_remove; // Make immutable
510    let headers_to_remove: Vec<_> = response
511        .headers()
512        .iter()
513        .filter(|(k, _)| {
514            headers_to_remove
515                .iter()
516                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
517        })
518        .map(|(k, _)| k.clone())
519        .collect();
520
521    for header in headers_to_remove {
522        response.headers_mut().remove(&header);
523    }
524
525    // Handle Via header in response - if in firewall mode, replace all entries with "1.1 firewall"
526    if let Some(via) = response.headers().get(http::header::VIA)
527        && let Ok(via_str) = via.to_str()
528        && via_str.contains("firewall")
529    {
530        response
531            .headers_mut()
532            .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
533    }
534}
535
536/// Check if a header is a hop-by-hop header
537fn is_hop_by_hop_header(name: &HeaderName) -> bool {
538    HOP_BY_HOP_HEADERS
539        .iter()
540        .any(|h| name.as_str().eq_ignore_ascii_case(h))
541        || name.as_str().eq_ignore_ascii_case("via")
542}
543
544/// Check if a header is a known end-to-end header
545fn is_end_to_end_header(name: &HeaderName) -> bool {
546    matches!(
547        name.as_str(),
548        "cache-control"
549            | "authorization"
550            | "content-length"
551            | "content-type"
552            | "content-encoding"
553            | "accept"
554            | "accept-encoding"
555            | "accept-language"
556            | "range"
557            | "cookie"
558            | "set-cookie"
559            | "etag"
560    )
561}
562
563/// Check if a request appears to be a WebSocket upgrade request.
564/// This is detected by the presence of sec-websocket-key and sec-websocket-version headers,
565/// which are required for all WebSocket handshakes per RFC 6455.
566fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
567    request.headers().contains_key("sec-websocket-key")
568        && request.headers().contains_key("sec-websocket-version")
569}