axum-reverse-proxy 1.3.0

A flexible and efficient reverse proxy implementation for Axum web applications
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
//! RFC9110 (HTTP Semantics) Compliance Layer
//!
//! This module implements middleware for RFC9110 compliance, focusing on:
//! - [RFC9110 Section 7: Message Routing](https://www.rfc-editor.org/rfc/rfc9110.html#section-7)
//! - [RFC9110 Section 5.7: Message Forwarding](https://www.rfc-editor.org/rfc/rfc9110.html#section-5.7)
//!
//! Key compliance points:
//! 1. Connection header handling (Section 7.6.1)
//!    - Remove Connection header and all headers listed within it
//!    - Remove standard hop-by-hop headers
//!
//! 2. Via header handling (Section 7.6.3)
//!    - Add Via header entries for request/response
//!    - Support protocol version, pseudonym, and comments
//!    - Optional combining of multiple entries
//!
//! 3. Max-Forwards handling (Section 7.6.2)
//!    - Process for TRACE and OPTIONS methods
//!    - Decrement value or respond directly if zero
//!
//! 4. Loop detection (Section 7.3)
//!    - Detect loops using Via headers
//!    - Check server names/aliases
//!    - Return 508 Loop Detected status
//!
//! 5. End-to-end and Hop-by-hop Headers (Section 7.6.1)
//!    - Preserve end-to-end headers
//!    - Remove hop-by-hop headers

use axum::{
    body::Body,
    http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
};
use std::{
    collections::HashSet,
    future::Future,
    pin::Pin,
    str::FromStr,
    task::{Context, Poll},
};

/// Standard hop-by-hop headers defined by RFC 9110
static HOP_BY_HOP_HEADERS: &[&str] = &[
    "connection",
    "keep-alive",
    "proxy-connection",
    "transfer-encoding",
    "te",
    "trailer",
    "upgrade",
];
use tower::{Layer, Service};

/// Represents a single Via header entry
#[derive(Debug, Clone)]
#[allow(dead_code)] // Fields are used for future extensibility
struct ViaEntry {
    protocol: String,        // e.g., "1.1"
    pseudonym: String,       // e.g., "proxy1"
    port: Option<String>,    // e.g., "8080"
    comment: Option<String>, // e.g., "(Proxy Software 1.0)"
}

impl ViaEntry {
    fn parse(entry: &str) -> Option<Self> {
        let mut parts = entry.split_whitespace();

        // Get protocol version
        let protocol = parts.next()?.to_string();

        // Get pseudonym and optional port
        let pseudonym_part = parts.next()?;
        let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
            let (name, port) = pseudonym_part.split_at(colon_idx);
            (name.to_string(), Some(port[1..].to_string()))
        } else {
            (pseudonym_part.to_string(), None)
        };

        // Get optional comment (everything between parentheses)
        let comment = entry
            .find('(')
            .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));

        Some(ViaEntry {
            protocol,
            pseudonym,
            port,
            comment,
        })
    }
}

impl std::fmt::Display for ViaEntry {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{} {}", self.protocol, self.pseudonym)?;
        if let Some(port) = &self.port {
            write!(f, ":{port}")?;
        }
        if let Some(comment) = &self.comment {
            write!(f, " {comment}")?;
        }
        Ok(())
    }
}

/// Parse a Via header value into a vector of ViaEntry structs
#[allow(dead_code)] // Kept for future extensibility
fn parse_via_header(header: &str) -> Vec<ViaEntry> {
    header
        .split(',')
        .filter_map(|entry| ViaEntry::parse(entry.trim()))
        .collect()
}

/// Configuration for RFC9110 middleware
#[derive(Clone, Debug)]
pub struct Rfc9110Config {
    /// Server names to check for loop detection
    pub server_names: Option<HashSet<String>>,
    /// Pseudonym to use in Via headers
    pub pseudonym: Option<String>,
    /// Whether to combine Via headers with the same protocol version
    pub combine_via: bool,
    /// Whether to preserve WebSocket upgrade headers (Connection, Upgrade)
    /// When true, the layer will detect WebSocket upgrade requests and preserve
    /// the hop-by-hop headers needed for the upgrade to work.
    /// Default: true
    pub preserve_websocket_headers: bool,
}

