1use crate::{
2 Body, Request,
3 dep::{http_body, http_body_util::BodyExt},
4};
5use bytes::Bytes;
6use rama_core::error::BoxError;
7use rama_http_types::proto::{
8 h1::Http1HeaderMap,
9 h2::{PseudoHeader, PseudoHeaderOrder},
10};
11use tokio::io::{AsyncWrite, AsyncWriteExt};
12
13pub async fn write_http_request<W, B>(
15 w: &mut W,
16 req: Request<B>,
17 write_headers: bool,
18 write_body: bool,
19) -> Result<Request, BoxError>
20where
21 W: AsyncWrite + Unpin + Send + Sync + 'static,
22 B: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
23{
24 let (mut parts, body) = req.into_parts();
25
26 if write_headers {
27 w.write_all(
28 format!(
29 "{} {}{} {:?}\r\n",
30 parts.method,
31 parts.uri.path(),
32 parts
33 .uri
34 .query()
35 .map(|q| format!("?{}", q))
36 .unwrap_or_default(),
37 parts.version
38 )
39 .as_bytes(),
40 )
41 .await?;
42
43 if let Some(pseudo_headers) = parts.extensions.get::<PseudoHeaderOrder>() {
44 for header in pseudo_headers.iter() {
45 match header {
46 PseudoHeader::Method => {
47 w.write_all(format!("[{}: {}]\r\n", header, parts.method).as_bytes())
48 .await?;
49 }
50 PseudoHeader::Scheme => {
51 w.write_all(
52 format!(
53 "[{}: {}]\r\n",
54 header,
55 parts.uri.scheme_str().unwrap_or("?")
56 )
57 .as_bytes(),
58 )
59 .await?;
60 }
61 PseudoHeader::Authority => {
62 w.write_all(
63 format!(
64 "[{}: {}]\r\n",
65 header,
66 parts.uri.authority().map(|a| a.as_str()).unwrap_or("?")
67 )
68 .as_bytes(),
69 )
70 .await?;
71 }
72 PseudoHeader::Path => {
73 w.write_all(format!("[{}: {}]\r\n", header, parts.uri.path()).as_bytes())
74 .await?;
75 }
76 PseudoHeader::Protocol => (), PseudoHeader::Status => (), }
79 }
80 }
81
82 let header_map = Http1HeaderMap::new(parts.headers, Some(&mut parts.extensions));
83 parts.headers = header_map.clone().consume(&mut parts.extensions);
85
86 for (name, value) in header_map {
87 match parts.version {
88 rama_http_types::Version::HTTP_2 | rama_http_types::Version::HTTP_3 => {
89 w.write_all(
91 format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?)
92 .as_bytes(),
93 )
94 .await?;
95 }
96 _ => {
97 w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
98 .await?;
99 }
100 }
101 }
102 }
103
104 let body = if write_body {
105 let body = body.collect().await.map_err(Into::into)?.to_bytes();
106 w.write_all(b"\r\n").await?;
107 if !body.is_empty() {
108 w.write_all(body.as_ref()).await?;
109 }
110 Body::from(body)
111 } else {
112 Body::new(body)
113 };
114
115 let req = Request::from_parts(parts, body);
116 Ok(req)
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[tokio::test]
124 async fn test_write_http_request_get() {
125 let mut buf = Vec::new();
126 let req = Request::builder()
127 .method("GET")
128 .uri("http://example.com")
129 .body(Body::empty())
130 .unwrap();
131
132 write_http_request(&mut buf, req, true, true).await.unwrap();
133
134 let req = String::from_utf8(buf).unwrap();
135 assert_eq!(req, "GET / HTTP/1.1\r\n\r\n");
136 }
137
138 #[tokio::test]
139 async fn test_write_http_request_get_with_headers() {
140 let mut buf = Vec::new();
141 let req = Request::builder()
142 .method("GET")
143 .uri("http://example.com")
144 .header("content-type", "text/plain")
145 .header("user-agent", "test/0")
146 .body(Body::empty())
147 .unwrap();
148
149 write_http_request(&mut buf, req, true, true).await.unwrap();
150
151 let req = String::from_utf8(buf).unwrap();
152 assert_eq!(
153 req,
154 "GET / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n"
155 );
156 }
157
158 #[tokio::test]
159 async fn test_write_http_request_get_with_headers_and_query() {
160 let mut buf = Vec::new();
161 let req = Request::builder()
162 .method("GET")
163 .uri("http://example.com?foo=bar")
164 .header("content-type", "text/plain")
165 .header("user-agent", "test/0")
166 .body(Body::empty())
167 .unwrap();
168
169 write_http_request(&mut buf, req, true, true).await.unwrap();
170
171 let req = String::from_utf8(buf).unwrap();
172 assert_eq!(
173 req,
174 "GET /?foo=bar HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\n"
175 );
176 }
177
178 #[tokio::test]
179 async fn test_write_http_request_post_with_headers_and_body() {
180 let mut buf = Vec::new();
181 let req = Request::builder()
182 .method("POST")
183 .uri("http://example.com")
184 .header("content-type", "text/plain")
185 .header("user-agent", "test/0")
186 .body(Body::from("hello"))
187 .unwrap();
188
189 write_http_request(&mut buf, req, true, true).await.unwrap();
190
191 let req = String::from_utf8(buf).unwrap();
192 assert_eq!(
193 req,
194 "POST / HTTP/1.1\r\ncontent-type: text/plain\r\nuser-agent: test/0\r\n\r\nhello"
195 );
196 }
197}