Skip to main content

fastapi_http/
connection.rs

1//! HTTP Connection header handling.
2//!
3//! This module provides proper parsing and handling of the HTTP `Connection` header
4//! per RFC 7230, including:
5//!
6//! - Parsing comma-separated connection tokens
7//! - Handling `close`, `keep-alive`, and `upgrade` directives
8//! - Extracting hop-by-hop header names for stripping
9//! - HTTP version-aware default behavior
10//!
11//! # Connection Header Semantics
12//!
13//! The Connection header is a comma-separated list of tokens. Each token is either:
14//! - A connection option (`close`, `keep-alive`, `upgrade`)
15//! - The name of a hop-by-hop header field to be stripped when forwarding
16//!
17//! # Example
18//!
19//! ```ignore
20//! use fastapi_http::connection::{ConnectionInfo, parse_connection_header};
21//!
22//! let info = parse_connection_header(Some(b"keep-alive, X-Custom-Header"));
23//! assert!(info.keep_alive);
24//! assert!(info.hop_by_hop_headers.contains(&"x-custom-header".to_string()));
25//! ```
26
27use fastapi_core::{HttpVersion, Request};
28
29/// Standard hop-by-hop headers that should always be stripped when forwarding.
30///
31/// These headers are connection-specific and must not be forwarded by proxies,
32/// regardless of whether they appear in the Connection header.
33pub const STANDARD_HOP_BY_HOP_HEADERS: &[&str] = &[
34    "connection",
35    "keep-alive",
36    "proxy-authenticate",
37    "proxy-authorization",
38    "te",
39    "trailer",
40    "transfer-encoding",
41    "upgrade",
42];
43
44/// Parsed Connection header information.
45#[derive(Debug, Clone, Default)]
46pub struct ConnectionInfo {
47    /// Whether `close` token was present.
48    pub close: bool,
49    /// Whether `keep-alive` token was present.
50    pub keep_alive: bool,
51    /// Whether `upgrade` token was present.
52    pub upgrade: bool,
53    /// Hop-by-hop header names to strip (lowercased).
54    ///
55    /// These are header field names that appeared in the Connection header
56    /// and should be removed when forwarding the message.
57    pub hop_by_hop_headers: Vec<String>,
58}
59
60impl ConnectionInfo {
61    /// Creates an empty ConnectionInfo.
62    #[must_use]
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Parses Connection header value(s).
68    ///
69    /// The value should be a comma-separated list of tokens. Tokens are
70    /// case-insensitive and whitespace around commas is ignored.
71    #[must_use]
72    pub fn parse(value: &[u8]) -> Self {
73        let mut info = Self::new();
74
75        let value_str = match std::str::from_utf8(value) {
76            Ok(s) => s,
77            Err(_) => return info,
78        };
79
80        for token in value_str.split(',') {
81            let token = token.trim();
82            if token.is_empty() {
83                continue;
84            }
85
86            // Case-insensitive match without allocation for known tokens
87            if token.eq_ignore_ascii_case("close") {
88                info.close = true;
89            } else if token.eq_ignore_ascii_case("keep-alive") {
90                info.keep_alive = true;
91            } else if token.eq_ignore_ascii_case("upgrade") {
92                info.upgrade = true;
93            } else {
94                // Only allocate for custom hop-by-hop headers (rare path)
95                let lower = token.to_ascii_lowercase();
96                // Don't add standard hop-by-hop headers again
97                if !STANDARD_HOP_BY_HOP_HEADERS.contains(&lower.as_str()) {
98                    info.hop_by_hop_headers.push(lower);
99                }
100            }
101        }
102
103        info
104    }
105
106    /// Returns whether the connection should be kept alive based on HTTP version.
107    ///
108    /// - HTTP/1.1: defaults to keep-alive unless `close` is present
109    /// - HTTP/1.0: defaults to close unless `keep-alive` is present
110    #[must_use]
111    pub fn should_keep_alive(&self, version: HttpVersion) -> bool {
112        // Explicit close always wins
113        if self.close {
114            return false;
115        }
116
117        // Explicit keep-alive always wins
118        if self.keep_alive {
119            return true;
120        }
121
122        // Default behavior based on HTTP version
123        match version {
124            HttpVersion::Http11 => true,  // HTTP/1.1 defaults to keep-alive
125            HttpVersion::Http10 => false, // HTTP/1.0 defaults to close
126        }
127    }
128}
129
130/// Parses the Connection header from a request and returns connection info.
131///
132/// # Arguments
133///
134/// * `value` - The raw Connection header value, or None if header is missing
135///
136/// # Returns
137///
138/// Parsed ConnectionInfo with all directives and hop-by-hop header names.
139#[must_use]
140pub fn parse_connection_header(value: Option<&[u8]>) -> ConnectionInfo {
141    match value {
142        Some(v) => ConnectionInfo::parse(v),
143        None => ConnectionInfo::new(),
144    }
145}
146
147/// Determines if a connection should be kept alive based on request headers and version.
148///
149/// This is a convenience function that combines Connection header parsing with
150/// HTTP version-aware keep-alive logic.
151///
152/// # Arguments
153///
154/// * `request` - The HTTP request to check
155///
156/// # Returns
157///
158/// `true` if the connection should be kept alive, `false` otherwise.
159///
160/// # Behavior
161///
162/// - HTTP/1.1 defaults to keep-alive unless `Connection: close` is present
163/// - HTTP/1.0 requires explicit `Connection: keep-alive` to stay open
164/// - `Connection: close` always closes the connection
165/// - `Connection: keep-alive` always keeps the connection open
166#[must_use]
167pub fn should_keep_alive(request: &Request) -> bool {
168    let connection = request.headers().get("connection");
169    let info = parse_connection_header(connection);
170    info.should_keep_alive(request.version())
171}
172
173/// Strip hop-by-hop headers from a request.
174///
175/// Removes both standard hop-by-hop headers and any headers listed in the
176/// Connection header from the request.
177///
178/// # Arguments
179///
180/// * `request` - The request to modify
181///
182/// This is typically used when forwarding requests through a proxy or gateway.
183pub fn strip_hop_by_hop_headers(request: &mut Request) {
184    // Parse Connection header to find custom hop-by-hop headers
185    let connection = request.headers().get("connection").map(<[u8]>::to_vec);
186    let info = parse_connection_header(connection.as_deref());
187
188    // Remove standard hop-by-hop headers
189    for header in STANDARD_HOP_BY_HOP_HEADERS {
190        request.headers_mut().remove(header);
191    }
192
193    // Remove custom hop-by-hop headers listed in Connection
194    for header in &info.hop_by_hop_headers {
195        request.headers_mut().remove(header);
196    }
197}
198
199/// Check if a header name is a hop-by-hop header.
200///
201/// Returns true if the header is in the standard hop-by-hop list.
202/// Note: This doesn't check if it was listed in the Connection header.
203#[must_use]
204pub fn is_standard_hop_by_hop_header(name: &str) -> bool {
205    // Case-insensitive comparison without allocation
206    STANDARD_HOP_BY_HOP_HEADERS
207        .iter()
208        .any(|&h| name.eq_ignore_ascii_case(h))
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use fastapi_core::Method;
215
216    #[test]
217    fn connection_info_parse_close() {
218        let info = ConnectionInfo::parse(b"close");
219        assert!(info.close);
220        assert!(!info.keep_alive);
221        assert!(!info.upgrade);
222        assert!(info.hop_by_hop_headers.is_empty());
223    }
224
225    #[test]
226    fn connection_info_parse_keep_alive() {
227        let info = ConnectionInfo::parse(b"keep-alive");
228        assert!(!info.close);
229        assert!(info.keep_alive);
230        assert!(!info.upgrade);
231    }
232
233    #[test]
234    fn connection_info_parse_upgrade() {
235        let info = ConnectionInfo::parse(b"upgrade");
236        assert!(!info.close);
237        assert!(!info.keep_alive);
238        assert!(info.upgrade);
239    }
240
241    #[test]
242    fn connection_info_parse_multiple_tokens() {
243        let info = ConnectionInfo::parse(b"keep-alive, upgrade");
244        assert!(!info.close);
245        assert!(info.keep_alive);
246        assert!(info.upgrade);
247    }
248
249    #[test]
250    fn connection_info_parse_with_custom_headers() {
251        let info = ConnectionInfo::parse(b"keep-alive, X-Custom-Header, X-Another");
252        assert!(info.keep_alive);
253        assert_eq!(info.hop_by_hop_headers.len(), 2);
254        assert!(
255            info.hop_by_hop_headers
256                .contains(&"x-custom-header".to_string())
257        );
258        assert!(info.hop_by_hop_headers.contains(&"x-another".to_string()));
259    }
260
261    #[test]
262    fn connection_info_parse_case_insensitive() {
263        let info = ConnectionInfo::parse(b"CLOSE");
264        assert!(info.close);
265
266        let info = ConnectionInfo::parse(b"Keep-Alive");
267        assert!(info.keep_alive);
268
269        let info = ConnectionInfo::parse(b"UPGRADE");
270        assert!(info.upgrade);
271    }
272
273    #[test]
274    fn connection_info_parse_with_whitespace() {
275        let info = ConnectionInfo::parse(b"  keep-alive  ,  close  ");
276        assert!(info.close);
277        assert!(info.keep_alive);
278    }
279
280    #[test]
281    fn connection_info_parse_empty() {
282        let info = ConnectionInfo::parse(b"");
283        assert!(!info.close);
284        assert!(!info.keep_alive);
285        assert!(!info.upgrade);
286        assert!(info.hop_by_hop_headers.is_empty());
287    }
288
289    #[test]
290    fn connection_info_parse_invalid_utf8() {
291        let info = ConnectionInfo::parse(&[0xFF, 0xFE]);
292        assert!(!info.close);
293        assert!(!info.keep_alive);
294    }
295
296    #[test]
297    fn should_keep_alive_http11_default() {
298        let info = ConnectionInfo::new();
299        assert!(info.should_keep_alive(HttpVersion::Http11));
300    }
301
302    #[test]
303    fn should_keep_alive_http10_default() {
304        let info = ConnectionInfo::new();
305        assert!(!info.should_keep_alive(HttpVersion::Http10));
306    }
307
308    #[test]
309    fn should_keep_alive_http11_with_close() {
310        let info = ConnectionInfo::parse(b"close");
311        assert!(!info.should_keep_alive(HttpVersion::Http11));
312    }
313
314    #[test]
315    fn should_keep_alive_http10_with_keep_alive() {
316        let info = ConnectionInfo::parse(b"keep-alive");
317        assert!(info.should_keep_alive(HttpVersion::Http10));
318    }
319
320    #[test]
321    fn should_keep_alive_close_overrides_keep_alive() {
322        // When both are present, close wins
323        let info = ConnectionInfo::parse(b"keep-alive, close");
324        assert!(!info.should_keep_alive(HttpVersion::Http11));
325        assert!(!info.should_keep_alive(HttpVersion::Http10));
326    }
327
328    #[test]
329    fn should_keep_alive_request_http11_default() {
330        let request = Request::with_version(Method::Get, "/", HttpVersion::Http11);
331        assert!(should_keep_alive(&request));
332    }
333
334    #[test]
335    fn should_keep_alive_request_http10_default() {
336        let request = Request::with_version(Method::Get, "/", HttpVersion::Http10);
337        assert!(!should_keep_alive(&request));
338    }
339
340    #[test]
341    fn should_keep_alive_request_with_close_header() {
342        let mut request = Request::with_version(Method::Get, "/", HttpVersion::Http11);
343        request
344            .headers_mut()
345            .insert("connection", b"close".to_vec());
346        assert!(!should_keep_alive(&request));
347    }
348
349    #[test]
350    fn should_keep_alive_request_http10_with_keep_alive() {
351        let mut request = Request::with_version(Method::Get, "/", HttpVersion::Http10);
352        request
353            .headers_mut()
354            .insert("connection", b"keep-alive".to_vec());
355        assert!(should_keep_alive(&request));
356    }
357
358    #[test]
359    fn strip_hop_by_hop_headers_removes_standard() {
360        let mut request = Request::new(Method::Get, "/");
361        request
362            .headers_mut()
363            .insert("connection", b"close".to_vec());
364        request
365            .headers_mut()
366            .insert("keep-alive", b"timeout=5".to_vec());
367        request
368            .headers_mut()
369            .insert("transfer-encoding", b"chunked".to_vec());
370        request
371            .headers_mut()
372            .insert("host", b"example.com".to_vec());
373
374        strip_hop_by_hop_headers(&mut request);
375
376        assert!(request.headers().get("connection").is_none());
377        assert!(request.headers().get("keep-alive").is_none());
378        assert!(request.headers().get("transfer-encoding").is_none());
379        // Non-hop-by-hop headers should remain
380        assert!(request.headers().get("host").is_some());
381    }
382
383    #[test]
384    fn strip_hop_by_hop_headers_removes_custom() {
385        let mut request = Request::new(Method::Get, "/");
386        request
387            .headers_mut()
388            .insert("connection", b"X-Custom-Header".to_vec());
389        request
390            .headers_mut()
391            .insert("x-custom-header", b"value".to_vec());
392        request
393            .headers_mut()
394            .insert("host", b"example.com".to_vec());
395
396        strip_hop_by_hop_headers(&mut request);
397
398        assert!(request.headers().get("x-custom-header").is_none());
399        assert!(request.headers().get("host").is_some());
400    }
401
402    #[test]
403    fn is_standard_hop_by_hop_header_works() {
404        assert!(is_standard_hop_by_hop_header("connection"));
405        assert!(is_standard_hop_by_hop_header("Connection"));
406        assert!(is_standard_hop_by_hop_header("KEEP-ALIVE"));
407        assert!(is_standard_hop_by_hop_header("transfer-encoding"));
408
409        assert!(!is_standard_hop_by_hop_header("host"));
410        assert!(!is_standard_hop_by_hop_header("content-type"));
411        assert!(!is_standard_hop_by_hop_header("x-custom"));
412    }
413
414    #[test]
415    fn standard_hop_by_hop_not_duplicated_in_custom() {
416        // Standard headers listed in Connection shouldn't appear in hop_by_hop_headers
417        let info = ConnectionInfo::parse(b"keep-alive, transfer-encoding, X-Custom");
418        assert_eq!(info.hop_by_hop_headers.len(), 1);
419        assert!(info.hop_by_hop_headers.contains(&"x-custom".to_string()));
420    }
421}