http_handle/
request.rs

1// src/request.rs
2
3//! HTTP request parsing module for the Http Handle.
4//!
5//! This module provides functionality to parse incoming HTTP requests from a TCP stream.
6//! It defines the `Request` struct and associated methods for creating and interacting with HTTP requests in a secure and robust manner.
7
8use crate::error::ServerError;
9use std::fmt;
10use std::io::{BufRead, BufReader};
11use std::net::TcpStream;
12use std::time::Duration;
13
14/// Maximum length allowed for the request line (8KB).
15/// This includes the method, path, version, and the two spaces between them, but not the trailing \r\n.
16const MAX_REQUEST_LINE_LENGTH: usize = 8190;
17
18/// Number of parts expected in a valid HTTP request line.
19const REQUEST_PARTS: usize = 3;
20
21/// Timeout duration for reading from the TCP stream (in seconds).
22const TIMEOUT_SECONDS: u64 = 30;
23
24/// Represents an HTTP request, containing the HTTP method, the requested path, and the HTTP version.
25#[derive(Debug, Clone, PartialEq)]
26pub struct Request {
27    /// HTTP method of the request.
28    pub method: String,
29    /// Requested path.
30    pub path: String,
31    /// HTTP version of the request.
32    pub version: String,
33}
34
35impl Request {
36    /// Attempts to create a `Request` from the provided TCP stream by reading the first line.
37    ///
38    /// This method reads the first line of an HTTP request from the given TCP stream,
39    /// parses it, and constructs a `Request` instance if the input is valid.
40    ///
41    /// # Arguments
42    ///
43    /// * `stream` - A reference to the `TcpStream` from which the request will be read.
44    ///
45    /// # Returns
46    ///
47    /// * `Ok(Request)` - If the request is valid and successfully parsed.
48    /// * `Err(ServerError)` - If the request is malformed, cannot be read, or is invalid.
49    ///
50    /// # Errors
51    ///
52    /// This function returns a `ServerError::InvalidRequest` error if:
53    /// - The request line is too long (exceeds `MAX_REQUEST_LINE_LENGTH`)
54    /// - The request line does not contain exactly three parts
55    /// - The HTTP method is not recognized
56    /// - The request path does not start with a forward slash
57    /// - The HTTP version is not supported (only HTTP/1.0 and HTTP/1.1 are accepted)
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use std::net::TcpStream;
63    /// use http_handle::request::Request;
64    ///
65    /// fn handle_client(stream: TcpStream) {
66    ///     match Request::from_stream(&stream) {
67    ///         Ok(request) => println!("Received request: {}", request),
68    ///         Err(e) => eprintln!("Error parsing request: {}", e),
69    ///     }
70    /// }
71    /// ```
72    pub fn from_stream(
73        stream: &TcpStream,
74    ) -> Result<Self, ServerError> {
75        stream
76            .set_read_timeout(Some(Duration::from_secs(
77                TIMEOUT_SECONDS,
78            )))
79            .map_err(|e| {
80                ServerError::invalid_request(format!(
81                    "Failed to set read timeout: {}",
82                    e
83                ))
84            })?;
85
86        let mut buf_reader = BufReader::new(stream);
87        let mut request_line = String::new();
88
89        let _ =
90            buf_reader.read_line(&mut request_line).map_err(|e| {
91                ServerError::invalid_request(format!(
92                    "Failed to read request line: {}",
93                    e
94                ))
95            })?;
96
97        // Trim the trailing \r\n before checking the length
98        let trimmed_request_line = request_line.trim_end();
99
100        // Check if the request line exceeds the maximum allowed length
101        if request_line.len() > MAX_REQUEST_LINE_LENGTH {
102            return Err(ServerError::invalid_request(format!(
103                "Request line too long: {} characters (max {})",
104                request_line.len(),
105                MAX_REQUEST_LINE_LENGTH
106            )));
107        }
108
109        let parts: Vec<&str> =
110            trimmed_request_line.split_whitespace().collect();
111
112        if parts.len() != REQUEST_PARTS {
113            return Err(ServerError::invalid_request(format!(
114                "Invalid request line: expected {} parts, got {}",
115                REQUEST_PARTS,
116                parts.len()
117            )));
118        }
119
120        let method = parts[0].to_string();
121        if !Self::is_valid_method(&method) {
122            return Err(ServerError::invalid_request(format!(
123                "Invalid HTTP method: {}",
124                method
125            )));
126        }
127
128        let path = parts[1].to_string();
129        if !path.starts_with('/') {
130            return Err(ServerError::invalid_request(
131                "Invalid path: must start with '/'",
132            ));
133        }
134
135        let version = parts[2].to_string();
136        if !Self::is_valid_version(&version) {
137            return Err(ServerError::invalid_request(format!(
138                "Invalid HTTP version: {}",
139                version
140            )));
141        }
142
143        Ok(Request {
144            method,
145            path,
146            version,
147        })
148    }
149
150    /// Returns the HTTP method of the request.
151    ///
152    /// # Returns
153    ///
154    /// A string slice containing the HTTP method (e.g., "GET", "POST").
155    pub fn method(&self) -> &str {
156        &self.method
157    }
158
159    /// Returns the requested path of the request.
160    ///
161    /// # Returns
162    ///
163    /// A string slice containing the requested path.
164    pub fn path(&self) -> &str {
165        &self.path
166    }
167
168    /// Returns the HTTP version of the request.
169    ///
170    /// # Returns
171    ///
172    /// A string slice containing the HTTP version (e.g., "HTTP/1.1").
173    pub fn version(&self) -> &str {
174        &self.version
175    }
176
177    /// Checks if the given method is a valid HTTP method.
178    ///
179    /// # Arguments
180    ///
181    /// * `method` - A string slice containing the HTTP method to validate.
182    ///
183    /// # Returns
184    ///
185    /// `true` if the method is valid, `false` otherwise.
186    fn is_valid_method(method: &str) -> bool {
187        matches!(
188            method.to_ascii_uppercase().as_str(),
189            "GET"
190                | "POST"
191                | "PUT"
192                | "DELETE"
193                | "HEAD"
194                | "OPTIONS"
195                | "PATCH"
196        )
197    }
198
199    /// Checks if the given HTTP version is supported.
200    ///
201    /// # Arguments
202    ///
203    /// * `version` - A string slice containing the HTTP version to validate.
204    ///
205    /// # Returns
206    ///
207    /// `true` if the version is supported, `false` otherwise.
208    fn is_valid_version(version: &str) -> bool {
209        version.eq_ignore_ascii_case("HTTP/1.0")
210            || version.eq_ignore_ascii_case("HTTP/1.1")
211    }
212}
213
214impl fmt::Display for Request {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        write!(f, "{} {} {}", self.method, self.path, self.version)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use std::io::Write;
224    use std::net::TcpListener;
225
226    #[test]
227    fn test_valid_request() {
228        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
229        let addr = listener.local_addr().unwrap();
230
231        let _ = std::thread::spawn(move || {
232            let (mut stream, _) = listener.accept().unwrap();
233            stream.write_all(b"GET /index.html HTTP/1.1\r\n").unwrap();
234        });
235
236        let stream = TcpStream::connect(addr).unwrap();
237        let request = Request::from_stream(&stream).unwrap();
238
239        assert_eq!(request.method(), "GET");
240        assert_eq!(request.path(), "/index.html");
241        assert_eq!(request.version(), "HTTP/1.1");
242    }
243
244    #[test]
245    fn test_invalid_method() {
246        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
247        let addr = listener.local_addr().unwrap();
248
249        let _ = std::thread::spawn(move || {
250            let (mut stream, _) = listener.accept().unwrap();
251            stream
252                .write_all(b"INVALID /index.html HTTP/1.1\r\n")
253                .unwrap();
254        });
255
256        let stream = TcpStream::connect(addr).unwrap();
257        let result = Request::from_stream(&stream);
258
259        assert!(result.is_err());
260        assert!(matches!(
261            result.unwrap_err(),
262            ServerError::InvalidRequest(_)
263        ));
264    }
265
266    #[test]
267    fn test_max_length_request() {
268        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
269        let addr = listener.local_addr().unwrap();
270
271        let _ = std::thread::spawn(move || {
272            let (mut stream, _) = listener.accept().unwrap();
273            let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 16); // Account for "GET ", " HTTP/1.1", and "\r\n"
274            let request = format!("GET {} HTTP/1.1\r\n", long_path);
275            stream.write_all(request.as_bytes()).unwrap();
276        });
277
278        let stream = TcpStream::connect(addr).unwrap();
279        let result = Request::from_stream(&stream);
280
281        assert!(
282            result.is_ok(),
283            "Max length request should be valid. Error: {:?}",
284            result.err()
285        );
286        assert_eq!(
287            result.unwrap().path().len(),
288            MAX_REQUEST_LINE_LENGTH - 16
289        );
290    }
291
292    #[test]
293    fn test_oversized_request() {
294        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
295        let addr = listener.local_addr().unwrap();
296
297        let _ = std::thread::spawn(move || {
298            let (mut stream, _) = listener.accept().unwrap();
299            let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 13); // 13 = len("GET  HTTP/1.1")
300            let request = format!("GET {} HTTP/1.1\r\n", long_path);
301            stream.write_all(request.as_bytes()).unwrap();
302        });
303
304        let stream = TcpStream::connect(addr).unwrap();
305        let result = Request::from_stream(&stream);
306
307        assert!(
308            result.is_err(),
309            "Oversized request should be invalid. Request: {:?}",
310            result
311        );
312        match result.unwrap_err() {
313            ServerError::InvalidRequest(msg) => {
314                assert!(
315                    msg.starts_with("Request line too long:"),
316                    "Unexpected error message: {}",
317                    msg
318                );
319            }
320            _ => panic!("Unexpected error type"),
321        }
322    }
323
324    #[test]
325    fn test_invalid_path() {
326        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
327        let addr = listener.local_addr().unwrap();
328
329        let _ = std::thread::spawn(move || {
330            let (mut stream, _) = listener.accept().unwrap();
331            stream.write_all(b"GET index.html HTTP/1.1\r\n").unwrap();
332        });
333
334        let stream = TcpStream::connect(addr).unwrap();
335        let result = Request::from_stream(&stream);
336
337        assert!(result.is_err());
338        assert!(matches!(
339            result.unwrap_err(),
340            ServerError::InvalidRequest(_)
341        ));
342    }
343
344    #[test]
345    fn test_invalid_version() {
346        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
347        let addr = listener.local_addr().unwrap();
348
349        let _ = std::thread::spawn(move || {
350            let (mut stream, _) = listener.accept().unwrap();
351            stream.write_all(b"GET /index.html HTTP/2.0\r\n").unwrap();
352        });
353
354        let stream = TcpStream::connect(addr).unwrap();
355        let result = Request::from_stream(&stream);
356
357        assert!(result.is_err());
358        assert!(matches!(
359            result.unwrap_err(),
360            ServerError::InvalidRequest(_)
361        ));
362    }
363}