async_h1/server/
decode.rs1use std::str::FromStr;
4
5use async_dup::{Arc, Mutex};
6use futures_lite::io::{AsyncRead as Read, AsyncWrite as Write, BufReader};
7use futures_lite::prelude::*;
8use http_types::content::ContentLength;
9use http_types::headers::{EXPECT, TRANSFER_ENCODING};
10use http_types::{ensure, ensure_eq, format_err};
11use http_types::{Body, Method, Request, Url};
12
13use super::body_reader::BodyReader;
14use crate::chunked::ChunkedDecoder;
15use crate::read_notifier::ReadNotifier;
16use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
17
18const LF: u8 = b'\n';
19
20const HTTP_1_1_VERSION: u8 = 1;
22
23const CONTINUE_HEADER_VALUE: &str = "100-continue";
24const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
25
26pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
28where
29 IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
30{
31 let mut reader = BufReader::new(io.clone());
32 let mut buf = Vec::new();
33 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
34 let mut httparse_req = httparse::Request::new(&mut headers);
35
36 loop {
38 let bytes_read = reader.read_until(LF, &mut buf).await?;
39 if bytes_read == 0 {
41 return Ok(None);
42 }
43
44 ensure!(
46 buf.len() < MAX_HEAD_LENGTH,
47 "Head byte length should be less than 8kb"
48 );
49
50 let idx = buf.len() - 1;
52 if idx >= 3 && &buf[idx - 3..=idx] == b"\r\n\r\n" {
53 break;
54 }
55 }
56
57 let status = httparse_req.parse(&buf)?;
59
60 ensure!(!status.is_partial(), "Malformed HTTP head");
61
62 let method = httparse_req.method;
64 let method = method.ok_or_else(|| format_err!("No method found"))?;
65
66 let version = httparse_req.version;
67 let version = version.ok_or_else(|| format_err!("No version found"))?;
68
69 ensure_eq!(
70 version,
71 HTTP_1_1_VERSION,
72 "Unsupported HTTP version 1.{}",
73 version
74 );
75
76 let url = url_from_httparse_req(&httparse_req)?;
77
78 let mut req = Request::new(Method::from_str(method)?, url);
79
80 req.set_version(Some(http_types::Version::Http1_1));
81
82 for header in httparse_req.headers.iter() {
83 req.append_header(header.name, std::str::from_utf8(header.value)?);
84 }
85
86 let content_length = ContentLength::from_headers(&req)?;
87 let transfer_encoding = req.header(TRANSFER_ENCODING);
88
89 http_types::ensure_status!(
94 content_length.is_none() || transfer_encoding.is_none(),
95 400,
96 "Unexpected Content-Length header"
97 );
98
99 let (body_read_sender, body_read_receiver) = async_channel::bounded(1);
104
105 if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
106 async_global_executor::spawn(async move {
107 if let Ok(()) = body_read_receiver.recv().await {
110 io.write_all(CONTINUE_RESPONSE).await.ok();
111 };
112 })
116 .detach();
117 }
118
119 if transfer_encoding
121 .map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
122 .unwrap_or(false)
123 {
124 let trailer_sender = req.send_trailers();
125 let reader = ChunkedDecoder::new(reader, trailer_sender);
126 let reader = Arc::new(Mutex::new(reader));
127 let reader_clone = reader.clone();
128 let reader = ReadNotifier::new(reader, body_read_sender);
129 let reader = BufReader::new(reader);
130 req.set_body(Body::from_reader(reader, None));
131 Ok(Some((req, BodyReader::Chunked(reader_clone))))
132 } else if let Some(len) = content_length {
133 let len = len.len();
134 let reader = Arc::new(Mutex::new(reader.take(len)));
135 req.set_body(Body::from_reader(
136 BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
137 Some(len as usize),
138 ));
139 Ok(Some((req, BodyReader::Fixed(reader))))
140 } else {
141 Ok(Some((req, BodyReader::None)))
142 }
143}
144
145fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {
146 let path = req.path.ok_or_else(|| format_err!("No uri found"))?;
147
148 let host = req
149 .headers
150 .iter()
151 .find(|x| x.name.eq_ignore_ascii_case("host"))
152 .ok_or_else(|| format_err!("Mandatory Host header missing"))?
153 .value;
154
155 let host = std::str::from_utf8(host)?;
156
157 if path.starts_with("http://") || path.starts_with("https://") {
158 Ok(Url::parse(path)?)
159 } else if path.starts_with('/') {
160 Ok(Url::parse(&format!("http://{}{}", host, path))?)
161 } else if req.method.unwrap().eq_ignore_ascii_case("connect") {
162 Ok(Url::parse(&format!("http://{}/", path))?)
163 } else {
164 Err(format_err!("unexpected uri format"))
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 fn httparse_req(buf: &str, f: impl Fn(httparse::Request<'_, '_>)) {
173 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
174 let mut res = httparse::Request::new(&mut headers[..]);
175 res.parse(buf.as_bytes()).unwrap();
176 f(res)
177 }
178
179 #[test]
180 fn url_for_connect() {
181 httparse_req(
182 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n",
183 |req| {
184 let url = url_from_httparse_req(&req).unwrap();
185 assert_eq!(url.as_str(), "http://server.example.com:443/");
186 },
187 );
188 }
189
190 #[test]
191 fn url_for_host_plus_path() {
192 httparse_req(
193 "GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n",
194 |req| {
195 let url = url_from_httparse_req(&req).unwrap();
196 assert_eq!(url.as_str(), "http://server.example.com:443/some/resource");
197 },
198 )
199 }
200
201 #[test]
202 fn url_for_host_plus_absolute_url() {
203 httparse_req(
204 "GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n",
205 |req| {
206 let url = url_from_httparse_req(&req).unwrap();
207 assert_eq!(url.as_str(), "http://domain.com/some/resource"); },
209 )
210 }
211
212 #[test]
213 fn url_for_conflicting_connect() {
214 httparse_req(
215 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n",
216 |req| {
217 let url = url_from_httparse_req(&req).unwrap();
218 assert_eq!(url.as_str(), "http://server.example.com:443/");
219 },
220 )
221 }
222
223 #[test]
224 fn url_for_malformed_resource_path() {
225 httparse_req(
226 "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n",
227 |req| {
228 assert!(url_from_httparse_req(&req).is_err());
229 },
230 )
231 }
232
233 #[test]
234 fn url_for_double_slash_path() {
235 httparse_req(
236 "GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
237 |req| {
238 let url = url_from_httparse_req(&req).unwrap();
239 assert_eq!(
240 url.as_str(),
241 "http://server.example.com:443//double/slashes"
242 );
243 },
244 )
245 }
246 #[test]
247 fn url_for_triple_slash_path() {
248 httparse_req(
249 "GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
250 |req| {
251 let url = url_from_httparse_req(&req).unwrap();
252 assert_eq!(
253 url.as_str(),
254 "http://server.example.com:443///triple/slashes"
255 );
256 },
257 )
258 }
259
260 #[test]
261 fn url_for_query() {
262 httparse_req(
263 "GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n",
264 |req| {
265 let url = url_from_httparse_req(&req).unwrap();
266 assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1");
267 },
268 )
269 }
270
271 #[test]
272 fn url_for_anchor() {
273 httparse_req(
274 "GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n",
275 |req| {
276 let url = url_from_httparse_req(&req).unwrap();
277 assert_eq!(
278 url.as_str(),
279 "http://server.example.com:443/foo?bar=1#anchor"
280 );
281 },
282 )
283 }
284}