async_http1_lite/
encoder.rs

1use core::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    time::Duration,
5};
6use std::io::{Error as IoError, ErrorKind as IoErrorKind};
7
8use async_sleep::{rw::AsyncWriteWithTimeoutExt as _, Sleepble};
9use async_trait::async_trait;
10use futures_io::AsyncWrite;
11use http::{
12    header::{CONTENT_LENGTH, TRANSFER_ENCODING},
13    request::Parts as RequestParts,
14    response::Parts as ResponseParts,
15    HeaderMap, HeaderValue, Request, Response, Version,
16};
17use http1_spec::{
18    body_framing::BodyFraming,
19    head_renderer::{Head, HeadRenderer},
20    request_head_renderer::RequestHeadRenderer,
21    response_head_renderer::ResponseHeadRenderer,
22    ReasonPhrase, CHUNKED,
23};
24
25use crate::{body::EncoderBody, stream::Http1StreamEncoder};
26
27//
28//
29//
30pub struct Http1Encoder<H, HR>
31where
32    H: Head,
33    HR: HeadRenderer<H>,
34{
35    head_renderer: HR,
36    buf: Vec<u8>,
37    write_timeout: Duration,
38    state: State,
39    phantom: PhantomData<H>,
40}
41#[derive(Debug, PartialEq, Eq)]
42enum State {
43    Idle,
44    WriteBody(BodyFraming),
45}
46impl Default for State {
47    fn default() -> Self {
48        Self::Idle
49    }
50}
51impl<H, HR> Http1Encoder<H, HR>
52where
53    H: Head,
54    HR: HeadRenderer<H>,
55{
56    //
57    fn new(buf_capacity: usize) -> Self {
58        Self {
59            head_renderer: HR::new(),
60            buf: Vec::with_capacity(buf_capacity),
61            write_timeout: Duration::from_secs(5),
62            state: Default::default(),
63            phantom: PhantomData,
64        }
65    }
66
67    //
68    fn set_write_timeout(&mut self, dur: Duration) {
69        self.write_timeout = dur;
70    }
71
72    //
73    fn update_headers(
74        &self,
75        headers: &mut HeaderMap<HeaderValue>,
76        version: &Version,
77        body_framing: &BodyFraming,
78    ) -> Result<(), IoError> {
79        match body_framing {
80            BodyFraming::Neither => {
81                headers.remove(CONTENT_LENGTH);
82                headers.remove(TRANSFER_ENCODING);
83            }
84            BodyFraming::ContentLength(n) => {
85                if n == &0 {
86                    headers.remove(CONTENT_LENGTH);
87                    headers.remove(TRANSFER_ENCODING);
88                } else {
89                    headers.insert(
90                        CONTENT_LENGTH,
91                        HeaderValue::from_str(&format!("{n}"))
92                            .map_err(|err| IoError::new(IoErrorKind::Other, err))?,
93                    );
94                    if version == &Version::HTTP_11 {
95                        if let Some(header_value) = headers.get(&TRANSFER_ENCODING) {
96                            if header_value == CHUNKED {
97                                headers.remove(TRANSFER_ENCODING);
98                            }
99                        }
100                    }
101                }
102            }
103            BodyFraming::Chunked => {
104                if version != &Version::HTTP_11 {
105                    return Err(IoError::new(IoErrorKind::InvalidInput, "unimplemented now"));
106                }
107                headers.remove(CONTENT_LENGTH);
108                headers.insert(
109                    TRANSFER_ENCODING,
110                    HeaderValue::from_str(CHUNKED)
111                        .map_err(|err| IoError::new(IoErrorKind::Other, err))?,
112                );
113            }
114        }
115
116        Ok(())
117    }
118
119    fn encode_head(&mut self, head: H) -> Result<(), IoError> {
120        self.head_renderer.render(head, &mut self.buf)
121    }
122
123    async fn write_head0<S: AsyncWrite + Unpin, SLEEP: Sleepble>(
124        &self,
125        stream: &mut S,
126    ) -> Result<(), IoError> {
127        let mut n_write = 0;
128        while !self.buf[n_write..].is_empty() {
129            let n = stream
130                .write_with_timeout::<SLEEP>(&self.buf[n_write..], self.write_timeout)
131                .await?;
132            n_write += n;
133
134            if n == 0 {
135                return Err(IoErrorKind::WriteZero.into());
136            }
137        }
138        Ok(())
139    }
140
141    async fn write_body0<S: AsyncWrite + Unpin, SLEEP: Sleepble>(
142        &mut self,
143        stream: &mut S,
144        body: EncoderBody,
145    ) -> Result<(), IoError> {
146        match &mut self.state {
147            State::Idle => {
148                return Err(IoError::new(
149                    IoErrorKind::Other,
150                    "state should is WriteBody",
151                ));
152            }
153            State::WriteBody(body_framing) => match body_framing.clone() {
154                BodyFraming::Neither => {}
155                BodyFraming::ContentLength(content_length) => {
156                    if content_length == 0 {
157                        return Ok(());
158                    }
159
160                    let bytes = match &body {
161                        EncoderBody::Completed(bytes) => {
162                            if bytes.len() != content_length {
163                                return Err(IoError::new(
164                                    IoErrorKind::InvalidInput,
165                                    "bytes len mismatch",
166                                ));
167                            }
168                            bytes
169                        }
170                        EncoderBody::Partial(bytes) => {
171                            if bytes.len() >= content_length {
172                                return Err(IoError::new(
173                                    IoErrorKind::InvalidInput,
174                                    "bytes len mismatch",
175                                ));
176                            }
177                            bytes
178                        }
179                    };
180                    let bytes_len = bytes.len();
181
182                    let mut n_write = 0;
183                    while !bytes[n_write..].is_empty() {
184                        let n = stream
185                            .write_with_timeout::<SLEEP>(&bytes[n_write..], self.write_timeout)
186                            .await?;
187                        n_write += n;
188
189                        if n == 0 {
190                            return Err(IoErrorKind::WriteZero.into());
191                        }
192                    }
193
194                    match &body {
195                        EncoderBody::Completed(_) => {
196                            self.state = State::Idle;
197                        }
198                        EncoderBody::Partial(_) => {
199                            body_framing.update_content_length_value(content_length - bytes_len)?;
200                        }
201                    };
202                }
203                BodyFraming::Chunked => {
204                    return Err(IoError::new(IoErrorKind::InvalidInput, "unimplemented now"))
205                }
206            },
207        }
208
209        Ok(())
210    }
211}
212
213//
214//
215//
216pub type Http1RequestEncoderInner = Http1Encoder<RequestParts, RequestHeadRenderer>;
217pub struct Http1RequestEncoder {
218    inner: Http1RequestEncoderInner,
219}
220impl Deref for Http1RequestEncoder {
221    type Target = Http1RequestEncoderInner;
222
223    fn deref(&self) -> &Http1RequestEncoderInner {
224        &self.inner
225    }
226}
227impl DerefMut for Http1RequestEncoder {
228    fn deref_mut(&mut self) -> &mut Http1RequestEncoderInner {
229        &mut self.inner
230    }
231}
232impl Http1RequestEncoder {
233    pub fn new(buf_capacity: usize) -> Self {
234        Self {
235            inner: Http1RequestEncoderInner::new(buf_capacity),
236        }
237    }
238}
239
240#[async_trait]
241impl<S, SLEEP> Http1StreamEncoder<S, SLEEP, Request<()>> for Http1RequestEncoder
242where
243    S: AsyncWrite + Unpin + Send,
244    SLEEP: Sleepble,
245{
246    async fn write_head(
247        &mut self,
248        stream: &mut S,
249        head: Request<()>,
250        body_framing: BodyFraming,
251    ) -> Result<(), IoError> {
252        if self.state != State::Idle {
253            return Err(IoError::new(IoErrorKind::Other, "state should is Idle"));
254        }
255
256        self.buf.clear();
257
258        let (mut parts, _) = head.into_parts();
259
260        self.update_headers(&mut parts.headers, &parts.version, &body_framing)?;
261
262        self.encode_head(parts)?;
263
264        self.write_head0::<_, SLEEP>(stream).await?;
265
266        match body_framing {
267            BodyFraming::Neither => {
268                self.state = State::Idle;
269            }
270            BodyFraming::ContentLength(n) if n == 0 => {
271                self.state = State::Idle;
272            }
273            _ => {
274                self.state = State::WriteBody(body_framing);
275            }
276        }
277
278        Ok(())
279    }
280    async fn write_body(&mut self, stream: &mut S, body: EncoderBody) -> Result<(), IoError> {
281        self.write_body0::<_, SLEEP>(stream, body).await
282    }
283
284    fn set_write_timeout(&mut self, dur: Duration) {
285        self.inner.set_write_timeout(dur)
286    }
287}
288
289//
290//
291//
292pub type Http1ResponseEncoderInner =
293    Http1Encoder<(ResponseParts, ReasonPhrase), ResponseHeadRenderer>;
294pub struct Http1ResponseEncoder {
295    inner: Http1ResponseEncoderInner,
296}
297impl Deref for Http1ResponseEncoder {
298    type Target = Http1ResponseEncoderInner;
299
300    fn deref(&self) -> &Http1ResponseEncoderInner {
301        &self.inner
302    }
303}
304impl DerefMut for Http1ResponseEncoder {
305    fn deref_mut(&mut self) -> &mut Http1ResponseEncoderInner {
306        &mut self.inner
307    }
308}
309impl Http1ResponseEncoder {
310    pub fn new(buf_capacity: usize) -> Self {
311        Self {
312            inner: Http1ResponseEncoderInner::new(buf_capacity),
313        }
314    }
315}
316
317#[async_trait]
318impl<S, SLEEP> Http1StreamEncoder<S, SLEEP, (Response<()>, ReasonPhrase)> for Http1ResponseEncoder
319where
320    S: AsyncWrite + Unpin + Send,
321    SLEEP: Sleepble,
322{
323    async fn write_head(
324        &mut self,
325        stream: &mut S,
326        head: (Response<()>, ReasonPhrase),
327        body_framing: BodyFraming,
328    ) -> Result<(), IoError> {
329        if self.state != State::Idle {
330            return Err(IoError::new(IoErrorKind::Other, "state should is Idle"));
331        }
332
333        self.buf.clear();
334
335        let (head, reason_phrase) = head;
336        let (mut parts, _) = head.into_parts();
337
338        self.update_headers(&mut parts.headers, &parts.version, &body_framing)?;
339
340        self.encode_head((parts, reason_phrase))?;
341
342        self.write_head0::<_, SLEEP>(stream).await?;
343
344        match body_framing {
345            BodyFraming::Neither => {
346                self.state = State::Idle;
347            }
348            BodyFraming::ContentLength(n) if n == 0 => {
349                self.state = State::Idle;
350            }
351            _ => {
352                self.state = State::WriteBody(body_framing);
353            }
354        }
355
356        Ok(())
357    }
358    async fn write_body(&mut self, stream: &mut S, body: EncoderBody) -> Result<(), IoError> {
359        self.write_body0::<_, SLEEP>(stream, body).await
360    }
361
362    fn set_write_timeout(&mut self, dur: Duration) {
363        self.inner.set_write_timeout(dur)
364    }
365}