1use async_dup::{Arc, Mutex};
4use futures_lite::io::{AsyncRead as Read, AsyncWrite as Write, BufReader};
5use futures_lite::prelude::*;
6use http_types::content::ContentLength;
7use http_types::headers::{EXPECT, TRANSFER_ENCODING};
8use http_types::{Body, Request, Url, Version};
10
11use crate::{Error, Result}; use super::body_reader::BodyReader;
14use crate::read_notifier::ReadNotifier;
15use crate::{chunked::ChunkedDecoder, ServerOptions};
16use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
17
18const LF: u8 = b'\n';
19const SPACE: u8 = b' ';
20
21const CONTINUE_HEADER_VALUE: &str = "100-continue";
22const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
23
24pub async fn decode<IO>(
27 mut io: IO,
28 opts: &ServerOptions,
29) -> Result<Option<(Request, BodyReader<IO>)>>
30where
31 IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
32{
33 let mut reader = BufReader::with_capacity(MAX_HEAD_LENGTH, io.clone()); let mut buf = Vec::new();
35 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
36 let mut httparse_req = httparse::Request::new(&mut headers);
37
38 let mut first_line = true;
39 loop {
41 let bytes_read = reader.read_until(LF, &mut buf).await?;
42
43 if bytes_read == 0 {
45 return Ok(None);
46 }
47
48 if buf.len() >= MAX_HEAD_LENGTH {
50 return Err(Error::HeadersTooLong);
51 }
52
53 if first_line {
54 first_line = false;
55
56 let mut split = buf.split(|b| { b == &SPACE });
57 let method = split.next().ok_or(Error::MissingMethod)?;
58
59 let path = split.next().ok_or(Error::RequestPathMissing)?;
60 let path = non_ascii_printable_to_percent_encoded(path);
61
62 let mut parts = vec![method, &path];
63 for part in split {
64 parts.push(part);
65 }
66
67 buf = parts.join(&SPACE);
68 }
69
70 if buf.ends_with(b"\r\n\r\n") || buf.ends_with(b"\n\n") {
72 break;
73 }
74 }
75
76 let status = httparse_req.parse(&buf)?;
78
79 if status.is_partial() {
80 return Err(Error::PartialHead);
81 }
82
83 let method = httparse_req
85 .method
86 .ok_or(Error::MissingMethod)?
87 .parse()
88 .map_err(|_| Error::UnrecognizedMethod(httparse_req.method.unwrap().to_string()))?;
89
90 let version = match (&opts.default_host, httparse_req.version) {
91 (Some(_), None) | (Some(_), Some(0)) => Version::Http1_0,
92 (_, Some(1)) => Version::Http1_1,
93 (None, Some(0)) | (None, None) => return Err(Error::HostHeaderMissing),
94 (_, Some(other_version)) => return Err(Error::UnsupportedVersion(other_version)),
95 };
96
97 let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref())?;
98
99 let mut req = Request::new(method, url);
100
101 req.set_version(Some(version));
102
103 for header in httparse_req.headers.iter() {
104 req.append_header(header.name, std::str::from_utf8(header.value)?);
105 }
106
107 let content_length =
108 ContentLength::from_headers(&req).map_err(|_| Error::MalformedHeader("content-length"))?;
109 let transfer_encoding = req.header(TRANSFER_ENCODING);
110
111 if content_length.is_some() && transfer_encoding.is_some() {
112 return Err(Error::UnexpectedHeader("content-length"));
113 }
114
115 let (body_read_sender, body_read_receiver) = async_channel::bounded(1);
120
121 if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
122 smolscale2::spawn(async move {
123 if let Ok(()) = body_read_receiver.recv().await {
126 io.write_all(CONTINUE_RESPONSE).await.ok();
127 };
128 })
132 .detach();
133 }
134
135 if transfer_encoding
137 .map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
138 .unwrap_or(false)
139 {
140 let trailer_sender = req.send_trailers();
141 let reader = ChunkedDecoder::new(reader, trailer_sender);
142 let reader = Arc::new(Mutex::new(reader));
143 let reader_clone = reader.clone();
144 let reader = ReadNotifier::new(reader, body_read_sender);
145 let reader = BufReader::new(reader);
146 req.set_body(Body::from_reader(reader, None));
147 Ok(Some((req, BodyReader::Chunked(reader_clone))))
148 } else if let Some(content_length) = content_length {
149 let len = content_length.len();
150 let reader = Arc::new(Mutex::new(reader.take(len)));
151 req.set_body(Body::from_reader(
152 BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
153 Some(len as usize),
154 ));
155 Ok(Some((req, BodyReader::Fixed(reader))))
156 } else {
157 Ok(Some((req, BodyReader::None)))
158 }
159}
160
161fn non_ascii_printable_to_percent_encoded(path: &[u8]) -> Vec<u8> {
162 const WHITELIST: &[u8] = b" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
164
165 let mut out = Vec::new();
166 for byte in path.iter() {
167 if WHITELIST.contains(byte) {
168 out.push(*byte);
169 } else {
170 out.extend(format!("%{byte:02X}").as_bytes());
171 }
172 }
173 out
174}
175
176fn url_from_httparse_req(
177 req: &httparse::Request<'_, '_>,
178 default_host: Option<&str>,
179) -> Result<Url> {
180 let path = req.path.ok_or(Error::RequestPathMissing)?;
181
182 let host = req
183 .headers
184 .iter()
185 .find(|x| x.name.eq_ignore_ascii_case("host"));
186
187 let host = match host {
188 Some(header) => std::str::from_utf8(header.value)?,
189 None => default_host.ok_or(Error::HostHeaderMissing)?,
190 };
191
192 if path.starts_with("http://") || path.starts_with("https://") {
193 Ok(Url::parse(&path)?)
194 } else if path.starts_with('/') {
195 Ok(Url::parse(&format!("http://{}{}", host, &path))?)
196 } else if req.method.unwrap().eq_ignore_ascii_case("connect") {
197 Ok(Url::parse(&format!("http://{}/", &path))?)
198 } else {
199 Err(Error::UnexpectedURIFormat)
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 fn httparse_req(buf: &str, f: impl Fn(httparse::Request<'_, '_>)) {
208 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
209 let mut res = httparse::Request::new(&mut headers[..]);
210 res.parse(buf.as_bytes()).unwrap();
211 f(res)
212 }
213
214 #[test]
215 fn url_for_connect() {
216 httparse_req(
217 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n",
218 |req| {
219 let url = url_from_httparse_req(&req, None).unwrap();
220 assert_eq!(url.as_str(), "http://server.example.com:443/");
221 },
222 );
223 }
224
225 #[test]
226 fn url_for_host_plus_path() {
227 httparse_req(
228 "GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n",
229 |req| {
230 let url = url_from_httparse_req(&req, None).unwrap();
231 assert_eq!(url.as_str(), "http://server.example.com:443/some/resource");
232 },
233 )
234 }
235
236 #[test]
237 fn url_for_host_plus_absolute_url() {
238 httparse_req(
239 "GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n",
240 |req| {
241 let url = url_from_httparse_req(&req, None).unwrap();
242 assert_eq!(url.as_str(), "http://domain.com/some/resource"); },
244 )
245 }
246
247 #[test]
248 fn url_for_conflicting_connect() {
249 httparse_req(
250 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n",
251 |req| {
252 let url = url_from_httparse_req(&req, None).unwrap();
253 assert_eq!(url.as_str(), "http://server.example.com:443/");
254 },
255 )
256 }
257
258 #[test]
259 fn url_for_malformed_resource_path() {
260 httparse_req(
261 "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n",
262 |req| {
263 assert!(matches!(
264 url_from_httparse_req(&req, None),
265 Err(Error::UnexpectedURIFormat)
266 ));
267 },
268 )
269 }
270
271 #[test]
272 fn url_for_double_slash_path() {
273 httparse_req(
274 "GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
275 |req| {
276 let url = url_from_httparse_req(&req, None).unwrap();
277 assert_eq!(
278 url.as_str(),
279 "http://server.example.com:443//double/slashes"
280 );
281 },
282 )
283 }
284 #[test]
285 fn url_for_triple_slash_path() {
286 httparse_req(
287 "GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
288 |req| {
289 let url = url_from_httparse_req(&req, None).unwrap();
290 assert_eq!(
291 url.as_str(),
292 "http://server.example.com:443///triple/slashes"
293 );
294 },
295 )
296 }
297
298 #[test]
299 fn url_for_query() {
300 httparse_req(
301 "GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n",
302 |req| {
303 let url = url_from_httparse_req(&req, None).unwrap();
304 assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1");
305 },
306 )
307 }
308
309 #[test]
310 fn url_for_anchor() {
311 httparse_req(
312 "GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n",
313 |req| {
314 let url = url_from_httparse_req(&req, None).unwrap();
315 assert_eq!(
316 url.as_str(),
317 "http://server.example.com:443/foo?bar=1#anchor"
318 );
319 },
320 )
321 }
322}