1
2use std::fmt::Display;
3use std::str::{self, Utf8Error};
4use std::net::SocketAddr;
5use std::error;
6
7use crate::cookie::{Cookie, parse_cookies};
8use crate::util::*;
9
10use debug_print::{debug_eprintln, debug_println};
11
12
13#[derive(Debug, Default)]
16pub struct Request<'buf> {
17 pub body: Option<&'buf [u8]>,
18 pub method: &'buf str,
19 pub document: &'buf str,
20 pub query_raw: &'buf str,
21 pub protocol: &'buf str,
22 pub version: &'buf str,
23 pub header_raw_lines: Vec<&'buf str>,
24
25 headers: Vec<(&'buf str, &'buf str)>,
26 get: Vec<(&'buf str, &'buf str)>,
27 post: Vec<(&'buf str, &'buf str)>,
28 cookies: Vec<Cookie<'buf>>,
29
30 host: Option<&'buf str>,
31 user_agent: Option<&'buf str>,
32 content_type: Option<&'buf str>,
33 content_length: Option<usize>,
34
35 pub peer_addr: Option<SocketAddr>,
36}
37
38#[derive(Debug)]
39pub enum RequestError<'buf> {
40 RequestLineMalformed(Vec<&'buf [u8]>),
41
42 DocumentNotUtf8(Utf8Error),
43 DocumentMalformed(&'buf [u8]),
44
45 MethodNotUtf8(Utf8Error),
46
47 QueryNotUtf8(Utf8Error),
48
49 ProtoNotUtf8(Utf8Error),
50 ProtoMalformed(&'buf [u8]),
51 ProtoInvalid(&'buf [u8]),
52
53 ProtoVersionNotUtf8(Utf8Error),
54 ProtoVersionInvalid(&'buf [u8]),
55
56 HeadersNotUtf8(Utf8Error),
57
58 ContentLengthDiscrepancy {expected: usize, got: usize },
59
60 PostParamsMalformed(&'buf [u8]),
61}
62impl Display for RequestError<'_> {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{:?}", self)
65 }
66}
67impl error::Error for RequestError<'_> {}
68
69#[allow(dead_code)]
70impl<'buf> Request<'buf> {
71 pub fn from_slice(buf: &'buf [u8]) -> Result<Self, RequestError<'buf>> {
73 Self::new(buf, None)
74 }
75
76 pub fn bad() -> Self {
78 Self::default()
79 }
80
81 pub fn new(
83 buf: &'buf [u8],
84 peer_addr: Option<&SocketAddr>
85 ) -> Result<Self, RequestError<'buf>> {
86 let (mut request_head, request_body) = request_head_body_split(buf);
87
88 loop {
90 request_head = match request_head.strip_prefix(b"\r\n") {
91 Some(head) => head,
92 None => break,
93 };
94 }
95
96 let body = request_body;
97
98 let (request_line, request_headers) = request_line_header_split(request_head);
99
100 let request_line_items: [&[u8]; 3] = request_line
101 .split(|c| *c == b' ')
102 .collect::<Vec<&[u8]>>()
103 .try_into()
104 .map_err(RequestError::RequestLineMalformed)?;
105
106 let method = str::from_utf8(&request_line_items[0])
107 .map_err(RequestError::MethodNotUtf8)?;
108
109 let (document_slice, query) = split_once(request_line_items[1], b'?');
110
111 let document = str::from_utf8(document_slice)
112 .map_err(RequestError::DocumentNotUtf8)?;
113
114 if !document.starts_with('/') {
115 debug_eprintln!("ERROR: {document} does not start with /");
116 return Err(RequestError::DocumentMalformed(document_slice));
117 }
118
119 let query = match query {
120 None => "",
121 Some(thing) => str::from_utf8(thing)
122 .map_err(RequestError::QueryNotUtf8)?
123 };
124
125 let proto_version_items: [&[u8]; 2] = match request_line_items[2]
126 .split(|c| *c == b'/')
127 .collect::<Vec<&[u8]>>()
128 .try_into() {
129 Err(_) => {
130 debug_eprintln!("ERROR: Invalid protocol string: {}",
131 str::from_utf8(request_line_items[2])
132 .unwrap_or(&format!("{:?}", request_line_items[2])));
133 return Err(RequestError::ProtoMalformed(request_line_items[2]));
134 },
135 Ok(items) => items,
136 };
137
138 let protocol = str::from_utf8(proto_version_items[0])
139 .map_err(RequestError::ProtoNotUtf8)?;
140
141 if protocol != "HTTP" {
142 debug_eprintln!("ERROR: Invalid protocol {protocol}");
143 return Err(RequestError::ProtoInvalid(request_line));
144 }
145
146 let version = str::from_utf8(proto_version_items[1])
147 .map_err(RequestError::ProtoVersionNotUtf8)?
148 .trim_end_matches(|c| ['\r', '\n', '\0'].contains(&c));
149
150 if version != "1.1" {
151 debug_eprintln!("ERROR: Invalid version {version}");
152 return Err(RequestError::ProtoVersionInvalid(request_line));
153 }
154
155 let header_raw_lines = str::from_utf8(request_headers.unwrap_or_default())
156 .map_err(RequestError::HeadersNotUtf8)?
157 .split(&"\r\n")
158 .collect::<Vec<_>>();
159
160 let headers_len = header_raw_lines.len();
161
162 Ok(Self {
164 body,
165 method,
166 document,
167 query_raw: query,
168 protocol,
169 version,
170 header_raw_lines,
171 headers: Vec::with_capacity(headers_len),
172 get: Vec::new(),
173 post: Vec::new(),
174 cookies: Vec::new(),
175 host: None,
176 user_agent: None,
177 content_type: None,
178 content_length: None,
179 peer_addr: peer_addr.copied(),
180 })
181 }
182
183 pub fn host(&mut self) -> Option<&'buf str> {
184 if let Some(host) = self.host {
185 return Some(host);
186 } else {
187 if let Some(host) = self.header("Host") {
188 self.host = Some(host);
189 return Some(host);
190 } else {
191 return None;
192 }
193 }
194 }
195
196 pub fn user_agent(&mut self) -> Option<&'buf str> {
197 if let Some(ua) = self.user_agent {
198 return Some(ua);
199 } else {
200 if let Some(ua) = self.header("User-Agent") {
201 self.user_agent = Some(ua);
202 return Some(ua);
203 } else {
204 return None;
205 }
206 }
207 }
208
209 pub fn content_type(&mut self) -> Option<&'buf str> {
210 if let Some(ct) = self.content_type {
211 return Some(ct);
212 } else {
213 if let Some(ct) = self.header("Content-Type") {
214 self.content_type = Some(ct);
215 return Some(ct);
216 } else {
217 return None;
218 }
219 }
220 }
221
222 pub fn content_length(&mut self) -> Option<usize> {
223 if let Some(cl) = self.content_length {
224 return Some(cl);
225 } else {
226 if let Some(cl) = self.header("Content-Length") {
227 let cl = cl.parse::<usize>().ok();
228 self.content_length = cl;
229 return cl;
230 } else {
231 return None;
232 }
233 }
234 }
235
236 pub fn header(&mut self, key: &str) -> Option<&'buf str> {
239 if self.header_raw_lines.is_empty() {
240 return None;
241 }
242 if let Some((_k, v)) = self.headers.iter()
243 .find(|(k, _v)| *k == key) {
244 return Some(v);
245 } else {
246 if let Some(raw) = self.header_raw_lines.iter()
247 .find(|line| line.find(": ").map(|idx| &line[..idx] == key).unwrap_or(false)) {
248 if let Some((key, value)) = parse_header(raw) {
249 self.headers.push((key, value));
250 return Some(value);
251 }
252 }
253 }
254 return None;
255 }
256
257 pub fn cookie(&mut self, key: &str) -> Option<&Cookie<'buf>> {
260 if self.header_raw_lines.is_empty() {
261 return None;
262 }
263 if self.cookies.is_empty() {
264 if let Some(cookies_raw) = self.header("Cookie") {
265 let cookies = parse_cookies(cookies_raw);
266 if cookies.is_empty() {
267 return None;
268 }
269 self.cookies = cookies;
270 } else {
271 return None;
272 }
273 }
274 return self.cookies.iter()
275 .find(|c| c.name == key);
276 }
277
278 pub fn get(&mut self, key: &str) -> Option<&str> {
281 if self.query_raw.is_empty() {
282 return None;
283 }
284 if self.get.is_empty() {
285 if let Some(get) = parse_parameters(self.query_raw).ok() {
286 if get.is_empty() {
287 return None;
288 }
289 self.get = get;
290 } else {
291 return None;
292 }
293 }
294 return self.get.iter()
295 .find(|(k, _v)| *k == key)
296 .map(|(_k, v)| *v);
297 }
298
299 pub fn post(&mut self, key: &str) -> Option<&str> {
302 if self.method != "POST" {
304 return None;
305 }
306 if self.body.is_none() {
308 return None;
309 }
310 if self.post.is_empty() {
312 if let Some(content_len) = self.content_length() {
314 if content_len == 0 {
316 return None;
317 }
318 if let Some(content_type) = self.content_type() {
320 if content_type != "application/x-www-form-urlencoded" {
322 return None;
323 }
324 if let Some(body) = self.body {
326 if let Some(body) = str::from_utf8(body.get(0..content_len)?).ok() {
328 match parse_parameters(body) {
330 Ok(params) if params.is_empty() => {
331 return None;
332 }
333 Ok(params) => {
334 self.post = params;
336 return self.post.iter()
338 .find(|(k, _v)| *k == key)
339 .map(|(_k, v)| *v);
340 },
341 Err(err) => {
342 debug_println!("ERROR: Invalid post parameters: {body}: {}", err);
343 },
344 }
345 }
346 }
347 }
348 }
349 } else {
350 return self.post.iter()
352 .find(|(k, _v)| *k == key)
353 .map(|(_k, v)| *v);
354 }
355 return None;
356 }
357}
358
359fn request_line_header_split(to_split: &[u8]) -> (&[u8], Option<&[u8]>) {
362 let mut found_cr = false;
363 let mut found_lf = false;
364 let mut crlf_start_idx = 0;
365
366 for (idx, byte) in to_split.iter().enumerate() {
368 if *byte == b'\r' {
369 crlf_start_idx = idx;
370 found_cr = true;
371 continue;
372 }
373 if found_cr && *byte == b'\n' {
374 found_lf = true;
375 break;
376 }
377 crlf_start_idx = 0;
378 found_cr = false;
379 }
380
381 if crlf_start_idx == 0 || !found_cr || !found_lf {
383 let line_cleaned = match to_split.strip_suffix(b"\r\n") {
384 None => return (to_split, None),
385 Some(thing) => thing,
386 };
387 return (line_cleaned, None);
388 }
389
390 let (req_line, req_headers) = to_split.split_at(crlf_start_idx);
392 let req_headers = req_headers.split_at(2).1;
393 (req_line, Some(req_headers))
394}
395
396fn request_head_body_split(to_split: &[u8]) -> (&[u8], Option<&[u8]>) {
400 let mut found_cr = false;
401 let mut crlf_count = 0;
402 let mut crlf_start_idx = 0;
403
404 for (idx, byte) in to_split.iter().enumerate() {
406 if crlf_count == 2 { break;
408 }
409 if *byte == b'\r' {
410 if crlf_count == 0 { crlf_start_idx = idx;
412 }
413 found_cr = true;
414 continue;
415 }
416 if found_cr && *byte == b'\n' {
417 crlf_count += 1;
418 found_cr = false;
419 continue;
420 }
421 crlf_count = 0;
422 crlf_start_idx = 0;
423 found_cr = false;
424 }
425
426 if crlf_start_idx == 0 {
428 return (to_split, None);
429 }
430
431 if crlf_count != 2 {
433 return (to_split, None);
434 }
435
436 let (head, body) = to_split.split_at(crlf_start_idx);
438 let body = body.split_at(4).1;
439 (head, Some(body))
440}
441