http1_spec/
request_head_parser.rs1use std::io::{BufRead, Read as _};
2
3use http::{request::Parts as RequestParts, HeaderMap, HeaderValue, Method, Request, Uri, Version};
4
5use crate::head_parser::{HeadParseConfig, HeadParseError, HeadParseOutput, HeadParser};
6
7#[derive(Default)]
11pub struct RequestHeadParser {
12 pub method: Method,
13 pub uri: Uri,
14 pub http_version: Version,
15 pub headers: HeaderMap<HeaderValue>,
16 config: HeadParseConfig,
18 state: State,
20 buf: Vec<u8>,
21}
22
23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
24enum State {
25 Idle,
26 MethodParsed,
27 UriParsed,
28 HttpVersionParsed,
29 HeadersParsing,
30}
31impl Default for State {
32 fn default() -> Self {
33 Self::Idle
34 }
35}
36
37impl RequestHeadParser {
38 pub fn to_request_parts(&self) -> RequestParts {
39 let (mut parts, _) = Request::new(()).into_parts();
40 parts.method = self.method.to_owned();
41 parts.uri = self.uri.to_owned();
42 parts.version = self.http_version;
43 parts.headers = self.headers.to_owned();
44 parts
45 }
46
47 pub fn to_request<B>(&self, body: B) -> Request<B> {
48 let parts = self.to_request_parts();
49 Request::from_parts(parts, body)
50 }
51}
52
53impl HeadParser for RequestHeadParser {
57 fn new() -> Self {
58 Self::default()
59 }
60 fn with_config(config: HeadParseConfig) -> Self {
61 let buf = Vec::with_capacity(config.buf_capacity());
62 let headers = HeaderMap::with_capacity(config.header_map_capacity());
63 RequestHeadParser {
64 config,
65 buf,
66 headers,
67 ..Default::default()
68 }
69 }
70
71 fn get_headers(&self) -> &HeaderMap<HeaderValue> {
72 &self.headers
73 }
74 fn get_version(&self) -> &Version {
75 &self.http_version
76 }
77
78 fn parse<R: BufRead>(&mut self, r: &mut R) -> Result<HeadParseOutput, HeadParseError> {
79 let mut take = r.take(0);
80 let mut parsed_num_bytes = 0_usize;
81
82 if self.state < State::MethodParsed {
83 self.buf.clear();
85 match Self::parse_method(&mut take, &mut self.buf, &self.config)? {
86 Some((method, n)) => {
87 self.state = State::MethodParsed;
88
89 self.method = method;
90 parsed_num_bytes += n;
91 }
92 None => return Ok(HeadParseOutput::Partial(parsed_num_bytes)),
93 }
94 }
95
96 if self.state < State::UriParsed {
97 self.buf.clear();
99 match Self::parse_uri(&mut take, &mut self.buf, &self.config)? {
100 Some((uri, n)) => {
101 self.state = State::UriParsed;
102
103 self.uri = uri;
104 parsed_num_bytes += n;
105 }
106 None => return Ok(HeadParseOutput::Partial(parsed_num_bytes)),
107 }
108 }
109
110 if self.state < State::HttpVersionParsed {
111 self.buf.clear();
113 match Self::parse_http_version_for_request(&mut take, &mut self.buf)? {
114 Some((http_version, n)) => {
115 self.state = State::HttpVersionParsed;
116
117 self.http_version = http_version;
118 parsed_num_bytes += n;
119 }
120 None => return Ok(HeadParseOutput::Partial(parsed_num_bytes)),
121 }
122 }
123
124 if self.state < State::HeadersParsing {
126 self.headers.clear();
127 }
128 loop {
129 if self.state <= State::HeadersParsing {
130 self.buf.clear();
131 match Self::parse_header(&mut take, &mut self.buf, &self.config, &mut self.headers)?
132 {
133 Some((is_all_completed, n)) => {
134 parsed_num_bytes += n;
135
136 if is_all_completed {
137 self.state = State::Idle;
138
139 return Ok(HeadParseOutput::Completed(parsed_num_bytes));
140 } else {
141 self.state = State::HeadersParsing;
142
143 continue;
144 }
145 }
146 None => return Ok(HeadParseOutput::Partial(parsed_num_bytes)),
147 }
148 } else {
149 unreachable!()
150 }
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_to_request() {
161 let p = RequestHeadParser {
162 method: Method::POST,
163 uri: Uri::try_from("/path").unwrap(),
164 http_version: Version::HTTP_2,
165 headers: {
166 let mut h = HeaderMap::new();
167 h.insert("x-foo", "bar".parse().unwrap());
168 h
169 },
170 ..Default::default()
171 };
172
173 let req = p.to_request("body");
174 assert_eq!(req.method(), Method::POST);
175 assert_eq!(req.uri(), &Uri::try_from("/path").unwrap());
176 assert_eq!(req.version(), Version::HTTP_2);
177 assert_eq!(req.headers().get("x-foo").unwrap(), "bar");
178 assert_eq!(req.body(), &"body");
179 }
180}