1use std::{
2 future::Future,
3 io::{Error, ErrorKind, Result},
4};
5
6use futures::{AsyncWrite, AsyncWriteExt, TryStreamExt};
7use http::{
8 header::{InvalidHeaderValue, ToStrError, CONTENT_LENGTH, HOST, TRANSFER_ENCODING},
9 HeaderValue, Request, Response,
10};
11
12use crate::body::BodyReader;
13
14fn map_to_str_error(err: ToStrError) -> Error {
15 Error::new(ErrorKind::InvalidData, err)
16}
17
18fn map_invalid_header_value_error(err: InvalidHeaderValue) -> Error {
19 Error::new(ErrorKind::InvalidData, err)
20}
21
22pub trait HttpWriter: AsyncWrite + Unpin {
23 fn write_request(&mut self, request: Request<BodyReader>) -> impl Future<Output = Result<()>> {
24 async move {
25 let (mut parts, mut body) = request.into_parts();
26
27 self.write_all(
28 format!(
29 "{} {} {:?}\r\n",
30 parts.method,
31 parts.uri.path(),
32 parts.version
33 )
34 .as_bytes(),
35 )
36 .await?;
37
38 if parts.headers.get(HOST).is_none() {
39 if let Some(host) = parts.uri.host() {
40 parts.headers.insert(
41 HOST,
42 HeaderValue::from_str(host).map_err(map_invalid_header_value_error)?,
43 );
44 }
45 }
46
47 if parts.headers.get(CONTENT_LENGTH).is_none()
48 && parts.headers.get(TRANSFER_ENCODING).is_none()
49 {
50 if let Some(len) = body.len() {
51 parts.headers.insert(CONTENT_LENGTH, len.into());
52 } else {
53 parts
54 .headers
55 .insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
56 }
57 }
58
59 for (name, value) in &parts.headers {
60 self.write_all(
61 format!(
62 "{}: {}\r\n",
63 name,
64 value.to_str().map_err(map_to_str_error)?
65 )
66 .as_bytes(),
67 )
68 .await?;
69 }
70
71 self.write_all(b"\r\n").await?;
72
73 if let Some(len) = body.len() {
74 let body = body
75 .try_next()
76 .await?
77 .expect("Ready fixed length body error");
78
79 assert_eq!(body.len(), len);
80
81 self.write_all(&body).await?;
82 } else {
83 while let Some(chunk) = body.try_next().await? {
84 self.write_all(format!("{:x}\r\n", chunk.len()).as_bytes())
85 .await?;
86
87 self.write_all(&chunk).await?;
88 }
89
90 self.write_all(b"0\r\n").await?;
91 }
92
93 Ok(())
94 }
95 }
96
97 fn write_response(
98 &mut self,
99 response: Response<BodyReader>,
100 ) -> impl Future<Output = Result<()>> {
101 async move {
102 let (parts, mut body) = response.into_parts();
103
104 self.write_all(format!("{:?} {}\r\n", parts.version, parts.status).as_bytes())
106 .await?;
107
108 for (name, value) in &parts.headers {
109 self.write_all(
110 format!(
111 "{}: {}\r\n",
112 name,
113 value.to_str().map_err(map_to_str_error)?
114 )
115 .as_bytes(),
116 )
117 .await?;
118 }
119
120 if let Some(len) = body.len() {
121 self.write_all(format!("{}: {}\r\n", CONTENT_LENGTH, len).as_bytes())
122 .await?;
123
124 self.write_all(b"\r\n").await?;
125
126 let body = body
127 .try_next()
128 .await?
129 .expect("Ready fixed length body error");
130
131 self.write_all(&body).await?;
132 } else {
133 self.write_all(format!("{}: chunked\r\n", TRANSFER_ENCODING).as_bytes())
134 .await?;
135
136 self.write_all(b"\r\n").await?;
137
138 while let Some(chunk) = body.try_next().await? {
139 self.write_all(format!("{:x}\r\n", chunk.len()).as_bytes())
140 .await?;
141
142 self.write_all(&chunk).await?;
143 }
144
145 self.write_all(b"0\r\n\r\n").await?;
146 }
147
148 Ok(())
149 }
150 }
151}
152
153impl<T: AsyncWrite + Unpin> HttpWriter for T {}