impl Default for Rfc9110Config {
    fn default() -> Self {
        Self {
            server_names: None,
            pseudonym: None,
            combine_via: true,
            preserve_websocket_headers: true,
        }
    }
}

/// Layer that applies RFC9110 middleware
#[derive(Clone)]
pub struct Rfc9110Layer {
    config: Rfc9110Config,
}

impl Default for Rfc9110Layer {
    fn default() -> Self {
        Self::new()
    }
}

impl Rfc9110Layer {
    /// Create a new RFC9110 layer with default configuration
    pub fn new() -> Self {
        Self {
            config: Rfc9110Config::default(),
        }
    }

    /// Create a new RFC9110 layer with custom configuration
    pub fn with_config(config: Rfc9110Config) -> Self {
        Self { config }
    }
}

impl<S> Layer<S> for Rfc9110Layer {
    type Service = Rfc9110<S>;

    fn layer(&self, inner: S) -> Self::Service {
        Rfc9110 {
            inner,
            config: self.config.clone(),
        }
    }
}

/// RFC9110 middleware service
#[derive(Clone)]
pub struct Rfc9110<S> {
    inner: S,
    config: Rfc9110Config,
}

impl<S> Service<Request<Body>> for Rfc9110<S>
where
    S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, mut request: Request<Body>) -> Self::Future {
        let clone = self.inner.clone();
        let mut inner = std::mem::replace(&mut self.inner, clone);
        let config = self.config.clone();

        Box::pin(async move {
            // 1. Check for loops
            if let Some(response) = detect_loop(&request, &config) {
                return Ok(response);
            }

            // Save original Max-Forwards value for non-TRACE/OPTIONS methods
            let original_max_forwards =
                if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
                    request.headers().get(http::header::MAX_FORWARDS).cloned()
                } else {
                    None
                };

            // 2. Process Max-Forwards
            if let Some(response) = process_max_forwards(&mut request) {
                return Ok(response);
            }

            // Save the Max-Forwards value after processing
            let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();

            // Detect WebSocket upgrade request to preserve necessary headers
            let is_websocket =
                config.preserve_websocket_headers && is_websocket_upgrade_request(&request);

            // 3. Process Connection header and remove hop-by-hop headers
            process_connection_header(&mut request, is_websocket);

            // 4. Add Via header and save it for the response
            let via_header = add_via_header(&mut request, &config);

            // 5. Forward the request
            let mut response = inner.call(request).await?;

            // 6. Process response headers (preserve WebSocket headers for 101 responses too)
            let is_websocket_response =
                is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
            process_response_headers(&mut response, is_websocket_response);

            // 7. Add Via header to response (use the same one we set in the request)
            if let Some(via) = via_header {
                // In firewall mode, always use "1.1 firewall"
                if config.pseudonym.is_some() && !config.combine_via {
                    response
                        .headers_mut()
                        .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
                } else {
                    response.headers_mut().insert(http::header::VIA, via);
                }
            }

            // 8. Restore Max-Forwards header
            if let Some(max_forwards) = original_max_forwards {
                // For non-TRACE/OPTIONS methods, restore original value
                response
                    .headers_mut()
                    .insert(http::header::MAX_FORWARDS, max_forwards);
            } else if let Some(max_forwards) = max_forwards {
                // For TRACE/OPTIONS, copy the decremented value to the response
                response
                    .headers_mut()
                    .insert(http::header::MAX_FORWARDS, max_forwards);
            }

            Ok(response)
        })
    }
}

