cyfs_async_h1/server/
decode.rs1use std::str::FromStr;
4
5use async_dup::{Arc, Mutex};
6use async_std::io::{BufReader, Read, Write};
7use async_std::{prelude::*, task};
8use http_types::content::ContentLength;
9use http_types::headers::{EXPECT, TRANSFER_ENCODING, ToHeaderValues};
10use http_types::{ensure, ensure_eq, format_err};
11use http_types::{Body, Method, Request, Url};
12
13use super::body_reader::BodyReader;
14use crate::chunked::ChunkedDecoder;
15use crate::read_notifier::ReadNotifier;
16use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};
17
18const LF: u8 = b'\n';
19
20const HTTP_1_1_VERSION: u8 = 1;
22
23const CONTINUE_HEADER_VALUE: &str = "100-continue";
24const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
25
26pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
28where
29 IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
30{
31 let mut reader = BufReader::new(io.clone());
32 let mut buf = Vec::new();
33 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
34 let mut httparse_req = httparse::Request::new(&mut headers);
35
36 loop {
38 let bytes_read = reader.read_until(LF, &mut buf).await?;
39 if bytes_read == 0 {
41 return Ok(None);
42 }
43
44 ensure!(
46 buf.len() < MAX_HEAD_LENGTH,
47 "Head byte length should be less than 8kb"
48 );
49
50 let idx = buf.len() - 1;
52 if idx >= 3 && &buf[idx - 3..=idx] == b"\r\n\r\n" {
53 break;
54 }
55 }
56
57 let status = httparse_req.parse(&buf)?;
59
60 ensure!(!status.is_partial(), "Malformed HTTP head");
61
62 let method = httparse_req.method;
64 let method = method.ok_or_else(|| format_err!("No method found"))?;
65
66 let version = httparse_req.version;
67 let version = version.ok_or_else(|| format_err!("No version found"))?;
68
69 ensure_eq!(
70 version,
71 HTTP_1_1_VERSION,
72 "Unsupported HTTP version 1.{}",
73 version
74 );
75
76 let url = url_from_httparse_req(&httparse_req)?;
77
78 let mut req = Request::new(Method::from_str(method)?, url);
79
80 req.set_version(Some(http_types::Version::Http1_1));
81
82 for header in httparse_req.headers.iter() {
83 let value = std::str::from_utf8(header.value)?;
84 match value.to_header_values() {
85 Ok(headers) => {
86 for item in headers {
87 req.append_header(header.name, item);
88 }
89 }
90 Err(e) => {
91 log::warn!("got non ascii header: {} -- {}, {}", header.name, value, e);
92 let value = percent_encoding::utf8_percent_encode(value, percent_encoding::NON_ALPHANUMERIC).to_string();
93 req.append_header(header.name, value);
94 }
95 }
96 }
97
98 let content_length = ContentLength::from_headers(&req)?;
99 let transfer_encoding = req.header(TRANSFER_ENCODING);
100
101 http_types::ensure_status!(
106 content_length.is_none() || transfer_encoding.is_none(),
107 400,
108 "Unexpected Content-Length header"
109 );
110
111 let (body_read_sender, body_read_receiver) = async_channel::bounded(1);
116
117 if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
118 task::spawn(async move {
119 if let Ok(()) = body_read_receiver.recv().await {
122 io.write_all(CONTINUE_RESPONSE).await.ok();
123 };
124 });
128 }
129
130 if transfer_encoding
132 .map(|te| te.as_str().eq_ignore_ascii_case("chunked"))
133 .unwrap_or(false)
134 {
135 let trailer_sender = req.send_trailers();
136 let reader = ChunkedDecoder::new(reader, trailer_sender);
137 let reader = Arc::new(Mutex::new(reader));
138 let reader_clone = reader.clone();
139 let reader = ReadNotifier::new(reader, body_read_sender);
140 let reader = BufReader::new(reader);
141 req.set_body(Body::from_reader(reader, None));
142 Ok(Some((req, BodyReader::Chunked(reader_clone))))
143 } else if let Some(len) = content_length {
144 let len = len.len();
145 let reader = Arc::new(Mutex::new(reader.take(len)));
146 req.set_body(Body::from_reader(
147 BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
148 Some(len as usize),
149 ));
150 Ok(Some((req, BodyReader::Fixed(reader))))
151 } else {
152 Ok(Some((req, BodyReader::None)))
153 }
154}
155
156fn url_from_httparse_req(req: &httparse::Request<'_, '_>) -> http_types::Result<Url> {
157 let path = req.path.ok_or_else(|| format_err!("No uri found"))?;
158
159 let host = req
160 .headers
161 .iter()
162 .find(|x| x.name.eq_ignore_ascii_case("host"))
163 .ok_or_else(|| format_err!("Mandatory Host header missing"))?
164 .value;
165
166 let host = std::str::from_utf8(host)?;
167
168 if path.starts_with("http://") || path.starts_with("https://") {
169 Ok(Url::parse(path)?)
170 } else if path.starts_with('/') {
171 Ok(Url::parse(&format!("http://{}{}", host, path))?)
172 } else if req.method.unwrap().eq_ignore_ascii_case("connect") {
173 Ok(Url::parse(&format!("http://{}/", path))?)
174 } else {
175 Err(format_err!("unexpected uri format"))
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 fn httparse_req(buf: &str, f: impl Fn(httparse::Request<'_, '_>)) {
184 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
185 let mut res = httparse::Request::new(&mut headers[..]);
186 res.parse(buf.as_bytes()).unwrap();
187 f(res)
188 }
189
190 #[test]
191 fn url_for_connect() {
192 httparse_req(
193 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: server.example.com:443\r\n",
194 |req| {
195 let url = url_from_httparse_req(&req).unwrap();
196 assert_eq!(url.as_str(), "http://server.example.com:443/");
197 },
198 );
199 }
200
201 #[test]
202 fn url_for_host_plus_path() {
203 httparse_req(
204 "GET /some/resource HTTP/1.1\r\nHost: server.example.com:443\r\n",
205 |req| {
206 let url = url_from_httparse_req(&req).unwrap();
207 assert_eq!(url.as_str(), "http://server.example.com:443/some/resource");
208 },
209 )
210 }
211
212 #[test]
213 fn url_for_host_plus_absolute_url() {
214 httparse_req(
215 "GET http://domain.com/some/resource HTTP/1.1\r\nHost: server.example.com\r\n",
216 |req| {
217 let url = url_from_httparse_req(&req).unwrap();
218 assert_eq!(url.as_str(), "http://domain.com/some/resource"); },
220 )
221 }
222
223 #[test]
224 fn url_for_conflicting_connect() {
225 httparse_req(
226 "CONNECT server.example.com:443 HTTP/1.1\r\nHost: conflicting.host\r\n",
227 |req| {
228 let url = url_from_httparse_req(&req).unwrap();
229 assert_eq!(url.as_str(), "http://server.example.com:443/");
230 },
231 )
232 }
233
234 #[test]
235 fn url_for_malformed_resource_path() {
236 httparse_req(
237 "GET not-a-url HTTP/1.1\r\nHost: server.example.com\r\n",
238 |req| {
239 assert!(url_from_httparse_req(&req).is_err());
240 },
241 )
242 }
243
244 #[test]
245 fn url_for_double_slash_path() {
246 httparse_req(
247 "GET //double/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
248 |req| {
249 let url = url_from_httparse_req(&req).unwrap();
250 assert_eq!(
251 url.as_str(),
252 "http://server.example.com:443//double/slashes"
253 );
254 },
255 )
256 }
257 #[test]
258 fn url_for_triple_slash_path() {
259 httparse_req(
260 "GET ///triple/slashes HTTP/1.1\r\nHost: server.example.com:443\r\n",
261 |req| {
262 let url = url_from_httparse_req(&req).unwrap();
263 assert_eq!(
264 url.as_str(),
265 "http://server.example.com:443///triple/slashes"
266 );
267 },
268 )
269 }
270
271 #[test]
272 fn url_for_query() {
273 httparse_req(
274 "GET /foo?bar=1 HTTP/1.1\r\nHost: server.example.com:443\r\n",
275 |req| {
276 let url = url_from_httparse_req(&req).unwrap();
277 assert_eq!(url.as_str(), "http://server.example.com:443/foo?bar=1");
278 },
279 )
280 }
281
282 #[test]
283 fn url_for_anchor() {
284 httparse_req(
285 "GET /foo?bar=1#anchor HTTP/1.1\r\nHost: server.example.com:443\r\n",
286 |req| {
287 let url = url_from_httparse_req(&req).unwrap();
288 assert_eq!(
289 url.as_str(),
290 "http://server.example.com:443/foo?bar=1#anchor"
291 );
292 },
293 )
294 }
295}