1use std::{
5 fmt::{Debug, Formatter},
6 str::FromStr,
7 sync::Arc,
8};
9
10use bytes::Bytes;
11use http::{HeaderMap, HeaderName, HeaderValue, Request, Response, request::Parts};
12use hyper::{
13 body::{Body, Incoming},
14 client::conn::http1::{self, SendRequest},
15};
16use hyper_util::rt::tokio::WithHyperIo;
17use tokio::{
18 io::{AsyncRead, AsyncWrite},
19 sync::Mutex,
20 task::JoinSet,
21};
22
23use crate::{Client, Error};
24
25const ENCODING_CHUNKED: HeaderValue = HeaderValue::from_static("chunked");
27
28const MAX_PARSED_HEADERS: usize = 16;
30
31#[derive(Clone)]
36pub struct Http1<B> {
37 inner: Arc<Inner<B>>,
38}
39
40impl<B> Debug for Http1<B> {
41 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("Http1").finish_non_exhaustive()
43 }
44}
45
46struct Inner<B> {
47 client: Mutex<SendRequest<B>>,
48 _runner: JoinSet<()>,
49}
50
51impl<B> Client<B> for Http1<B>
52where
53 B: Body + Send + 'static,
54 B::Data: Send,
55 B::Error: Send + Sync + 'static,
56{
57 async fn send(&self, req: Request<B>) -> Result<Response<Incoming>, Error> {
58 let mut client = self.inner.client.lock().await;
59
60 client
61 .send_request(req)
62 .await
63 .inspect_err(|e| {
64 tracing::error!(error = %e, "sending request");
65 })
66 .map_err(From::from)
67 }
68}
69
70pub async fn connect<B>(
72 lower: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
73) -> Result<Http1<B>, Error>
74where
75 B: Body + Send + 'static,
76 B::Data: Send,
77 B::Error: core::error::Error + Send + Sync + 'static,
78{
79 let (client, conn) = http1::handshake(WithHyperIo::new(lower))
80 .await
81 .inspect_err(|e| {
82 tracing::error!(error = %e, "sending request");
83 })
84 .map_err(Error::from)?;
85
86 let mut joinset = JoinSet::new();
87
88 joinset.spawn(async move {
89 if let Err(e) = conn.with_upgrades().await {
90 tracing::error!(?e, "error in http/1.1 connection; closing connection");
91 }
92 });
93
94 Ok(Http1 {
95 inner: Arc::new(Inner {
96 client: Mutex::new(client),
97 _runner: joinset,
98 }),
99 })
100}
101
102pub async fn connect_tcp<B>(url: &url::Url) -> Result<Http1<B>, Error>
104where
105 B: Body + Send + 'static,
106 B::Data: Send,
107 B::Error: core::error::Error + Send + Sync + 'static,
108{
109 let conn = crate::dial_tcp(url).await?;
110 connect(conn).await
111}
112
113pub async fn connect_tls<B>(url: &url::Url) -> Result<Http1<B>, Error>
115where
116 B: Body + Send + 'static,
117 B::Data: Send,
118 B::Error: core::error::Error + Send + Sync + 'static,
119{
120 let conn = crate::dial_tls(url, [b"http/1.1".to_vec()]).await?;
121 connect(conn).await
122}
123
124fn parse_request_parts(buf: &[u8]) -> Result<(Parts, usize), Error> {
130 let mut headers = [httparse::EMPTY_HEADER; MAX_PARSED_HEADERS];
131 let mut req = httparse::Request::new(&mut headers);
132
133 let res = req.parse(buf).map_err(|err| {
134 tracing::trace!(error = %err, "error parsing http request");
135 Error::InvalidInput
136 })?;
137 if res.is_partial() {
138 tracing::trace!(request = ?req, "incomplete http request");
139 return Err(Error::InvalidInput);
140 }
141
142 let httparse::Request {
143 method: Some(method),
144 path: Some(uri),
145 version: Some(1),
146 headers,
147 ..
148 } = req
149 else {
150 tracing::trace!("invalid http request");
151 return Err(Error::InvalidInput);
152 };
153
154 let mut builder = Request::builder()
156 .version(http::Version::HTTP_11)
157 .method(method)
158 .uri(uri);
159 for hdr in headers {
160 let name = HeaderName::from_str(hdr.name).map_err(|err| {
161 tracing::trace!(error = %err, "error parsing http header name");
162 Error::InvalidInput
163 })?;
164 let value = HeaderValue::from_bytes(hdr.value).map_err(|err| {
165 tracing::trace!(error = %err, "error parsing http header value");
166 Error::InvalidInput
167 })?;
168 builder = builder.header(name, value);
169 }
170
171 let (parts, _) = builder
172 .body(())
173 .map_err(|err| {
174 tracing::trace!(error = %err, "error building, invalid http request");
175 Error::InvalidInput
176 })?
177 .into_parts();
178 Ok((parts, res.unwrap()))
179}
180
181fn parse_body(headers: &HeaderMap, body: &[u8]) -> Result<Bytes, Error> {
188 match headers.get("transfer-encoding") {
189 None => Ok(Bytes::copy_from_slice(body)),
190 Some(encoding) if encoding == ENCODING_CHUNKED => {
191 let mut idx = 0;
192 let mut bytes = bytes::BytesMut::new();
193 while let Ok(httparse::Status::Complete((start_offset, chunk_size))) =
194 httparse::parse_chunk_size(&body[idx..])
195 {
196 let start_idx = idx + start_offset;
197 let end_idx = start_idx + chunk_size as usize;
198 let chunk = &body[start_idx..end_idx];
199 tracing::trace!(start_idx, end_idx, ?chunk, "parsed chunk");
200 bytes.extend_from_slice(chunk);
201 idx += start_offset + chunk_size as usize;
202 }
203 Ok(bytes.freeze())
204 }
205 Some(encoding) => {
206 tracing::trace!(?encoding, "unsupported transfer encoding");
207 Err(Error::InvalidInput)
208 }
209 }
210}
211
212pub fn parse_request(buf: &[u8]) -> Result<Request<String>, Error> {
218 let (parts, offset) = parse_request_parts(buf)?;
219 let bytes = parse_body(&parts.headers, &buf[offset..])?;
220 let body = String::from_utf8(bytes.to_vec()).map_err(|_| Error::InvalidInput)?;
221 Ok(Request::from_parts(parts, body))
222}