/// Detect request loops based on Via headers and server names
fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
    // 1. Check if the target host matches any of our server names
    if let Some(server_names) = &config.server_names
        && let Some(host) = request.uri().host()
        && server_names.contains(host)
    {
        let mut response = Response::new(Body::empty());
        *response.status_mut() = StatusCode::LOOP_DETECTED;
        return Some(response);
    }

    // 2. Check for loops in Via headers
    if let Some(via) = request.headers().get(http::header::VIA)
        && let Ok(via_str) = via.to_str()
    {
        let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
        let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();

        // Check if our pseudonym appears in any Via header
        for entry in via_entries {
            let parts: Vec<&str> = entry.split_whitespace().collect();
            if parts.len() >= 2 && parts[1] == pseudonym {
                let mut response = Response::new(Body::empty());
                *response.status_mut() = StatusCode::LOOP_DETECTED;
                return Some(response);
            }
        }
    }

    None
}

/// Process Max-Forwards header for TRACE and OPTIONS methods
fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
    let method = request.method();

    // Only process Max-Forwards for TRACE and OPTIONS
    if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
        if *method != Method::TRACE && *method != Method::OPTIONS {
            // For other methods, just preserve the header
            return None;
        }

        if let Ok(value_str) = max_forwards.to_str() {
            if let Ok(value) = value_str.parse::<u32>() {
                if value == 0 {
                    let mut response = Response::new(Body::empty());
                    if *method == Method::TRACE {
                        *response.body_mut() = Body::from(format!("{request:?}"));
                    } else {
                        // For OPTIONS, return 200 OK with Allow header
                        response.headers_mut().insert(
                            http::header::ALLOW,
                            HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
                        );
                    }
                    *response.status_mut() = StatusCode::OK;
                    Some(response)
                } else {
                    // Decrement Max-Forwards
                    let new_value = value - 1;
                    request.headers_mut().insert(
                        http::header::MAX_FORWARDS,
                        HeaderValue::from_str(&new_value.to_string()).unwrap(),
                    );
                    None
                }
            } else {
                None // Invalid number format
            }
        } else {
            None // Invalid header value format
        }
    } else {
        None // No Max-Forwards header
    }
}

/// Headers that should be preserved for WebSocket upgrades
static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];

/// Process Connection header and remove hop-by-hop headers
fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
    let mut headers_to_remove = HashSet::new();

    // Add standard hop-by-hop headers
    for &name in HOP_BY_HOP_HEADERS {
        // Skip connection and upgrade headers if this is a WebSocket upgrade
        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
            continue;
        }
        headers_to_remove.insert(HeaderName::from_static(name));
    }

    // Get headers listed in Connection header
    if let Some(connection) = request
        .headers()
        .get_all(http::header::CONNECTION)
        .iter()
        .next()
        && let Ok(connection_str) = connection.to_str()
    {
        for header in connection_str.split(',') {
            let header = header.trim();
            // Skip connection and upgrade headers if this is a WebSocket upgrade
            if preserve_websocket
                && WEBSOCKET_HEADERS
                    .iter()
                    .any(|h| header.eq_ignore_ascii_case(h))
            {
                continue;
            }
            if let Ok(header_name) = HeaderName::from_str(header)
                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
            {
                headers_to_remove.insert(header_name);
            }
        }
    }

    // Remove all identified headers (case-insensitive)
    let headers_to_remove = headers_to_remove; // Make immutable
    let headers_to_remove: Vec<_> = request
        .headers()
        .iter()
        .filter(|(k, _)| {
            headers_to_remove
                .iter()
                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
        })
        .map(|(k, _)| k.clone())
        .collect();

    for header in headers_to_remove {
        request.headers_mut().remove(&header);
    }
}

