1use std::{fmt::Debug, task::Poll};
2
3use futures::{
4 io::{BufReader, Lines},
5 stream::{once, BoxStream},
6 AsyncBufReadExt, AsyncRead, AsyncReadExt, Stream, StreamExt,
7};
8use http::{
9 header::{CONTENT_LENGTH, TRANSFER_ENCODING},
10 HeaderMap,
11};
12
13#[derive(Debug, thiserror::Error)]
14pub enum BodyReaderError {
15 #[error("Parse CONTENT_LENGTH header with error: {0}")]
16 ParseContentLength(String),
17
18 #[error("Parse TRANSFER_ENCODING header with error: {0}")]
19 ParseTransferEncoding(String),
20
21 #[error("CONTENT_LENGTH or TRANSFER_ENCODING not found.")]
22 UnsporTransferEncoding,
23
24 #[error(transparent)]
25 Io(#[from] std::io::Error),
26}
27
28pub type BodyReaderResult<T> = Result<T, BodyReaderError>;
29
30pub struct BodyReader {
32 length: Option<usize>,
33 stream: BoxStream<'static, std::io::Result<Vec<u8>>>,
34}
35
36impl Debug for BodyReader {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 write!(f, "BodyReader, length={:?}", self.length)
39 }
40}
41
42impl From<Vec<u8>> for BodyReader {
43 fn from(value: Vec<u8>) -> Self {
44 Self {
45 length: Some(value.len()),
46 stream: Box::pin(once(async move { Ok(value) })),
47 }
48 }
49}
50
51impl From<&[u8]> for BodyReader {
52 fn from(value: &[u8]) -> Self {
53 value.to_owned().into()
54 }
55}
56
57impl From<&str> for BodyReader {
58 fn from(value: &str) -> Self {
59 value.as_bytes().into()
60 }
61}
62
63impl From<String> for BodyReader {
64 fn from(value: String) -> Self {
65 value.as_bytes().into()
66 }
67}
68
69impl BodyReader {
70 pub fn empty() -> Self {
71 BodyReader::from(vec![])
72 }
73 pub fn from_stream<S>(stream: S) -> Self
75 where
76 S: Stream<Item = std::io::Result<Vec<u8>>> + Send + Unpin + 'static,
77 {
78 Self {
79 length: None,
80 stream: Box::pin(stream),
81 }
82 }
83
84 pub fn len(&self) -> Option<usize> {
86 self.length
87 }
88
89 pub async fn parse<R>(headers: &HeaderMap, mut read: R) -> BodyReaderResult<Self>
91 where
92 R: AsyncRead + Unpin + Send + 'static,
93 {
94 if let Some(transfer_encoding) = headers.get(TRANSFER_ENCODING) {
96 let transfer_encoding = transfer_encoding
97 .to_str()
98 .map_err(|err| BodyReaderError::ParseTransferEncoding(err.to_string()))?;
99
100 if transfer_encoding != "chunked" {
101 return Err(BodyReaderError::ParseTransferEncoding(format!(
102 "Unsupport TRANSFER_ENCODING: {}",
103 transfer_encoding
104 )));
105 }
106
107 return Ok(Self::from_stream(ChunkedBodyStream::from(read)));
108 }
109
110 if let Some(content_length) = headers.get(CONTENT_LENGTH) {
111 let content_length = content_length
112 .to_str()
113 .map_err(|err| BodyReaderError::ParseContentLength(err.to_string()))?;
114
115 let content_length = usize::from_str_radix(content_length, 10)
116 .map_err(|err| BodyReaderError::ParseContentLength(err.to_string()))?;
117
118 let mut buf = vec![0u8; content_length];
119
120 read.read_exact(&mut buf).await?;
121
122 return Ok(buf.into());
123 }
124
125 Ok(Self::from(vec![]))
126 }
127}
128
129impl Stream for BodyReader {
130 type Item = std::io::Result<Vec<u8>>;
131
132 fn poll_next(
133 mut self: std::pin::Pin<&mut Self>,
134 cx: &mut std::task::Context<'_>,
135 ) -> std::task::Poll<Option<Self::Item>> {
136 self.stream.poll_next_unpin(cx)
137 }
138}
139
140struct ChunkedBodyStream<R> {
141 lines: Lines<BufReader<R>>,
142 chunk_len: Option<usize>,
143}
144
145impl<R> From<R> for ChunkedBodyStream<R>
146where
147 R: AsyncRead + Unpin,
148{
149 fn from(value: R) -> Self {
150 Self {
151 lines: BufReader::new(value).lines(),
152 chunk_len: None,
153 }
154 }
155}
156
157impl<R> Stream for ChunkedBodyStream<R>
158where
159 R: AsyncRead + Unpin,
160{
161 type Item = std::io::Result<Vec<u8>>;
162
163 fn poll_next(
164 mut self: std::pin::Pin<&mut Self>,
165 cx: &mut std::task::Context<'_>,
166 ) -> std::task::Poll<Option<Self::Item>> {
167 loop {
168 if let Some(mut len) = self.chunk_len {
169 match self.lines.poll_next_unpin(cx) {
170 Poll::Ready(Some(Ok(buf))) => {
171 if buf.len() > len {
172 return Poll::Ready(Some(Err(std::io::Error::new(
173 std::io::ErrorKind::InvalidData,
174 "chunck data overflow",
175 ))));
176 }
177
178 len -= buf.len();
179
180 if len == 0 {
181 self.chunk_len.take();
182 } else {
183 self.chunk_len = Some(len);
184 }
185
186 return Poll::Ready(Some(Ok(buf.into_bytes())));
187 }
188 poll => return poll.map_ok(|s| s.into_bytes()),
189 }
190 } else {
191 match self.lines.poll_next_unpin(cx) {
192 Poll::Ready(Some(Ok(line))) => match usize::from_str_radix(&line, 16) {
193 Ok(len) => {
194 if len == 0 {
196 return Poll::Ready(None);
197 }
198
199 self.chunk_len = Some(len);
200 continue;
201 }
202 Err(err) => {
203 return Poll::Ready(Some(Err(std::io::Error::new(
204 std::io::ErrorKind::InvalidData,
205 format!("Parse chunck length with error: {}", err),
206 ))))
207 }
208 },
209 poll => return poll.map_ok(|s| s.into_bytes()),
210 }
211 }
212 }
213 }
214}