futures_http/
body.rs

1use std::{fmt::Debug, task::Poll};
2
3use futures::{
4    io::{BufReader, Lines},
5    stream::{once, BoxStream},
6    AsyncBufReadExt, AsyncRead, AsyncReadExt, Stream, StreamExt,
7};
8use http::{
9    header::{CONTENT_LENGTH, TRANSFER_ENCODING},
10    HeaderMap,
11};
12
13#[derive(Debug, thiserror::Error)]
14pub enum BodyReaderError {
15    #[error("Parse CONTENT_LENGTH header with error: {0}")]
16    ParseContentLength(String),
17
18    #[error("Parse TRANSFER_ENCODING header with error: {0}")]
19    ParseTransferEncoding(String),
20
21    #[error("CONTENT_LENGTH or TRANSFER_ENCODING not found.")]
22    UnsporTransferEncoding,
23
24    #[error(transparent)]
25    Io(#[from] std::io::Error),
26}
27
28pub type BodyReaderResult<T> = Result<T, BodyReaderError>;
29
30/// The sender to send http body data to peer.
31pub struct BodyReader {
32    length: Option<usize>,
33    stream: BoxStream<'static, std::io::Result<Vec<u8>>>,
34}
35
36impl Debug for BodyReader {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "BodyReader, length={:?}", self.length)
39    }
40}
41
42impl From<Vec<u8>> for BodyReader {
43    fn from(value: Vec<u8>) -> Self {
44        Self {
45            length: Some(value.len()),
46            stream: Box::pin(once(async move { Ok(value) })),
47        }
48    }
49}
50
51impl From<&[u8]> for BodyReader {
52    fn from(value: &[u8]) -> Self {
53        value.to_owned().into()
54    }
55}
56
57impl From<&str> for BodyReader {
58    fn from(value: &str) -> Self {
59        value.as_bytes().into()
60    }
61}
62
63impl From<String> for BodyReader {
64    fn from(value: String) -> Self {
65        value.as_bytes().into()
66    }
67}
68
69impl BodyReader {
70    pub fn empty() -> Self {
71        BodyReader::from(vec![])
72    }
73    /// Create a new `BodySender` instance from `stream`
74    pub fn from_stream<S>(stream: S) -> Self
75    where
76        S: Stream<Item = std::io::Result<Vec<u8>>> + Send + Unpin + 'static,
77    {
78        Self {
79            length: None,
80            stream: Box::pin(stream),
81        }
82    }
83
84    /// Return true if the underlying data is a stream
85    pub fn len(&self) -> Option<usize> {
86        self.length
87    }
88
89    /// Parse headers and generate property `BodyReader`.
90    pub async fn parse<R>(headers: &HeaderMap, mut read: R) -> BodyReaderResult<Self>
91    where
92        R: AsyncRead + Unpin + Send + 'static,
93    {
94        // TRANSFER_ENCODING has higher priority
95        if let Some(transfer_encoding) = headers.get(TRANSFER_ENCODING) {
96            let transfer_encoding = transfer_encoding
97                .to_str()
98                .map_err(|err| BodyReaderError::ParseTransferEncoding(err.to_string()))?;
99
100            if transfer_encoding != "chunked" {
101                return Err(BodyReaderError::ParseTransferEncoding(format!(
102                    "Unsupport TRANSFER_ENCODING: {}",
103                    transfer_encoding
104                )));
105            }
106
107            return Ok(Self::from_stream(ChunkedBodyStream::from(read)));
108        }
109
110        if let Some(content_length) = headers.get(CONTENT_LENGTH) {
111            let content_length = content_length
112                .to_str()
113                .map_err(|err| BodyReaderError::ParseContentLength(err.to_string()))?;
114
115            let content_length = usize::from_str_radix(content_length, 10)
116                .map_err(|err| BodyReaderError::ParseContentLength(err.to_string()))?;
117
118            let mut buf = vec![0u8; content_length];
119
120            read.read_exact(&mut buf).await?;
121
122            return Ok(buf.into());
123        }
124
125        Ok(Self::from(vec![]))
126    }
127}
128
129impl Stream for BodyReader {
130    type Item = std::io::Result<Vec<u8>>;
131
132    fn poll_next(
133        mut self: std::pin::Pin<&mut Self>,
134        cx: &mut std::task::Context<'_>,
135    ) -> std::task::Poll<Option<Self::Item>> {
136        self.stream.poll_next_unpin(cx)
137    }
138}
139
140struct ChunkedBodyStream<R> {
141    lines: Lines<BufReader<R>>,
142    chunk_len: Option<usize>,
143}
144
145impl<R> From<R> for ChunkedBodyStream<R>
146where
147    R: AsyncRead + Unpin,
148{
149    fn from(value: R) -> Self {
150        Self {
151            lines: BufReader::new(value).lines(),
152            chunk_len: None,
153        }
154    }
155}
156
157impl<R> Stream for ChunkedBodyStream<R>
158where
159    R: AsyncRead + Unpin,
160{
161    type Item = std::io::Result<Vec<u8>>;
162
163    fn poll_next(
164        mut self: std::pin::Pin<&mut Self>,
165        cx: &mut std::task::Context<'_>,
166    ) -> std::task::Poll<Option<Self::Item>> {
167        loop {
168            if let Some(mut len) = self.chunk_len {
169                match self.lines.poll_next_unpin(cx) {
170                    Poll::Ready(Some(Ok(buf))) => {
171                        if buf.len() > len {
172                            return Poll::Ready(Some(Err(std::io::Error::new(
173                                std::io::ErrorKind::InvalidData,
174                                "chunck data overflow",
175                            ))));
176                        }
177
178                        len -= buf.len();
179
180                        if len == 0 {
181                            self.chunk_len.take();
182                        } else {
183                            self.chunk_len = Some(len);
184                        }
185
186                        return Poll::Ready(Some(Ok(buf.into_bytes())));
187                    }
188                    poll => return poll.map_ok(|s| s.into_bytes()),
189                }
190            } else {
191                match self.lines.poll_next_unpin(cx) {
192                    Poll::Ready(Some(Ok(line))) => match usize::from_str_radix(&line, 16) {
193                        Ok(len) => {
194                            // body last chunk.
195                            if len == 0 {
196                                return Poll::Ready(None);
197                            }
198
199                            self.chunk_len = Some(len);
200                            continue;
201                        }
202                        Err(err) => {
203                            return Poll::Ready(Some(Err(std::io::Error::new(
204                                std::io::ErrorKind::InvalidData,
205                                format!("Parse chunck length with error: {}", err),
206                            ))))
207                        }
208                    },
209                    poll => return poll.map_ok(|s| s.into_bytes()),
210                }
211            }
212        }
213    }
214}