fastapi_http/
connection.rs1use fastapi_core::{HttpVersion, Request};
28
29pub const STANDARD_HOP_BY_HOP_HEADERS: &[&str] = &[
34 "connection",
35 "keep-alive",
36 "proxy-authenticate",
37 "proxy-authorization",
38 "te",
39 "trailer",
40 "transfer-encoding",
41 "upgrade",
42];
43
44#[derive(Debug, Clone, Default)]
46pub struct ConnectionInfo {
47 pub close: bool,
49 pub keep_alive: bool,
51 pub upgrade: bool,
53 pub hop_by_hop_headers: Vec<String>,
58}
59
60impl ConnectionInfo {
61 #[must_use]
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 #[must_use]
72 pub fn parse(value: &[u8]) -> Self {
73 let mut info = Self::new();
74
75 let value_str = match std::str::from_utf8(value) {
76 Ok(s) => s,
77 Err(_) => return info,
78 };
79
80 for part in value_str.split(',') {
81 let part = part.trim();
82 if part.is_empty() {
83 continue;
84 }
85
86 if part.eq_ignore_ascii_case("close") {
88 info.close = true;
89 } else if part.eq_ignore_ascii_case("keep-alive") {
90 info.keep_alive = true;
91 } else if part.eq_ignore_ascii_case("upgrade") {
92 info.upgrade = true;
93 } else {
94 let lower = part.to_ascii_lowercase();
96 if !STANDARD_HOP_BY_HOP_HEADERS.contains(&lower.as_str()) {
98 info.hop_by_hop_headers.push(lower);
99 }
100 }
101 }
102
103 info
104 }
105
106 #[must_use]
111 pub fn should_keep_alive(&self, version: HttpVersion) -> bool {
112 if self.close {
114 return false;
115 }
116
117 if self.keep_alive {
119 return true;
120 }
121
122 match version {
124 HttpVersion::Http11 => true, HttpVersion::Http10 => false, HttpVersion::Http2 => true, }
128 }
129}
130
131#[must_use]
141pub fn parse_connection_header(value: Option<&[u8]>) -> ConnectionInfo {
142 match value {
143 Some(v) => ConnectionInfo::parse(v),
144 None => ConnectionInfo::new(),
145 }
146}
147
148#[must_use]
168pub fn should_keep_alive(request: &Request) -> bool {
169 let connection = request.headers().get("connection");
170 let info = parse_connection_header(connection);
171 info.should_keep_alive(request.version())
172}
173
174pub fn strip_hop_by_hop_headers(request: &mut Request) {
185 let connection = request.headers().get("connection").map(<[u8]>::to_vec);
187 let info = parse_connection_header(connection.as_deref());
188
189 for header in STANDARD_HOP_BY_HOP_HEADERS {
191 request.headers_mut().remove(header);
192 }
193
194 for header in &info.hop_by_hop_headers {
196 request.headers_mut().remove(header);
197 }
198}
199
200#[must_use]
205pub fn is_standard_hop_by_hop_header(name: &str) -> bool {
206 STANDARD_HOP_BY_HOP_HEADERS
208 .iter()
209 .any(|&h| name.eq_ignore_ascii_case(h))
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use fastapi_core::Method;
216
217 #[test]
218 fn connection_info_parse_close() {
219 let info = ConnectionInfo::parse(b"close");
220 assert!(info.close);
221 assert!(!info.keep_alive);
222 assert!(!info.upgrade);
223 assert!(info.hop_by_hop_headers.is_empty());
224 }
225
226 #[test]
227 fn connection_info_parse_keep_alive() {
228 let info = ConnectionInfo::parse(b"keep-alive");
229 assert!(!info.close);
230 assert!(info.keep_alive);
231 assert!(!info.upgrade);
232 }
233
234 #[test]
235 fn connection_info_parse_upgrade() {
236 let info = ConnectionInfo::parse(b"upgrade");
237 assert!(!info.close);
238 assert!(!info.keep_alive);
239 assert!(info.upgrade);
240 }
241
242 #[test]
243 fn connection_info_parse_multiple_tokens() {
244 let info = ConnectionInfo::parse(b"keep-alive, upgrade");
245 assert!(!info.close);
246 assert!(info.keep_alive);
247 assert!(info.upgrade);
248 }
249
250 #[test]
251 fn connection_info_parse_with_custom_headers() {
252 let info = ConnectionInfo::parse(b"keep-alive, X-Custom-Header, X-Another");
253 assert!(info.keep_alive);
254 assert_eq!(info.hop_by_hop_headers.len(), 2);
255 assert!(
256 info.hop_by_hop_headers
257 .contains(&"x-custom-header".to_string())
258 );
259 assert!(info.hop_by_hop_headers.contains(&"x-another".to_string()));
260 }
261
262 #[test]
263 fn connection_info_parse_case_insensitive() {
264 let info = ConnectionInfo::parse(b"CLOSE");
265 assert!(info.close);
266
267 let info = ConnectionInfo::parse(b"Keep-Alive");
268 assert!(info.keep_alive);
269
270 let info = ConnectionInfo::parse(b"UPGRADE");
271 assert!(info.upgrade);
272 }
273
274 #[test]
275 fn connection_info_parse_with_whitespace() {
276 let info = ConnectionInfo::parse(b" keep-alive , close ");
277 assert!(info.close);
278 assert!(info.keep_alive);
279 }
280
281 #[test]
282 fn connection_info_parse_empty() {
283 let info = ConnectionInfo::parse(b"");
284 assert!(!info.close);
285 assert!(!info.keep_alive);
286 assert!(!info.upgrade);
287 assert!(info.hop_by_hop_headers.is_empty());
288 }
289
290 #[test]
291 fn connection_info_parse_invalid_utf8() {
292 let info = ConnectionInfo::parse(&[0xFF, 0xFE]);
293 assert!(!info.close);
294 assert!(!info.keep_alive);
295 }
296
297 #[test]
298 fn should_keep_alive_http11_default() {
299 let info = ConnectionInfo::new();
300 assert!(info.should_keep_alive(HttpVersion::Http11));
301 }
302
303 #[test]
304 fn should_keep_alive_http10_default() {
305 let info = ConnectionInfo::new();
306 assert!(!info.should_keep_alive(HttpVersion::Http10));
307 }
308
309 #[test]
310 fn should_keep_alive_http11_with_close() {
311 let info = ConnectionInfo::parse(b"close");
312 assert!(!info.should_keep_alive(HttpVersion::Http11));
313 }
314
315 #[test]
316 fn should_keep_alive_http10_with_keep_alive() {
317 let info = ConnectionInfo::parse(b"keep-alive");
318 assert!(info.should_keep_alive(HttpVersion::Http10));
319 }
320
321 #[test]
322 fn should_keep_alive_close_overrides_keep_alive() {
323 let info = ConnectionInfo::parse(b"keep-alive, close");
325 assert!(!info.should_keep_alive(HttpVersion::Http11));
326 assert!(!info.should_keep_alive(HttpVersion::Http10));
327 }
328
329 #[test]
330 fn should_keep_alive_request_http11_default() {
331 let request = Request::with_version(Method::Get, "/", HttpVersion::Http11);
332 assert!(should_keep_alive(&request));
333 }
334
335 #[test]
336 fn should_keep_alive_request_http10_default() {
337 let request = Request::with_version(Method::Get, "/", HttpVersion::Http10);
338 assert!(!should_keep_alive(&request));
339 }
340
341 #[test]
342 fn should_keep_alive_request_with_close_header() {
343 let mut request = Request::with_version(Method::Get, "/", HttpVersion::Http11);
344 request
345 .headers_mut()
346 .insert("connection", b"close".to_vec());
347 assert!(!should_keep_alive(&request));
348 }
349
350 #[test]
351 fn should_keep_alive_request_http10_with_keep_alive() {
352 let mut request = Request::with_version(Method::Get, "/", HttpVersion::Http10);
353 request
354 .headers_mut()
355 .insert("connection", b"keep-alive".to_vec());
356 assert!(should_keep_alive(&request));
357 }
358
359 #[test]
360 fn strip_hop_by_hop_headers_removes_standard() {
361 let mut request = Request::new(Method::Get, "/");
362 request
363 .headers_mut()
364 .insert("connection", b"close".to_vec());
365 request
366 .headers_mut()
367 .insert("keep-alive", b"timeout=5".to_vec());
368 request
369 .headers_mut()
370 .insert("transfer-encoding", b"chunked".to_vec());
371 request
372 .headers_mut()
373 .insert("host", b"example.com".to_vec());
374
375 strip_hop_by_hop_headers(&mut request);
376
377 assert!(request.headers().get("connection").is_none());
378 assert!(request.headers().get("keep-alive").is_none());
379 assert!(request.headers().get("transfer-encoding").is_none());
380 assert!(request.headers().get("host").is_some());
382 }
383
384 #[test]
385 fn strip_hop_by_hop_headers_removes_custom() {
386 let mut request = Request::new(Method::Get, "/");
387 request
388 .headers_mut()
389 .insert("connection", b"X-Custom-Header".to_vec());
390 request
391 .headers_mut()
392 .insert("x-custom-header", b"value".to_vec());
393 request
394 .headers_mut()
395 .insert("host", b"example.com".to_vec());
396
397 strip_hop_by_hop_headers(&mut request);
398
399 assert!(request.headers().get("x-custom-header").is_none());
400 assert!(request.headers().get("host").is_some());
401 }
402
403 #[test]
404 fn is_standard_hop_by_hop_header_works() {
405 assert!(is_standard_hop_by_hop_header("connection"));
406 assert!(is_standard_hop_by_hop_header("Connection"));
407 assert!(is_standard_hop_by_hop_header("KEEP-ALIVE"));
408 assert!(is_standard_hop_by_hop_header("transfer-encoding"));
409
410 assert!(!is_standard_hop_by_hop_header("host"));
411 assert!(!is_standard_hop_by_hop_header("content-type"));
412 assert!(!is_standard_hop_by_hop_header("x-custom"));
413 }
414
415 #[test]
416 fn standard_hop_by_hop_not_duplicated_in_custom() {
417 let info = ConnectionInfo::parse(b"keep-alive, transfer-encoding, X-Custom");
419 assert_eq!(info.hop_by_hop_headers.len(), 1);
420 assert!(info.hop_by_hop_headers.contains(&"x-custom".to_string()));
421 }
422}