Skip to main content

pipa/http/
headers.rs

1use std::collections::HashMap;
2
3#[derive(Debug, Clone)]
4pub struct Headers {
5    map: HashMap<String, String>,
6}
7
8impl Headers {
9    pub fn new() -> Self {
10        Headers {
11            map: HashMap::new(),
12        }
13    }
14
15    pub fn from_bytes(data: &[u8]) -> Result<(Self, usize), String> {
16        let mut map = HashMap::new();
17        let mut pos = 0;
18        let len = data.len();
19
20        loop {
21            if pos >= len {
22                return Err("unexpected end of headers".into());
23            }
24            if data[pos] == b'\r' {
25                if pos + 1 < len && data[pos + 1] == b'\n' {
26                    return Ok((Headers { map }, pos + 2));
27                }
28                return Err("malformed header terminator".into());
29            }
30
31            let line_start = pos;
32            while pos < len && data[pos] != b'\r' {
33                pos += 1;
34            }
35            if pos >= len {
36                return Err("unexpected end of headers".into());
37            }
38            let line_end = pos;
39            if pos + 1 >= len || data[pos + 1] != b'\n' {
40                return Err("malformed header line".into());
41            }
42            pos += 2;
43
44            let line = &data[line_start..line_end];
45            if line.is_empty() {
46                return Err("empty header line".into());
47            }
48
49            let colon_pos = line.iter().position(|&b| b == b':');
50            match colon_pos {
51                Some(cpos) => {
52                    let name = String::from_utf8_lossy(&line[..cpos]).trim().to_lowercase();
53                    let value = String::from_utf8_lossy(&line[cpos + 1..])
54                        .trim()
55                        .to_string();
56                    map.insert(name, value);
57                }
58                None => {
59                    return Err(format!(
60                        "malformed header (no colon): {:?}",
61                        String::from_utf8_lossy(line)
62                    ));
63                }
64            }
65        }
66    }
67
68    pub fn get(&self, name: &str) -> Option<&str> {
69        self.map.get(&name.to_lowercase()).map(|s| s.as_str())
70    }
71
72    pub fn set(&mut self, name: &str, value: &str) {
73        self.map.insert(name.to_lowercase(), value.to_string());
74    }
75
76    pub fn remove(&mut self, name: &str) {
77        self.map.remove(&name.to_lowercase());
78    }
79
80    pub fn contains(&self, name: &str) -> bool {
81        self.map.contains_key(&name.to_lowercase())
82    }
83
84    pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
85        self.map.iter().map(|(k, v)| (k.as_str(), v.as_str()))
86    }
87
88    pub fn to_request_bytes(&self) -> Vec<u8> {
89        let mut buf = Vec::new();
90        for (name, value) in &self.map {
91            buf.extend_from_slice(name.as_bytes());
92            buf.extend_from_slice(b": ");
93            buf.extend_from_slice(value.as_bytes());
94            buf.extend_from_slice(b"\r\n");
95        }
96        buf
97    }
98
99    pub fn len(&self) -> usize {
100        self.map.len()
101    }
102
103    pub fn is_empty(&self) -> bool {
104        self.map.is_empty()
105    }
106}
107
108impl Default for Headers {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_parse_headers() {
120        let data = b"Content-Type: text/plain\r\nContent-Length: 42\r\n\r\n";
121        let (headers, consumed) = Headers::from_bytes(data).unwrap();
122        assert_eq!(consumed, data.len());
123        assert_eq!(headers.get("content-type").unwrap(), "text/plain");
124        assert_eq!(headers.get("Content-Length").unwrap(), "42");
125    }
126
127    #[test]
128    fn test_case_insensitive() {
129        let data = b"X-Custom: value\r\n\r\n";
130        let (headers, _) = Headers::from_bytes(data).unwrap();
131        assert_eq!(headers.get("x-custom").unwrap(), "value");
132        assert_eq!(headers.get("X-CUSTOM").unwrap(), "value");
133    }
134
135    #[test]
136    fn test_serialize() {
137        let mut h = Headers::new();
138        h.set("Host", "example.com");
139        h.set("Accept", "*/*");
140        let bytes = h.to_request_bytes();
141        let s = String::from_utf8_lossy(&bytes);
142        assert!(s.contains("host: example.com\r\n"));
143        assert!(s.contains("accept: */*\r\n"));
144    }
145}