1use crate::_abnf::{METHOD, REASON_PHRASE, REQUEST_TARGET};
2use crate::{_headers::Headers, _util::ProtocolError};
3use lazy_static::lazy_static;
4use regex::bytes::Regex;
5use std::fmt::{self, Formatter};
6
7lazy_static! {
8 static ref HTTP_VERSION_RE: Regex = Regex::new(r"^[0-9]\.[0-9]$").unwrap();
9 static ref METHOD_RE: Regex = Regex::new(&format!(r"^{}$", *METHOD)).unwrap();
10 static ref REASON_RE: Regex = Regex::new(&format!(r"^{}$", *REASON_PHRASE)).unwrap();
11 static ref REQUEST_TARGET_RE: Regex = Regex::new(&format!(r"^{}$", *REQUEST_TARGET)).unwrap();
12}
13
14#[derive(Clone, PartialEq, Eq, Default)]
20pub struct Request {
21 pub method: Vec<u8>,
23 pub headers: Headers,
25 pub target: Vec<u8>,
27 pub http_version: Vec<u8>,
29}
30
31impl Request {
32 pub fn new<M, T, V>(
34 method: M,
35 headers: Headers,
36 target: T,
37 http_version: V,
38 ) -> Result<Self, ProtocolError>
39 where
40 M: AsRef<[u8]>,
41 T: AsRef<[u8]>,
42 V: AsRef<[u8]>,
43 {
44 let request = Self {
45 method: method.as_ref().to_vec(),
46 headers,
47 target: target.as_ref().to_vec(),
48 http_version: http_version.as_ref().to_vec(),
49 };
50 request.validate()?;
51 Ok(request)
52 }
53
54 pub fn new_http11<M, T>(method: M, headers: Headers, target: T) -> Result<Self, ProtocolError>
56 where
57 M: AsRef<[u8]>,
58 T: AsRef<[u8]>,
59 {
60 Self::new(method, headers, target, b"1.1")
61 }
62
63 pub fn validate(&self) -> Result<(), ProtocolError> {
65 let mut host_count = 0;
66 for (name, _) in self.headers.iter() {
67 if name == b"host" {
68 host_count += 1;
69 }
70 }
71 if !HTTP_VERSION_RE.is_match(&self.http_version) {
72 return Err(ProtocolError::LocalProtocolError(
73 ("Illegal HTTP version".to_string(), 400).into(),
74 ));
75 }
76 if self.http_version == b"1.1" && host_count == 0 {
77 return Err(ProtocolError::LocalProtocolError(
78 ("Missing mandatory Host: header".to_string(), 400).into(),
79 ));
80 }
81 if host_count > 1 {
82 return Err(ProtocolError::LocalProtocolError(
83 ("Found multiple Host: headers".to_string(), 400).into(),
84 ));
85 }
86
87 if !METHOD_RE.is_match(&self.method) {
88 return Err(ProtocolError::LocalProtocolError(
89 ("Illegal method characters".to_string(), 400).into(),
90 ));
91 }
92 if !REQUEST_TARGET_RE.is_match(&self.target) {
93 return Err(ProtocolError::LocalProtocolError(
94 ("Illegal target characters".to_string(), 400).into(),
95 ));
96 }
97
98 Ok(())
99 }
100}
101
102impl std::fmt::Debug for Request {
103 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
104 f.debug_struct("Request")
105 .field("method", &String::from_utf8_lossy(&self.method))
106 .field("headers", &self.headers)
107 .field("target", &String::from_utf8_lossy(&self.target))
108 .field("http_version", &String::from_utf8_lossy(&self.http_version))
109 .finish()
110 }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq, Default)]
119pub struct Response {
120 pub headers: Headers,
122 pub http_version: Vec<u8>,
124 pub reason: Vec<u8>,
126 pub status_code: u16,
128}
129
130impl Response {
131 pub fn new<R, V>(
133 status_code: u16,
134 headers: Headers,
135 reason: R,
136 http_version: V,
137 ) -> Result<Self, ProtocolError>
138 where
139 R: AsRef<[u8]>,
140 V: AsRef<[u8]>,
141 {
142 let response = Self {
143 headers,
144 http_version: http_version.as_ref().to_vec(),
145 reason: reason.as_ref().to_vec(),
146 status_code,
147 };
148 response.validate()?;
149 Ok(response)
150 }
151
152 pub fn new_http11<R>(
154 status_code: u16,
155 headers: Headers,
156 reason: R,
157 ) -> Result<Self, ProtocolError>
158 where
159 R: AsRef<[u8]>,
160 {
161 Self::new(status_code, headers, reason, b"1.1")
162 }
163
164 pub fn new_informational<R, V>(
168 status_code: u16,
169 headers: Headers,
170 reason: R,
171 http_version: V,
172 ) -> Result<Self, ProtocolError>
173 where
174 R: AsRef<[u8]>,
175 V: AsRef<[u8]>,
176 {
177 let response = Self::new(status_code, headers, reason, http_version)?;
178 if !(100..=199).contains(&response.status_code) {
179 return Err(ProtocolError::LocalProtocolError(
180 (
181 "Informational responses must use status codes in the range 100..=199",
182 400,
183 )
184 .into(),
185 ));
186 }
187 Ok(response)
188 }
189
190 pub fn new_informational_http11<R>(
194 status_code: u16,
195 headers: Headers,
196 reason: R,
197 ) -> Result<Self, ProtocolError>
198 where
199 R: AsRef<[u8]>,
200 {
201 Self::new_informational(status_code, headers, reason, b"1.1")
202 }
203
204 pub fn new_final<R, V>(
208 status_code: u16,
209 headers: Headers,
210 reason: R,
211 http_version: V,
212 ) -> Result<Self, ProtocolError>
213 where
214 R: AsRef<[u8]>,
215 V: AsRef<[u8]>,
216 {
217 let response = Self::new(status_code, headers, reason, http_version)?;
218 if response.status_code < 200 {
219 return Err(ProtocolError::LocalProtocolError(
220 ("Final responses must use status codes >= 200", 400).into(),
221 ));
222 }
223 Ok(response)
224 }
225
226 pub fn new_final_http11<R>(
230 status_code: u16,
231 headers: Headers,
232 reason: R,
233 ) -> Result<Self, ProtocolError>
234 where
235 R: AsRef<[u8]>,
236 {
237 Self::new_final(status_code, headers, reason, b"1.1")
238 }
239
240 pub fn validate(&self) -> Result<(), ProtocolError> {
242 if !(100..=999).contains(&self.status_code) {
243 return Err(ProtocolError::LocalProtocolError(
244 ("Illegal status code".to_string(), 400).into(),
245 ));
246 }
247 if !HTTP_VERSION_RE.is_match(&self.http_version) {
248 return Err(ProtocolError::LocalProtocolError(
249 ("Illegal HTTP version".to_string(), 400).into(),
250 ));
251 }
252 if !REASON_RE.is_match(&self.reason) {
253 return Err(ProtocolError::LocalProtocolError(
254 ("Illegal reason phrase".to_string(), 400).into(),
255 ));
256 }
257 Ok(())
258 }
259}
260
261#[derive(Debug, Clone, PartialEq, Eq, Default)]
263pub struct Data {
264 pub data: Vec<u8>,
266 pub chunk_start: bool,
268 pub chunk_end: bool,
270}
271
272#[derive(Debug, Clone, PartialEq, Eq, Default)]
274pub struct EndOfMessage {
275 pub headers: Headers,
277}
278
279#[derive(Debug, Clone, PartialEq, Eq, Default)]
281pub struct ConnectionClosed {}
282
283#[derive(Debug, Clone, PartialEq, Eq)]
285pub enum Event {
286 Request(Request),
288 NormalResponse(Response),
290 InformationalResponse(Response),
292 Data(Data),
294 EndOfMessage(EndOfMessage),
296 ConnectionClosed(ConnectionClosed),
298 NeedData(),
300 Paused(),
302}
303
304impl From<Request> for Event {
305 fn from(request: Request) -> Self {
306 Self::Request(request)
307 }
308}
309
310impl From<Response> for Event {
311 fn from(response: Response) -> Self {
312 match response.status_code {
313 100..=199 => Self::InformationalResponse(response),
314 _ => Self::NormalResponse(response),
315 }
316 }
317}
318
319impl Event {
320 pub fn informational_response(response: Response) -> Result<Self, ProtocolError> {
322 if !(100..=199).contains(&response.status_code) {
323 return Err(ProtocolError::LocalProtocolError(
324 (
325 "Informational responses must use status codes in the range 100..=199",
326 400,
327 )
328 .into(),
329 ));
330 }
331 response.validate()?;
332 Ok(Self::InformationalResponse(response))
333 }
334
335 pub fn normal_response(response: Response) -> Result<Self, ProtocolError> {
337 if response.status_code < 200 {
338 return Err(ProtocolError::LocalProtocolError(
339 ("Normal responses must use status codes >= 200", 400).into(),
340 ));
341 }
342 response.validate()?;
343 Ok(Self::NormalResponse(response))
344 }
345}
346
347impl From<Data> for Event {
348 fn from(data: Data) -> Self {
349 Self::Data(data)
350 }
351}
352
353impl From<EndOfMessage> for Event {
354 fn from(end_of_message: EndOfMessage) -> Self {
355 Self::EndOfMessage(end_of_message)
356 }
357}
358
359impl From<ConnectionClosed> for Event {
360 fn from(connection_closed: ConnectionClosed) -> Self {
361 Self::ConnectionClosed(connection_closed)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_response_new_rejects_invalid_input() {
371 assert!(Response::new(99, Headers::default(), b"OK".to_vec(), b"1.1".to_vec()).is_err());
372 assert!(Response::new(1000, Headers::default(), b"OK".to_vec(), b"1.1".to_vec()).is_err());
373 assert!(Response::new(
374 200,
375 Headers::default(),
376 b"OK".to_vec(),
377 b"HTTP/1.1".to_vec()
378 )
379 .is_err());
380 assert!(Response::new(
381 200,
382 Headers::default(),
383 b"bad\nreason".to_vec(),
384 b"1.1".to_vec()
385 )
386 .is_err());
387 }
388
389 #[test]
390 fn test_request_new_rejects_invalid_http_version() {
391 assert!(Request::new(
392 b"GET".to_vec(),
393 Headers::new(vec![(b"Host".to_vec(), b"example.com".to_vec())]).unwrap(),
394 b"/".to_vec(),
395 b"HTTP/1.1".to_vec(),
396 )
397 .is_err());
398 }
399
400 #[test]
401 fn test_request_new_accepts_borrowed_inputs_and_http11_default() {
402 let request =
403 Request::new_http11("GET", Headers::new([("Host", "example.com")]).unwrap(), "/")
404 .unwrap();
405
406 assert_eq!(request.method, b"GET");
407 assert_eq!(request.target, b"/");
408 assert_eq!(request.http_version, b"1.1");
409 }
410
411 #[test]
412 fn test_response_new_accepts_borrowed_inputs_and_http11_default() {
413 let response =
414 Response::new_http11(200, Headers::new([("Content-Length", "0")]).unwrap(), "OK")
415 .unwrap();
416
417 assert_eq!(response.status_code, 200);
418 assert_eq!(response.reason, b"OK");
419 assert_eq!(response.http_version, b"1.1");
420 }
421
422 #[test]
423 fn test_response_range_checked_constructors() {
424 let informational =
425 Response::new_informational_http11(100, Headers::default(), "Continue").unwrap();
426 assert_eq!(informational.status_code, 100);
427
428 let final_response = Response::new_final_http11(200, Headers::default(), "OK").unwrap();
429 assert_eq!(final_response.status_code, 200);
430
431 assert!(Response::new_informational_http11(200, Headers::default(), "OK").is_err());
432 assert!(Response::new_final_http11(199, Headers::default(), "Early").is_err());
433 }
434
435 #[test]
436 fn test_event_response_constructors_validate_status_ranges() {
437 let informational =
438 Response::new_informational_http11(100, Headers::default(), "Continue").unwrap();
439 assert!(matches!(
440 Event::informational_response(informational).unwrap(),
441 Event::InformationalResponse(_)
442 ));
443
444 let final_response =
445 Response::new_final_http11(204, Headers::default(), "No Content").unwrap();
446 assert!(matches!(
447 Event::normal_response(final_response).unwrap(),
448 Event::NormalResponse(_)
449 ));
450
451 let informational =
452 Response::new_informational_http11(101, Headers::default(), "Switching Protocols")
453 .unwrap();
454 assert!(Event::normal_response(informational).is_err());
455
456 let final_response = Response::new_final_http11(200, Headers::default(), "OK").unwrap();
457 assert!(Event::informational_response(final_response).is_err());
458 }
459}