1use bytes::Bytes;
13use http_body_util::Empty;
14use hyper_util::rt::TokioIo;
15use tokio::io::duplex;
16use tokio::sync::oneshot;
17
18pub use httparse::{Header, Request};
19
20use crate::error::WireError;
21use crate::util::{is_chunked_slice, parse_chunked_body, parse_usize};
22use crate::wire::WireCapture;
23use crate::{WireDecode, WireEncode, WireEncodeAsync};
24use std::mem::MaybeUninit;
25
26impl<B> WireEncode for http::Request<B>
28where
29 B: http_body_util::BodyExt + Send + Sync + 'static,
30 B::Data: Send + Sync + 'static,
31 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
32{
33 fn encode(self) -> Result<Bytes, WireError> {
34 let rt = tokio::runtime::Builder::new_current_thread()
36 .enable_all()
37 .build()
38 .map_err(|e| WireError::Connection(Box::new(e)))?;
39
40 rt.block_on(self.encode_async())
42 }
43}
44
45impl<B> WireEncodeAsync for http::Request<B>
46where
47 B::Data: Send + Sync + 'static,
48 B: http_body_util::BodyExt + Send + Sync + 'static,
49 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
50{
51 #[inline]
52 async fn encode_async(self) -> Result<Bytes, WireError> {
53 use hyper::service::service_fn;
54 use std::convert::Infallible;
55
56 let version = self.version();
58 if version != http::Version::HTTP_11 && version != http::Version::HTTP_10 {
59 return Err(WireError::UnsupportedVersion);
60 }
61
62 let (client, server) = duplex(8192);
63 let capture_client = WireCapture::new(client);
64 let captured_ref = capture_client.captured.clone();
65
66 let (tx, rx) = oneshot::channel::<Result<(), WireError>>();
67
68 let server_handle = tokio::spawn(async move {
70 let tx = std::sync::Mutex::new(Some(tx));
71 let service = service_fn(move |_req: http::Request<hyper::body::Incoming>| {
72 if let Some(tx) = tx.lock().unwrap().take() {
74 let _ = tx.send(Ok(()));
75 }
76 async move {
77 Ok::<_, Infallible>(http::Response::new(Empty::<Bytes>::new()))
79 }
80 });
81
82 hyper::server::conn::http1::Builder::new()
83 .serve_connection(TokioIo::new(server), service)
84 .await
85 });
86
87 let client_handle = tokio::spawn(async move {
89 let client_connection = hyper::client::conn::http1::Builder::new()
90 .handshake(TokioIo::new(capture_client))
91 .await;
92
93 match client_connection {
94 Ok((mut sender, connection)) => {
95 tokio::spawn(connection);
97
98 sender
100 .send_request(self)
101 .await
102 .map(|_| ())
103 .map_err(|e| WireError::Connection(Box::new(e)))
104 }
105 Err(e) => Err(WireError::Connection(Box::new(e))),
106 }
107 });
108
109 rx.await.map_err(|_| WireError::Sync)??;
111
112 client_handle.abort();
114 server_handle.abort();
115
116 let result = captured_ref.lock().clone();
117 Ok(Bytes::from(result))
118 }
119}
120
121pub struct FullRequest<'headers, 'buf> {
129 pub head: httparse::Request<'headers, 'buf>,
130 pub body: &'buf [u8],
131}
132
133impl<'headers, 'buf> FullRequest<'headers, 'buf> {
134 fn parse_core(&mut self, buf: &'buf [u8], headers_len: usize) -> Result<usize, WireError> {
137 let mut content_len: Option<usize> = None;
138 let mut is_chunked = false;
139
140 for header in self.head.headers.iter() {
142 let name = header.name.as_bytes();
143 if name.len() == 14 && name.eq_ignore_ascii_case(b"Content-Length") {
144 content_len = parse_usize(header.value);
145 } else if name.len() == 17 && name.eq_ignore_ascii_case(b"Transfer-Encoding") {
146 is_chunked = is_chunked_slice(header.value);
147 }
148 }
149
150 if is_chunked {
152 let body_len =
153 parse_chunked_body(&buf[headers_len..]).ok_or(WireError::InvalidChunkedBody)?;
154 self.body = &buf[headers_len..headers_len + body_len];
155 Ok(headers_len + body_len)
156 } else {
157 let body_len = content_len.unwrap_or(0);
159 let total = headers_len + body_len;
160 if buf.len() >= total {
161 self.body = &buf[headers_len..total];
162 Ok(total)
163 } else {
164 Err(WireError::IncompleteBody(total - buf.len()))
165 }
166 }
167 }
168
169 pub fn parse(&mut self, buf: &'buf [u8]) -> Result<usize, WireError> {
171 match self.head.parse(buf) {
172 Ok(httparse::Status::Complete(headers_len)) => self.parse_core(buf, headers_len),
173 Ok(httparse::Status::Partial) => Err(WireError::PartialHead),
174 Err(err) => Err(err.into()),
175 }
176 }
177
178 pub fn parse_uninit(
180 &mut self,
181 buf: &'buf [u8],
182 headers: &'headers mut [MaybeUninit<Header<'buf>>],
183 ) -> Result<usize, WireError> {
184 match self.head.parse_with_uninit_headers(buf, headers) {
185 Ok(httparse::Status::Complete(headers_len)) => self.parse_core(buf, headers_len),
186 Ok(httparse::Status::Partial) => Err(WireError::PartialHead),
187 Err(err) => Err(err.into()),
188 }
189 }
190}
191
192impl<'headers, 'buf> WireDecode<'headers, 'buf> for FullRequest<'headers, 'buf> {
193 fn decode(
194 buf: &'buf [u8],
195 headers: &'headers mut [Header<'buf>],
196 ) -> Result<(Self, usize), WireError> {
197 let mut full_request = FullRequest {
198 head: httparse::Request::new(headers),
199 body: &[],
200 };
201
202 let total = full_request.parse(buf)?;
203 Ok((full_request, total))
204 }
205
206 fn decode_uninit(
207 buf: &'buf [u8],
208 headers: &'headers mut [MaybeUninit<Header<'buf>>],
209 ) -> Result<(Self, usize), WireError> {
210 let mut full_request = FullRequest {
211 head: httparse::Request::new(&mut []),
212 body: &[],
213 };
214
215 let total = full_request.parse_uninit(buf, headers)?;
216 Ok((full_request, total))
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use http_body_util::{Empty, Full};
224
225 #[test]
226 fn test_request_sync_no_body() {
227 let request = http::Request::builder()
228 .method("GET")
229 .uri("/api/test")
230 .header("Host", "example.com")
231 .body(Empty::<Bytes>::new())
232 .unwrap();
233
234 let bytes = request.encode().unwrap();
235 let output = String::from_utf8_lossy(&bytes);
236
237 assert!(output.contains("GET /api/test HTTP/1.1"));
238 assert!(output.contains("host: example.com"));
239 }
240
241 #[test]
242 fn test_request_sync_with_body() {
243 let body = r#"{"test":"data"}"#;
244 let request = http::Request::builder()
245 .method("POST")
246 .uri("/api/submit")
247 .header("Host", "example.com")
248 .header("Content-Type", "application/json")
249 .body(Full::new(Bytes::from(body)))
250 .unwrap();
251
252 let bytes = request.encode().unwrap();
253 let output = String::from_utf8_lossy(&bytes);
254
255 assert!(output.contains("POST /api/submit HTTP/1.1"));
256 assert!(output.contains(body));
257 }
258
259 #[test]
260 fn test_request_sync_http2_rejected() {
261 let request = http::Request::builder()
262 .method("GET")
263 .uri("/")
264 .version(http::Version::HTTP_2)
265 .body(Empty::<Bytes>::new())
266 .unwrap();
267
268 let result = request.encode();
269 assert!(matches!(result, Err(WireError::UnsupportedVersion)));
270 }
271
272 #[tokio::test]
273 async fn test_request_to_wire() {
274 let request = http::Request::builder()
275 .method("GET")
276 .uri("/api/test")
277 .header("Host", "example.com")
278 .body(Empty::<Bytes>::new())
279 .unwrap();
280
281 let bytes = request.encode_async().await.unwrap();
282 let output = String::from_utf8_lossy(&bytes);
283
284 assert!(output.contains("GET /api/test HTTP/1.1"));
285 assert!(output.contains("host: example.com"));
286 }
287
288 #[tokio::test]
289 async fn test_request_with_body_to_wire() {
290 let body = r#"{"test":"data"}"#;
291 let request = http::Request::builder()
292 .method("POST")
293 .uri("/api/submit")
294 .header("Host", "example.com")
295 .header("Content-Type", "application/json")
296 .body(Full::new(Bytes::from(body)))
297 .unwrap();
298
299 let bytes = request.encode_async().await.unwrap();
300 let output = String::from_utf8_lossy(&bytes);
301
302 assert!(output.contains("POST /api/submit HTTP/1.1"));
303 assert!(output.contains(body));
304 }
305
306 #[tokio::test]
307 async fn test_http2_request_rejected() {
308 let request = http::Request::builder()
309 .method("GET")
310 .uri("/")
311 .version(http::Version::HTTP_2)
312 .body(Empty::<Bytes>::new())
313 .unwrap();
314
315 let result = request.encode_async().await;
316 assert!(matches!(result, Err(WireError::UnsupportedVersion)));
317 }
318
319 #[test]
320 fn test_decode_request_no_body() {
321 let raw = b"GET /api/users HTTP/1.1\r\nHost: example.com\r\n\r\n";
322 let mut headers = [httparse::EMPTY_HEADER; 16];
323 let result = FullRequest::decode(raw, &mut headers);
324 assert!(result.is_ok());
325 }
326
327 #[test]
328 fn test_decode_request_with_content_length() {
329 let raw = b"POST /api/users HTTP/1.1\r\nHost: example.com\r\nContent-Length: 14\r\n\r\n{\"name\":\"foo\"}";
330 let mut headers = [httparse::EMPTY_HEADER; 16];
331 let result = FullRequest::decode(raw, &mut headers);
332 assert!(result.is_ok());
333 }
334
335 #[test]
336 fn test_decode_request_incomplete_body() {
337 let raw =
339 b"POST /api/users HTTP/1.1\r\nHost: example.com\r\nContent-Length: 13\r\n\r\nhello";
340 let mut headers = [httparse::EMPTY_HEADER; 16];
341 let result = FullRequest::decode(raw, &mut headers);
342 assert!(matches!(result, Err(WireError::IncompleteBody(_))));
343 }
344
345 #[test]
346 fn test_decode_request_incomplete_headers() {
347 let raw = b"POST /api/users HTTP/1.1\r\nHost: example.com\r\n";
348 let mut headers = [httparse::EMPTY_HEADER; 16];
349 let result = FullRequest::decode(raw, &mut headers);
350 assert!(matches!(result, Err(WireError::PartialHead)));
351 }
352
353 #[test]
354 fn test_decode_request_chunked_encoding() {
355 let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n";
356 let mut headers = [httparse::EMPTY_HEADER; 16];
357 let result = FullRequest::decode(raw, &mut headers);
358 assert!(result.is_ok());
359 }
360
361 #[test]
362 fn test_decode_request_chunked_multiple_chunks() {
363 let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
364 let mut headers = [httparse::EMPTY_HEADER; 16];
365 let result = FullRequest::decode(raw, &mut headers);
366 assert!(result.is_ok());
367 }
368
369 #[test]
370 fn test_decode_request_chunked_incomplete() {
371 let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n";
373 let mut headers = [httparse::EMPTY_HEADER; 16];
374 let result = FullRequest::decode(raw, &mut headers);
375 assert!(matches!(result, Err(WireError::InvalidChunkedBody)));
376 }
377
378 #[test]
379 fn test_decode_request_extra_data_after() {
380 let request = b"GET /api/users HTTP/1.1\r\nHost: example.com\r\n\r\n";
382 let mut raw = request.to_vec();
383 raw.extend_from_slice(b"extra garbage data");
384 let mut headers = [httparse::EMPTY_HEADER; 16];
385 let result = FullRequest::decode(&raw, &mut headers);
386 assert!(result.is_ok());
387 }
388
389 #[test]
390 fn test_decode_request_chunked_case_insensitive() {
391 let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: CHUNKED\r\n\r\n5\r\nhello\r\n0\r\n\r\n";
392 let mut headers = [httparse::EMPTY_HEADER; 16];
393 let result = FullRequest::decode(raw, &mut headers);
394 assert!(result.is_ok());
395 }
396
397 #[test]
398 fn test_decode_request_uninit_no_body() {
399 let raw = b"GET /api/users HTTP/1.1\r\nHost: example.com\r\n\r\n";
400 let mut headers = [const { MaybeUninit::uninit() }; 16];
401 let result = FullRequest::decode_uninit(raw, &mut headers);
402 assert!(result.is_ok());
403 }
404
405 #[test]
406 fn test_decode_request_uninit_with_body() {
407 let raw = b"POST /api/users HTTP/1.1\r\nHost: example.com\r\nContent-Length: 14\r\n\r\n{\"name\":\"foo\"}";
408 let mut headers = [const { MaybeUninit::uninit() }; 16];
409 let result = FullRequest::decode_uninit(raw, &mut headers);
410 assert!(result.is_ok());
411 let (req, _) = result.unwrap();
412 assert_eq!(req.body, b"{\"name\":\"foo\"}");
413 }
414}