rama_http/io/
request.rs

1use crate::{
2    Body, Request,
3    dep::{http_body, http_body_util::BodyExt},
4};
5use bytes::Bytes;
6use rama_core::error::BoxError;
7use rama_http_types::proto::{
8    h1::Http1HeaderMap,
9    h2::{PseudoHeader, PseudoHeaderOrder},
10};
11use tokio::io::{AsyncWrite, AsyncWriteExt};
12
13/// Write an HTTP request to a writer in std http format.
14pub async fn write_http_request<W, B>(
15    w: &mut W,
16    req: Request<B>,
17    write_headers: bool,
18    write_body: bool,
19) -> Result<Request, BoxError>
20where
21    W: AsyncWrite + Unpin + Send + Sync + 'static,
22    B: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
23{
24    let (mut parts, body) = req.into_parts();
25
26    if write_headers {
27        w.write_all(
28            format!(
29                "{} {}{} {:?}\r\n",
30                parts.method,
31                parts.uri.path(),
32                parts
33                    .uri
34                    .query()
35                    .map(|q| format!("?{}", q))
36                    .unwrap_or_default(),
37                parts.version
38            )
39            .as_bytes(),
40        )
41        .await?;
42
43        if let Some(pseudo_headers) = parts.extensions.get::<PseudoHeaderOrder>() {
44            for header in pseudo_headers.iter() {
45                match header {
46                    PseudoHeader::Method => {
47                        w.write_all(format!("[{}: {}]\r\n", header, parts.method).as_bytes())
48                            .await?;
49                    }
50                    PseudoHeader::Scheme => {
51                        w.write_all(
52                            format!(
53                                "[{}: {}]\r\n",
54                                header,
55                                parts.uri.scheme_str().unwrap_or("?")
56                            )
57                            .as_bytes(),
58                        )
59                        .await?;
60                    }
61                    PseudoHeader::Authority => {
62                        w.write_all(
63                            format!(
64                                "[{}: {}]\r\n",
65                                header,
66                                parts.uri.authority().map(|a| a.as_str()).unwrap_or("?")
67                            )
68                            .as_bytes(),
69                        )
70                        .await?;
71                    }
72                    PseudoHeader::Path => {
73                        w.write_all(format!("[{}: {}]\r\n", header, parts.uri.path()).as_bytes())
74                            .await?;
75                    }
76                    PseudoHeader::Protocol => (), // TODO: move ext h2 protocol out of h2 proto core once we need this info
77                    PseudoHeader::Status => (),   // not expected in request
78                }
79            }
80        }
81
82        let header_map = Http1HeaderMap::new(parts.headers, Some(&mut parts.extensions));
83        // put a clone of this data back into parts as we don't really want to consume it, just trace it
84        parts.headers = header_map.clone().consume(&mut parts.extensions);
85
86        for (name, value) in header_map {
87            match parts.version {
88                rama_http_types::Version::HTTP_2 | rama_http_types::Version::HTTP_3 => {
89                    // write lower-case for H2/H3
90                    w.write_all(
91                        format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?)
92                            .as_bytes(),
93                    )
94                    .await?;
95                }
96                _ => {
97                    w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
98                        .await?;
99                }
100            }
101        }
102    }
103
104    let body = if write_body {
105        let body = body.collect().await.map_err(Into::into)?.to_bytes();
106        w.write_all(b"\r\n").await?;
107        if !body.is_empty() {
108            w.write_all(body.as_ref()).await?;
109        }
110        Body::from(body)
111    } else {
112        Body::new(body)
113    };
114
115    let req = Request::from_parts(parts, body);
116    Ok(req)
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[tokio::test]
124    async fn test_write_http_request_get() {
125        let mut buf = Vec::new();
126        let req = Request::builder()
127            .method("GET")
128            .uri("http://example.com")
129            .body(Body::empty())
130            .unwrap();
131
132        write_http_request(&mut buf, req, true, true).await.unwrap();
133
134        let req = String::from_utf8(buf).unwrap();
135        assert_eq!(req, "GET / HTTP/1.1\r\n\r\n");
136    }
137
138    #[tokio::test]
139    async fn test_write_http_request_get_with_headers() {
140        let mut buf = Vec::new();
141        let req = Request::builder()
142            .method("GET")
143            .uri("http://example.com")
144            .header("content-type", "text/plain")
145            .header("user-agent", "test/0")
146            .body(Body::empty())
147            .unwrap();
148
149        write_http_request(&mut buf, req, true, true).await.unwrap();
150
151        let req = String::from_utf8(buf).unwrap();
152        assert_eq!(
153            req,
154            "GET / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n"
155        );
156    }
157
158    #[tokio::test]
159    async fn test_write_http_request_get_with_headers_and_query() {
160        let mut buf = Vec::new();
161        let req = Request::builder()
162            .method("GET")
163            .uri("http://example.com?foo=bar")
164            .header("content-type", "text/plain")
165            .header("user-agent", "test/0")
166            .body(Body::empty())
167            .unwrap();
168
169        write_http_request(&mut buf, req, true, true).await.unwrap();
170
171        let req = String::from_utf8(buf).unwrap();
172        assert_eq!(
173            req,
174            "GET /?foo=bar HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n"
175        );
176    }
177
178    #[tokio::test]
179    async fn test_write_http_request_post_with_headers_and_body() {
180        let mut buf = Vec::new();
181        let req = Request::builder()
182            .method("POST")
183            .uri("http://example.com")
184            .header("content-type", "text/plain")
185            .header("user-agent", "test/0")
186            .body(Body::from("hello"))
187            .unwrap();
188
189        write_http_request(&mut buf, req, true, true).await.unwrap();
190
191        let req = String::from_utf8(buf).unwrap();
192        assert_eq!(
193            req,
194            "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello"
195        );
196    }
197}