1#[cfg(test)]
2use std::collections::BTreeMap;
3use std::error::Error;
4use std::fmt;
5
6pub const TEST_HTTP_MAX_BODY_BYTES: usize = 16 * 1024 * 1024;
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq)]
10pub struct HttpContentLengthLimitError {
11 pub content_length: usize,
12 pub max_bytes: usize,
13}
14
15impl fmt::Display for HttpContentLengthLimitError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 write!(
18 f,
19 "HTTP Content-Length {} exceeds limit {} bytes",
20 self.content_length, self.max_bytes
21 )
22 }
23}
24
25impl Error for HttpContentLengthLimitError {}
26
27pub fn enforce_http_content_length_limit(
28 content_length: usize,
29 max_bytes: usize,
30) -> Result<usize, HttpContentLengthLimitError> {
31 if content_length > max_bytes {
32 Err(HttpContentLengthLimitError {
33 content_length,
34 max_bytes,
35 })
36 } else {
37 Ok(content_length)
38 }
39}
40
41pub fn parse_http_content_length(
42 value: Option<&str>,
43 max_bytes: usize,
44) -> Result<usize, HttpContentLengthLimitError> {
45 let Some(value) = value else {
46 return Ok(0);
47 };
48 let Ok(content_length) = value.trim().parse::<usize>() else {
49 return Ok(0);
50 };
51 enforce_http_content_length_limit(content_length, max_bytes)
52}
53
54#[cfg(test)]
55pub fn http_content_length_from_headers(
56 headers: &BTreeMap<String, String>,
57 max_bytes: usize,
58) -> Result<usize, HttpContentLengthLimitError> {
59 parse_http_content_length(headers.get("content-length").map(String::as_str), max_bytes)
60}
61
62pub fn http_content_length_from_header_lines<'a>(
63 lines: impl IntoIterator<Item = &'a str>,
64 max_bytes: usize,
65) -> Result<usize, HttpContentLengthLimitError> {
66 let value = lines.into_iter().find_map(|line| {
67 let (name, value) = line.split_once(':')?;
68 name.eq_ignore_ascii_case("content-length")
69 .then_some(value.trim())
70 });
71 parse_http_content_length(value, max_bytes)
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 #[test]
79 fn missing_or_invalid_content_length_defaults_to_zero() {
80 assert_eq!(
81 parse_http_content_length(None, TEST_HTTP_MAX_BODY_BYTES).expect("missing length"),
82 0
83 );
84 assert_eq!(
85 parse_http_content_length(Some("not-a-number"), TEST_HTTP_MAX_BODY_BYTES)
86 .expect("invalid length"),
87 0
88 );
89 }
90
91 #[test]
92 fn header_line_lookup_is_case_insensitive() {
93 let headers = ["Host: example.test", "Content-Length: 12"];
94 assert_eq!(
95 http_content_length_from_header_lines(headers, TEST_HTTP_MAX_BODY_BYTES)
96 .expect("content length"),
97 12
98 );
99 }
100
101 #[test]
102 fn oversized_content_length_is_rejected() {
103 let error = parse_http_content_length(
104 Some(&(TEST_HTTP_MAX_BODY_BYTES + 1).to_string()),
105 TEST_HTTP_MAX_BODY_BYTES,
106 )
107 .expect_err("oversized content length");
108 assert_eq!(error.content_length, TEST_HTTP_MAX_BODY_BYTES + 1);
109 assert_eq!(error.max_bytes, TEST_HTTP_MAX_BODY_BYTES);
110 }
111}