cellos_supervisor/sni_proxy/
http.rs1use std::fmt;
22
23pub const MAX_HEADER_SECTION_LEN: usize = 8192;
27
28const HTTP2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
30
31#[derive(Debug, PartialEq, Eq, Clone)]
33pub enum HttpParseError {
34 HeaderSectionOverflow,
37 MalformedRequestLine,
39 InvalidByte,
41 Http09NotSupported,
43}
44
45impl fmt::Display for HttpParseError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 HttpParseError::HeaderSectionOverflow => write!(
49 f,
50 "HTTP header section exceeded {MAX_HEADER_SECTION_LEN} bytes without CRLFCRLF"
51 ),
52 HttpParseError::MalformedRequestLine => write!(f, "malformed HTTP/1.x request line"),
53 HttpParseError::InvalidByte => write!(f, "invalid byte in request line"),
54 HttpParseError::Http09NotSupported => write!(f, "HTTP/0.9 has no Host semantics"),
55 }
56 }
57}
58
59impl std::error::Error for HttpParseError {}
60
61pub fn extract_http_host(first_bytes: &[u8]) -> Result<Option<String>, HttpParseError> {
64 if first_bytes.len() >= HTTP2_PREFACE.len()
65 && &first_bytes[..HTTP2_PREFACE.len()] == HTTP2_PREFACE
66 {
67 return Ok(None);
68 }
69
70 let scan_end = first_bytes.len().min(MAX_HEADER_SECTION_LEN);
71 let header_end = match find_crlf_crlf(&first_bytes[..scan_end]) {
72 Some(idx) => idx,
73 None => {
74 if first_bytes.len() >= MAX_HEADER_SECTION_LEN {
75 return Err(HttpParseError::HeaderSectionOverflow);
76 }
77 return Err(HttpParseError::MalformedRequestLine);
78 }
79 };
80 let header_section_end = header_end + 2;
88 let header_section = &first_bytes[..header_section_end];
89
90 let line_end = find_crlf(header_section).ok_or(HttpParseError::MalformedRequestLine)?;
91 let request_line = &header_section[..line_end];
92 validate_request_line(request_line)?;
93 let space_count = request_line.iter().filter(|&&b| b == b' ').count();
94 if space_count < 2 {
95 return Err(HttpParseError::Http09NotSupported);
96 }
97
98 let mut idx = line_end + 2;
99 while idx < header_section.len() {
100 let rel_end = match find_crlf(&header_section[idx..]) {
101 Some(e) => e,
102 None => break,
103 };
104 let line = &header_section[idx..idx + rel_end];
105 if line.is_empty() {
106 break;
107 }
108 if let Some(colon) = line.iter().position(|&b| b == b':') {
109 let name = &line[..colon];
110 if eq_ignore_ascii(name, b"Host") {
111 let value = &line[colon + 1..];
112 return Ok(Some(normalise_host_value(value)));
113 }
114 }
115 idx += rel_end + 2;
116 }
117 Ok(None)
118}
119
120fn find_crlf(buf: &[u8]) -> Option<usize> {
121 buf.windows(2).position(|w| w == b"\r\n")
122}
123
124fn find_crlf_crlf(buf: &[u8]) -> Option<usize> {
125 buf.windows(4).position(|w| w == b"\r\n\r\n")
126}
127
128fn validate_request_line(line: &[u8]) -> Result<(), HttpParseError> {
129 if line.is_empty() {
130 return Err(HttpParseError::MalformedRequestLine);
131 }
132 for &b in line {
133 if b == 0 || b == b'\n' || b == b'\r' {
134 return Err(HttpParseError::InvalidByte);
135 }
136 if !(0x20..=0x7E).contains(&b) {
137 return Err(HttpParseError::InvalidByte);
138 }
139 }
140 Ok(())
141}
142
143fn eq_ignore_ascii(a: &[u8], b: &[u8]) -> bool {
144 if a.len() != b.len() {
145 return false;
146 }
147 a.iter()
148 .zip(b.iter())
149 .all(|(x, y)| x.eq_ignore_ascii_case(y))
150}
151
152fn normalise_host_value(raw: &[u8]) -> String {
153 let mut start = 0;
154 while start < raw.len() && (raw[start] == b' ' || raw[start] == b'\t') {
155 start += 1;
156 }
157 let mut end = raw.len();
158 while end > start && (raw[end - 1] == b' ' || raw[end - 1] == b'\t') {
159 end -= 1;
160 }
161 let v = &raw[start..end];
162 let host_bytes = if v.first() == Some(&b'[') {
163 if let Some(close) = v.iter().position(|&b| b == b']') {
164 &v[..=close]
165 } else {
166 v
167 }
168 } else if let Some(colon) = v.iter().position(|&b| b == b':') {
169 &v[..colon]
170 } else {
171 v
172 };
173 let mut s = String::from_utf8_lossy(host_bytes).to_string();
174 s.make_ascii_lowercase();
175 if s.ends_with('.') {
176 s.pop();
177 }
178 s
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn extracts_get_host() {
187 let req = b"GET / HTTP/1.1\r\nHost: api.example.com\r\nUser-Agent: x\r\n\r\n";
188 assert_eq!(
189 extract_http_host(req).unwrap().as_deref(),
190 Some("api.example.com")
191 );
192 }
193
194 #[test]
195 fn extracts_post_host() {
196 let req = b"POST /v1/x HTTP/1.1\r\nHost: api.example.com\r\nContent-Length: 0\r\n\r\n";
197 assert_eq!(
198 extract_http_host(req).unwrap().as_deref(),
199 Some("api.example.com")
200 );
201 }
202
203 #[test]
204 fn missing_host_returns_ok_none() {
205 let req = b"GET / HTTP/1.1\r\nAccept: */*\r\n\r\n";
206 assert_eq!(extract_http_host(req).unwrap(), None);
207 }
208
209 #[test]
210 fn case_insensitive_host_header() {
211 let req = b"GET / HTTP/1.1\r\nhost: api.example.com\r\n\r\n";
212 assert_eq!(
213 extract_http_host(req).unwrap().as_deref(),
214 Some("api.example.com")
215 );
216 let req2 = b"GET / HTTP/1.1\r\nHOST: api.example.com\r\n\r\n";
217 assert_eq!(
218 extract_http_host(req2).unwrap().as_deref(),
219 Some("api.example.com")
220 );
221 }
222
223 #[test]
224 fn host_with_port_strips_port() {
225 let req = b"GET / HTTP/1.1\r\nHost: api.example.com:8443\r\n\r\n";
226 assert_eq!(
227 extract_http_host(req).unwrap().as_deref(),
228 Some("api.example.com")
229 );
230 }
231
232 #[test]
233 fn ipv6_host_with_port_keeps_brackets() {
234 let req = b"GET / HTTP/1.1\r\nHost: [::1]:443\r\n\r\n";
235 assert_eq!(extract_http_host(req).unwrap().as_deref(), Some("[::1]"));
236 }
237
238 #[test]
239 fn malformed_request_line_no_version() {
240 let req = b"GET /\r\nHost: x.example.com\r\n\r\n";
241 assert!(matches!(
242 extract_http_host(req),
243 Err(HttpParseError::Http09NotSupported)
244 ));
245 }
246
247 #[test]
248 fn oversized_header_section_rejected() {
249 let mut buf: Vec<u8> = Vec::with_capacity(MAX_HEADER_SECTION_LEN + 16);
250 buf.extend_from_slice(b"GET / HTTP/1.1\r\n");
251 buf.extend_from_slice(b"X-Pad: ");
252 while buf.len() < MAX_HEADER_SECTION_LEN + 8 {
253 buf.push(b'a');
254 }
255 assert_eq!(
256 extract_http_host(&buf),
257 Err(HttpParseError::HeaderSectionOverflow)
258 );
259 }
260
261 #[test]
262 fn http2_preface_returns_ok_none() {
263 let bytes = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n\x00\x00";
264 assert_eq!(extract_http_host(bytes).unwrap(), None);
265 }
266
267 #[test]
268 fn crlf_injection_in_request_line_rejected() {
269 let req = b"GET /\rinjected HTTP/1.1\r\nHost: api.example.com\r\n\r\n";
270 assert!(matches!(
271 extract_http_host(req),
272 Err(HttpParseError::InvalidByte) | Err(HttpParseError::MalformedRequestLine)
273 ));
274 }
275
276 #[test]
277 fn nul_byte_in_request_line_rejected() {
278 let req = b"GET /\x00 HTTP/1.1\r\nHost: api.example.com\r\n\r\n";
279 assert!(matches!(
280 extract_http_host(req),
281 Err(HttpParseError::InvalidByte)
282 ));
283 }
284
285 #[test]
286 fn trailing_dot_in_host_stripped() {
287 let req = b"GET / HTTP/1.1\r\nHost: api.example.com.\r\n\r\n";
288 assert_eq!(
289 extract_http_host(req).unwrap().as_deref(),
290 Some("api.example.com")
291 );
292 }
293
294 #[test]
295 fn host_lowercased() {
296 let req = b"GET / HTTP/1.1\r\nHost: API.Example.COM\r\n\r\n";
297 assert_eq!(
298 extract_http_host(req).unwrap().as_deref(),
299 Some("api.example.com")
300 );
301 }
302
303 #[test]
304 fn incomplete_request_no_crlfcrlf_is_malformed() {
305 let req = b"GET / HTTP/1.1\r\nHost: api.example.com\r\n";
306 assert_eq!(
307 extract_http_host(req),
308 Err(HttpParseError::MalformedRequestLine)
309 );
310 }
311}