1#![deny(missing_docs)]
15#![forbid(unsafe_code)]
16
17extern crate bytes;
18pub extern crate http;
19extern crate httparse;
20
21#[cfg(feature = "basicauth")]
22extern crate base64;
23#[cfg(feature = "basicauth")]
24extern crate percent_encoding;
25
26pub type Request = http::request::Request<()>;
28pub type Response = http::response::Response<()>;
30
31pub use httparse::EMPTY_HEADER;
32
33#[derive(Debug)]
37#[allow(missing_docs)]
38pub enum Error {
39 Parse(httparse::Error),
42 Path(http::uri::InvalidUri),
43 HeaderName(http::header::InvalidHeaderName),
44 HeaderValue(http::header::InvalidHeaderValue),
45 StatusCode(http::status::InvalidStatusCode),
46 InvalidAuthority(http::uri::InvalidUriBytes),
47 #[cfg(feature = "basicauth")]
48 BasicAuth(base64::DecodeError),
50}
51
52impl std::fmt::Display for Error {
53 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
54 match self {
55 Error::Parse(x) => x.fmt(f),
56 Error::Path(x) => x.fmt(f),
57 Error::HeaderName(x) => x.fmt(f),
58 Error::HeaderValue(x) => x.fmt(f),
59 Error::StatusCode(x) => x.fmt(f),
60 Error::InvalidAuthority(x) => x.fmt(f),
61 #[cfg(feature = "basicauth")]
62 Error::BasicAuth(x) => x.fmt(f),
63 }
64 }
65}
66impl std::error::Error for Error {
67 fn cause(&self) -> Option<&dyn std::error::Error> {
68 Some(match self {
69 Error::Parse(x) => x,
70 Error::Path(x) => x,
71 Error::HeaderName(x) => x,
72 Error::HeaderValue(x) => x,
73 Error::StatusCode(x) => x,
74 Error::InvalidAuthority(x) => x,
75 #[cfg(feature = "basicauth")]
76 Error::BasicAuth(x) => x,
77 })
78 }
79}
80
81use http::header::{HeaderName, HeaderValue, HOST};
82use http::uri::{Authority, Parts as UriParts, PathAndQuery};
83use http::{Method, StatusCode};
84use std::str::FromStr;
85
86pub fn parse_request_header_easy(buf: &[u8]) -> Result<Option<(Request, &[u8])>, Error> {
96 let mut h = [httparse::EMPTY_HEADER; 50];
97 parse_request_header(buf, h.as_mut(), None)
98}
99
100pub fn parse_response_header_easy(buf: &[u8]) -> Result<Option<(Response, &[u8])>, Error> {
108 let mut h = [httparse::EMPTY_HEADER; 50];
109 parse_response_header(buf, h.as_mut())
110}
111
112pub fn parse_request_header<'a, 'b>(
148 buf: &'a [u8],
149 headers_buffer: &'b mut [httparse::Header<'a>],
150 scheme: Option<http::uri::Scheme>,
151) -> Result<Option<(Request, &'a [u8])>, Error> {
152 let mut parsed_request = httparse::Request::new(headers_buffer);
153 let header_size = match parsed_request.parse(buf).map_err(Error::Parse)? {
154 httparse::Status::Partial => return Ok(None),
155 httparse::Status::Complete(size) => size,
156 };
157 let trailer = &buf[header_size..];
158 let mut request = Request::new(());
159 *request.method_mut() = Method::from_str(parsed_request.method.unwrap())
160 .map_err(|_| Error::Parse(httparse::Error::Token))?;
161 *request.version_mut() = http::Version::HTTP_11; let mut up: UriParts = Default::default();
163 up.path_and_query =
164 Some(PathAndQuery::from_str(parsed_request.path.unwrap()).map_err(Error::Path)?);
165
166 for header in parsed_request.headers {
167 let n = HeaderName::from_str(header.name).map_err(Error::HeaderName)?;
168 let v = HeaderValue::from_bytes(header.value).map_err(Error::HeaderValue)?;
169 request.headers_mut().append(n, v);
170 }
171 if scheme.is_some() {
172 if let Some(hv) = request.headers().get(HOST) {
173 up.scheme = scheme;
174 let authority_buf = bytes::Bytes::from(hv.as_bytes());
175 #[allow(unused_mut)]
176 let mut authority_buf = authority_buf;
177 #[cfg(feature = "basicauth")]
178 {
179 use percent_encoding::{percent_encode, USERINFO_ENCODE_SET};
180 use std::io::Write;
181
182 #[derive(Clone, Copy)]
183 struct CorrectedUserinfoEncodeSet;
184 impl percent_encoding::EncodeSet for CorrectedUserinfoEncodeSet {
185 fn contains(&self, byte: u8) -> bool {
186 if byte == b'%' {
187 return true;
188 }
189 USERINFO_ENCODE_SET.contains(byte)
190 }
191 }
192
193 if let Some(u) = request.headers().get(http::header::AUTHORIZATION) {
194 let u = u.as_bytes();
195 let mut b = false;
196 b |= u.starts_with(b"Basic ");
197 b |= u.starts_with(b"basic ");
198 b |= u.starts_with(b"BASIC ");
199 if b && u.len() > 8 {
200 let u = &u[6..];
201 let u = base64::decode(u).map_err(Error::BasicAuth)?;
202
203 let u = u[..]
205 .split(|v| *v == b':')
206 .map(|v| percent_encode(v, CorrectedUserinfoEncodeSet).to_string())
207 .collect::<Vec<_>>()
208 .join(":")
209 .into_bytes();
210 let l = u.len();
213 let mut u = std::io::Cursor::new(u);
214 u.set_position(l as u64);
215 u.write_all(b"@").unwrap();
216 u.write_all(authority_buf.as_ref()).unwrap();
217 authority_buf = bytes::Bytes::from(u.into_inner());
218 }
219 }
220 }
221 let a = Authority::from_shared(authority_buf).map_err(Error::InvalidAuthority)?;
222 up.authority = Some(a);
223 }
224 }
225 *request.uri_mut() = http::Uri::from_parts(up).unwrap();
226 Ok(Some((request, trailer)))
227}
228
229pub fn parse_response_header<'a, 'b>(
252 buf: &'a [u8],
253 headers_buffer: &'b mut [httparse::Header<'a>],
254) -> Result<Option<(Response, &'a [u8])>, Error> {
255 let mut x = httparse::Response::new(headers_buffer);
256 let n = match x.parse(buf).map_err(Error::Parse)? {
257 httparse::Status::Partial => return Ok(None),
258 httparse::Status::Complete(size) => size,
259 };
260 let trailer = &buf[n..];
261 let mut r = Response::new(());
262 *r.status_mut() = StatusCode::from_u16(x.code.unwrap()).map_err(Error::StatusCode)?;
263 *r.version_mut() = http::Version::HTTP_11; for h in x.headers {
266 let n = HeaderName::from_str(h.name).map_err(Error::HeaderName)?;
267 let v = HeaderValue::from_bytes(h.value).map_err(Error::HeaderValue)?;
268 r.headers_mut().append(n, v);
269 }
270 Ok(Some((r, trailer)))
271}
272
273fn io_other_error(msg: &'static str) -> std::io::Error {
274 let e: Box<dyn std::error::Error + Send + Sync + 'static> = msg.into();
275 std::io::Error::new(std::io::ErrorKind::Other, e)
276}
277
278pub fn write_request_header<T>(
290 r: &http::Request<T>,
291 mut io: impl std::io::Write,
292) -> std::io::Result<usize> {
293 let mut len = 0;
294 let verb = r.method().as_str();
295 let path = r
296 .uri()
297 .path_and_query()
298 .ok_or_else(|| io_other_error("Invalid URI"))?;
299
300 let need_to_insert_host = r.uri().host().is_some() && !r.headers().contains_key(HOST);
301
302 macro_rules! w {
303 ($x:expr) => {
304 io.write_all($x)?;
305 len += $x.len();
306 };
307 }
308 w!(verb.as_bytes());
309 w!(b" ");
310 w!(path.as_str().as_bytes());
311 w!(b" HTTP/1.1\r\n");
312
313 if need_to_insert_host {
314 w!(b"Host: ");
315 let host = r.uri().host().unwrap();
316 w!(host.as_bytes());
317 if let Some(p) = r.uri().port_part() {
318 w!(b":");
319 w!(p.as_str().as_bytes());
320 }
321 w!(b"\r\n");
322 }
323 #[cfg(feature = "basicauth")]
324 {
325 let already_present = r.headers().get(http::header::AUTHORIZATION).is_some();
326 let at_sign = r
327 .uri()
328 .authority_part()
329 .map_or(false, |x| x.as_str().contains('@'));
330 if !already_present && at_sign {
331 w!(b"Authorization: Basic ");
332 let a = r.uri().authority_part().unwrap().as_str();
333 let a = &a[0..(a.find('@').unwrap())];
334 let a = a
335 .as_bytes()
336 .split(|v| *v == b':')
337 .map(|v| percent_encoding::percent_decode(v).collect::<Vec<u8>>())
338 .collect::<Vec<Vec<u8>>>()
339 .join(&b':');
340 let a = base64::encode(&a);
341 w!(a.as_bytes());
342 w!(b"\r\n");
343 }
344 }
345
346 for (hn, hv) in r.headers() {
347 w!(hn.as_str().as_bytes());
348 w!(b": ");
349 w!(hv.as_bytes());
350 w!(b"\r\n");
351 }
352
353 w!(b"\r\n");
354
355 Ok(len)
356}
357
358pub fn write_response_header<T>(
364 r: &http::Response<T>,
365 mut io: impl std::io::Write,
366) -> std::io::Result<usize> {
367 let mut len = 0;
368 macro_rules! w {
369 ($x:expr) => {
370 io.write_all($x)?;
371 len += $x.len();
372 };
373 }
374
375 let status = r.status();
376 let code = status.as_str();
377 let reason = status.canonical_reason().unwrap_or("Unknown");
378 let headers = r.headers();
379
380 w!(b"HTTP/1.1 ");
381 w!(code.as_bytes());
382 w!(b" ");
383 w!(reason.as_bytes());
384 w!(b"\r\n");
385
386 for (hn, hv) in headers {
387 w!(hn.as_str().as_bytes());
388 w!(b": ");
389 w!(hv.as_bytes());
390 w!(b"\r\n");
391 }
392
393 w!(b"\r\n");
394 Ok(len)
395}
396
397pub fn request_header_to_vec<T>(r: &http::Request<T>) -> Vec<u8> {
401 let v = Vec::with_capacity(120);
402 let mut c = std::io::Cursor::new(v);
403 write_request_header(r, &mut c).unwrap();
404 c.into_inner()
405}
406
407pub fn response_header_to_vec<T>(r: &http::Response<T>) -> Vec<u8> {
411 let v = Vec::with_capacity(120);
412 let mut c = std::io::Cursor::new(v);
413 write_response_header(r, &mut c).unwrap();
414 c.into_inner()
415}
416
417#[cfg(test)]
418mod fuzztest;
419#[cfg(test)]
420mod test;