rama_http/io/
response.rs

1use crate::{
2    dep::{http_body, http_body_util::BodyExt},
3    Body, Response,
4};
5use bytes::Bytes;
6use rama_core::error::BoxError;
7use rama_http_types::proto::{
8    h1::Http1HeaderMap,
9    h2::{PseudoHeader, PseudoHeaderOrder},
10};
11use tokio::io::{AsyncWrite, AsyncWriteExt};
12
13/// Write an HTTP response to a writer in std http format.
14pub async fn write_http_response<W, B>(
15    w: &mut W,
16    res: Response<B>,
17    write_headers: bool,
18    write_body: bool,
19) -> Result<Response, BoxError>
20where
21    W: AsyncWrite + Unpin + Send + Sync + 'static,
22    B: http_body::Body<Data = Bytes, Error: Into<BoxError>> + Send + Sync + 'static,
23{
24    let (mut parts, body) = res.into_parts();
25
26    if write_headers {
27        w.write_all(
28            format!(
29                "{:?} {}{}\r\n",
30                parts.version,
31                parts.status.as_u16(),
32                parts
33                    .status
34                    .canonical_reason()
35                    .map(|r| format!(" {}", r))
36                    .unwrap_or_default(),
37            )
38            .as_bytes(),
39        )
40        .await?;
41
42        if let Some(pseudo_headers) = parts.extensions.get::<PseudoHeaderOrder>() {
43            for header in pseudo_headers.iter() {
44                match header {
45                    PseudoHeader::Method
46                    | PseudoHeader::Scheme
47                    | PseudoHeader::Authority
48                    | PseudoHeader::Path
49                    | PseudoHeader::Protocol => (), // not expected in response
50                    PseudoHeader::Status => {
51                        w.write_all(
52                            format!(
53                                "[{}: {} {}]\r\n",
54                                header,
55                                parts.status.as_u16(),
56                                parts
57                                    .status
58                                    .canonical_reason()
59                                    .map(|r| format!(" {}", r))
60                                    .unwrap_or_default(),
61                            )
62                            .as_bytes(),
63                        )
64                        .await?;
65                    }
66                }
67            }
68        }
69
70        let header_map = Http1HeaderMap::new(parts.headers, Some(&mut parts.extensions));
71        // put a clone of this data back into parts as we don't really want to consume it, just trace it
72        parts.headers = header_map.clone().consume(&mut parts.extensions);
73
74        for (name, value) in header_map {
75            match parts.version {
76                http::Version::HTTP_2 | http::Version::HTTP_3 => {
77                    // write lower-case for H2/H3
78                    w.write_all(
79                        format!("{}: {}\r\n", name.header_name().as_str(), value.to_str()?)
80                            .as_bytes(),
81                    )
82                    .await?;
83                }
84                _ => {
85                    w.write_all(format!("{}: {}\r\n", name, value.to_str()?).as_bytes())
86                        .await?;
87                }
88            }
89        }
90    }
91
92    let body = if write_body {
93        let body = body.collect().await.map_err(Into::into)?.to_bytes();
94        w.write_all(b"\r\n").await?;
95        if !body.is_empty() {
96            w.write_all(body.as_ref()).await?;
97        }
98        Body::from(body)
99    } else {
100        Body::new(body)
101    };
102
103    let req = Response::from_parts(parts, body);
104    Ok(req)
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[tokio::test]
112    async fn test_write_response_ok() {
113        let mut buf = Vec::new();
114        let res = Response::builder().status(200).body(Body::empty()).unwrap();
115
116        write_http_response(&mut buf, res, true, true)
117            .await
118            .unwrap();
119
120        let res = String::from_utf8(buf).unwrap();
121        assert_eq!(res, "HTTP/1.1 200 OK\r\n\r\n");
122    }
123
124    #[tokio::test]
125    async fn test_write_response_redirect() {
126        let mut buf = Vec::new();
127        let res = Response::builder()
128            .status(301)
129            .header("location", "http://example.com")
130            .header("server", "test/0")
131            .body(Body::empty())
132            .unwrap();
133
134        write_http_response(&mut buf, res, true, true)
135            .await
136            .unwrap();
137
138        let res = String::from_utf8(buf).unwrap();
139        assert_eq!(
140            res,
141            "HTTP/1.1 301 Moved Permanently\r\nlocation: http://example.com\r\nserver: test/0\r\n\r\n"
142        );
143    }
144
145    #[tokio::test]
146    async fn test_write_response_with_headers_and_body() {
147        let mut buf = Vec::new();
148        let res = Response::builder()
149            .status(200)
150            .header("content-type", "text/plain")
151            .header("server", "test/0")
152            .body(Body::from("hello"))
153            .unwrap();
154
155        write_http_response(&mut buf, res, true, true)
156            .await
157            .unwrap();
158
159        let res = String::from_utf8(buf).unwrap();
160        assert_eq!(
161            res,
162            "HTTP/1.1 200 OK\r\ncontent-type: text/plain\r\nserver: test/0\r\n\r\nhello"
163        );
164    }
165}