1use crate::error::ServerError;
12use std::fmt;
13use std::io::{self, BufRead, BufReader};
14use std::net::TcpStream;
15use std::time::Duration;
16
17const MAX_REQUEST_LINE_LENGTH: usize = 8190;
20
21const REQUEST_PARTS: usize = 3;
23
24const TIMEOUT_SECONDS: u64 = 30;
26const MAX_HEADER_COUNT: usize = 100;
28const MAX_HEADER_LINE_LENGTH: usize = 8192;
30const MAX_HEADER_BYTES: usize = 64 * 1024;
32
33fn map_timeout_error(error: io::Error) -> ServerError {
34 ServerError::invalid_request(format!(
35 "Failed to set read timeout: {}",
36 error
37 ))
38}
39
40fn map_read_error(error: io::Error) -> ServerError {
41 ServerError::invalid_request(format!(
42 "Failed to read request line: {}",
43 error
44 ))
45}
46
47#[doc(alias = "http request")]
70#[derive(Debug, Clone, PartialEq)]
71pub struct Request {
72 pub method: String,
74 pub path: String,
76 pub version: String,
78 pub headers: Vec<(String, String)>,
86}
87
88impl Request {
89 #[doc(alias = "parse")]
127 #[doc(alias = "from tcp")]
128 pub fn from_stream(
129 stream: &TcpStream,
130 ) -> Result<Self, ServerError> {
131 stream
132 .set_read_timeout(Some(Duration::from_secs(
133 TIMEOUT_SECONDS,
134 )))
135 .map_err(map_timeout_error)?;
136
137 let mut buf_reader = BufReader::new(stream);
138 let mut request_line = String::new();
139
140 let _ = buf_reader
141 .read_line(&mut request_line)
142 .map_err(map_read_error)?;
143
144 let trimmed_request_line = request_line.trim_end();
146
147 if request_line.len() > MAX_REQUEST_LINE_LENGTH {
149 return Err(ServerError::invalid_request(format!(
150 "Request line too long: {} characters (max {})",
151 request_line.len(),
152 MAX_REQUEST_LINE_LENGTH
153 )));
154 }
155
156 let mut parts = trimmed_request_line.split_whitespace();
157 let Some(method_part) = parts.next() else {
158 return Err(ServerError::invalid_request(
159 "Invalid request line: missing method",
160 ));
161 };
162 let Some(path_part) = parts.next() else {
163 return Err(ServerError::invalid_request(
164 "Invalid request line: missing path",
165 ));
166 };
167 let Some(version_part) = parts.next() else {
168 return Err(ServerError::invalid_request(
169 "Invalid request line: missing HTTP version",
170 ));
171 };
172 if parts.next().is_some() {
173 return Err(ServerError::invalid_request(format!(
174 "Invalid request line: expected {} parts",
175 REQUEST_PARTS
176 )));
177 }
178
179 let method = method_part.to_string();
180 if !Self::is_valid_method(&method) {
181 return Err(ServerError::invalid_request(format!(
182 "Invalid HTTP method: {}",
183 method
184 )));
185 }
186
187 let path = path_part.to_string();
188 let is_options_asterisk =
189 method.eq_ignore_ascii_case("OPTIONS") && path == "*";
190 if !path.starts_with('/') && !is_options_asterisk {
191 return Err(ServerError::invalid_request(
192 "Invalid path: must start with '/' (or be '*' for OPTIONS)",
193 ));
194 }
195
196 let version = version_part.to_string();
197 if !Self::is_valid_version(&version) {
198 return Err(ServerError::invalid_request(format!(
199 "Invalid HTTP version: {}",
200 version
201 )));
202 }
203
204 let headers = Self::read_headers(&mut buf_reader)?;
205
206 Ok(Request {
207 method,
208 path,
209 version,
210 headers,
211 })
212 }
213
214 pub fn method(&self) -> &str {
220 &self.method
221 }
222
223 pub fn path(&self) -> &str {
229 &self.path
230 }
231
232 pub fn version(&self) -> &str {
238 &self.version
239 }
240
241 #[doc(alias = "header lookup")]
264 pub fn header(&self, name: &str) -> Option<&str> {
265 self.headers
268 .iter()
269 .find(|(k, _)| k.eq_ignore_ascii_case(name))
270 .map(|(_, v)| v.as_str())
271 }
272
273 pub fn headers(&self) -> &[(String, String)] {
275 &self.headers
276 }
277
278 fn is_valid_method(method: &str) -> bool {
288 matches!(
289 method.to_ascii_uppercase().as_str(),
290 "GET"
291 | "POST"
292 | "PUT"
293 | "DELETE"
294 | "HEAD"
295 | "OPTIONS"
296 | "PATCH"
297 )
298 }
299
300 fn is_valid_version(version: &str) -> bool {
310 version.eq_ignore_ascii_case("HTTP/1.0")
311 || version.eq_ignore_ascii_case("HTTP/1.1")
312 }
313
314 fn read_headers<R: BufRead>(
315 reader: &mut R,
316 ) -> Result<Vec<(String, String)>, ServerError> {
317 let mut headers: Vec<(String, String)> = Vec::with_capacity(16);
318 let mut total_bytes = 0_usize;
319 let mut line = String::new();
322
323 loop {
324 line.clear();
325 let bytes =
326 reader.read_line(&mut line).map_err(map_read_error)?;
327 if bytes == 0 {
328 break;
329 }
330 total_bytes = total_bytes.saturating_add(bytes);
331 if total_bytes > MAX_HEADER_BYTES {
332 return Err(ServerError::invalid_request(
333 "Header section too large",
334 ));
335 }
336
337 let trimmed = line.trim_end();
338 if trimmed.is_empty() {
339 break;
340 }
341 if trimmed.len() > MAX_HEADER_LINE_LENGTH {
342 return Err(ServerError::invalid_request(
343 "Header line too long",
344 ));
345 }
346 let bytes = trimmed.as_bytes();
351 let colon =
352 memchr::memchr(b':', bytes).ok_or_else(|| {
353 ServerError::invalid_request(
354 "Malformed header line",
355 )
356 })?;
357 let (name, value) = trimmed.split_at(colon);
362 let value = &value[1..];
363 if headers.len() >= MAX_HEADER_COUNT {
364 return Err(ServerError::invalid_request(
365 "Too many request headers",
366 ));
367 }
368 headers.push((
369 name.trim().to_ascii_lowercase(),
370 value.trim().to_string(),
371 ));
372 }
373
374 Ok(headers)
375 }
376}
377
378impl fmt::Display for Request {
379 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380 write!(f, "{} {} {}", self.method, self.path, self.version)
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use std::io::Write;
388 use std::net::TcpListener;
389
390 #[test]
391 fn test_valid_request() {
392 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
393 let addr = listener.local_addr().unwrap();
394
395 let _ = std::thread::spawn(move || {
396 let (mut stream, _) = listener.accept().unwrap();
397 stream.write_all(b"GET /index.html HTTP/1.1\r\n").unwrap();
398 });
399
400 let stream = TcpStream::connect(addr).unwrap();
401 let request = Request::from_stream(&stream).unwrap();
402
403 assert_eq!(request.method(), "GET");
404 assert_eq!(request.path(), "/index.html");
405 assert_eq!(request.version(), "HTTP/1.1");
406 assert!(request.headers().is_empty());
407 }
408
409 #[test]
410 fn test_invalid_method() {
411 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
412 let addr = listener.local_addr().unwrap();
413
414 let _ = std::thread::spawn(move || {
415 let (mut stream, _) = listener.accept().unwrap();
416 stream
417 .write_all(b"INVALID /index.html HTTP/1.1\r\n")
418 .unwrap();
419 });
420
421 let stream = TcpStream::connect(addr).unwrap();
422 let result = Request::from_stream(&stream);
423
424 assert!(result.is_err());
425 assert!(matches!(
426 result.unwrap_err(),
427 ServerError::InvalidRequest(_)
428 ));
429 }
430
431 #[test]
432 fn test_max_length_request() {
433 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
434 let addr = listener.local_addr().unwrap();
435
436 let _ = std::thread::spawn(move || {
437 let (mut stream, _) = listener.accept().unwrap();
438 let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 16); let request = format!("GET {} HTTP/1.1\r\n", long_path);
440 stream.write_all(request.as_bytes()).unwrap();
441 });
442
443 let stream = TcpStream::connect(addr).unwrap();
444 let result = Request::from_stream(&stream);
445
446 assert!(result.is_ok());
447 assert_eq!(
448 result.unwrap().path().len(),
449 MAX_REQUEST_LINE_LENGTH - 16
450 );
451 }
452
453 #[test]
454 fn test_oversized_request() {
455 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
456 let addr = listener.local_addr().unwrap();
457
458 let _ = std::thread::spawn(move || {
459 let (mut stream, _) = listener.accept().unwrap();
460 let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 13); let request = format!("GET {} HTTP/1.1\r\n", long_path);
462 stream.write_all(request.as_bytes()).unwrap();
463 });
464
465 let stream = TcpStream::connect(addr).unwrap();
466 let result = Request::from_stream(&stream);
467
468 assert!(
469 result.is_err(),
470 "Oversized request should be invalid. Request: {:?}",
471 result
472 );
473 let msg = result.unwrap_err().to_string();
474 assert!(
475 msg.contains("Request line too long:"),
476 "Unexpected error message: {}",
477 msg
478 );
479 }
480
481 #[test]
482 fn test_invalid_path() {
483 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
484 let addr = listener.local_addr().unwrap();
485
486 let _ = std::thread::spawn(move || {
487 let (mut stream, _) = listener.accept().unwrap();
488 stream.write_all(b"GET index.html HTTP/1.1\r\n").unwrap();
489 });
490
491 let stream = TcpStream::connect(addr).unwrap();
492 let result = Request::from_stream(&stream);
493
494 assert!(result.is_err());
495 assert!(matches!(
496 result.unwrap_err(),
497 ServerError::InvalidRequest(_)
498 ));
499 }
500
501 #[test]
502 fn test_invalid_version() {
503 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
504 let addr = listener.local_addr().unwrap();
505
506 let _ = std::thread::spawn(move || {
507 let (mut stream, _) = listener.accept().unwrap();
508 stream.write_all(b"GET /index.html HTTP/2.0\r\n").unwrap();
509 });
510
511 let stream = TcpStream::connect(addr).unwrap();
512 let result = Request::from_stream(&stream);
513
514 assert!(result.is_err());
515 assert!(matches!(
516 result.unwrap_err(),
517 ServerError::InvalidRequest(_)
518 ));
519 }
520
521 #[test]
522 fn test_head_request() {
523 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
524 let addr = listener.local_addr().unwrap();
525
526 let _ = std::thread::spawn(move || {
527 let (mut stream, _) = listener.accept().unwrap();
528 stream.write_all(b"HEAD /index.html HTTP/1.1\r\n").unwrap();
529 });
530
531 let stream = TcpStream::connect(addr).unwrap();
532 let request = Request::from_stream(&stream).unwrap();
533
534 assert_eq!(request.method(), "HEAD");
535 assert_eq!(request.path(), "/index.html");
536 assert_eq!(request.version(), "HTTP/1.1");
537 }
538
539 #[test]
540 fn test_options_request() {
541 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
542 let addr = listener.local_addr().unwrap();
543
544 let _ = std::thread::spawn(move || {
545 let (mut stream, _) = listener.accept().unwrap();
546 stream.write_all(b"OPTIONS * HTTP/1.1\r\n").unwrap();
547 });
548
549 let stream = TcpStream::connect(addr).unwrap();
550 let request = Request::from_stream(&stream).unwrap();
551
552 assert_eq!(request.method(), "OPTIONS");
553 assert_eq!(request.path(), "*");
554 assert_eq!(request.version(), "HTTP/1.1");
555 }
556
557 #[test]
558 fn test_internal_error_mapping_helpers() {
559 let timeout_err =
560 io::Error::new(io::ErrorKind::TimedOut, "timeout");
561 let mapped = map_timeout_error(timeout_err);
562 assert!(
563 mapped.to_string().contains("Failed to set read timeout")
564 );
565
566 let read_err =
567 io::Error::new(io::ErrorKind::UnexpectedEof, "eof");
568 let mapped = map_read_error(read_err);
569 assert!(
570 mapped.to_string().contains("Failed to read request line")
571 );
572 }
573
574 #[test]
575 fn test_parses_headers() {
576 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
577 let addr = listener.local_addr().unwrap();
578
579 let _ = std::thread::spawn(move || {
580 let (mut stream, _) = listener.accept().unwrap();
581 stream
582 .write_all(
583 b"GET /index.html HTTP/1.1\r\nHost: localhost\r\nRange: bytes=0-1\r\n\r\n",
584 )
585 .unwrap();
586 });
587
588 let stream = TcpStream::connect(addr).unwrap();
589 let request = Request::from_stream(&stream).unwrap();
590 assert_eq!(request.header("host"), Some("localhost"));
591 assert_eq!(request.header("range"), Some("bytes=0-1"));
592 }
593
594 fn run_request_bytes(
595 bytes: Vec<u8>,
596 ) -> Result<Request, ServerError> {
597 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
598 let addr = listener.local_addr().unwrap();
599 let _ = std::thread::spawn(move || {
600 let (mut stream, _) = listener.accept().unwrap();
601 let _ = stream.write_all(&bytes);
602 });
603 let stream = TcpStream::connect(addr).unwrap();
604 Request::from_stream(&stream)
605 }
606
607 #[test]
608 fn test_missing_method_returns_error() {
609 let err = run_request_bytes(b"\r\n".to_vec()).unwrap_err();
610 assert!(
611 err.to_string().contains("missing method"),
612 "unexpected error: {err}"
613 );
614 }
615
616 #[test]
617 fn test_too_many_parts_returns_error() {
618 let err =
619 run_request_bytes(b"GET / HTTP/1.1 extra\r\n".to_vec())
620 .unwrap_err();
621 let msg = err.to_string();
622 assert!(
623 msg.contains("expected") && msg.contains("parts"),
624 "unexpected error: {msg}"
625 );
626 }
627
628 #[test]
629 fn test_malformed_header_returns_error() {
630 let err = run_request_bytes(
631 b"GET / HTTP/1.1\r\nmissing-colon-line\r\n\r\n".to_vec(),
632 )
633 .unwrap_err();
634 assert!(
635 err.to_string().contains("Malformed header line"),
636 "unexpected error: {err}"
637 );
638 }
639
640 #[test]
641 fn test_header_line_too_long_returns_error() {
642 let mut req = Vec::from("GET / HTTP/1.1\r\nX: ");
643 req.extend(std::iter::repeat_n(b'A', MAX_HEADER_LINE_LENGTH));
644 req.extend_from_slice(b"\r\n\r\n");
645 let err = run_request_bytes(req).unwrap_err();
646 assert!(
647 err.to_string().contains("Header line too long"),
648 "unexpected error: {err}"
649 );
650 }
651
652 #[test]
653 fn test_header_section_too_large_returns_error() {
654 let mut req = Vec::from("GET / HTTP/1.1\r\n");
658 let filler: String = "A".repeat(8000);
659 for i in 0..10 {
661 req.extend_from_slice(
662 format!("H{i}: {filler}\r\n").as_bytes(),
663 );
664 }
665 req.extend_from_slice(b"\r\n");
666 let err = run_request_bytes(req).unwrap_err();
667 assert!(
668 err.to_string().contains("Header section too large"),
669 "unexpected error: {err}"
670 );
671 }
672
673 #[test]
674 fn test_too_many_headers_returns_error() {
675 let mut req = Vec::from("GET / HTTP/1.1\r\n");
676 for i in 0..=MAX_HEADER_COUNT {
677 req.extend_from_slice(format!("H{i}: v\r\n").as_bytes());
678 }
679 req.extend_from_slice(b"\r\n");
680 let err = run_request_bytes(req).unwrap_err();
681 assert!(
682 err.to_string().contains("Too many request headers"),
683 "unexpected error: {err}"
684 );
685 }
686
687 #[test]
688 fn test_missing_http_version_returns_error() {
689 let err = run_request_bytes(b"GET /\r\n".to_vec()).unwrap_err();
692 assert!(
693 err.to_string().contains("missing HTTP version"),
694 "unexpected error: {err}"
695 );
696 }
697
698 #[test]
699 fn test_request_display_formats_method_path_version() {
700 let request = Request {
701 method: "GET".to_string(),
702 path: "/index.html".to_string(),
703 version: "HTTP/1.1".to_string(),
704 headers: Vec::new(),
705 };
706 assert_eq!(format!("{request}"), "GET /index.html HTTP/1.1");
707 }
708}