/// Add Via header to the request
fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
    // Get the protocol version from the request
    let protocol_version = match request.version() {
        http::Version::HTTP_09 => "0.9",
        http::Version::HTTP_10 => "1.0",
        http::Version::HTTP_11 => "1.1",
        http::Version::HTTP_2 => "2.0",
        http::Version::HTTP_3 => "3.0",
        _ => "1.1", // Default to HTTP/1.1 for unknown versions
    };

    // Get the pseudonym from the config or use the default
    let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");

    // If we're in firewall mode, always use "1.1 firewall"
    if config.pseudonym.is_some() && !config.combine_via {
        let via = HeaderValue::from_static("1.1 firewall");
        request.headers_mut().insert(http::header::VIA, via.clone());
        return Some(via);
    }

    // Get any existing Via headers
    let mut via_values = Vec::new();
    if let Some(existing_via) = request.headers().get(http::header::VIA)
        && let Ok(existing_via_str) = existing_via.to_str()
    {
        // If we're combining Via headers and have a pseudonym, replace all entries with our protocol version
        if config.combine_via && config.pseudonym.is_some() {
            let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
            let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
            if all_same_protocol {
                let via = HeaderValue::from_str(&format!(
                    "{} {}",
                    protocol_version,
                    config.pseudonym.as_ref().unwrap()
                ))
                .ok()?;
                request.headers_mut().insert(http::header::VIA, via.clone());
                return Some(via);
            }
        }
        via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
    }

    // Add our new Via header value
    let new_value = format!("{protocol_version} {pseudonym}");
    via_values.push(new_value);

    // Create the combined Via header value
    let combined_via = via_values.join(", ");
    let via = HeaderValue::from_str(&combined_via).ok()?;
    request.headers_mut().insert(http::header::VIA, via.clone());
    Some(via)
}

/// Process response headers according to RFC9110
fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
    let mut headers_to_remove = HashSet::new();

    // Add standard hop-by-hop headers
    for &name in HOP_BY_HOP_HEADERS {
        // Skip connection and upgrade headers if this is a WebSocket upgrade response
        if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
            continue;
        }
        headers_to_remove.insert(HeaderName::from_static(name));
    }

    // Get headers listed in Connection header
    if let Some(connection) = response
        .headers()
        .get_all(http::header::CONNECTION)
        .iter()
        .next()
        && let Ok(connection_str) = connection.to_str()
    {
        for header in connection_str.split(',') {
            let header = header.trim();
            // Skip connection and upgrade headers if this is a WebSocket upgrade response
            if preserve_websocket
                && WEBSOCKET_HEADERS
                    .iter()
                    .any(|h| header.eq_ignore_ascii_case(h))
            {
                continue;
            }
            if let Ok(header_name) = HeaderName::from_str(header)
                && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
            {
                headers_to_remove.insert(header_name);
            }
        }
    }

    // Remove all identified headers (case-insensitive)
    let headers_to_remove = headers_to_remove; // Make immutable
    let headers_to_remove: Vec<_> = response
        .headers()
        .iter()
        .filter(|(k, _)| {
            headers_to_remove
                .iter()
                .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
        })
        .map(|(k, _)| k.clone())
        .collect();

    for header in headers_to_remove {
        response.headers_mut().remove(&header);
    }

    // Handle Via header in response - if in firewall mode, replace all entries with "1.1 firewall"
    if let Some(via) = response.headers().get(http::header::VIA)
        && let Ok(via_str) = via.to_str()
        && via_str.contains("firewall")
    {
        response
            .headers_mut()
            .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
    }
}

/// Check if a header is a hop-by-hop header
fn is_hop_by_hop_header(name: &HeaderName) -> bool {
    HOP_BY_HOP_HEADERS
        .iter()
        .any(|h| name.as_str().eq_ignore_ascii_case(h))
        || name.as_str().eq_ignore_ascii_case("via")
}

/// Check if a header is a known end-to-end header
fn is_end_to_end_header(name: &HeaderName) -> bool {
    matches!(
        name.as_str(),
        "cache-control"
            | "authorization"
            | "content-length"
            | "content-type"
            | "content-encoding"
            | "accept"
            | "accept-encoding"
            | "accept-language"
            | "range"
            | "cookie"
            | "set-cookie"
            | "etag"
    )
}

/// Check if a request appears to be a WebSocket upgrade request.
/// This is detected by the presence of sec-websocket-key and sec-websocket-version headers,
/// which are required for all WebSocket handshakes per RFC 6455.
fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
    request.headers().contains_key("sec-websocket-key")
        && request.headers().contains_key("sec-websocket-version")
}