axum_reverse_proxy/
rfc9110.rs

1//! RFC9110 (HTTP Semantics) Compliance Layer
2//!
3//! This module implements middleware for RFC9110 compliance, focusing on:
4//! - [RFC9110 Section 7: Message Routing](https://www.rfc-editor.org/rfc/rfc9110.html#section-7)
5//! - [RFC9110 Section 5.7: Message Forwarding](https://www.rfc-editor.org/rfc/rfc9110.html#section-5.7)
6//!
7//! Key compliance points:
8//! 1. Connection header handling (Section 7.6.1)
9//!    - Remove Connection header and all headers listed within it
10//!    - Remove standard hop-by-hop headers
11//!
12//! 2. Via header handling (Section 7.6.3)
13//!    - Add Via header entries for request/response
14//!    - Support protocol version, pseudonym, and comments
15//!    - Optional combining of multiple entries
16//!
17//! 3. Max-Forwards handling (Section 7.6.2)
18//!    - Process for TRACE and OPTIONS methods
19//!    - Decrement value or respond directly if zero
20//!
21//! 4. Loop detection (Section 7.3)
22//!    - Detect loops using Via headers
23//!    - Check server names/aliases
24//!    - Return 508 Loop Detected status
25//!
26//! 5. End-to-end and Hop-by-hop Headers (Section 7.6.1)
27//!    - Preserve end-to-end headers
28//!    - Remove hop-by-hop headers
29
30use axum::{
31    body::Body,
32    http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35    collections::{HashMap, HashSet},
36    future::Future,
37    pin::Pin,
38    str::FromStr,
39    task::{Context, Poll},
40};
41
42/// Standard hop-by-hop headers defined by RFC 9110
43static HOP_BY_HOP_HEADERS: &[&str] = &[
44    "connection",
45    "keep-alive",
46    "proxy-connection",
47    "transfer-encoding",
48    "te",
49    "trailer",
50    "upgrade",
51];
52use tower::{Layer, Service};
53
54/// Represents a single Via header entry
55#[derive(Debug, Clone)]
56#[allow(dead_code)] // Fields are used for future extensibility
57struct ViaEntry {
58    protocol: String,        // e.g., "1.1"
59    pseudonym: String,       // e.g., "proxy1"
60    port: Option<String>,    // e.g., "8080"
61    comment: Option<String>, // e.g., "(Proxy Software 1.0)"
62}
63
64impl ViaEntry {
65    fn parse(entry: &str) -> Option<Self> {
66        let mut parts = entry.split_whitespace();
67
68        // Get protocol version
69        let protocol = parts.next()?.to_string();
70
71        // Get pseudonym and optional port
72        let pseudonym_part = parts.next()?;
73        let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74            let (name, port) = pseudonym_part.split_at(colon_idx);
75            (name.to_string(), Some(port[1..].to_string()))
76        } else {
77            (pseudonym_part.to_string(), None)
78        };
79
80        // Get optional comment (everything between parentheses)
81        let comment = entry
82            .find('(')
83            .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85        Some(ViaEntry {
86            protocol,
87            pseudonym,
88            port,
89            comment,
90        })
91    }
92}
93
94impl std::fmt::Display for ViaEntry {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(f, "{} {}", self.protocol, self.pseudonym)?;
97        if let Some(port) = &self.port {
98            write!(f, ":{port}")?;
99        }
100        if let Some(comment) = &self.comment {
101            write!(f, " {comment}")?;
102        }
103        Ok(())
104    }
105}
106
107/// Parse a Via header value into a vector of ViaEntry structs
108fn parse_via_header(header: &str) -> Vec<ViaEntry> {
109    header
110        .split(',')
111        .filter_map(|entry| ViaEntry::parse(entry.trim()))
112        .collect()
113}
114
115/// Group Via entries by protocol version
116fn group_by_protocol(entries: Vec<ViaEntry>) -> HashMap<String, Vec<ViaEntry>> {
117    let mut groups = HashMap::new();
118    for entry in entries {
119        groups
120            .entry(entry.protocol.clone())
121            .or_insert_with(Vec::new)
122            .push(entry);
123    }
124    groups
125}
126
127/// Configuration for RFC9110 middleware
128#[derive(Clone, Debug)]
129pub struct Rfc9110Config {
130    /// Server names to check for loop detection
131    pub server_names: Option<HashSet<String>>,
132    /// Pseudonym to use in Via headers
133    pub pseudonym: Option<String>,
134    /// Whether to combine Via headers with the same protocol version
135    pub combine_via: bool,
136    /// Whether to preserve WebSocket upgrade headers (Connection, Upgrade)
137    /// When true, the layer will detect WebSocket upgrade requests and preserve
138    /// the hop-by-hop headers needed for the upgrade to work.
139    /// Default: true
140    pub preserve_websocket_headers: bool,
141}
142
143impl Default for Rfc9110Config {
144    fn default() -> Self {
145        Self {
146            server_names: None,
147            pseudonym: None,
148            combine_via: true,
149            preserve_websocket_headers: true,
150        }
151    }
152}
153
154/// Layer that applies RFC9110 middleware
155#[derive(Clone)]
156pub struct Rfc9110Layer {
157    config: Rfc9110Config,
158}
159
160impl Default for Rfc9110Layer {
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166impl Rfc9110Layer {
167    /// Create a new RFC9110 layer with default configuration
168    pub fn new() -> Self {
169        Self {
170            config: Rfc9110Config::default(),
171        }
172    }
173
174    /// Create a new RFC9110 layer with custom configuration
175    pub fn with_config(config: Rfc9110Config) -> Self {
176        Self { config }
177    }
178}
179
180impl<S> Layer<S> for Rfc9110Layer {
181    type Service = Rfc9110<S>;
182
183    fn layer(&self, inner: S) -> Self::Service {
184        Rfc9110 {
185            inner,
186            config: self.config.clone(),
187        }
188    }
189}
190
191/// RFC9110 middleware service
192#[derive(Clone)]
193pub struct Rfc9110<S> {
194    inner: S,
195    config: Rfc9110Config,
196}
197
198impl<S> Service<Request<Body>> for Rfc9110<S>
199where
200    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
201    S::Future: Send + 'static,
202{
203    type Response = S::Response;
204    type Error = S::Error;
205    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
206
207    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208        self.inner.poll_ready(cx)
209    }
210
211    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
212        let mut inner = self.inner.clone();
213        let config = self.config.clone();
214
215        Box::pin(async move {
216            // 1. Check for loops
217            if let Some(response) = detect_loop(&request, &config) {
218                return Ok(response);
219            }
220
221            // Save original Max-Forwards value for non-TRACE/OPTIONS methods
222            let original_max_forwards =
223                if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
224                    request.headers().get(http::header::MAX_FORWARDS).cloned()
225                } else {
226                    None
227                };
228
229            // 2. Process Max-Forwards
230            if let Some(response) = process_max_forwards(&mut request) {
231                return Ok(response);
232            }
233
234            // Save the Max-Forwards value after processing
235            let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
236
237            // Detect WebSocket upgrade request to preserve necessary headers
238            let is_websocket =
239                config.preserve_websocket_headers && is_websocket_upgrade_request(&request);
240
241            // 3. Process Connection header and remove hop-by-hop headers
242            process_connection_header(&mut request, is_websocket);
243
244            // Save end-to-end headers after processing Connection header
245            let preserved_headers = request.headers().clone();
246
247            // 4. Add Via header and save it for the response
248            let via_header = add_via_header(&mut request, &config);
249
250            // 5. Forward the request
251            let mut response = inner.call(request).await?;
252
253            // 6. Process response headers (preserve WebSocket headers for 101 responses too)
254            let is_websocket_response =
255                is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
256            process_response_headers(&mut response, is_websocket_response);
257
258            // 7. Add Via header to response (use the same one we set in the request)
259            if let Some(via) = via_header {
260                // In firewall mode, always use "1.1 firewall"
261                if config.pseudonym.is_some() && !config.combine_via {
262                    response
263                        .headers_mut()
264                        .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
265                } else {
266                    response.headers_mut().insert(http::header::VIA, via);
267                }
268            }
269
270            // 8. Restore Max-Forwards header
271            if let Some(max_forwards) = original_max_forwards {
272                // For non-TRACE/OPTIONS methods, restore original value
273                response
274                    .headers_mut()
275                    .insert(http::header::MAX_FORWARDS, max_forwards);
276            } else if let Some(max_forwards) = max_forwards {
277                // For TRACE/OPTIONS, copy the decremented value to the response
278                response
279                    .headers_mut()
280                    .insert(http::header::MAX_FORWARDS, max_forwards);
281            }
282
283            // 9. Restore preserved end-to-end headers
284            for (name, value) in preserved_headers.iter() {
285                if !is_hop_by_hop_header(name) {
286                    response.headers_mut().insert(name, value.clone());
287                }
288            }
289
290            Ok(response)
291        })
292    }
293}
294
295/// Detect request loops based on Via headers and server names
296fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
297    // 1. Check if the target host matches any of our server names
298    if let Some(server_names) = &config.server_names
299        && let Some(host) = request.uri().host()
300        && server_names.contains(host)
301    {
302        let mut response = Response::new(Body::empty());
303        *response.status_mut() = StatusCode::LOOP_DETECTED;
304        return Some(response);
305    }
306
307    // 2. Check for loops in Via headers
308    if let Some(via) = request.headers().get(http::header::VIA)
309        && let Ok(via_str) = via.to_str()
310    {
311        let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
312        let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
313
314        // Check if our pseudonym appears in any Via header
315        for entry in via_entries {
316            let parts: Vec<&str> = entry.split_whitespace().collect();
317            if parts.len() >= 2 && parts[1] == pseudonym {
318                let mut response = Response::new(Body::empty());
319                *response.status_mut() = StatusCode::LOOP_DETECTED;
320                return Some(response);
321            }
322        }
323    }
324
325    None
326}
327
328/// Process Max-Forwards header for TRACE and OPTIONS methods
329fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
330    let method = request.method();
331
332    // Only process Max-Forwards for TRACE and OPTIONS
333    if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
334        if *method != Method::TRACE && *method != Method::OPTIONS {
335            // For other methods, just preserve the header
336            return None;
337        }
338
339        if let Ok(value_str) = max_forwards.to_str() {
340            if let Ok(value) = value_str.parse::<u32>() {
341                if value == 0 {
342                    let mut response = Response::new(Body::empty());
343                    if *method == Method::TRACE {
344                        *response.body_mut() = Body::from(format!("{request:?}"));
345                    } else {
346                        // For OPTIONS, return 200 OK with Allow header
347                        response.headers_mut().insert(
348                            http::header::ALLOW,
349                            HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
350                        );
351                    }
352                    *response.status_mut() = StatusCode::OK;
353                    Some(response)
354                } else {
355                    // Decrement Max-Forwards
356                    let new_value = value - 1;
357                    request.headers_mut().insert(
358                        http::header::MAX_FORWARDS,
359                        HeaderValue::from_str(&new_value.to_string()).unwrap(),
360                    );
361                    None
362                }
363            } else {
364                None // Invalid number format
365            }
366        } else {
367            None // Invalid header value format
368        }
369    } else {
370        None // No Max-Forwards header
371    }
372}
373
374/// Headers that should be preserved for WebSocket upgrades
375static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];
376
377/// Process Connection header and remove hop-by-hop headers
378fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
379    let mut headers_to_remove = HashSet::new();
380
381    // Add standard hop-by-hop headers
382    for &name in HOP_BY_HOP_HEADERS {
383        // Skip connection and upgrade headers if this is a WebSocket upgrade
384        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
385            continue;
386        }
387        headers_to_remove.insert(HeaderName::from_static(name));
388    }
389
390    // Get headers listed in Connection header
391    if let Some(connection) = request
392        .headers()
393        .get_all(http::header::CONNECTION)
394        .iter()
395        .next()
396        && let Ok(connection_str) = connection.to_str()
397    {
398        for header in connection_str.split(',') {
399            let header = header.trim();
400            // Skip connection and upgrade headers if this is a WebSocket upgrade
401            if preserve_websocket
402                && WEBSOCKET_HEADERS
403                    .iter()
404                    .any(|h| header.eq_ignore_ascii_case(h))
405            {
406                continue;
407            }
408            if let Ok(header_name) = HeaderName::from_str(header)
409                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
410            {
411                headers_to_remove.insert(header_name);
412            }
413        }
414    }
415
416    // Remove all identified headers (case-insensitive)
417    let headers_to_remove = headers_to_remove; // Make immutable
418    let headers_to_remove: Vec<_> = request
419        .headers()
420        .iter()
421        .filter(|(k, _)| {
422            headers_to_remove
423                .iter()
424                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
425        })
426        .map(|(k, _)| k.clone())
427        .collect();
428
429    for header in headers_to_remove {
430        request.headers_mut().remove(&header);
431    }
432}
433
434/// Add Via header to the request
435fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
436    // Get the protocol version from the request
437    let protocol_version = match request.version() {
438        http::Version::HTTP_09 => "0.9",
439        http::Version::HTTP_10 => "1.0",
440        http::Version::HTTP_11 => "1.1",
441        http::Version::HTTP_2 => "2.0",
442        http::Version::HTTP_3 => "3.0",
443        _ => "1.1", // Default to HTTP/1.1 for unknown versions
444    };
445
446    // Get the pseudonym from the config or use the default
447    let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
448
449    // If we're in firewall mode, always use "1.1 firewall"
450    if config.pseudonym.is_some() && !config.combine_via {
451        let via = HeaderValue::from_static("1.1 firewall");
452        request.headers_mut().insert(http::header::VIA, via.clone());
453        return Some(via);
454    }
455
456    // Get any existing Via headers
457    let mut via_values = Vec::new();
458    if let Some(existing_via) = request.headers().get(http::header::VIA)
459        && let Ok(existing_via_str) = existing_via.to_str()
460    {
461        // If we're combining Via headers and have a pseudonym, replace all entries with our protocol version
462        if config.combine_via && config.pseudonym.is_some() {
463            let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
464            let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
465            if all_same_protocol {
466                let via = HeaderValue::from_str(&format!(
467                    "{} {}",
468                    protocol_version,
469                    config.pseudonym.as_ref().unwrap()
470                ))
471                .ok()?;
472                request.headers_mut().insert(http::header::VIA, via.clone());
473                return Some(via);
474            }
475        }
476        via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
477    }
478
479    // Add our new Via header value
480    let new_value = format!("{protocol_version} {pseudonym}");
481    via_values.push(new_value);
482
483    // Create the combined Via header value
484    let combined_via = via_values.join(", ");
485    let via = HeaderValue::from_str(&combined_via).ok()?;
486    request.headers_mut().insert(http::header::VIA, via.clone());
487    Some(via)
488}
489
490/// Process response headers according to RFC9110
491fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
492    let mut headers_to_remove = HashSet::new();
493
494    // Add standard hop-by-hop headers
495    for &name in HOP_BY_HOP_HEADERS {
496        // Skip connection and upgrade headers if this is a WebSocket upgrade response
497        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
498            continue;
499        }
500        headers_to_remove.insert(HeaderName::from_static(name));
501    }
502
503    // Get headers listed in Connection header
504    if let Some(connection) = response
505        .headers()
506        .get_all(http::header::CONNECTION)
507        .iter()
508        .next()
509        && let Ok(connection_str) = connection.to_str()
510    {
511        for header in connection_str.split(',') {
512            let header = header.trim();
513            // Skip connection and upgrade headers if this is a WebSocket upgrade response
514            if preserve_websocket
515                && WEBSOCKET_HEADERS
516                    .iter()
517                    .any(|h| header.eq_ignore_ascii_case(h))
518            {
519                continue;
520            }
521            if let Ok(header_name) = HeaderName::from_str(header)
522                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
523            {
524                headers_to_remove.insert(header_name);
525            }
526        }
527    }
528
529    // Remove all identified headers (case-insensitive)
530    let headers_to_remove = headers_to_remove; // Make immutable
531    let headers_to_remove: Vec<_> = response
532        .headers()
533        .iter()
534        .filter(|(k, _)| {
535            headers_to_remove
536                .iter()
537                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
538        })
539        .map(|(k, _)| k.clone())
540        .collect();
541
542    for header in headers_to_remove {
543        response.headers_mut().remove(&header);
544    }
545
546    // Handle Via header in response
547    if let Some(via) = response.headers().get(http::header::VIA)
548        && let Ok(via_str) = via.to_str()
549    {
550        let entries = parse_via_header(via_str);
551        let _groups = group_by_protocol(entries);
552
553        // If in firewall mode, replace all entries with "1.1 firewall"
554        if via_str.contains("firewall") {
555            response
556                .headers_mut()
557                .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
558        }
559    }
560}
561
562/// Check if a header is a hop-by-hop header
563fn is_hop_by_hop_header(name: &HeaderName) -> bool {
564    HOP_BY_HOP_HEADERS
565        .iter()
566        .any(|h| name.as_str().eq_ignore_ascii_case(h))
567        || name.as_str().eq_ignore_ascii_case("via")
568}
569
570/// Check if a header is a known end-to-end header
571fn is_end_to_end_header(name: &HeaderName) -> bool {
572    matches!(
573        name.as_str(),
574        "cache-control"
575            | "authorization"
576            | "content-length"
577            | "content-type"
578            | "content-encoding"
579            | "accept"
580            | "accept-encoding"
581            | "accept-language"
582            | "range"
583            | "cookie"
584            | "set-cookie"
585            | "etag"
586    )
587}
588
589/// Check if a request appears to be a WebSocket upgrade request.
590/// This is detected by the presence of sec-websocket-key and sec-websocket-version headers,
591/// which are required for all WebSocket handshakes per RFC 6455.
592fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
593    request.headers().contains_key("sec-websocket-key")
594        && request.headers().contains_key("sec-websocket-version")
595}