axum_reverse_proxy/
rfc9110.rs

1//! RFC9110 (HTTP Semantics) Compliance Layer
2//!
3//! This module implements middleware for RFC9110 compliance, focusing on:
4//! - [RFC9110 Section 7: Message Routing](https://www.rfc-editor.org/rfc/rfc9110.html#section-7)
5//! - [RFC9110 Section 5.7: Message Forwarding](https://www.rfc-editor.org/rfc/rfc9110.html#section-5.7)
6//!
7//! Key compliance points:
8//! 1. Connection header handling (Section 7.6.1)
9//!    - Remove Connection header and all headers listed within it
10//!    - Remove standard hop-by-hop headers
11//!
12//! 2. Via header handling (Section 7.6.3)
13//!    - Add Via header entries for request/response
14//!    - Support protocol version, pseudonym, and comments
15//!    - Optional combining of multiple entries
16//!
17//! 3. Max-Forwards handling (Section 7.6.2)
18//!    - Process for TRACE and OPTIONS methods
19//!    - Decrement value or respond directly if zero
20//!
21//! 4. Loop detection (Section 7.3)
22//!    - Detect loops using Via headers
23//!    - Check server names/aliases
24//!    - Return 508 Loop Detected status
25//!
26//! 5. End-to-end and Hop-by-hop Headers (Section 7.6.1)
27//!    - Preserve end-to-end headers
28//!    - Remove hop-by-hop headers
29
30use axum::{
31    body::Body,
32    http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35    collections::{HashMap, HashSet},
36    future::Future,
37    pin::Pin,
38    str::FromStr,
39    task::{Context, Poll},
40};
41
42/// Standard hop-by-hop headers defined by RFC 9110
43static HOP_BY_HOP_HEADERS: &[&str] = &[
44    "connection",
45    "keep-alive",
46    "proxy-connection",
47    "transfer-encoding",
48    "te",
49    "trailer",
50    "upgrade",
51];
52use tower::{Layer, Service};
53
54/// Represents a single Via header entry
55#[derive(Debug, Clone)]
56#[allow(dead_code)] // Fields are used for future extensibility
57struct ViaEntry {
58    protocol: String,        // e.g., "1.1"
59    pseudonym: String,       // e.g., "proxy1"
60    port: Option<String>,    // e.g., "8080"
61    comment: Option<String>, // e.g., "(Proxy Software 1.0)"
62}
63
64impl ViaEntry {
65    fn parse(entry: &str) -> Option<Self> {
66        let mut parts = entry.split_whitespace();
67
68        // Get protocol version
69        let protocol = parts.next()?.to_string();
70
71        // Get pseudonym and optional port
72        let pseudonym_part = parts.next()?;
73        let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74            let (name, port) = pseudonym_part.split_at(colon_idx);
75            (name.to_string(), Some(port[1..].to_string()))
76        } else {
77            (pseudonym_part.to_string(), None)
78        };
79
80        // Get optional comment (everything between parentheses)
81        let comment = entry
82            .find('(')
83            .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85        Some(ViaEntry {
86            protocol,
87            pseudonym,
88            port,
89            comment,
90        })
91    }
92}
93
94impl std::fmt::Display for ViaEntry {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        write!(f, "{} {}", self.protocol, self.pseudonym)?;
97        if let Some(port) = &self.port {
98            write!(f, ":{port}")?;
99        }
100        if let Some(comment) = &self.comment {
101            write!(f, " {comment}")?;
102        }
103        Ok(())
104    }
105}
106
107/// Parse a Via header value into a vector of ViaEntry structs
108fn parse_via_header(header: &str) -> Vec<ViaEntry> {
109    header
110        .split(',')
111        .filter_map(|entry| ViaEntry::parse(entry.trim()))
112        .collect()
113}
114
115/// Group Via entries by protocol version
116fn group_by_protocol(entries: Vec<ViaEntry>) -> HashMap<String, Vec<ViaEntry>> {
117    let mut groups = HashMap::new();
118    for entry in entries {
119        groups
120            .entry(entry.protocol.clone())
121            .or_insert_with(Vec::new)
122            .push(entry);
123    }
124    groups
125}
126
127/// Configuration for RFC9110 middleware
128#[derive(Clone, Debug)]
129pub struct Rfc9110Config {
130    /// Server names to check for loop detection
131    pub server_names: Option<HashSet<String>>,
132    /// Pseudonym to use in Via headers
133    pub pseudonym: Option<String>,
134    /// Whether to combine Via headers with the same protocol version
135    pub combine_via: bool,
136}
137
138impl Default for Rfc9110Config {
139    fn default() -> Self {
140        Self {
141            server_names: None,
142            pseudonym: None,
143            combine_via: true,
144        }
145    }
146}
147
148/// Layer that applies RFC9110 middleware
149#[derive(Clone)]
150pub struct Rfc9110Layer {
151    config: Rfc9110Config,
152}
153
154impl Default for Rfc9110Layer {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160impl Rfc9110Layer {
161    /// Create a new RFC9110 layer with default configuration
162    pub fn new() -> Self {
163        Self {
164            config: Rfc9110Config::default(),
165        }
166    }
167
168    /// Create a new RFC9110 layer with custom configuration
169    pub fn with_config(config: Rfc9110Config) -> Self {
170        Self { config }
171    }
172}
173
174impl<S> Layer<S> for Rfc9110Layer {
175    type Service = Rfc9110<S>;
176
177    fn layer(&self, inner: S) -> Self::Service {
178        Rfc9110 {
179            inner,
180            config: self.config.clone(),
181        }
182    }
183}
184
185/// RFC9110 middleware service
186#[derive(Clone)]
187pub struct Rfc9110<S> {
188    inner: S,
189    config: Rfc9110Config,
190}
191
192impl<S> Service<Request<Body>> for Rfc9110<S>
193where
194    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
195    S::Future: Send + 'static,
196{
197    type Response = S::Response;
198    type Error = S::Error;
199    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
200
201    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202        self.inner.poll_ready(cx)
203    }
204
205    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
206        let mut inner = self.inner.clone();
207        let config = self.config.clone();
208
209        Box::pin(async move {
210            // 1. Check for loops
211            if let Some(response) = detect_loop(&request, &config) {
212                return Ok(response);
213            }
214
215            // Save original Max-Forwards value for non-TRACE/OPTIONS methods
216            let original_max_forwards =
217                if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
218                    request.headers().get(http::header::MAX_FORWARDS).cloned()
219                } else {
220                    None
221                };
222
223            // 2. Process Max-Forwards
224            if let Some(response) = process_max_forwards(&mut request) {
225                return Ok(response);
226            }
227
228            // Save the Max-Forwards value after processing
229            let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
230
231            // 3. Process Connection header and remove hop-by-hop headers
232            process_connection_header(&mut request);
233
234            // Save end-to-end headers after processing Connection header
235            let preserved_headers = request.headers().clone();
236
237            // 4. Add Via header and save it for the response
238            let via_header = add_via_header(&mut request, &config);
239
240            // 5. Forward the request
241            let mut response = inner.call(request).await?;
242
243            // 6. Process response headers
244            process_response_headers(&mut response);
245
246            // 7. Add Via header to response (use the same one we set in the request)
247            if let Some(via) = via_header {
248                // In firewall mode, always use "1.1 firewall"
249                if config.pseudonym.is_some() && !config.combine_via {
250                    response
251                        .headers_mut()
252                        .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
253                } else {
254                    response.headers_mut().insert(http::header::VIA, via);
255                }
256            }
257
258            // 8. Restore Max-Forwards header
259            if let Some(max_forwards) = original_max_forwards {
260                // For non-TRACE/OPTIONS methods, restore original value
261                response
262                    .headers_mut()
263                    .insert(http::header::MAX_FORWARDS, max_forwards);
264            } else if let Some(max_forwards) = max_forwards {
265                // For TRACE/OPTIONS, copy the decremented value to the response
266                response
267                    .headers_mut()
268                    .insert(http::header::MAX_FORWARDS, max_forwards);
269            }
270
271            // 9. Restore preserved end-to-end headers
272            for (name, value) in preserved_headers.iter() {
273                if !is_hop_by_hop_header(name) {
274                    response.headers_mut().insert(name, value.clone());
275                }
276            }
277
278            Ok(response)
279        })
280    }
281}
282
283/// Detect request loops based on Via headers and server names
284fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
285    // 1. Check if the target host matches any of our server names
286    if let Some(server_names) = &config.server_names {
287        if let Some(host) = request.uri().host() {
288            if server_names.contains(host) {
289                let mut response = Response::new(Body::empty());
290                *response.status_mut() = StatusCode::LOOP_DETECTED;
291                return Some(response);
292            }
293        }
294    }
295
296    // 2. Check for loops in Via headers
297    if let Some(via) = request.headers().get(http::header::VIA) {
298        if let Ok(via_str) = via.to_str() {
299            let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
300            let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
301
302            // Check if our pseudonym appears in any Via header
303            for entry in via_entries {
304                let parts: Vec<&str> = entry.split_whitespace().collect();
305                if parts.len() >= 2 && parts[1] == pseudonym {
306                    let mut response = Response::new(Body::empty());
307                    *response.status_mut() = StatusCode::LOOP_DETECTED;
308                    return Some(response);
309                }
310            }
311        }
312    }
313
314    None
315}
316
317/// Process Max-Forwards header for TRACE and OPTIONS methods
318fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
319    let method = request.method();
320
321    // Only process Max-Forwards for TRACE and OPTIONS
322    if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
323        if *method != Method::TRACE && *method != Method::OPTIONS {
324            // For other methods, just preserve the header
325            return None;
326        }
327
328        if let Ok(value_str) = max_forwards.to_str() {
329            if let Ok(value) = value_str.parse::<u32>() {
330                if value == 0 {
331                    let mut response = Response::new(Body::empty());
332                    if *method == Method::TRACE {
333                        *response.body_mut() = Body::from(format!("{request:?}"));
334                    } else {
335                        // For OPTIONS, return 200 OK with Allow header
336                        response.headers_mut().insert(
337                            http::header::ALLOW,
338                            HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
339                        );
340                    }
341                    *response.status_mut() = StatusCode::OK;
342                    Some(response)
343                } else {
344                    // Decrement Max-Forwards
345                    let new_value = value - 1;
346                    request.headers_mut().insert(
347                        http::header::MAX_FORWARDS,
348                        HeaderValue::from_str(&new_value.to_string()).unwrap(),
349                    );
350                    None
351                }
352            } else {
353                None // Invalid number format
354            }
355        } else {
356            None // Invalid header value format
357        }
358    } else {
359        None // No Max-Forwards header
360    }
361}
362
363/// Process Connection header and remove hop-by-hop headers
364fn process_connection_header(request: &mut Request<Body>) {
365    let mut headers_to_remove = HashSet::new();
366
367    // Add standard hop-by-hop headers
368    for &name in HOP_BY_HOP_HEADERS {
369        headers_to_remove.insert(HeaderName::from_static(name));
370    }
371
372    // Get headers listed in Connection header
373    if let Some(connection) = request
374        .headers()
375        .get_all(http::header::CONNECTION)
376        .iter()
377        .next()
378    {
379        if let Ok(connection_str) = connection.to_str() {
380            for header in connection_str.split(',') {
381                let header = header.trim();
382                if let Ok(header_name) = HeaderName::from_str(header) {
383                    if is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name) {
384                        headers_to_remove.insert(header_name);
385                    }
386                }
387            }
388        }
389    }
390
391    // Remove all identified headers (case-insensitive)
392    let headers_to_remove = headers_to_remove; // Make immutable
393    let headers_to_remove: Vec<_> = request
394        .headers()
395        .iter()
396        .filter(|(k, _)| {
397            headers_to_remove
398                .iter()
399                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
400        })
401        .map(|(k, _)| k.clone())
402        .collect();
403
404    for header in headers_to_remove {
405        request.headers_mut().remove(&header);
406    }
407}
408
409/// Add Via header to the request
410fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
411    // Get the protocol version from the request
412    let protocol_version = match request.version() {
413        http::Version::HTTP_09 => "0.9",
414        http::Version::HTTP_10 => "1.0",
415        http::Version::HTTP_11 => "1.1",
416        http::Version::HTTP_2 => "2.0",
417        http::Version::HTTP_3 => "3.0",
418        _ => "1.1", // Default to HTTP/1.1 for unknown versions
419    };
420
421    // Get the pseudonym from the config or use the default
422    let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
423
424    // If we're in firewall mode, always use "1.1 firewall"
425    if config.pseudonym.is_some() && !config.combine_via {
426        let via = HeaderValue::from_static("1.1 firewall");
427        request.headers_mut().insert(http::header::VIA, via.clone());
428        return Some(via);
429    }
430
431    // Get any existing Via headers
432    let mut via_values = Vec::new();
433    if let Some(existing_via) = request.headers().get(http::header::VIA) {
434        if let Ok(existing_via_str) = existing_via.to_str() {
435            // If we're combining Via headers and have a pseudonym, replace all entries with our protocol version
436            if config.combine_via && config.pseudonym.is_some() {
437                let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
438                let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
439                if all_same_protocol {
440                    let via = HeaderValue::from_str(&format!(
441                        "{} {}",
442                        protocol_version,
443                        config.pseudonym.as_ref().unwrap()
444                    ))
445                    .ok()?;
446                    request.headers_mut().insert(http::header::VIA, via.clone());
447                    return Some(via);
448                }
449            }
450            via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
451        }
452    }
453
454    // Add our new Via header value
455    let new_value = format!("{protocol_version} {pseudonym}");
456    via_values.push(new_value);
457
458    // Create the combined Via header value
459    let combined_via = via_values.join(", ");
460    let via = HeaderValue::from_str(&combined_via).ok()?;
461    request.headers_mut().insert(http::header::VIA, via.clone());
462    Some(via)
463}
464
465/// Process response headers according to RFC9110
466fn process_response_headers(response: &mut Response<Body>) {
467    let mut headers_to_remove = HashSet::new();
468
469    // Add standard hop-by-hop headers
470    for &name in HOP_BY_HOP_HEADERS {
471        headers_to_remove.insert(HeaderName::from_static(name));
472    }
473
474    // Get headers listed in Connection header
475    if let Some(connection) = response
476        .headers()
477        .get_all(http::header::CONNECTION)
478        .iter()
479        .next()
480    {
481        if let Ok(connection_str) = connection.to_str() {
482            for header in connection_str.split(',') {
483                let header = header.trim();
484                if let Ok(header_name) = HeaderName::from_str(header) {
485                    if is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name) {
486                        headers_to_remove.insert(header_name);
487                    }
488                }
489            }
490        }
491    }
492
493    // Remove all identified headers (case-insensitive)
494    let headers_to_remove = headers_to_remove; // Make immutable
495    let headers_to_remove: Vec<_> = response
496        .headers()
497        .iter()
498        .filter(|(k, _)| {
499            headers_to_remove
500                .iter()
501                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
502        })
503        .map(|(k, _)| k.clone())
504        .collect();
505
506    for header in headers_to_remove {
507        response.headers_mut().remove(&header);
508    }
509
510    // Handle Via header in response
511    if let Some(via) = response.headers().get(http::header::VIA) {
512        if let Ok(via_str) = via.to_str() {
513            let entries = parse_via_header(via_str);
514            let _groups = group_by_protocol(entries);
515
516            // If in firewall mode, replace all entries with "1.1 firewall"
517            if let Some(via_header) = response.headers().get(http::header::VIA) {
518                if let Ok(via_str) = via_header.to_str() {
519                    if via_str.contains("firewall") {
520                        response
521                            .headers_mut()
522                            .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
523                    }
524                }
525            }
526        }
527    }
528}
529
530/// Check if a header is a hop-by-hop header
531fn is_hop_by_hop_header(name: &HeaderName) -> bool {
532    HOP_BY_HOP_HEADERS
533        .iter()
534        .any(|h| name.as_str().eq_ignore_ascii_case(h))
535        || name.as_str().eq_ignore_ascii_case("via")
536}
537
538/// Check if a header is a known end-to-end header
539fn is_end_to_end_header(name: &HeaderName) -> bool {
540    matches!(
541        name.as_str(),
542        "cache-control"
543            | "authorization"
544            | "content-length"
545            | "content-type"
546            | "content-encoding"
547            | "accept"
548            | "accept-encoding"
549            | "accept-language"
550            | "range"
551            | "cookie"
552            | "set-cookie"
553            | "etag"
554    )
555}