actori_http/h1/
codec.rs

1use std::{fmt, io};
2
3use actori_codec::{Decoder, Encoder};
4use bitflags::bitflags;
5use bytes::BytesMut;
6use http::{Method, Version};
7
8use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
9use super::{decoder, encoder};
10use super::{Message, MessageType};
11use crate::body::BodySize;
12use crate::config::ServiceConfig;
13use crate::error::ParseError;
14use crate::message::ConnectionType;
15use crate::request::Request;
16use crate::response::Response;
17
18bitflags! {
19    struct Flags: u8 {
20        const HEAD              = 0b0000_0001;
21        const KEEPALIVE_ENABLED = 0b0000_0010;
22        const STREAM            = 0b0000_0100;
23    }
24}
25
26/// HTTP/1 Codec
27pub struct Codec {
28    config: ServiceConfig,
29    decoder: decoder::MessageDecoder<Request>,
30    payload: Option<PayloadDecoder>,
31    version: Version,
32    ctype: ConnectionType,
33
34    // encoder part
35    flags: Flags,
36    encoder: encoder::MessageEncoder<Response<()>>,
37}
38
39impl Default for Codec {
40    fn default() -> Self {
41        Codec::new(ServiceConfig::default())
42    }
43}
44
45impl fmt::Debug for Codec {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        write!(f, "h1::Codec({:?})", self.flags)
48    }
49}
50
51impl Codec {
52    /// Create HTTP/1 codec.
53    ///
54    /// `keepalive_enabled` how response `connection` header get generated.
55    pub fn new(config: ServiceConfig) -> Self {
56        let flags = if config.keep_alive_enabled() {
57            Flags::KEEPALIVE_ENABLED
58        } else {
59            Flags::empty()
60        };
61        Codec {
62            config,
63            flags,
64            decoder: decoder::MessageDecoder::default(),
65            payload: None,
66            version: Version::HTTP_11,
67            ctype: ConnectionType::Close,
68            encoder: encoder::MessageEncoder::default(),
69        }
70    }
71
72    #[inline]
73    /// Check if request is upgrade
74    pub fn upgrade(&self) -> bool {
75        self.ctype == ConnectionType::Upgrade
76    }
77
78    #[inline]
79    /// Check if last response is keep-alive
80    pub fn keepalive(&self) -> bool {
81        self.ctype == ConnectionType::KeepAlive
82    }
83
84    #[inline]
85    /// Check if keep-alive enabled on server level
86    pub fn keepalive_enabled(&self) -> bool {
87        self.flags.contains(Flags::KEEPALIVE_ENABLED)
88    }
89
90    #[inline]
91    /// Check last request's message type
92    pub fn message_type(&self) -> MessageType {
93        if self.flags.contains(Flags::STREAM) {
94            MessageType::Stream
95        } else if self.payload.is_none() {
96            MessageType::None
97        } else {
98            MessageType::Payload
99        }
100    }
101
102    #[inline]
103    pub fn config(&self) -> &ServiceConfig {
104        &self.config
105    }
106}
107
108impl Decoder for Codec {
109    type Item = Message<Request>;
110    type Error = ParseError;
111
112    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
113        if self.payload.is_some() {
114            Ok(match self.payload.as_mut().unwrap().decode(src)? {
115                Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
116                Some(PayloadItem::Eof) => {
117                    self.payload.take();
118                    Some(Message::Chunk(None))
119                }
120                None => None,
121            })
122        } else if let Some((req, payload)) = self.decoder.decode(src)? {
123            let head = req.head();
124            self.flags.set(Flags::HEAD, head.method == Method::HEAD);
125            self.version = head.version;
126            self.ctype = head.connection_type();
127            if self.ctype == ConnectionType::KeepAlive
128                && !self.flags.contains(Flags::KEEPALIVE_ENABLED)
129            {
130                self.ctype = ConnectionType::Close
131            }
132            match payload {
133                PayloadType::None => self.payload = None,
134                PayloadType::Payload(pl) => self.payload = Some(pl),
135                PayloadType::Stream(pl) => {
136                    self.payload = Some(pl);
137                    self.flags.insert(Flags::STREAM);
138                }
139            }
140            Ok(Some(Message::Item(req)))
141        } else {
142            Ok(None)
143        }
144    }
145}
146
147impl Encoder for Codec {
148    type Item = Message<(Response<()>, BodySize)>;
149    type Error = io::Error;
150
151    fn encode(
152        &mut self,
153        item: Self::Item,
154        dst: &mut BytesMut,
155    ) -> Result<(), Self::Error> {
156        match item {
157            Message::Item((mut res, length)) => {
158                // set response version
159                res.head_mut().version = self.version;
160
161                // connection status
162                self.ctype = if let Some(ct) = res.head().ctype() {
163                    if ct == ConnectionType::KeepAlive {
164                        self.ctype
165                    } else {
166                        ct
167                    }
168                } else {
169                    self.ctype
170                };
171
172                // encode message
173                self.encoder.encode(
174                    dst,
175                    &mut res,
176                    self.flags.contains(Flags::HEAD),
177                    self.flags.contains(Flags::STREAM),
178                    self.version,
179                    length,
180                    self.ctype,
181                    &self.config,
182                )?;
183                // self.headers_size = (dst.len() - len) as u32;
184            }
185            Message::Chunk(Some(bytes)) => {
186                self.encoder.encode_chunk(bytes.as_ref(), dst)?;
187            }
188            Message::Chunk(None) => {
189                self.encoder.encode_eof(dst)?;
190            }
191        }
192        Ok(())
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use bytes::BytesMut;
199    use http::Method;
200
201    use super::*;
202    use crate::httpmessage::HttpMessage;
203
204    #[test]
205    fn test_http_request_chunked_payload_and_next_message() {
206        let mut codec = Codec::default();
207
208        let mut buf = BytesMut::from(
209            "GET /test HTTP/1.1\r\n\
210             transfer-encoding: chunked\r\n\r\n",
211        );
212        let item = codec.decode(&mut buf).unwrap().unwrap();
213        let req = item.message();
214
215        assert_eq!(req.method(), Method::GET);
216        assert!(req.chunked().unwrap());
217
218        buf.extend(
219            b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\
220               POST /test2 HTTP/1.1\r\n\
221               transfer-encoding: chunked\r\n\r\n"
222                .iter(),
223        );
224
225        let msg = codec.decode(&mut buf).unwrap().unwrap();
226        assert_eq!(msg.chunk().as_ref(), b"data");
227
228        let msg = codec.decode(&mut buf).unwrap().unwrap();
229        assert_eq!(msg.chunk().as_ref(), b"line");
230
231        let msg = codec.decode(&mut buf).unwrap().unwrap();
232        assert!(msg.eof());
233
234        // decode next message
235        let item = codec.decode(&mut buf).unwrap().unwrap();
236        let req = item.message();
237        assert_eq!(*req.method(), Method::POST);
238        assert!(req.chunked().unwrap());
239    }
240}