1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use super::{headers::parse_header_map, version::version_from_bytes};
use crate::{
    ascii::str_from_ascii, chunks::ChunksSlice, error::Error, windows::IteratorExt,
    ParseRequestError, Parsed,
};
use safe_http::{Method, RequestHead, RequestLine, Version};
use safe_uri::{Scheme, Uri, UriRef};
use shared_bytes::SharedBytes;

pub fn parse_request_head(
    chunks: &[SharedBytes],
) -> Result<Parsed<RequestHead>, ParseRequestError> {
    let slice = ChunksSlice::new(chunks);
    let (line, remainder) = parse_request_line(slice).map_err(ParseRequestError)?;
    let (headers, remainder) = parse_header_map(remainder).map_err(ParseRequestError)?;
    Ok(Parsed {
        value: RequestHead {
            line,
            headers,
            ..Default::default()
        },
        remainder: remainder.to_continuous_shared(),
    })
}

fn parse_request_line(slice: ChunksSlice) -> Result<(RequestLine, ChunksSlice), Error> {
    let (method, remainder) = parse_method(slice)?;
    let (uri, remainder) = parse_request_target(remainder)?;
    let (version, remainder) = parse_version(remainder)?;
    let line = RequestLine {
        method,
        uri,
        version,
        ..Default::default()
    };
    Ok((line, remainder))
}

fn parse_method(slice: ChunksSlice) -> Result<(Method, ChunksSlice), Error> {
    let method_end = slice
        .bytes_indexed()
        .find_map(|(index, byte)| (byte == b' ').then(|| index))
        .ok_or(Error::Message("missing method separator ' '"))?;
    let method = method_from_slice(slice.index(..method_end))?;
    let remainder_start = slice.next_chunks_index(method_end).ok_or(Error::Message(
        "no request target (URI) after separator ' '",
    ))?;
    let remainder = slice.index(remainder_start..);
    Ok((method, remainder))
}

fn method_from_slice(slice: ChunksSlice) -> Result<Method, Error> {
    let string = str_from_ascii(slice)?;
    string.try_into().map_err(Error::Method)
}

fn parse_request_target(slice: ChunksSlice) -> Result<(Uri, ChunksSlice), Error> {
    let uri_end = slice
        .bytes_indexed()
        .find_map(|(index, byte)| (byte == b' ').then(|| index))
        .ok_or(Error::Message("missing uri separator ' '"))?;
    let uri = uri_from_slice(slice.index(..uri_end))?;
    let remainder_start = slice
        .next_chunks_index(uri_end)
        .ok_or(Error::Message("no HTTP version after separator ' '"))?;
    let remainder = slice.index(remainder_start..);
    Ok((uri, remainder))
}

fn uri_from_slice(slice: ChunksSlice) -> Result<Uri, Error> {
    let string = str_from_ascii(slice)?;
    let uri_ref = UriRef::parse(string).map_err(Error::Uri)?;
    Ok(uri_ref.into_uri_with_default_scheme(|| Scheme::HTTPS))
}

fn parse_version(slice: ChunksSlice) -> Result<(Version, ChunksSlice), Error> {
    let [version_end, line_end] = slice
        .bytes_indexed()
        .windows::<2>()
        .find_map(|[(index0, byte0), (index1, byte1)]| {
            ([byte0, byte1] == *b"\r\n").then(|| [index0, index1])
        })
        .ok_or(Error::Message("missing version separator (CRLF)"))?;
    let version = version_from_bytes(&slice.index(..version_end).to_continuous_cow())?;
    let remainder_start = slice
        .next_chunks_index(line_end)
        .ok_or(Error::Message("no bytes following the request line"))?;
    let remainder = slice.index(remainder_start..);
    Ok((version, remainder))
}

#[cfg(test)]
mod tests {
    use super::*;
    use safe_http::{HeaderMap, HeaderName, HeaderValue, Version};
    use safe_uri::Path;
    use tap::Tap;

    #[test]
    fn good_case() {
        let request = "PUT /example HTTP/2.0\r\n\
            content-type: text/plain\r\n\
            content-length: 5\r\n\r\n\
            hello";
        let bytes = SharedBytes::from(request);
        let parsed_request_head = parse_request_head(&[bytes]).unwrap();
        let expected_head = RequestHead {
            line: RequestLine {
                method: Method::PUT,
                uri: Uri::new().tap_mut(|u| {
                    u.scheme = Scheme::HTTPS;
                    u.resource.path = Path::from_static("/example");
                }),
                version: Version::HTTP_2_0,
                ..Default::default()
            },
            headers: HeaderMap::from([
                (
                    HeaderName::CONTENT_TYPE,
                    HeaderValue::from_static(b"text/plain"),
                ),
                (HeaderName::CONTENT_LENGTH, HeaderValue::from_static(b"5")),
            ]),
            ..Default::default()
        };
        assert_eq!(parsed_request_head.value, expected_head);
        assert_eq!(parsed_request_head.remainder, "hello");
    }
}