Skip to main content

http_handle/
request.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (c) 2023 - 2026 HTTP Handle
3
4// src/request.rs
5
6//! HTTP/1.x request parsing and validation.
7//!
8//! Use this module to convert raw stream input into typed request data with bounded parsing,
9//! header normalization, and explicit malformed-request errors.
10
11use crate::error::ServerError;
12use std::fmt;
13use std::io::{self, BufRead, BufReader};
14use std::net::TcpStream;
15use std::time::Duration;
16
17/// Maximum length allowed for the request line (8KB).
18/// This includes the method, path, version, and the two spaces between them, but not the trailing \r\n.
19const MAX_REQUEST_LINE_LENGTH: usize = 8190;
20
21/// Number of parts expected in a valid HTTP request line.
22const REQUEST_PARTS: usize = 3;
23
24/// Timeout duration for reading from the TCP stream (in seconds).
25const TIMEOUT_SECONDS: u64 = 30;
26/// Maximum number of accepted request headers.
27const MAX_HEADER_COUNT: usize = 100;
28/// Maximum allowed length for a single header line.
29const MAX_HEADER_LINE_LENGTH: usize = 8192;
30/// Maximum cumulative bytes for all headers.
31const MAX_HEADER_BYTES: usize = 64 * 1024;
32
33fn map_timeout_error(error: io::Error) -> ServerError {
34    ServerError::invalid_request(format!(
35        "Failed to set read timeout: {}",
36        error
37    ))
38}
39
40fn map_read_error(error: io::Error) -> ServerError {
41    ServerError::invalid_request(format!(
42        "Failed to read request line: {}",
43        error
44    ))
45}
46
47/// Represents a parsed HTTP/1.x request line and headers.
48///
49/// You receive this type after successful stream parsing. It is the primary request model
50/// used by the synchronous server path and shared response-generation helpers.
51///
52/// # Examples
53///
54/// ```rust
55/// use http_handle::request::Request;
56///
57/// let request = Request {
58///     method: "GET".to_string(),
59///     path: "/".to_string(),
60///     version: "HTTP/1.1".to_string(),
61///     headers: Vec::new(),
62/// };
63/// assert_eq!(request.method(), "GET");
64/// ```
65///
66/// # Panics
67///
68/// This type does not panic on construction.
69#[doc(alias = "http request")]
70#[derive(Debug, Clone, PartialEq)]
71pub struct Request {
72    /// HTTP method of the request.
73    pub method: String,
74    /// Requested path.
75    pub path: String,
76    /// HTTP version of the request.
77    pub version: String,
78    /// Parsed request headers (header-name lowercased).
79    ///
80    /// Stored as `Vec<(String, String)>` rather than a `HashMap` —
81    /// realistic request payloads carry well under 32 headers, so a
82    /// linear scan in `Request::header` outperforms hashing for both
83    /// lookup latency and per-request allocator pressure (no hash table
84    /// to grow + rehash).
85    pub headers: Vec<(String, String)>,
86}
87
88impl Request {
89    /// Parses a request line and headers from a `TcpStream`.
90    ///
91    /// This method reads the first line of an HTTP request from the given TCP stream,
92    /// parses it, and constructs a `Request` instance if the input is valid.
93    ///
94    /// # Arguments
95    ///
96    /// * `stream` - A reference to the `TcpStream` from which the request will be read.
97    ///
98    /// # Returns
99    ///
100    /// * `Ok(Request)` - If the request is valid and successfully parsed.
101    /// * `Err(ServerError)` - If the request is malformed, cannot be read, or is invalid.
102    ///
103    /// # Errors
104    ///
105    /// This function returns a `ServerError::InvalidRequest` error if:
106    /// - The request line is too long (exceeds `MAX_REQUEST_LINE_LENGTH`)
107    /// - The request line does not contain exactly three parts
108    /// - The HTTP method is not recognized
109    /// - The request path does not start with a forward slash (except `OPTIONS *`)
110    /// - The HTTP version is not supported (only HTTP/1.0 and HTTP/1.1 are accepted)
111    ///
112    /// # Examples
113    ///
114    /// ```rust,no_run
115    /// use std::net::TcpStream;
116    /// use http_handle::request::Request;
117    ///
118    /// let stream = TcpStream::connect("127.0.0.1:8080").expect("connect");
119    /// let parsed = Request::from_stream(&stream);
120    /// assert!(parsed.is_ok() || parsed.is_err());
121    /// ```
122    ///
123    /// # Panics
124    ///
125    /// This function does not intentionally panic.
126    #[doc(alias = "parse")]
127    #[doc(alias = "from tcp")]
128    pub fn from_stream(
129        stream: &TcpStream,
130    ) -> Result<Self, ServerError> {
131        stream
132            .set_read_timeout(Some(Duration::from_secs(
133                TIMEOUT_SECONDS,
134            )))
135            .map_err(map_timeout_error)?;
136
137        let mut buf_reader = BufReader::new(stream);
138        let mut request_line = String::new();
139
140        let _ = buf_reader
141            .read_line(&mut request_line)
142            .map_err(map_read_error)?;
143
144        // Trim the trailing \r\n before checking the length
145        let trimmed_request_line = request_line.trim_end();
146
147        // Check if the request line exceeds the maximum allowed length
148        if request_line.len() > MAX_REQUEST_LINE_LENGTH {
149            return Err(ServerError::invalid_request(format!(
150                "Request line too long: {} characters (max {})",
151                request_line.len(),
152                MAX_REQUEST_LINE_LENGTH
153            )));
154        }
155
156        let mut parts = trimmed_request_line.split_whitespace();
157        let Some(method_part) = parts.next() else {
158            return Err(ServerError::invalid_request(
159                "Invalid request line: missing method",
160            ));
161        };
162        let Some(path_part) = parts.next() else {
163            return Err(ServerError::invalid_request(
164                "Invalid request line: missing path",
165            ));
166        };
167        let Some(version_part) = parts.next() else {
168            return Err(ServerError::invalid_request(
169                "Invalid request line: missing HTTP version",
170            ));
171        };
172        if parts.next().is_some() {
173            return Err(ServerError::invalid_request(format!(
174                "Invalid request line: expected {} parts",
175                REQUEST_PARTS
176            )));
177        }
178
179        let method = method_part.to_string();
180        if !Self::is_valid_method(&method) {
181            return Err(ServerError::invalid_request(format!(
182                "Invalid HTTP method: {}",
183                method
184            )));
185        }
186
187        let path = path_part.to_string();
188        let is_options_asterisk =
189            method.eq_ignore_ascii_case("OPTIONS") && path == "*";
190        if !path.starts_with('/') && !is_options_asterisk {
191            return Err(ServerError::invalid_request(
192                "Invalid path: must start with '/' (or be '*' for OPTIONS)",
193            ));
194        }
195
196        let version = version_part.to_string();
197        if !Self::is_valid_version(&version) {
198            return Err(ServerError::invalid_request(format!(
199                "Invalid HTTP version: {}",
200                version
201            )));
202        }
203
204        let headers = Self::read_headers(&mut buf_reader)?;
205
206        Ok(Request {
207            method,
208            path,
209            version,
210            headers,
211        })
212    }
213
214    /// Returns the HTTP method of the request.
215    ///
216    /// # Returns
217    ///
218    /// A string slice containing the HTTP method (e.g., "GET", "POST").
219    pub fn method(&self) -> &str {
220        &self.method
221    }
222
223    /// Returns the requested path of the request.
224    ///
225    /// # Returns
226    ///
227    /// A string slice containing the requested path.
228    pub fn path(&self) -> &str {
229        &self.path
230    }
231
232    /// Returns the HTTP version of the request.
233    ///
234    /// # Returns
235    ///
236    /// A string slice containing the HTTP version (e.g., "HTTP/1.1").
237    pub fn version(&self) -> &str {
238        &self.version
239    }
240
241    /// Returns the value of a header by case-insensitive name.
242    ///
243    /// # Examples
244    ///
245    /// ```rust
246    /// use http_handle::request::Request;
247    ///
248    /// let request = Request {
249    ///     method: "GET".to_string(),
250    ///     path: "/".to_string(),
251    ///     version: "HTTP/1.1".to_string(),
252    ///     headers: vec![(
253    ///         "content-type".to_string(),
254    ///         "text/plain".to_string(),
255    ///     )],
256    /// };
257    /// assert_eq!(request.header("Content-Type"), Some("text/plain"));
258    /// ```
259    ///
260    /// # Panics
261    ///
262    /// This function does not panic.
263    #[doc(alias = "header lookup")]
264    pub fn header(&self, name: &str) -> Option<&str> {
265        // Linear scan: header counts in real traffic are O(10), so a
266        // case-insensitive equality check beats hashing the lookup key.
267        self.headers
268            .iter()
269            .find(|(k, _)| k.eq_ignore_ascii_case(name))
270            .map(|(_, v)| v.as_str())
271    }
272
273    /// Returns all parsed headers.
274    pub fn headers(&self) -> &[(String, String)] {
275        &self.headers
276    }
277
278    /// Checks if the given method is a valid HTTP method.
279    ///
280    /// # Arguments
281    ///
282    /// * `method` - A string slice containing the HTTP method to validate.
283    ///
284    /// # Returns
285    ///
286    /// `true` if the method is valid, `false` otherwise.
287    fn is_valid_method(method: &str) -> bool {
288        matches!(
289            method.to_ascii_uppercase().as_str(),
290            "GET"
291                | "POST"
292                | "PUT"
293                | "DELETE"
294                | "HEAD"
295                | "OPTIONS"
296                | "PATCH"
297        )
298    }
299
300    /// Checks if the given HTTP version is supported.
301    ///
302    /// # Arguments
303    ///
304    /// * `version` - A string slice containing the HTTP version to validate.
305    ///
306    /// # Returns
307    ///
308    /// `true` if the version is supported, `false` otherwise.
309    fn is_valid_version(version: &str) -> bool {
310        version.eq_ignore_ascii_case("HTTP/1.0")
311            || version.eq_ignore_ascii_case("HTTP/1.1")
312    }
313
314    fn read_headers<R: BufRead>(
315        reader: &mut R,
316    ) -> Result<Vec<(String, String)>, ServerError> {
317        let mut headers: Vec<(String, String)> = Vec::with_capacity(16);
318        let mut total_bytes = 0_usize;
319        // Reuse a single line buffer across iterations to avoid allocating
320        // a fresh String per header line.
321        let mut line = String::new();
322
323        loop {
324            line.clear();
325            let bytes =
326                reader.read_line(&mut line).map_err(map_read_error)?;
327            if bytes == 0 {
328                break;
329            }
330            total_bytes = total_bytes.saturating_add(bytes);
331            if total_bytes > MAX_HEADER_BYTES {
332                return Err(ServerError::invalid_request(
333                    "Header section too large",
334                ));
335            }
336
337            let trimmed = line.trim_end();
338            if trimmed.is_empty() {
339                break;
340            }
341            if trimmed.len() > MAX_HEADER_LINE_LENGTH {
342                return Err(ServerError::invalid_request(
343                    "Header line too long",
344                ));
345            }
346            // memchr finds the first ':' via SIMD (NEON on Apple
347            // Silicon, AVX2 on x86_64). For typical 12–40 byte header
348            // lines the win is small; for longer lines (cookies,
349            // user-agent) it's measurable.
350            let bytes = trimmed.as_bytes();
351            let colon =
352                memchr::memchr(b':', bytes).ok_or_else(|| {
353                    ServerError::invalid_request(
354                        "Malformed header line",
355                    )
356                })?;
357            // SAFETY: `colon` is an index returned by memchr inside
358            // `bytes`, which is the byte view of the `&str` `trimmed`.
359            // ASCII ':' is exactly one UTF-8 byte, so the split lands
360            // on a UTF-8 boundary.
361            let (name, value) = trimmed.split_at(colon);
362            let value = &value[1..];
363            if headers.len() >= MAX_HEADER_COUNT {
364                return Err(ServerError::invalid_request(
365                    "Too many request headers",
366                ));
367            }
368            headers.push((
369                name.trim().to_ascii_lowercase(),
370                value.trim().to_string(),
371            ));
372        }
373
374        Ok(headers)
375    }
376}
377
378impl fmt::Display for Request {
379    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380        write!(f, "{} {} {}", self.method, self.path, self.version)
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use std::io::Write;
388    use std::net::TcpListener;
389
390    #[test]
391    fn test_valid_request() {
392        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
393        let addr = listener.local_addr().unwrap();
394
395        let _ = std::thread::spawn(move || {
396            let (mut stream, _) = listener.accept().unwrap();
397            stream.write_all(b"GET /index.html HTTP/1.1\r\n").unwrap();
398        });
399
400        let stream = TcpStream::connect(addr).unwrap();
401        let request = Request::from_stream(&stream).unwrap();
402
403        assert_eq!(request.method(), "GET");
404        assert_eq!(request.path(), "/index.html");
405        assert_eq!(request.version(), "HTTP/1.1");
406        assert!(request.headers().is_empty());
407    }
408
409    #[test]
410    fn test_invalid_method() {
411        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
412        let addr = listener.local_addr().unwrap();
413
414        let _ = std::thread::spawn(move || {
415            let (mut stream, _) = listener.accept().unwrap();
416            stream
417                .write_all(b"INVALID /index.html HTTP/1.1\r\n")
418                .unwrap();
419        });
420
421        let stream = TcpStream::connect(addr).unwrap();
422        let result = Request::from_stream(&stream);
423
424        assert!(result.is_err());
425        assert!(matches!(
426            result.unwrap_err(),
427            ServerError::InvalidRequest(_)
428        ));
429    }
430
431    #[test]
432    fn test_max_length_request() {
433        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
434        let addr = listener.local_addr().unwrap();
435
436        let _ = std::thread::spawn(move || {
437            let (mut stream, _) = listener.accept().unwrap();
438            let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 16); // Account for "GET ", " HTTP/1.1", and "\r\n"
439            let request = format!("GET {} HTTP/1.1\r\n", long_path);
440            stream.write_all(request.as_bytes()).unwrap();
441        });
442
443        let stream = TcpStream::connect(addr).unwrap();
444        let result = Request::from_stream(&stream);
445
446        assert!(result.is_ok());
447        assert_eq!(
448            result.unwrap().path().len(),
449            MAX_REQUEST_LINE_LENGTH - 16
450        );
451    }
452
453    #[test]
454    fn test_oversized_request() {
455        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
456        let addr = listener.local_addr().unwrap();
457
458        let _ = std::thread::spawn(move || {
459            let (mut stream, _) = listener.accept().unwrap();
460            let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 13); // 13 = len("GET  HTTP/1.1")
461            let request = format!("GET {} HTTP/1.1\r\n", long_path);
462            stream.write_all(request.as_bytes()).unwrap();
463        });
464
465        let stream = TcpStream::connect(addr).unwrap();
466        let result = Request::from_stream(&stream);
467
468        assert!(
469            result.is_err(),
470            "Oversized request should be invalid. Request: {:?}",
471            result
472        );
473        let msg = result.unwrap_err().to_string();
474        assert!(
475            msg.contains("Request line too long:"),
476            "Unexpected error message: {}",
477            msg
478        );
479    }
480
481    #[test]
482    fn test_invalid_path() {
483        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
484        let addr = listener.local_addr().unwrap();
485
486        let _ = std::thread::spawn(move || {
487            let (mut stream, _) = listener.accept().unwrap();
488            stream.write_all(b"GET index.html HTTP/1.1\r\n").unwrap();
489        });
490
491        let stream = TcpStream::connect(addr).unwrap();
492        let result = Request::from_stream(&stream);
493
494        assert!(result.is_err());
495        assert!(matches!(
496            result.unwrap_err(),
497            ServerError::InvalidRequest(_)
498        ));
499    }
500
501    #[test]
502    fn test_invalid_version() {
503        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
504        let addr = listener.local_addr().unwrap();
505
506        let _ = std::thread::spawn(move || {
507            let (mut stream, _) = listener.accept().unwrap();
508            stream.write_all(b"GET /index.html HTTP/2.0\r\n").unwrap();
509        });
510
511        let stream = TcpStream::connect(addr).unwrap();
512        let result = Request::from_stream(&stream);
513
514        assert!(result.is_err());
515        assert!(matches!(
516            result.unwrap_err(),
517            ServerError::InvalidRequest(_)
518        ));
519    }
520
521    #[test]
522    fn test_head_request() {
523        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
524        let addr = listener.local_addr().unwrap();
525
526        let _ = std::thread::spawn(move || {
527            let (mut stream, _) = listener.accept().unwrap();
528            stream.write_all(b"HEAD /index.html HTTP/1.1\r\n").unwrap();
529        });
530
531        let stream = TcpStream::connect(addr).unwrap();
532        let request = Request::from_stream(&stream).unwrap();
533
534        assert_eq!(request.method(), "HEAD");
535        assert_eq!(request.path(), "/index.html");
536        assert_eq!(request.version(), "HTTP/1.1");
537    }
538
539    #[test]
540    fn test_options_request() {
541        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
542        let addr = listener.local_addr().unwrap();
543
544        let _ = std::thread::spawn(move || {
545            let (mut stream, _) = listener.accept().unwrap();
546            stream.write_all(b"OPTIONS * HTTP/1.1\r\n").unwrap();
547        });
548
549        let stream = TcpStream::connect(addr).unwrap();
550        let request = Request::from_stream(&stream).unwrap();
551
552        assert_eq!(request.method(), "OPTIONS");
553        assert_eq!(request.path(), "*");
554        assert_eq!(request.version(), "HTTP/1.1");
555    }
556
557    #[test]
558    fn test_internal_error_mapping_helpers() {
559        let timeout_err =
560            io::Error::new(io::ErrorKind::TimedOut, "timeout");
561        let mapped = map_timeout_error(timeout_err);
562        assert!(
563            mapped.to_string().contains("Failed to set read timeout")
564        );
565
566        let read_err =
567            io::Error::new(io::ErrorKind::UnexpectedEof, "eof");
568        let mapped = map_read_error(read_err);
569        assert!(
570            mapped.to_string().contains("Failed to read request line")
571        );
572    }
573
574    #[test]
575    fn test_parses_headers() {
576        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
577        let addr = listener.local_addr().unwrap();
578
579        let _ = std::thread::spawn(move || {
580            let (mut stream, _) = listener.accept().unwrap();
581            stream
582                .write_all(
583                    b"GET /index.html HTTP/1.1\r\nHost: localhost\r\nRange: bytes=0-1\r\n\r\n",
584                )
585                .unwrap();
586        });
587
588        let stream = TcpStream::connect(addr).unwrap();
589        let request = Request::from_stream(&stream).unwrap();
590        assert_eq!(request.header("host"), Some("localhost"));
591        assert_eq!(request.header("range"), Some("bytes=0-1"));
592    }
593
594    fn run_request_bytes(
595        bytes: Vec<u8>,
596    ) -> Result<Request, ServerError> {
597        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
598        let addr = listener.local_addr().unwrap();
599        let _ = std::thread::spawn(move || {
600            let (mut stream, _) = listener.accept().unwrap();
601            let _ = stream.write_all(&bytes);
602        });
603        let stream = TcpStream::connect(addr).unwrap();
604        Request::from_stream(&stream)
605    }
606
607    #[test]
608    fn test_missing_method_returns_error() {
609        let err = run_request_bytes(b"\r\n".to_vec()).unwrap_err();
610        assert!(
611            err.to_string().contains("missing method"),
612            "unexpected error: {err}"
613        );
614    }
615
616    #[test]
617    fn test_too_many_parts_returns_error() {
618        let err =
619            run_request_bytes(b"GET / HTTP/1.1 extra\r\n".to_vec())
620                .unwrap_err();
621        let msg = err.to_string();
622        assert!(
623            msg.contains("expected") && msg.contains("parts"),
624            "unexpected error: {msg}"
625        );
626    }
627
628    #[test]
629    fn test_malformed_header_returns_error() {
630        let err = run_request_bytes(
631            b"GET / HTTP/1.1\r\nmissing-colon-line\r\n\r\n".to_vec(),
632        )
633        .unwrap_err();
634        assert!(
635            err.to_string().contains("Malformed header line"),
636            "unexpected error: {err}"
637        );
638    }
639
640    #[test]
641    fn test_header_line_too_long_returns_error() {
642        let mut req = Vec::from("GET / HTTP/1.1\r\nX: ");
643        req.extend(std::iter::repeat_n(b'A', MAX_HEADER_LINE_LENGTH));
644        req.extend_from_slice(b"\r\n\r\n");
645        let err = run_request_bytes(req).unwrap_err();
646        assert!(
647            err.to_string().contains("Header line too long"),
648            "unexpected error: {err}"
649        );
650    }
651
652    #[test]
653    fn test_header_section_too_large_returns_error() {
654        // Many moderately sized header lines (each under MAX_HEADER_LINE_LENGTH)
655        // whose cumulative byte count exceeds MAX_HEADER_BYTES before the
656        // per-line or header-count guards trip.
657        let mut req = Vec::from("GET / HTTP/1.1\r\n");
658        let filler: String = "A".repeat(8000);
659        // Ten ~8KiB headers = ~80 KiB > 64 KiB cap.
660        for i in 0..10 {
661            req.extend_from_slice(
662                format!("H{i}: {filler}\r\n").as_bytes(),
663            );
664        }
665        req.extend_from_slice(b"\r\n");
666        let err = run_request_bytes(req).unwrap_err();
667        assert!(
668            err.to_string().contains("Header section too large"),
669            "unexpected error: {err}"
670        );
671    }
672
673    #[test]
674    fn test_too_many_headers_returns_error() {
675        let mut req = Vec::from("GET / HTTP/1.1\r\n");
676        for i in 0..=MAX_HEADER_COUNT {
677            req.extend_from_slice(format!("H{i}: v\r\n").as_bytes());
678        }
679        req.extend_from_slice(b"\r\n");
680        let err = run_request_bytes(req).unwrap_err();
681        assert!(
682            err.to_string().contains("Too many request headers"),
683            "unexpected error: {err}"
684        );
685    }
686
687    #[test]
688    fn test_missing_http_version_returns_error() {
689        // Two-token request line: method + path, no version.
690        // Triggers the third let-else branch (missing HTTP version).
691        let err = run_request_bytes(b"GET /\r\n".to_vec()).unwrap_err();
692        assert!(
693            err.to_string().contains("missing HTTP version"),
694            "unexpected error: {err}"
695        );
696    }
697
698    #[test]
699    fn test_request_display_formats_method_path_version() {
700        let request = Request {
701            method: "GET".to_string(),
702            path: "/index.html".to_string(),
703            version: "HTTP/1.1".to_string(),
704            headers: Vec::new(),
705        };
706        assert_eq!(format!("{request}"), "GET /index.html HTTP/1.1");
707    }
708}