immortal_http/
request.rs

1
2use std::fmt::Display;
3use std::str::{self, Utf8Error};
4use std::net::SocketAddr;
5use std::error;
6
7use crate::cookie::{Cookie, parse_cookies};
8use crate::util::*;
9
10use debug_print::{debug_eprintln, debug_println};
11
12
13/// Request contains the request representation that is serialised from the main HTTP request from
14/// the socket.
15#[derive(Debug, Default)]
16pub struct Request<'buf> {
17    pub body: Option<&'buf [u8]>,
18    pub method: &'buf str,
19    pub document: &'buf str,
20    pub query_raw: &'buf str,
21    pub protocol: &'buf str,
22    pub version: &'buf str,
23    pub header_raw_lines: Vec<&'buf str>,
24
25    headers: Vec<(&'buf str, &'buf str)>,
26    get: Vec<(&'buf str, &'buf str)>,
27    post: Vec<(&'buf str, &'buf str)>,
28    cookies: Vec<Cookie<'buf>>,
29
30    host: Option<&'buf str>,
31    user_agent: Option<&'buf str>,
32    content_type: Option<&'buf str>,
33    content_length: Option<usize>,
34
35    pub peer_addr: Option<SocketAddr>,
36}
37
38#[derive(Debug)]
39pub enum RequestError<'buf> {
40    RequestLineMalformed(Vec<&'buf [u8]>),
41
42    DocumentNotUtf8(Utf8Error),
43    DocumentMalformed(&'buf [u8]),
44
45    MethodNotUtf8(Utf8Error),
46
47    QueryNotUtf8(Utf8Error),
48
49    ProtoNotUtf8(Utf8Error),
50    ProtoMalformed(&'buf [u8]),
51    ProtoInvalid(&'buf [u8]),
52
53    ProtoVersionNotUtf8(Utf8Error),
54    ProtoVersionInvalid(&'buf [u8]),
55
56    HeadersNotUtf8(Utf8Error),
57
58    ContentLengthDiscrepancy {expected: usize, got: usize },
59
60    PostParamsMalformed(&'buf [u8]),
61}
62impl Display for RequestError<'_> {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(f, "{:?}", self)
65    }
66}
67impl error::Error for RequestError<'_> {}
68
69#[allow(dead_code)]
70impl<'buf> Request<'buf> {
71    /// Construct a new request object using only a slice of u8
72    pub fn from_slice(buf: &'buf [u8]) -> Result<Self, RequestError<'buf>> {
73        Self::new(buf, None)
74    }
75
76    /// Create a default request object for a fail state
77    pub fn bad() -> Self {
78        Self::default()
79    }
80
81    /// Construct a new request object, parsing the request buffer
82    pub fn new(
83        buf: &'buf [u8],
84        peer_addr: Option<&SocketAddr>
85    ) -> Result<Self, RequestError<'buf>> {
86        let (mut request_head, request_body) = request_head_body_split(buf);
87
88        // ignore preceding clrf if they exist
89        loop {
90            request_head = match request_head.strip_prefix(b"\r\n") {
91                Some(head) => head,
92                None => break,
93            };
94        }
95
96        let body = request_body;
97
98        let (request_line, request_headers) = request_line_header_split(request_head);
99
100        let request_line_items: [&[u8]; 3] = request_line
101            .split(|c| *c == b' ')
102            .collect::<Vec<&[u8]>>()
103            .try_into()
104            .map_err(RequestError::RequestLineMalformed)?;
105
106        let method = str::from_utf8(&request_line_items[0])
107            .map_err(RequestError::MethodNotUtf8)?;
108
109        let (document_slice, query) = split_once(request_line_items[1], b'?');
110
111        let document = str::from_utf8(document_slice)
112            .map_err(RequestError::DocumentNotUtf8)?;
113
114        if !document.starts_with('/') {
115            debug_eprintln!("ERROR: {document} does not start with /");
116            return Err(RequestError::DocumentMalformed(document_slice));
117        }
118
119        let query = match query {
120            None => "",
121            Some(thing) => str::from_utf8(thing)
122                .map_err(RequestError::QueryNotUtf8)?
123        };
124        
125        let proto_version_items: [&[u8]; 2] = match request_line_items[2]
126            .split(|c| *c == b'/')
127            .collect::<Vec<&[u8]>>()
128            .try_into() {
129                Err(_) => {
130                    debug_eprintln!("ERROR: Invalid protocol string: {}", 
131                        str::from_utf8(request_line_items[2])
132                            .unwrap_or(&format!("{:?}", request_line_items[2])));
133                    return Err(RequestError::ProtoMalformed(request_line_items[2]));
134                },
135                Ok(items) => items,
136        };
137
138        let protocol = str::from_utf8(proto_version_items[0])
139            .map_err(RequestError::ProtoNotUtf8)?;
140
141        if protocol != "HTTP" {
142            debug_eprintln!("ERROR: Invalid protocol {protocol}");
143            return Err(RequestError::ProtoInvalid(request_line));
144        }
145
146        let version = str::from_utf8(proto_version_items[1])
147            .map_err(RequestError::ProtoVersionNotUtf8)?
148            .trim_end_matches(|c| ['\r', '\n', '\0'].contains(&c));
149
150        if version != "1.1" {
151            debug_eprintln!("ERROR: Invalid version {version}");
152            return Err(RequestError::ProtoVersionInvalid(request_line));
153        }
154
155        let header_raw_lines = str::from_utf8(request_headers.unwrap_or_default())
156            .map_err(RequestError::HeadersNotUtf8)?
157            .split(&"\r\n")
158            .collect::<Vec<_>>();
159
160        let headers_len = header_raw_lines.len();
161
162        // emit a complete Request object
163        Ok(Self {
164            body,
165            method,
166            document,
167            query_raw: query,
168            protocol,
169            version,
170            header_raw_lines,
171            headers: Vec::with_capacity(headers_len),
172            get: Vec::new(),
173            post: Vec::new(),
174            cookies: Vec::new(),
175            host: None,
176            user_agent: None,
177            content_type: None,
178            content_length: None,
179            peer_addr: peer_addr.copied(),
180        })
181    }
182    
183    pub fn host(&mut self) -> Option<&'buf str> {
184        if let Some(host) = self.host {
185            return Some(host);
186        } else {
187            if let Some(host) = self.header("Host") {
188                self.host = Some(host);
189                return Some(host);
190            } else {
191                return None;
192            }
193        }
194    }
195
196    pub fn user_agent(&mut self) -> Option<&'buf str> {
197        if let Some(ua) = self.user_agent {
198            return Some(ua);
199        } else {
200            if let Some(ua) = self.header("User-Agent") {
201                self.user_agent = Some(ua);
202                return Some(ua);
203            } else {
204                return None;
205            }
206        }
207    }
208
209    pub fn content_type(&mut self) -> Option<&'buf str> {
210        if let Some(ct) = self.content_type {
211            return Some(ct);
212        } else {
213            if let Some(ct) = self.header("Content-Type") {
214                self.content_type = Some(ct);
215                return Some(ct);
216            } else {
217                return None;
218            }
219        }
220    }
221
222    pub fn content_length(&mut self) -> Option<usize> {
223        if let Some(cl) = self.content_length {
224            return Some(cl);
225        } else {
226            if let Some(cl) = self.header("Content-Length") {
227                let cl = cl.parse::<usize>().ok();
228                self.content_length = cl;
229                return cl;
230            } else {
231                return None;
232            }
233        }
234    }
235
236    /// looks up HTTP headers and returns
237    /// headers are not parsed until they are needed
238    pub fn header(&mut self, key: &str) -> Option<&'buf str> {
239        if self.header_raw_lines.is_empty() {
240            return None;
241        }
242        if let Some((_k, v)) = self.headers.iter()
243                .find(|(k, _v)| *k == key) {
244            return Some(v);
245        } else {
246            if let Some(raw) = self.header_raw_lines.iter()
247                    .find(|line| line.find(": ").map(|idx| &line[..idx] == key).unwrap_or(false)) {
248                if let Some((key, value)) = parse_header(raw) {
249                    self.headers.push((key, value));
250                    return Some(value);
251                }
252            }
253        }
254        return None;
255    }
256
257    /// looks up cookies keys and returns its value
258    /// cookies are not parsed until they are needed, will parse headers too.
259    pub fn cookie(&mut self, key: &str) -> Option<&Cookie<'buf>> {
260        if self.header_raw_lines.is_empty() {
261            return None;
262        }
263        if self.cookies.is_empty() {
264            if let Some(cookies_raw) = self.header("Cookie") {
265                let cookies = parse_cookies(cookies_raw);
266                if cookies.is_empty() {
267                    return None;
268                }
269                self.cookies = cookies;
270            } else {
271                return None;
272            }
273        }
274        return self.cookies.iter()
275            .find(|c| c.name == key);
276    }
277
278    /// looks up get parameters and returns its value
279    /// will parse all parameters on the first call.
280    pub fn get(&mut self, key: &str) -> Option<&str> {
281        if self.query_raw.is_empty() {
282            return None;
283        }
284        if self.get.is_empty() {
285            if let Some(get) = parse_parameters(self.query_raw).ok() {
286                if get.is_empty() {
287                    return None;
288                }
289                self.get = get;
290            } else {
291                return None;
292            }
293        }
294        return self.get.iter()
295            .find(|(k, _v)| *k == key)
296            .map(|(_k, v)| *v);
297    }
298
299    /// looks up post parameters and returns its value
300    /// will parse the content_type and content_len header on the first call.
301    pub fn post(&mut self, key: &str) -> Option<&str> {
302        // method must be POST
303        if self.method != "POST" {
304            return None;
305        }
306        // body must exist
307        if self.body.is_none() {
308            return None;
309        }
310        // if post is empty, go about and parse the POST values from the request body.
311        if self.post.is_empty() {
312            // must have a content length
313            if let Some(content_len) = self.content_length() {
314                // and it must be nonzero
315                if content_len == 0 {
316                    return None;
317                }
318                // and there must be a content type
319                if let Some(content_type) = self.content_type() {
320                    // and the content type must be application/x-www-form-urlencoded
321                    if content_type != "application/x-www-form-urlencoded" {
322                        return None;
323                    }
324                    // and there must be a body
325                    if let Some(body) = self.body {
326                        // and the body, up to the content length, must be UTF-8
327                        if let Some(body) = str::from_utf8(body.get(0..content_len)?).ok() {
328                            // and the body is to be treated the same as GET query parameters
329                            match parse_parameters(body) {
330                                Ok(params) if params.is_empty() => {
331                                    return None;
332                                }
333                                Ok(params) => {
334                                    // and then, search the parsed POST values for the `key`
335                                    self.post = params;
336                                    // and return it if it exists.
337                                    return self.post.iter()
338                                        .find(|(k, _v)| *k == key)
339                                        .map(|(_k, v)| *v);
340                                    },
341                                Err(err) => {
342                                    debug_println!("ERROR: Invalid post parameters: {body}: {}", err);
343                                },
344                            }
345                        }
346                    }
347                }
348            }
349        } else {
350            // otherwise, look up the requested POST value.
351            return self.post.iter()
352                .find(|(k, _v)| *k == key)
353                .map(|(_k, v)| *v);
354        }
355        return None;
356    }
357}
358
359/// Find the index of the first crlf and return a tuple of two mutable string slices, the first
360/// being the buffer slice up to the crlf, and the second being the slice content after the clrf
361fn request_line_header_split(to_split: &[u8]) -> (&[u8], Option<&[u8]>) {
362    let mut found_cr = false;
363    let mut found_lf = false;
364    let mut crlf_start_idx = 0;
365
366    // iterate over the slice and get the index of the first crlf
367    for (idx, byte) in to_split.iter().enumerate() {
368        if *byte == b'\r' {
369            crlf_start_idx = idx;
370            found_cr = true;
371            continue;
372        }
373        if found_cr && *byte == b'\n' {
374            found_lf = true;
375            break;
376        }
377        crlf_start_idx = 0;
378        found_cr = false;
379    }
380
381    // if no crlf was found or its at index 0, strip off crlf if possible and then return it
382    if crlf_start_idx == 0 || !found_cr || !found_lf {
383        let line_cleaned = match to_split.strip_suffix(b"\r\n") {
384            None => return (to_split, None),
385            Some(thing) => thing,
386        };
387        return (line_cleaned, None);
388    }
389
390    // build the returned tuple excluding the crlf in the data
391    let (req_line, req_headers) = to_split.split_at(crlf_start_idx);
392    let req_headers = req_headers.split_at(2).1;
393    (req_line, Some(req_headers))
394}
395
396/// Find the index of the first double crlf and return a tuple of two mutable string slices, the
397/// first being the slice content up to the double crlf, and the second being being the slice 
398/// content after the double clrf
399fn request_head_body_split(to_split: &[u8]) -> (&[u8], Option<&[u8]>)  {
400    let mut found_cr = false;
401    let mut crlf_count = 0;
402    let mut crlf_start_idx = 0;
403
404    // iterate over the slice and get the index of the first double crlf
405    for (idx, byte) in to_split.iter().enumerate() {
406        if crlf_count == 2 { // exit case where crlf_start_index can be not zero
407            break;
408        }
409        if *byte == b'\r' {
410            if crlf_count == 0 { // record the crlf start index only on the first crlf
411                crlf_start_idx = idx;
412            }
413            found_cr = true;
414            continue;
415        }
416        if found_cr && *byte == b'\n' {
417            crlf_count += 1;
418            found_cr = false;
419            continue;
420        }
421        crlf_count = 0;
422        crlf_start_idx = 0;
423        found_cr = false;
424    }
425
426    // if no double crlf was found or its index is at 0, return it
427    if crlf_start_idx == 0 {
428        return (to_split, None);
429    }
430
431    // if exited without fulfilling 2 crlf's, return it
432    if crlf_count != 2 {
433        return (to_split, None);
434    }
435
436    // build the returned tuple excluding the double crlf in the data
437    let (head, body) = to_split.split_at(crlf_start_idx);
438    let body = body.split_at(4).1;
439    (head, Some(body))
440}
441