1#![allow(dead_code)] use super::error::HttpError;
4use nexus_net::buf::ReadBuf;
5
6pub struct Request<'a> {
8 pub method: &'a str,
10 pub path: &'a str,
12 pub version: u8,
14 data: &'a [u8],
15 header_offsets: &'a [(usize, usize, usize, usize)], }
17
18impl<'a> Request<'a> {
19 pub fn header(&self, name: &str) -> Option<&'a str> {
24 for &(ns, nl, vs, vl) in self.header_offsets {
25 let hname = &self.data[ns..ns + nl];
26 if hname.eq_ignore_ascii_case(name.as_bytes()) {
27 return std::str::from_utf8(&self.data[vs..vs + vl]).ok();
28 }
29 }
30 None
31 }
32
33 pub fn header_bytes(&self, name: &str) -> Option<&'a [u8]> {
37 for &(ns, nl, vs, vl) in self.header_offsets {
38 let hname = &self.data[ns..ns + nl];
39 if hname.eq_ignore_ascii_case(name.as_bytes()) {
40 return Some(&self.data[vs..vs + vl]);
41 }
42 }
43 None
44 }
45
46 pub fn headers(&self) -> impl Iterator<Item = (&'a str, &'a str)> {
51 self.header_offsets.iter().filter_map(|&(ns, nl, vs, vl)| {
52 let name = std::str::from_utf8(&self.data[ns..ns + nl]).ok()?;
53 let value = std::str::from_utf8(&self.data[vs..vs + vl]).ok()?;
54 Some((name, value))
55 })
56 }
57
58 pub fn header_count(&self) -> usize {
60 self.header_offsets.len()
61 }
62}
63
64impl std::fmt::Debug for Request<'_> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("Request")
67 .field("method", &self.method)
68 .field("path", &self.path)
69 .field("version", &self.version)
70 .field("headers", &self.header_count())
71 .finish()
72 }
73}
74
75pub struct RequestReader {
90 buf: ReadBuf,
91 max_headers: usize,
92 max_head_size: usize,
93 head_len: Option<usize>,
94 header_offsets: Vec<(usize, usize, usize, usize)>,
96 method_end: usize,
97 path_start: usize,
98 path_end: usize,
99 version: u8,
100}
101
102impl RequestReader {
103 #[must_use]
105 pub fn new(capacity: usize) -> Self {
106 Self {
107 buf: ReadBuf::with_capacity(capacity),
108 max_headers: 64,
109 max_head_size: 8192,
110 head_len: None,
111 header_offsets: Vec::new(),
112 method_end: 0,
113 path_start: 0,
114 path_end: 0,
115 version: 1,
116 }
117 }
118
119 #[must_use]
121 pub fn max_headers(mut self, n: usize) -> Self {
122 self.max_headers = n;
123 self
124 }
125
126 #[must_use]
128 pub fn max_head_size(mut self, n: usize) -> Self {
129 self.max_head_size = n;
130 self
131 }
132
133 #[inline]
138 pub fn spare(&mut self) -> &mut [u8] {
139 self.buf.spare()
140 }
141
142 #[inline]
144 pub fn filled(&mut self, n: usize) {
145 self.buf.filled(n);
146 }
147
148 pub fn read(&mut self, src: &[u8]) -> Result<(), HttpError> {
150 let spare = self.buf.spare();
151 if src.len() > spare.len() {
152 return Err(HttpError::BufferFull {
153 needed: src.len(),
154 available: spare.len(),
155 });
156 }
157 spare[..src.len()].copy_from_slice(src);
158 self.buf.filled(src.len());
159 Ok(())
160 }
161
162 #[allow(clippy::should_implement_trait)]
167 pub fn next(&mut self) -> Result<Option<Request<'_>>, HttpError> {
168 if self.head_len.is_none() {
169 self.try_parse()?;
170 }
171
172 if self.head_len.is_none() {
173 return Ok(None);
174 }
175
176 let data = self.buf.data();
177 let method = std::str::from_utf8(&data[..self.method_end])
178 .map_err(|_| HttpError::Malformed("invalid UTF-8 in method"))?;
179 let path = std::str::from_utf8(&data[self.path_start..self.path_end])
180 .map_err(|_| HttpError::Malformed("invalid UTF-8 in path"))?;
181
182 Ok(Some(Request {
183 method,
184 path,
185 version: self.version,
186 data,
187 header_offsets: &self.header_offsets,
188 }))
189 }
190
191 pub fn remainder(&self) -> &[u8] {
193 match self.head_len {
194 Some(n) => &self.buf.data()[n..],
195 None => &[],
196 }
197 }
198
199 pub fn reset(&mut self) {
201 self.buf.clear();
202 self.head_len = None;
203 self.header_offsets.clear();
204 }
205
206 fn try_parse(&mut self) -> Result<(), HttpError> {
207 let data = self.buf.data();
208 if data.is_empty() {
209 return Ok(());
210 }
211 if data.len() > self.max_head_size {
212 return Err(HttpError::HeadTooLarge {
213 max: self.max_head_size,
214 });
215 }
216
217 let mut stack_headers = [httparse::EMPTY_HEADER; 64];
220 let mut heap_headers;
221 let headers: &mut [httparse::Header<'_>] = if self.max_headers <= 64 {
222 &mut stack_headers[..self.max_headers]
223 } else {
224 heap_headers = vec![httparse::EMPTY_HEADER; self.max_headers];
225 &mut heap_headers
226 };
227 let mut req = httparse::Request::new(headers);
228
229 match req.parse(data) {
230 Ok(httparse::Status::Complete(head_len)) => {
231 let method = req
232 .method
233 .ok_or(HttpError::Malformed("missing request method"))?;
234 let path = req
235 .path
236 .ok_or(HttpError::Malformed("missing request path"))?;
237 let version = req
238 .version
239 .ok_or(HttpError::Malformed("missing HTTP version"))?;
240
241 let data_ptr = data.as_ptr();
242 self.method_end = method.len();
243 self.path_start = unsafe { path.as_ptr().offset_from(data_ptr) } as usize;
246 self.path_end = self.path_start + path.len();
247 self.version = version;
248
249 self.header_offsets.clear();
250 for h in req.headers.iter() {
251 let ns = unsafe { h.name.as_ptr().offset_from(data_ptr) } as usize;
253 let nl = h.name.len();
254 let vs = unsafe { h.value.as_ptr().offset_from(data_ptr) } as usize;
256 let vl = h.value.len();
257 self.header_offsets.push((ns, nl, vs, vl));
258 }
259
260 self.head_len = Some(head_len);
261 Ok(())
262 }
263 Ok(httparse::Status::Partial) => Ok(()),
264 Err(httparse::Error::TooManyHeaders) => Err(HttpError::TooManyHeaders),
265 Err(_) => Err(HttpError::Malformed("httparse rejected request")),
266 }
267 }
268}
269
270impl crate::ParserSink for RequestReader {
274 #[inline]
275 fn spare(&mut self) -> &mut [u8] {
276 RequestReader::spare(self)
277 }
278
279 #[inline]
280 fn filled(&mut self, n: usize) {
281 RequestReader::filled(self, n);
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn basic_get() {
291 let mut r = RequestReader::new(4096);
292 r.read(b"GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n")
293 .unwrap();
294 let req = r.next().unwrap().unwrap();
295 assert_eq!(req.method, "GET");
296 assert_eq!(req.path, "/path");
297 assert_eq!(req.version, 1);
298 assert_eq!(req.header("Host"), Some("example.com"));
299 }
300
301 #[test]
302 fn multiple_headers() {
303 let mut r = RequestReader::new(4096);
304 r.read(b"POST /api HTTP/1.1\r\nHost: a.com\r\nContent-Type: application/json\r\nX-Custom: value\r\n\r\n").unwrap();
305 let req = r.next().unwrap().unwrap();
306 assert_eq!(req.method, "POST");
307 assert_eq!(req.header("Content-Type"), Some("application/json"));
308 assert_eq!(req.header("x-custom"), Some("value")); assert_eq!(req.header_count(), 3);
310 }
311
312 #[test]
313 fn partial_then_complete() {
314 let mut r = RequestReader::new(4096);
315 r.read(b"GET / HTTP/1.1\r\nHost: ex").unwrap();
316 assert!(r.next().unwrap().is_none());
317 r.read(b"ample.com\r\n\r\n").unwrap();
318 let req = r.next().unwrap().unwrap();
319 assert_eq!(req.header("Host"), Some("example.com"));
320 }
321
322 #[test]
323 fn remainder_after_head() {
324 let mut r = RequestReader::new(4096);
325 r.read(b"GET / HTTP/1.1\r\nHost: a.com\r\n\r\nextra bytes")
326 .unwrap();
327 let _req = r.next().unwrap().unwrap();
328 assert_eq!(r.remainder(), b"extra bytes");
329 }
330
331 #[test]
332 fn head_too_large() {
333 let mut r = RequestReader::new(4096).max_head_size(32);
334 r.read(b"GET / HTTP/1.1\r\nHost: a-very-long-hostname.example.com\r\n\r\n")
335 .unwrap();
336 assert!(matches!(r.next(), Err(HttpError::HeadTooLarge { .. })));
337 }
338
339 #[test]
340 fn malformed_request() {
341 let mut r = RequestReader::new(4096);
342 r.read(b"NOT_HTTP\r\n\r\n").unwrap();
343 assert!(matches!(r.next(), Err(HttpError::Malformed(_))));
344 }
345
346 #[test]
347 fn buffer_full() {
348 let mut r = RequestReader::new(16);
349 let err = r
350 .read(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
351 .unwrap_err();
352 assert!(matches!(err, HttpError::BufferFull { .. }));
353 }
354
355 #[test]
356 fn ws_upgrade_request() {
357 let mut r = RequestReader::new(4096);
358 r.read(
359 b"GET /ws HTTP/1.1\r\n\
360 Host: localhost:8080\r\n\
361 Upgrade: websocket\r\n\
362 Connection: Upgrade\r\n\
363 Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
364 Sec-WebSocket-Version: 13\r\n\
365 \r\n",
366 )
367 .unwrap();
368 let req = r.next().unwrap().unwrap();
369 assert_eq!(req.method, "GET");
370 assert_eq!(req.path, "/ws");
371 assert_eq!(req.header("Upgrade"), Some("websocket"));
372 assert_eq!(req.header("Connection"), Some("Upgrade"));
373 assert_eq!(
374 req.header("Sec-WebSocket-Key"),
375 Some("dGhlIHNhbXBsZSBub25jZQ==")
376 );
377 assert_eq!(req.header("Sec-WebSocket-Version"), Some("13"));
378 }
379
380 #[test]
381 fn reset_then_reuse() {
382 let mut r = RequestReader::new(4096);
383 r.read(b"GET /a HTTP/1.1\r\nHost: a\r\n\r\n").unwrap();
384 let req = r.next().unwrap().unwrap();
385 assert_eq!(req.path, "/a");
386 let _ = req;
388
389 r.reset();
390 r.read(b"GET /b HTTP/1.1\r\nHost: b\r\n\r\n").unwrap();
391 let req = r.next().unwrap().unwrap();
392 assert_eq!(req.path, "/b");
393 }
394
395 #[test]
396 fn header_iter() {
397 let mut r = RequestReader::new(4096);
398 r.read(b"GET / HTTP/1.1\r\nA: 1\r\nB: 2\r\n\r\n").unwrap();
399 let req = r.next().unwrap().unwrap();
400 let hdrs: Vec<_> = req.headers().collect();
401 assert_eq!(hdrs.len(), 2);
402 assert_eq!(hdrs[0], ("A", "1"));
403 assert_eq!(hdrs[1], ("B", "2"));
404 }
405}