async_h1b/server/
decode.rs

1//! Process HTTP connections on the server.
2
3use async_dup::{Arc, Mutex};
4use futures_lite::io::{AsyncRead as Read, AsyncWrite as Write, BufReader};
5use futures_lite::prelude::*;
6use http_types::content::ContentLength;
7use http_types::headers::{EXPECT, TRANSFER_ENCODING};
8//use http_types::{ensure, ensure_eq,format_err};
9use http_types::{Body, /*Method,*/ Request, Url, Version};
10
11use crate::{Error, Result}; //
12
13use super::body_reader::BodyReader;
14use crate::read_notifier::ReadNotifier;
15use crate::{chunked::ChunkedDecoder, ServerOptions};
16use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
17
18const LF: u8 = b'\n';
19const SPACE: u8 = b' ';
20
21const CONTINUE_HEADER_VALUE: &str = "100-continue";
22const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
23
24/// Decode an HTTP request on the server.
25
26pub async fn decode<IO>(
27    mut io: IO,
28    opts: &ServerOptions,
29) -> Result<Option<(Request, BodyReader<IO>)>>
30where
31    IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
32{
33    let mut reader = BufReader::with_capacity(MAX_HEAD_LENGTH, io.clone()); // Prevent CWE-400 DoS with large HTTP Headers but without LF char.
34    let mut buf = Vec::new();
35    let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
36    let mut httparse_req = httparse::Request::new(&mut headers);
37
38    let mut first_line = true;
39    // Keep reading bytes from the stream until we hit the end of the stream.
40    loop {
41        let bytes_read = reader.read_until(LF, &mut buf).await?;
42
43        // No more bytes are yielded from the stream.
44        if bytes_read == 0 {
45            return Ok(None);
46        }
47
48        // Prevent CWE-400 DoS with large HTTP Headers.
49        if buf.len() >= MAX_HEAD_LENGTH {
50            return Err(Error::HeadersTooLong);
51        }
52
53        if first_line {
54            first_line = false;
55
56            let mut split = buf.split(|b| { b == &SPACE });
57            let method = split.next().ok_or(Error::MissingMethod)?;
58
59            let path = split.next().ok_or(Error::RequestPathMissing)?;
60            let path = non_ascii_printable_to_percent_encoded(path);
61
62            let mut parts = vec![method, &path];
63            for part in split {
64                parts.push(part);
65            }
66
67            buf = parts.join(&SPACE);
68        }
69
70        // We've hit the end delimiter of the stream.
71        if buf.ends_with(b"\r\n\r\n") || buf.ends_with(b"\n\n") {
72            break;
73        }
74    }
75
76    // Convert our header buf into an httparse instance, and validate.
77    let status = httparse_req.parse(&buf)?;
78
79    if status.is_partial() {
80        return Err(Error::PartialHead);
81    }
82
83    // Convert httparse headers + body into a `http_types::Request` type.
84    let method = httparse_req
85        .method
86        .ok_or(Error::MissingMethod)?
87        .parse()
88        .map_err(|_| Error::UnrecognizedMethod(httparse_req.method.unwrap().to_string()))?;
89
90    let version = match (&opts.default_host, httparse_req.version) {
91        (Some(_), None) | (Some(_), Some(0)) => Version::Http1_0,
92        (_, Some(1)) => Version::Http1_1,
93        (None, Some(0)) | (None, None) => return Err(Error::HostHeaderMissing),
94        (_, Some(other_version)) => return Err(Error::UnsupportedVersion(other_version)),
95    };
96
97    let url = url_from_httparse_req(&httparse_req, opts.default_host.as_deref())?;
98
99    let mut req = Request::new(method, url);
100
101    req.set_version(Some(version));
102
103    for header in httparse_req.headers.iter() {
104        req.append_header(header.name, std::str::from_utf8(header.value)?);
105    }
106
107    let content_length =
108        ContentLength::from_headers(&req).map_err(|_| Error::MalformedHeader("content-length"))?;
109    let transfer_encoding = req.header(TRANSFER_ENCODING);
110
111    if content_length.is_some() && transfer_encoding.is_some() {
112        return Err(Error::UnexpectedHeader("content-length"));
113    }
114
115    // Establish a channel to wait for the body to be read. This
116    // allows us to avoid sending 100-continue in situations that
117    // respond without reading the body, saving clients from uploading
118    // their body.
119    let (body_read_sender, body_read_receiver) = async_channel::bounded(1);
120
121    if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
122        smolscale2::spawn(async move {
123            // If the client expects a 100-continue header, spawn a
124            // task to wait for the first read attempt on the body.
125            if let Ok(()) = body_read_receiver.recv().await {
126                io.write_all(CONTINUE_RESPONSE).await.ok();
127            };
128            // Since the sender is moved into the Body, this task will
129            // finish when the client disconnects, whether or not
130            // 100-continue was sent.
131        })
132        .detach();
133    }
134
135    // Check for Transfer-Encoding
136    if transfer_encoding
137        .map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
138        .unwrap_or(false)
139    {
140        let trailer_sender = req.send_trailers();
141        let reader = ChunkedDecoder::new(reader, trailer_sender);
142        let reader = Arc::new(Mutex::new(reader));
143        let reader_clone = reader.clone();
144        let reader = ReadNotifier::new(reader, body_read_sender);
145        let reader = BufReader::new(reader);
146        req.set_body(Body::from_reader(reader, None));
147        Ok(Some((req, BodyReader::Chunked(reader_clone))))
148    } else if let Some(content_length) = content_length {
149        let len = content_length.len();
150        let reader = Arc::new(Mutex::new(reader.take(len)));
151        req.set_body(Body::from_reader(
152            BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
153            Some(len as usize),
154        ));
155        Ok(Some((req, BodyReader::Fixed(reader))))
156    } else {
157        Ok(Some((req, BodyReader::None)))
158    }
159}
160
161fn non_ascii_printable_to_percent_encoded(path: &[u8]) -> Vec<u8> {
162    // python: [chr(i) for i in range(256) if chr(i).isascii() and chr(i).isprintable()]
163    const WHITELIST: &[u8] = b" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
164
165    let mut out = Vec::new();
166    for byte in path.iter() {
167        if WHITELIST.contains(byte) {
168            out.push(*byte);
169        } else {
170            out.extend(format!("%{byte:02X}").as_bytes());
171        }
172    }
173    out
174}
175
176fn url_from_httparse_req(
177    req: &httparse::Request<'_, '_>,
178    default_host: Option<&str>,
179) -> Result<Url> {
180    let path = req.path.ok_or(Error::RequestPathMissing)?;
181
182    let host = req
183        .headers
184        .iter()
185        .find(|x| x.name.eq_ignore_ascii_case("host"));
186
187    let host = match host {
188        Some(header) => std::str::from_utf8(header.value)?,
189        None => default_host.ok_or(Error::HostHeaderMissing)?,
190    };
191
192    if path.starts_with("http://") || path.starts_with("https://") {
193        Ok(Url::parse(&path)?)
194    } else if path.starts_with('/') {
195        Ok(Url::parse(&format!("http://{}{}", host, &path))?)
196    } else if req.method.unwrap().eq_ignore_ascii_case("connect") {
197        Ok(Url::parse(&format!("http://{}/", &path))?)
198    } else {
199        Err(Error::UnexpectedURIFormat)
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    fn httparse_req(buf: &str, f: impl Fn(httparse::Request<'_, '_>)) {
208        let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
209        let mut res = httparse::Request::new(&mut headers[..]);
210        res.parse(buf.as_bytes()).unwrap();
211        f(res)
212    }
213
214    #[test]
215    fn url_for_connect() {
216        httparse_req(
217            "CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n",
218            |req| {
219                let url = url_from_httparse_req(&req, None).unwrap();
220                assert_eq!(url.as_str(), "http://server.example.com:443/");
221            },
222        );
223    }
224
225    #[test]
226    fn url_for_host_plus_path() {
227        httparse_req(
228            "GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n",
229            |req| {
230                let url = url_from_httparse_req(&req, None).unwrap();
231                assert_eq!(url.as_str(), "http://server.example.com:443/some/resource");
232            },
233        )
234    }
235
236    #[test]
237    fn url_for_host_plus_absolute_url() {
238        httparse_req(
239            "GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n",
240            |req| {
241                let url = url_from_httparse_req(&req, None).unwrap();
242                assert_eq!(url.as_str(), "http://domain.com/some/resource"); // host header MUST be ignored according to spec
243            },
244        )
245    }
246
247    #[test]
248    fn url_for_conflicting_connect() {
249        httparse_req(
250            "CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n",
251            |req| {
252                let url = url_from_httparse_req(&req, None).unwrap();
253                assert_eq!(url.as_str(), "http://server.example.com:443/");
254            },
255        )
256    }
257
258    #[test]
259    fn url_for_malformed_resource_path() {
260        httparse_req(
261            "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n",
262            |req| {
263                assert!(matches!(
264                    url_from_httparse_req(&req, None),
265                    Err(Error::UnexpectedURIFormat)
266                ));
267            },
268        )
269    }
270
271    #[test]
272    fn url_for_double_slash_path() {
273        httparse_req(
274            "GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
275            |req| {
276                let url = url_from_httparse_req(&req, None).unwrap();
277                assert_eq!(
278                    url.as_str(),
279                    "http://server.example.com:443//double/slashes"
280                );
281            },
282        )
283    }
284    #[test]
285    fn url_for_triple_slash_path() {
286        httparse_req(
287            "GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
288            |req| {
289                let url = url_from_httparse_req(&req, None).unwrap();
290                assert_eq!(
291                    url.as_str(),
292                    "http://server.example.com:443///triple/slashes"
293                );
294            },
295        )
296    }
297
298    #[test]
299    fn url_for_query() {
300        httparse_req(
301            "GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n",
302            |req| {
303                let url = url_from_httparse_req(&req, None).unwrap();
304                assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1");
305            },
306        )
307    }
308
309    #[test]
310    fn url_for_anchor() {
311        httparse_req(
312            "GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n",
313            |req| {
314                let url = url_from_httparse_req(&req, None).unwrap();
315                assert_eq!(
316                    url.as_str(),
317                    "http://server.example.com:443/foo?bar=1#anchor"
318                );
319            },
320        )
321    }
322}