1use crate::{
2 dep::{http_body, http_body_util::BodyExt},
3 Body, Response,
4};
5use bytes::Bytes;
6use rama_core::error::BoxError;
7use rama_http_types::proto::{
8 h1::Http1HeaderMap,
9 h2::{PseudoHeader, PseudoHeaderOrder},
10};
11use tokio::io::{AsyncWrite, AsyncWriteExt};
12
13pub async fn write_http_response<W, B>(
15 w: &mut W,
16 res: Response<B>,
17 write_headers: bool,
18 write_body: bool,
19) -> Result<Response, BoxError>
20where
21 W: AsyncWrite + Unpin + Send + Sync + 'static,
22 B: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
23{
24 let (mut parts, body) = res.into_parts();
25
26 if write_headers {
27 w.write_all(
28 format!(
29 "{:?} {}{}\r\n",
30 parts.version,
31 parts.status.as_u16(),
32 parts
33 .status
34 .canonical_reason()
35 .map(|r| format!(" {}", r))
36 .unwrap_or_default(),
37 )
38 .as_bytes(),
39 )
40 .await?;
41
42 if let Some(pseudo_headers) = parts.extensions.get::<PseudoHeaderOrder>() {
43 for header in pseudo_headers.iter() {
44 match header {
45 PseudoHeader::Method
46 | PseudoHeader::Scheme
47 | PseudoHeader::Authority
48 | PseudoHeader::Path
49 | PseudoHeader::Protocol => (), PseudoHeader::Status => {
51 w.write_all(
52 format!(
53 "[{}: {} {}]\r\n",
54 header,
55 parts.status.as_u16(),
56 parts
57 .status
58 .canonical_reason()
59 .map(|r| format!(" {}", r))
60 .unwrap_or_default(),
61 )
62 .as_bytes(),
63 )
64 .await?;
65 }
66 }
67 }
68 }
69
70 let header_map = Http1HeaderMap::new(parts.headers, Some(&mut parts.extensions));
71 parts.headers = header_map.clone().consume(&mut parts.extensions);
73
74 for (name, value) in header_map {
75 match parts.version {
76 http::Version::HTTP_2 | http::Version::HTTP_3 => {
77 w.write_all(
79 format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?)
80 .as_bytes(),
81 )
82 .await?;
83 }
84 _ => {
85 w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
86 .await?;
87 }
88 }
89 }
90 }
91
92 let body = if write_body {
93 let body = body.collect().await.map_err(Into::into)?.to_bytes();
94 w.write_all(b"\r\n").await?;
95 if !body.is_empty() {
96 w.write_all(body.as_ref()).await?;
97 }
98 Body::from(body)
99 } else {
100 Body::new(body)
101 };
102
103 let req = Response::from_parts(parts, body);
104 Ok(req)
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[tokio::test]
112 async fn test_write_response_ok() {
113 let mut buf = Vec::new();
114 let res = Response::builder().status(200).body(Body::empty()).unwrap();
115
116 write_http_response(&mut buf, res, true, true)
117 .await
118 .unwrap();
119
120 let res = String::from_utf8(buf).unwrap();
121 assert_eq!(res, "HTTP/1.1 200 OK\r\n\r\n");
122 }
123
124 #[tokio::test]
125 async fn test_write_response_redirect() {
126 let mut buf = Vec::new();
127 let res = Response::builder()
128 .status(301)
129 .header("location", "http://example.com")
130 .header("server", "test/0")
131 .body(Body::empty())
132 .unwrap();
133
134 write_http_response(&mut buf, res, true, true)
135 .await
136 .unwrap();
137
138 let res = String::from_utf8(buf).unwrap();
139 assert_eq!(
140 res,
141 "HTTP/1.1 301 Moved Permanently\r\nlocation: http://example.com\r\nserver: test/0\r\n\r\n"
142 );
143 }
144
145 #[tokio::test]
146 async fn test_write_response_with_headers_and_body() {
147 let mut buf = Vec::new();
148 let res = Response::builder()
149 .status(200)
150 .header("content-type", "text/plain")
151 .header("server", "test/0")
152 .body(Body::from("hello"))
153 .unwrap();
154
155 write_http_response(&mut buf, res, true, true)
156 .await
157 .unwrap();
158
159 let res = String::from_utf8(buf).unwrap();
160 assert_eq!(
161 res,
162 "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello"
163 );
164 }
165}