Skip to main content

http_wire/
request.rs

1//! HTTP request encoding and decoding.
2//!
3//! This module handles the serialization of `http::Request` objects into wire-format bytes
4//! and the parsing of raw bytes to determine request boundaries.
5//!
6//! # Request Encoding
7//!
8//! The [`WireEncode`] and [`WireEncodeAsync`] traits are implemented for [`http::Request`],
9//! allowing you to serialize requests to bytes.
10//!
11
12use bytes::Bytes;
13use http_body_util::Empty;
14use hyper_util::rt::TokioIo;
15use tokio::io::duplex;
16use tokio::sync::oneshot;
17
18pub use httparse::{Header, Request};
19
20use crate::error::WireError;
21use crate::util::{is_chunked_slice, parse_chunked_body, parse_usize};
22use crate::wire::WireCapture;
23use crate::{WireDecode, WireEncode, WireEncodeAsync};
24use std::mem::MaybeUninit;
25
26// Implementation of WireEncode for Request
27impl<B> WireEncode for http::Request<B>
28where
29    B: http_body_util::BodyExt + Send + Sync + 'static,
30    B::Data: Send + Sync + 'static,
31    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
32{
33    fn encode(self) -> Result<Bytes, WireError> {
34        // Create a minimal single-threaded runtime
35        let rt = tokio::runtime::Builder::new_current_thread()
36            .enable_all()
37            .build()
38            .map_err(|e| WireError::Connection(Box::new(e)))?;
39
40        // Block on the async encode method
41        rt.block_on(self.encode_async())
42    }
43}
44
45impl<B> WireEncodeAsync for http::Request<B>
46where
47    B::Data: Send + Sync + 'static,
48    B: http_body_util::BodyExt + Send + Sync + 'static,
49    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
50{
51    #[inline]
52    async fn encode_async(self) -> Result<Bytes, WireError> {
53        use hyper::service::service_fn;
54        use std::convert::Infallible;
55
56        // Check HTTP version - only HTTP/1.1 and HTTP/1.0 are supported
57        let version = self.version();
58        if version != http::Version::HTTP_11 && version != http::Version::HTTP_10 {
59            return Err(WireError::UnsupportedVersion);
60        }
61
62        let (client, server) = duplex(8192);
63        let capture_client = WireCapture::new(client);
64        let captured_ref = capture_client.captured.clone();
65
66        let (tx, rx) = oneshot::channel::<Result<(), WireError>>();
67
68        // Spawn a mock server that will accept the connection and read the request
69        let server_handle = tokio::spawn(async move {
70            let tx = std::sync::Mutex::new(Some(tx));
71            let service = service_fn(move |_req: http::Request<hyper::body::Incoming>| {
72                // Signal that the request has been received
73                if let Some(tx) = tx.lock().unwrap().take() {
74                    let _ = tx.send(Ok(()));
75                }
76                async move {
77                    // Return a minimal response
78                    Ok::<_, Infallible>(http::Response::new(Empty::<Bytes>::new()))
79                }
80            });
81
82            hyper::server::conn::http1::Builder::new()
83                .serve_connection(TokioIo::new(server), service)
84                .await
85        });
86
87        // Send the request through the client side and capture what's written
88        let client_handle = tokio::spawn(async move {
89            let client_connection = hyper::client::conn::http1::Builder::new()
90                .handshake(TokioIo::new(capture_client))
91                .await;
92
93            match client_connection {
94                Ok((mut sender, connection)) => {
95                    // Spawn the connection driver
96                    tokio::spawn(connection);
97
98                    // Send the request
99                    sender
100                        .send_request(self)
101                        .await
102                        .map(|_| ())
103                        .map_err(|e| WireError::Connection(Box::new(e)))
104                }
105                Err(e) => Err(WireError::Connection(Box::new(e))),
106            }
107        });
108
109        // Wait for the server to receive the request
110        rx.await.map_err(|_| WireError::Sync)??;
111
112        // Cleanup
113        client_handle.abort();
114        server_handle.abort();
115
116        let result = captured_ref.lock().clone();
117        Ok(Bytes::from(result))
118    }
119}
120
121/// Decoder for determining HTTP request message length.
122///
123/// Returns the total length in bytes of a complete HTTP request (headers + body),
124/// or `None` if the request is incomplete or malformed.
125///
126/// Supports `Content-Length`, `Transfer-Encoding: chunked`, and body-less requests.
127///
128pub struct FullRequest<'headers, 'buf> {
129    pub head: httparse::Request<'headers, 'buf>,
130    pub body: &'buf [u8],
131}
132
133impl<'headers, 'buf> FullRequest<'headers, 'buf> {
134    /// Core parsing logic shared between parse and parse_uninit.
135    /// Assumes headers have already been parsed and self.head.headers is populated.
136    fn parse_core(&mut self, buf: &'buf [u8], headers_len: usize) -> Result<usize, WireError> {
137        let mut content_len: Option<usize> = None;
138        let mut is_chunked = false;
139
140        // Scan headers for Content-Length or Transfer-Encoding
141        for header in self.head.headers.iter() {
142            let name = header.name.as_bytes();
143            if name.len() == 14 && name.eq_ignore_ascii_case(b"Content-Length") {
144                content_len = parse_usize(header.value);
145            } else if name.len() == 17 && name.eq_ignore_ascii_case(b"Transfer-Encoding") {
146                is_chunked = is_chunked_slice(header.value);
147            }
148        }
149
150        // Calculate body length
151        if is_chunked {
152            let body_len =
153                parse_chunked_body(&buf[headers_len..]).ok_or(WireError::InvalidChunkedBody)?;
154            self.body = &buf[headers_len..headers_len + body_len];
155            Ok(headers_len + body_len)
156        } else {
157            // If content-length is missing, length is 0
158            let body_len = content_len.unwrap_or(0);
159            let total = headers_len + body_len;
160            if buf.len() >= total {
161                self.body = &buf[headers_len..total];
162                Ok(total)
163            } else {
164                Err(WireError::IncompleteBody(total - buf.len()))
165            }
166        }
167    }
168
169    /// Parse using initialized headers (compatible with httparse::Request::parse).
170    pub fn parse(&mut self, buf: &'buf [u8]) -> Result<usize, WireError> {
171        match self.head.parse(buf) {
172            Ok(httparse::Status::Complete(headers_len)) => self.parse_core(buf, headers_len),
173            Ok(httparse::Status::Partial) => Err(WireError::PartialHead),
174            Err(err) => Err(err.into()),
175        }
176    }
177
178    /// Parse using uninitialized headers (optimized, uses parse_with_uninit_headers).
179    pub fn parse_uninit(
180        &mut self,
181        buf: &'buf [u8],
182        headers: &'headers mut [MaybeUninit<Header<'buf>>],
183    ) -> Result<usize, WireError> {
184        match self.head.parse_with_uninit_headers(buf, headers) {
185            Ok(httparse::Status::Complete(headers_len)) => self.parse_core(buf, headers_len),
186            Ok(httparse::Status::Partial) => Err(WireError::PartialHead),
187            Err(err) => Err(err.into()),
188        }
189    }
190}
191
192impl<'headers, 'buf> WireDecode<'headers, 'buf> for FullRequest<'headers, 'buf> {
193    fn decode(
194        buf: &'buf [u8],
195        headers: &'headers mut [Header<'buf>],
196    ) -> Result<(Self, usize), WireError> {
197        let mut full_request = FullRequest {
198            head: httparse::Request::new(headers),
199            body: &[],
200        };
201
202        let total = full_request.parse(buf)?;
203        Ok((full_request, total))
204    }
205
206    fn decode_uninit(
207        buf: &'buf [u8],
208        headers: &'headers mut [MaybeUninit<Header<'buf>>],
209    ) -> Result<(Self, usize), WireError> {
210        let mut full_request = FullRequest {
211            head: httparse::Request::new(&mut []),
212            body: &[],
213        };
214
215        let total = full_request.parse_uninit(buf, headers)?;
216        Ok((full_request, total))
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use http_body_util::{Empty, Full};
224
225    #[test]
226    fn test_request_sync_no_body() {
227        let request = http::Request::builder()
228            .method("GET")
229            .uri("/api/test")
230            .header("Host", "example.com")
231            .body(Empty::<Bytes>::new())
232            .unwrap();
233
234        let bytes = request.encode().unwrap();
235        let output = String::from_utf8_lossy(&bytes);
236
237        assert!(output.contains("GET /api/test HTTP/1.1"));
238        assert!(output.contains("host: example.com"));
239    }
240
241    #[test]
242    fn test_request_sync_with_body() {
243        let body = r#"{"test":"data"}"#;
244        let request = http::Request::builder()
245            .method("POST")
246            .uri("/api/submit")
247            .header("Host", "example.com")
248            .header("Content-Type", "application/json")
249            .body(Full::new(Bytes::from(body)))
250            .unwrap();
251
252        let bytes = request.encode().unwrap();
253        let output = String::from_utf8_lossy(&bytes);
254
255        assert!(output.contains("POST /api/submit HTTP/1.1"));
256        assert!(output.contains(body));
257    }
258
259    #[test]
260    fn test_request_sync_http2_rejected() {
261        let request = http::Request::builder()
262            .method("GET")
263            .uri("/")
264            .version(http::Version::HTTP_2)
265            .body(Empty::<Bytes>::new())
266            .unwrap();
267
268        let result = request.encode();
269        assert!(matches!(result, Err(WireError::UnsupportedVersion)));
270    }
271
272    #[tokio::test]
273    async fn test_request_to_wire() {
274        let request = http::Request::builder()
275            .method("GET")
276            .uri("/api/test")
277            .header("Host", "example.com")
278            .body(Empty::<Bytes>::new())
279            .unwrap();
280
281        let bytes = request.encode_async().await.unwrap();
282        let output = String::from_utf8_lossy(&bytes);
283
284        assert!(output.contains("GET /api/test HTTP/1.1"));
285        assert!(output.contains("host: example.com"));
286    }
287
288    #[tokio::test]
289    async fn test_request_with_body_to_wire() {
290        let body = r#"{"test":"data"}"#;
291        let request = http::Request::builder()
292            .method("POST")
293            .uri("/api/submit")
294            .header("Host", "example.com")
295            .header("Content-Type", "application/json")
296            .body(Full::new(Bytes::from(body)))
297            .unwrap();
298
299        let bytes = request.encode_async().await.unwrap();
300        let output = String::from_utf8_lossy(&bytes);
301
302        assert!(output.contains("POST /api/submit HTTP/1.1"));
303        assert!(output.contains(body));
304    }
305
306    #[tokio::test]
307    async fn test_http2_request_rejected() {
308        let request = http::Request::builder()
309            .method("GET")
310            .uri("/")
311            .version(http::Version::HTTP_2)
312            .body(Empty::<Bytes>::new())
313            .unwrap();
314
315        let result = request.encode_async().await;
316        assert!(matches!(result, Err(WireError::UnsupportedVersion)));
317    }
318
319    #[test]
320    fn test_decode_request_no_body() {
321        let raw = b"GET /api/users HTTP/1.1\r\nHost: example.com\r\n\r\n";
322        let mut headers = [httparse::EMPTY_HEADER; 16];
323        let result = FullRequest::decode(raw, &mut headers);
324        assert!(result.is_ok());
325    }
326
327    #[test]
328    fn test_decode_request_with_content_length() {
329        let raw = b"POST /api/users HTTP/1.1\r\nHost: example.com\r\nContent-Length: 14\r\n\r\n{\"name\":\"foo\"}";
330        let mut headers = [httparse::EMPTY_HEADER; 16];
331        let result = FullRequest::decode(raw, &mut headers);
332        assert!(result.is_ok());
333    }
334
335    #[test]
336    fn test_decode_request_incomplete_body() {
337        // Content-Length says 13, but body is only 5 bytes
338        let raw =
339            b"POST /api/users HTTP/1.1\r\nHost: example.com\r\nContent-Length: 13\r\n\r\nhello";
340        let mut headers = [httparse::EMPTY_HEADER; 16];
341        let result = FullRequest::decode(raw, &mut headers);
342        assert!(matches!(result, Err(WireError::IncompleteBody(_))));
343    }
344
345    #[test]
346    fn test_decode_request_incomplete_headers() {
347        let raw = b"POST /api/users HTTP/1.1\r\nHost: example.com\r\n";
348        let mut headers = [httparse::EMPTY_HEADER; 16];
349        let result = FullRequest::decode(raw, &mut headers);
350        assert!(matches!(result, Err(WireError::PartialHead)));
351    }
352
353    #[test]
354    fn test_decode_request_chunked_encoding() {
355        let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n";
356        let mut headers = [httparse::EMPTY_HEADER; 16];
357        let result = FullRequest::decode(raw, &mut headers);
358        assert!(result.is_ok());
359    }
360
361    #[test]
362    fn test_decode_request_chunked_multiple_chunks() {
363        let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n";
364        let mut headers = [httparse::EMPTY_HEADER; 16];
365        let result = FullRequest::decode(raw, &mut headers);
366        assert!(result.is_ok());
367    }
368
369    #[test]
370    fn test_decode_request_chunked_incomplete() {
371        // Missing final 0\r\n\r\n
372        let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n";
373        let mut headers = [httparse::EMPTY_HEADER; 16];
374        let result = FullRequest::decode(raw, &mut headers);
375        assert!(matches!(result, Err(WireError::InvalidChunkedBody)));
376    }
377
378    #[test]
379    fn test_decode_request_extra_data_after() {
380        // Buffer has extra data after the request - should return correct length
381        let request = b"GET /api/users HTTP/1.1\r\nHost: example.com\r\n\r\n";
382        let mut raw = request.to_vec();
383        raw.extend_from_slice(b"extra garbage data");
384        let mut headers = [httparse::EMPTY_HEADER; 16];
385        let result = FullRequest::decode(&raw, &mut headers);
386        assert!(result.is_ok());
387    }
388
389    #[test]
390    fn test_decode_request_chunked_case_insensitive() {
391        let raw = b"POST /api/data HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: CHUNKED\r\n\r\n5\r\nhello\r\n0\r\n\r\n";
392        let mut headers = [httparse::EMPTY_HEADER; 16];
393        let result = FullRequest::decode(raw, &mut headers);
394        assert!(result.is_ok());
395    }
396
397    #[test]
398    fn test_decode_request_uninit_no_body() {
399        let raw = b"GET /api/users HTTP/1.1\r\nHost: example.com\r\n\r\n";
400        let mut headers = [const { MaybeUninit::uninit() }; 16];
401        let result = FullRequest::decode_uninit(raw, &mut headers);
402        assert!(result.is_ok());
403    }
404
405    #[test]
406    fn test_decode_request_uninit_with_body() {
407        let raw = b"POST /api/users HTTP/1.1\r\nHost: example.com\r\nContent-Length: 14\r\n\r\n{\"name\":\"foo\"}";
408        let mut headers = [const { MaybeUninit::uninit() }; 16];
409        let result = FullRequest::decode_uninit(raw, &mut headers);
410        assert!(result.is_ok());
411        let (req, _) = result.unwrap();
412        assert_eq!(req.body, b"{\"name\":\"foo\"}");
413    }
414}