li_async_h1/server/
decode.rs1use std::str::FromStr;
4
5use async_dup::{Arc, Mutex};
6use async_std::io::{BufReader, Read, Write};
7use async_std::{prelude::*, task};
8use li_http_types::content::ContentLength;
9use li_http_types::headers::{EXPECT, TRANSFER_ENCODING};
10use li_http_types::{ensure, ensure_eq, format_err};
11use li_http_types::{Body, Method, Request, Url};
12
13use super::body_reader::BodyReader;
14use crate::chunked::ChunkedDecoder;
15use crate::read_notifier::ReadNotifier;
16use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
17
18const LF: u8 = b'\n';
19
20const HTTP_1_1_VERSION: u8 = 1;
22
23const CONTINUE_HEADER_VALUE: &str = "100-continue";
24const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
25
26pub async fn decode<IO>(mut io: IO) -> li_http_types::Result<Option<(Request, BodyReader<IO>)>>
28where
29 IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
30{
31 let mut reader = BufReader::new(io.clone());
32 let mut buf = Vec::new();
33 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
34 let mut httparse_req = httparse::Request::new(&mut headers);
35
36 loop {
38 let bytes_read = reader.read_until(LF, &mut buf).await?;
39 if bytes_read == 0 {
41 return Ok(None);
42 }
43
44 ensure!(
46 buf.len() < MAX_HEAD_LENGTH,
47 "Head byte length should be less than 8kb"
48 );
49
50 let idx = buf.len() - 1;
52 if idx >= 3 && &buf[idx - 3..=idx] == b"\r\n\r\n" {
53 break;
54 }
55 }
56
57 let status = httparse_req.parse(&buf)?;
59
60 ensure!(!status.is_partial(), "Malformed HTTP head");
61
62 let method = httparse_req.method;
64 let method = method.ok_or_else(|| format_err!("No method found"))?;
65
66 let version = httparse_req.version;
67 let version = version.ok_or_else(|| format_err!("No version found"))?;
68
69 ensure_eq!(
70 version,
71 HTTP_1_1_VERSION,
72 "Unsupported HTTP version 1.{}",
73 version
74 );
75
76 let url = url_from_httparse_req(&httparse_req)?;
77
78 let mut req = Request::new(Method::from_str(method)?, url);
79
80 req.set_version(Some(li_http_types::Version::Http1_1));
81
82 for header in httparse_req.headers.iter() {
83 req.append_header(header.name, std::str::from_utf8(header.value)?);
84 }
85
86 let content_length = ContentLength::from_headers(&req)?;
87 let transfer_encoding = req.header(TRANSFER_ENCODING);
88
89 li_http_types::ensure_status!(
94 content_length.is_none() || transfer_encoding.is_none(),
95 400,
96 "Unexpected Content-Length header"
97 );
98
99 let (body_read_sender, body_read_receiver) = async_channel::bounded(1);
104
105 if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
106 task::spawn(async move {
107 if let Ok(()) = body_read_receiver.recv().await {
110 io.write_all(CONTINUE_RESPONSE).await.ok();
111 };
112 });
116 }
117
118 if transfer_encoding
120 .map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
121 .unwrap_or(false)
122 {
123 let trailer_sender = req.send_trailers();
124 let reader = ChunkedDecoder::new(reader, trailer_sender);
125 let reader = Arc::new(Mutex::new(reader));
126 let reader_clone = reader.clone();
127 let reader = ReadNotifier::new(reader, body_read_sender);
128 let reader = BufReader::new(reader);
129 req.set_body(Body::from_reader(reader, None));
130 Ok(Some((req, BodyReader::Chunked(reader_clone))))
131 } else if let Some(len) = content_length {
132 let len = len.len();
133 let reader = Arc::new(Mutex::new(reader.take(len)));
134 req.set_body(Body::from_reader(
135 BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
136 Some(len as usize),
137 ));
138 Ok(Some((req, BodyReader::Fixed(reader))))
139 } else {
140 Ok(Some((req, BodyReader::None)))
141 }
142}
143
144fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> li_http_types::Result<Url> {
145 let path = req.path.ok_or_else(|| format_err!("No uri found"))?;
146
147 let host = req
148 .headers
149 .iter()
150 .find(|x| x.name.eq_ignore_ascii_case("host"))
151 .ok_or_else(|| format_err!("Mandatory Host header missing"))?
152 .value;
153
154 let host = std::str::from_utf8(host)?;
155
156 if path.starts_with("http://") || path.starts_with("https://") {
157 Ok(Url::parse(path)?)
158 } else if path.starts_with('/') {
159 Ok(Url::parse(&format!("http://{}{}", host, path))?)
160 } else if req.method.unwrap().eq_ignore_ascii_case("connect") {
161 Ok(Url::parse(&format!("http://{}/", path))?)
162 } else {
163 Err(format_err!("unexpected uri format"))
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 fn httparse_req(buf: &str, f: impl Fn(httparse::Request<'_, '_>)) {
172 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
173 let mut res = httparse::Request::new(&mut headers[..]);
174 res.parse(buf.as_bytes()).unwrap();
175 f(res)
176 }
177
178 #[test]
179 fn url_for_connect() {
180 httparse_req(
181 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n",
182 |req| {
183 let url = url_from_httparse_req(&req).unwrap();
184 assert_eq!(url.as_str(), "http://server.example.com:443/");
185 },
186 );
187 }
188
189 #[test]
190 fn url_for_host_plus_path() {
191 httparse_req(
192 "GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n",
193 |req| {
194 let url = url_from_httparse_req(&req).unwrap();
195 assert_eq!(url.as_str(), "http://server.example.com:443/some/resource");
196 },
197 )
198 }
199
200 #[test]
201 fn url_for_host_plus_absolute_url() {
202 httparse_req(
203 "GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n",
204 |req| {
205 let url = url_from_httparse_req(&req).unwrap();
206 assert_eq!(url.as_str(), "http://domain.com/some/resource"); },
208 )
209 }
210
211 #[test]
212 fn url_for_conflicting_connect() {
213 httparse_req(
214 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n",
215 |req| {
216 let url = url_from_httparse_req(&req).unwrap();
217 assert_eq!(url.as_str(), "http://server.example.com:443/");
218 },
219 )
220 }
221
222 #[test]
223 fn url_for_malformed_resource_path() {
224 httparse_req(
225 "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n",
226 |req| {
227 assert!(url_from_httparse_req(&req).is_err());
228 },
229 )
230 }
231
232 #[test]
233 fn url_for_double_slash_path() {
234 httparse_req(
235 "GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
236 |req| {
237 let url = url_from_httparse_req(&req).unwrap();
238 assert_eq!(
239 url.as_str(),
240 "http://server.example.com:443//double/slashes"
241 );
242 },
243 )
244 }
245 #[test]
246 fn url_for_triple_slash_path() {
247 httparse_req(
248 "GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
249 |req| {
250 let url = url_from_httparse_req(&req).unwrap();
251 assert_eq!(
252 url.as_str(),
253 "http://server.example.com:443///triple/slashes"
254 );
255 },
256 )
257 }
258
259 #[test]
260 fn url_for_query() {
261 httparse_req(
262 "GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n",
263 |req| {
264 let url = url_from_httparse_req(&req).unwrap();
265 assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1");
266 },
267 )
268 }
269
270 #[test]
271 fn url_for_anchor() {
272 httparse_req(
273 "GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n",
274 |req| {
275 let url = url_from_httparse_req(&req).unwrap();
276 assert_eq!(
277 url.as_str(),
278 "http://server.example.com:443/foo?bar=1#anchor"
279 );
280 },
281 )
282 }
283}