1use crate::error::ServerError;
9use std::fmt;
10use std::io::{BufRead, BufReader};
11use std::net::TcpStream;
12use std::time::Duration;
13
14const MAX_REQUEST_LINE_LENGTH: usize = 8190;
17
18const REQUEST_PARTS: usize = 3;
20
21const TIMEOUT_SECONDS: u64 = 30;
23
24#[derive(Debug, Clone, PartialEq)]
26pub struct Request {
27 pub method: String,
29 pub path: String,
31 pub version: String,
33}
34
35impl Request {
36 pub fn from_stream(
73 stream: &TcpStream,
74 ) -> Result<Self, ServerError> {
75 stream
76 .set_read_timeout(Some(Duration::from_secs(
77 TIMEOUT_SECONDS,
78 )))
79 .map_err(|e| {
80 ServerError::invalid_request(format!(
81 "Failed to set read timeout: {}",
82 e
83 ))
84 })?;
85
86 let mut buf_reader = BufReader::new(stream);
87 let mut request_line = String::new();
88
89 let _ =
90 buf_reader.read_line(&mut request_line).map_err(|e| {
91 ServerError::invalid_request(format!(
92 "Failed to read request line: {}",
93 e
94 ))
95 })?;
96
97 let trimmed_request_line = request_line.trim_end();
99
100 if request_line.len() > MAX_REQUEST_LINE_LENGTH {
102 return Err(ServerError::invalid_request(format!(
103 "Request line too long: {} characters (max {})",
104 request_line.len(),
105 MAX_REQUEST_LINE_LENGTH
106 )));
107 }
108
109 let parts: Vec<&str> =
110 trimmed_request_line.split_whitespace().collect();
111
112 if parts.len() != REQUEST_PARTS {
113 return Err(ServerError::invalid_request(format!(
114 "Invalid request line: expected {} parts, got {}",
115 REQUEST_PARTS,
116 parts.len()
117 )));
118 }
119
120 let method = parts[0].to_string();
121 if !Self::is_valid_method(&method) {
122 return Err(ServerError::invalid_request(format!(
123 "Invalid HTTP method: {}",
124 method
125 )));
126 }
127
128 let path = parts[1].to_string();
129 if !path.starts_with('/') {
130 return Err(ServerError::invalid_request(
131 "Invalid path: must start with '/'",
132 ));
133 }
134
135 let version = parts[2].to_string();
136 if !Self::is_valid_version(&version) {
137 return Err(ServerError::invalid_request(format!(
138 "Invalid HTTP version: {}",
139 version
140 )));
141 }
142
143 Ok(Request {
144 method,
145 path,
146 version,
147 })
148 }
149
150 pub fn method(&self) -> &str {
156 &self.method
157 }
158
159 pub fn path(&self) -> &str {
165 &self.path
166 }
167
168 pub fn version(&self) -> &str {
174 &self.version
175 }
176
177 fn is_valid_method(method: &str) -> bool {
187 matches!(
188 method.to_ascii_uppercase().as_str(),
189 "GET"
190 | "POST"
191 | "PUT"
192 | "DELETE"
193 | "HEAD"
194 | "OPTIONS"
195 | "PATCH"
196 )
197 }
198
199 fn is_valid_version(version: &str) -> bool {
209 version.eq_ignore_ascii_case("HTTP/1.0")
210 || version.eq_ignore_ascii_case("HTTP/1.1")
211 }
212}
213
214impl fmt::Display for Request {
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 write!(f, "{} {} {}", self.method, self.path, self.version)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use std::io::Write;
224 use std::net::TcpListener;
225
226 #[test]
227 fn test_valid_request() {
228 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
229 let addr = listener.local_addr().unwrap();
230
231 let _ = std::thread::spawn(move || {
232 let (mut stream, _) = listener.accept().unwrap();
233 stream.write_all(b"GET /index.html HTTP/1.1\r\n").unwrap();
234 });
235
236 let stream = TcpStream::connect(addr).unwrap();
237 let request = Request::from_stream(&stream).unwrap();
238
239 assert_eq!(request.method(), "GET");
240 assert_eq!(request.path(), "/index.html");
241 assert_eq!(request.version(), "HTTP/1.1");
242 }
243
244 #[test]
245 fn test_invalid_method() {
246 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
247 let addr = listener.local_addr().unwrap();
248
249 let _ = std::thread::spawn(move || {
250 let (mut stream, _) = listener.accept().unwrap();
251 stream
252 .write_all(b"INVALID /index.html HTTP/1.1\r\n")
253 .unwrap();
254 });
255
256 let stream = TcpStream::connect(addr).unwrap();
257 let result = Request::from_stream(&stream);
258
259 assert!(result.is_err());
260 assert!(matches!(
261 result.unwrap_err(),
262 ServerError::InvalidRequest(_)
263 ));
264 }
265
266 #[test]
267 fn test_max_length_request() {
268 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
269 let addr = listener.local_addr().unwrap();
270
271 let _ = std::thread::spawn(move || {
272 let (mut stream, _) = listener.accept().unwrap();
273 let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 16); let request = format!("GET {} HTTP/1.1\r\n", long_path);
275 stream.write_all(request.as_bytes()).unwrap();
276 });
277
278 let stream = TcpStream::connect(addr).unwrap();
279 let result = Request::from_stream(&stream);
280
281 assert!(
282 result.is_ok(),
283 "Max length request should be valid. Error: {:?}",
284 result.err()
285 );
286 assert_eq!(
287 result.unwrap().path().len(),
288 MAX_REQUEST_LINE_LENGTH - 16
289 );
290 }
291
292 #[test]
293 fn test_oversized_request() {
294 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
295 let addr = listener.local_addr().unwrap();
296
297 let _ = std::thread::spawn(move || {
298 let (mut stream, _) = listener.accept().unwrap();
299 let long_path = "/".repeat(MAX_REQUEST_LINE_LENGTH - 13); let request = format!("GET {} HTTP/1.1\r\n", long_path);
301 stream.write_all(request.as_bytes()).unwrap();
302 });
303
304 let stream = TcpStream::connect(addr).unwrap();
305 let result = Request::from_stream(&stream);
306
307 assert!(
308 result.is_err(),
309 "Oversized request should be invalid. Request: {:?}",
310 result
311 );
312 match result.unwrap_err() {
313 ServerError::InvalidRequest(msg) => {
314 assert!(
315 msg.starts_with("Request line too long:"),
316 "Unexpected error message: {}",
317 msg
318 );
319 }
320 _ => panic!("Unexpected error type"),
321 }
322 }
323
324 #[test]
325 fn test_invalid_path() {
326 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
327 let addr = listener.local_addr().unwrap();
328
329 let _ = std::thread::spawn(move || {
330 let (mut stream, _) = listener.accept().unwrap();
331 stream.write_all(b"GET index.html HTTP/1.1\r\n").unwrap();
332 });
333
334 let stream = TcpStream::connect(addr).unwrap();
335 let result = Request::from_stream(&stream);
336
337 assert!(result.is_err());
338 assert!(matches!(
339 result.unwrap_err(),
340 ServerError::InvalidRequest(_)
341 ));
342 }
343
344 #[test]
345 fn test_invalid_version() {
346 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
347 let addr = listener.local_addr().unwrap();
348
349 let _ = std::thread::spawn(move || {
350 let (mut stream, _) = listener.accept().unwrap();
351 stream.write_all(b"GET /index.html HTTP/2.0\r\n").unwrap();
352 });
353
354 let stream = TcpStream::connect(addr).unwrap();
355 let result = Request::from_stream(&stream);
356
357 assert!(result.is_err());
358 assert!(matches!(
359 result.unwrap_err(),
360 ServerError::InvalidRequest(_)
361 ));
362 }
363}