fastapi_http/
connection.rs1use fastapi_core::{HttpVersion, Request};
28
29pub const STANDARD_HOP_BY_HOP_HEADERS: &[&str] = &[
34 "connection",
35 "keep-alive",
36 "proxy-authenticate",
37 "proxy-authorization",
38 "te",
39 "trailer",
40 "transfer-encoding",
41 "upgrade",
42];
43
44#[derive(Debug, Clone, Default)]
46pub struct ConnectionInfo {
47 pub close: bool,
49 pub keep_alive: bool,
51 pub upgrade: bool,
53 pub hop_by_hop_headers: Vec<String>,
58}
59
60impl ConnectionInfo {
61 #[must_use]
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 #[must_use]
72 pub fn parse(value: &[u8]) -> Self {
73 let mut info = Self::new();
74
75 let value_str = match std::str::from_utf8(value) {
76 Ok(s) => s,
77 Err(_) => return info,
78 };
79
80 for token in value_str.split(',') {
81 let token = token.trim();
82 if token.is_empty() {
83 continue;
84 }
85
86 if token.eq_ignore_ascii_case("close") {
88 info.close = true;
89 } else if token.eq_ignore_ascii_case("keep-alive") {
90 info.keep_alive = true;
91 } else if token.eq_ignore_ascii_case("upgrade") {
92 info.upgrade = true;
93 } else {
94 let lower = token.to_ascii_lowercase();
96 if !STANDARD_HOP_BY_HOP_HEADERS.contains(&lower.as_str()) {
98 info.hop_by_hop_headers.push(lower);
99 }
100 }
101 }
102
103 info
104 }
105
106 #[must_use]
111 pub fn should_keep_alive(&self, version: HttpVersion) -> bool {
112 if self.close {
114 return false;
115 }
116
117 if self.keep_alive {
119 return true;
120 }
121
122 match version {
124 HttpVersion::Http11 => true, HttpVersion::Http10 => false, }
127 }
128}
129
130#[must_use]
140pub fn parse_connection_header(value: Option<&[u8]>) -> ConnectionInfo {
141 match value {
142 Some(v) => ConnectionInfo::parse(v),
143 None => ConnectionInfo::new(),
144 }
145}
146
147#[must_use]
167pub fn should_keep_alive(request: &Request) -> bool {
168 let connection = request.headers().get("connection");
169 let info = parse_connection_header(connection);
170 info.should_keep_alive(request.version())
171}
172
173pub fn strip_hop_by_hop_headers(request: &mut Request) {
184 let connection = request.headers().get("connection").map(<[u8]>::to_vec);
186 let info = parse_connection_header(connection.as_deref());
187
188 for header in STANDARD_HOP_BY_HOP_HEADERS {
190 request.headers_mut().remove(header);
191 }
192
193 for header in &info.hop_by_hop_headers {
195 request.headers_mut().remove(header);
196 }
197}
198
199#[must_use]
204pub fn is_standard_hop_by_hop_header(name: &str) -> bool {
205 STANDARD_HOP_BY_HOP_HEADERS
207 .iter()
208 .any(|&h| name.eq_ignore_ascii_case(h))
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use fastapi_core::Method;
215
216 #[test]
217 fn connection_info_parse_close() {
218 let info = ConnectionInfo::parse(b"close");
219 assert!(info.close);
220 assert!(!info.keep_alive);
221 assert!(!info.upgrade);
222 assert!(info.hop_by_hop_headers.is_empty());
223 }
224
225 #[test]
226 fn connection_info_parse_keep_alive() {
227 let info = ConnectionInfo::parse(b"keep-alive");
228 assert!(!info.close);
229 assert!(info.keep_alive);
230 assert!(!info.upgrade);
231 }
232
233 #[test]
234 fn connection_info_parse_upgrade() {
235 let info = ConnectionInfo::parse(b"upgrade");
236 assert!(!info.close);
237 assert!(!info.keep_alive);
238 assert!(info.upgrade);
239 }
240
241 #[test]
242 fn connection_info_parse_multiple_tokens() {
243 let info = ConnectionInfo::parse(b"keep-alive, upgrade");
244 assert!(!info.close);
245 assert!(info.keep_alive);
246 assert!(info.upgrade);
247 }
248
249 #[test]
250 fn connection_info_parse_with_custom_headers() {
251 let info = ConnectionInfo::parse(b"keep-alive, X-Custom-Header, X-Another");
252 assert!(info.keep_alive);
253 assert_eq!(info.hop_by_hop_headers.len(), 2);
254 assert!(
255 info.hop_by_hop_headers
256 .contains(&"x-custom-header".to_string())
257 );
258 assert!(info.hop_by_hop_headers.contains(&"x-another".to_string()));
259 }
260
261 #[test]
262 fn connection_info_parse_case_insensitive() {
263 let info = ConnectionInfo::parse(b"CLOSE");
264 assert!(info.close);
265
266 let info = ConnectionInfo::parse(b"Keep-Alive");
267 assert!(info.keep_alive);
268
269 let info = ConnectionInfo::parse(b"UPGRADE");
270 assert!(info.upgrade);
271 }
272
273 #[test]
274 fn connection_info_parse_with_whitespace() {
275 let info = ConnectionInfo::parse(b" keep-alive , close ");
276 assert!(info.close);
277 assert!(info.keep_alive);
278 }
279
280 #[test]
281 fn connection_info_parse_empty() {
282 let info = ConnectionInfo::parse(b"");
283 assert!(!info.close);
284 assert!(!info.keep_alive);
285 assert!(!info.upgrade);
286 assert!(info.hop_by_hop_headers.is_empty());
287 }
288
289 #[test]
290 fn connection_info_parse_invalid_utf8() {
291 let info = ConnectionInfo::parse(&[0xFF, 0xFE]);
292 assert!(!info.close);
293 assert!(!info.keep_alive);
294 }
295
296 #[test]
297 fn should_keep_alive_http11_default() {
298 let info = ConnectionInfo::new();
299 assert!(info.should_keep_alive(HttpVersion::Http11));
300 }
301
302 #[test]
303 fn should_keep_alive_http10_default() {
304 let info = ConnectionInfo::new();
305 assert!(!info.should_keep_alive(HttpVersion::Http10));
306 }
307
308 #[test]
309 fn should_keep_alive_http11_with_close() {
310 let info = ConnectionInfo::parse(b"close");
311 assert!(!info.should_keep_alive(HttpVersion::Http11));
312 }
313
314 #[test]
315 fn should_keep_alive_http10_with_keep_alive() {
316 let info = ConnectionInfo::parse(b"keep-alive");
317 assert!(info.should_keep_alive(HttpVersion::Http10));
318 }
319
320 #[test]
321 fn should_keep_alive_close_overrides_keep_alive() {
322 let info = ConnectionInfo::parse(b"keep-alive, close");
324 assert!(!info.should_keep_alive(HttpVersion::Http11));
325 assert!(!info.should_keep_alive(HttpVersion::Http10));
326 }
327
328 #[test]
329 fn should_keep_alive_request_http11_default() {
330 let request = Request::with_version(Method::Get, "/", HttpVersion::Http11);
331 assert!(should_keep_alive(&request));
332 }
333
334 #[test]
335 fn should_keep_alive_request_http10_default() {
336 let request = Request::with_version(Method::Get, "/", HttpVersion::Http10);
337 assert!(!should_keep_alive(&request));
338 }
339
340 #[test]
341 fn should_keep_alive_request_with_close_header() {
342 let mut request = Request::with_version(Method::Get, "/", HttpVersion::Http11);
343 request
344 .headers_mut()
345 .insert("connection", b"close".to_vec());
346 assert!(!should_keep_alive(&request));
347 }
348
349 #[test]
350 fn should_keep_alive_request_http10_with_keep_alive() {
351 let mut request = Request::with_version(Method::Get, "/", HttpVersion::Http10);
352 request
353 .headers_mut()
354 .insert("connection", b"keep-alive".to_vec());
355 assert!(should_keep_alive(&request));
356 }
357
358 #[test]
359 fn strip_hop_by_hop_headers_removes_standard() {
360 let mut request = Request::new(Method::Get, "/");
361 request
362 .headers_mut()
363 .insert("connection", b"close".to_vec());
364 request
365 .headers_mut()
366 .insert("keep-alive", b"timeout=5".to_vec());
367 request
368 .headers_mut()
369 .insert("transfer-encoding", b"chunked".to_vec());
370 request
371 .headers_mut()
372 .insert("host", b"example.com".to_vec());
373
374 strip_hop_by_hop_headers(&mut request);
375
376 assert!(request.headers().get("connection").is_none());
377 assert!(request.headers().get("keep-alive").is_none());
378 assert!(request.headers().get("transfer-encoding").is_none());
379 assert!(request.headers().get("host").is_some());
381 }
382
383 #[test]
384 fn strip_hop_by_hop_headers_removes_custom() {
385 let mut request = Request::new(Method::Get, "/");
386 request
387 .headers_mut()
388 .insert("connection", b"X-Custom-Header".to_vec());
389 request
390 .headers_mut()
391 .insert("x-custom-header", b"value".to_vec());
392 request
393 .headers_mut()
394 .insert("host", b"example.com".to_vec());
395
396 strip_hop_by_hop_headers(&mut request);
397
398 assert!(request.headers().get("x-custom-header").is_none());
399 assert!(request.headers().get("host").is_some());
400 }
401
402 #[test]
403 fn is_standard_hop_by_hop_header_works() {
404 assert!(is_standard_hop_by_hop_header("connection"));
405 assert!(is_standard_hop_by_hop_header("Connection"));
406 assert!(is_standard_hop_by_hop_header("KEEP-ALIVE"));
407 assert!(is_standard_hop_by_hop_header("transfer-encoding"));
408
409 assert!(!is_standard_hop_by_hop_header("host"));
410 assert!(!is_standard_hop_by_hop_header("content-type"));
411 assert!(!is_standard_hop_by_hop_header("x-custom"));
412 }
413
414 #[test]
415 fn standard_hop_by_hop_not_duplicated_in_custom() {
416 let info = ConnectionInfo::parse(b"keep-alive, transfer-encoding, X-Custom");
418 assert_eq!(info.hop_by_hop_headers.len(), 1);
419 assert!(info.hop_by_hop_headers.contains(&"x-custom".to_string()));
420 }
421}