actori_http/h1/
client.rs

1use std::io;
2
3use actori_codec::{Decoder, Encoder};
4use bitflags::bitflags;
5use bytes::{Bytes, BytesMut};
6use http::{Method, Version};
7
8use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
9use super::{decoder, encoder, reserve_readbuf};
10use super::{Message, MessageType};
11use crate::body::BodySize;
12use crate::config::ServiceConfig;
13use crate::error::{ParseError, PayloadError};
14use crate::message::{ConnectionType, RequestHeadType, ResponseHead};
15
16bitflags! {
17    struct Flags: u8 {
18        const HEAD              = 0b0000_0001;
19        const KEEPALIVE_ENABLED = 0b0000_1000;
20        const STREAM            = 0b0001_0000;
21    }
22}
23
24/// HTTP/1 Codec
25pub struct ClientCodec {
26    inner: ClientCodecInner,
27}
28
29/// HTTP/1 Payload Codec
30pub struct ClientPayloadCodec {
31    inner: ClientCodecInner,
32}
33
34struct ClientCodecInner {
35    config: ServiceConfig,
36    decoder: decoder::MessageDecoder<ResponseHead>,
37    payload: Option<PayloadDecoder>,
38    version: Version,
39    ctype: ConnectionType,
40
41    // encoder part
42    flags: Flags,
43    encoder: encoder::MessageEncoder<RequestHeadType>,
44}
45
46impl Default for ClientCodec {
47    fn default() -> Self {
48        ClientCodec::new(ServiceConfig::default())
49    }
50}
51
52impl ClientCodec {
53    /// Create HTTP/1 codec.
54    ///
55    /// `keepalive_enabled` how response `connection` header get generated.
56    pub fn new(config: ServiceConfig) -> Self {
57        let flags = if config.keep_alive_enabled() {
58            Flags::KEEPALIVE_ENABLED
59        } else {
60            Flags::empty()
61        };
62        ClientCodec {
63            inner: ClientCodecInner {
64                config,
65                decoder: decoder::MessageDecoder::default(),
66                payload: None,
67                version: Version::HTTP_11,
68                ctype: ConnectionType::Close,
69
70                flags,
71                encoder: encoder::MessageEncoder::default(),
72            },
73        }
74    }
75
76    /// Check if request is upgrade
77    pub fn upgrade(&self) -> bool {
78        self.inner.ctype == ConnectionType::Upgrade
79    }
80
81    /// Check if last response is keep-alive
82    pub fn keepalive(&self) -> bool {
83        self.inner.ctype == ConnectionType::KeepAlive
84    }
85
86    /// Check last request's message type
87    pub fn message_type(&self) -> MessageType {
88        if self.inner.flags.contains(Flags::STREAM) {
89            MessageType::Stream
90        } else if self.inner.payload.is_none() {
91            MessageType::None
92        } else {
93            MessageType::Payload
94        }
95    }
96
97    /// Convert message codec to a payload codec
98    pub fn into_payload_codec(self) -> ClientPayloadCodec {
99        ClientPayloadCodec { inner: self.inner }
100    }
101}
102
103impl ClientPayloadCodec {
104    /// Check if last response is keep-alive
105    pub fn keepalive(&self) -> bool {
106        self.inner.ctype == ConnectionType::KeepAlive
107    }
108
109    /// Transform payload codec to a message codec
110    pub fn into_message_codec(self) -> ClientCodec {
111        ClientCodec { inner: self.inner }
112    }
113}
114
115impl Decoder for ClientCodec {
116    type Item = ResponseHead;
117    type Error = ParseError;
118
119    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
120        debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set");
121
122        if let Some((req, payload)) = self.inner.decoder.decode(src)? {
123            if let Some(ctype) = req.ctype() {
124                // do not use peer's keep-alive
125                self.inner.ctype = if ctype == ConnectionType::KeepAlive {
126                    self.inner.ctype
127                } else {
128                    ctype
129                };
130            }
131
132            if !self.inner.flags.contains(Flags::HEAD) {
133                match payload {
134                    PayloadType::None => self.inner.payload = None,
135                    PayloadType::Payload(pl) => self.inner.payload = Some(pl),
136                    PayloadType::Stream(pl) => {
137                        self.inner.payload = Some(pl);
138                        self.inner.flags.insert(Flags::STREAM);
139                    }
140                }
141            } else {
142                self.inner.payload = None;
143            }
144            reserve_readbuf(src);
145            Ok(Some(req))
146        } else {
147            Ok(None)
148        }
149    }
150}
151
152impl Decoder for ClientPayloadCodec {
153    type Item = Option<Bytes>;
154    type Error = PayloadError;
155
156    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
157        debug_assert!(
158            self.inner.payload.is_some(),
159            "Payload decoder is not specified"
160        );
161
162        Ok(match self.inner.payload.as_mut().unwrap().decode(src)? {
163            Some(PayloadItem::Chunk(chunk)) => {
164                reserve_readbuf(src);
165                Some(Some(chunk))
166            }
167            Some(PayloadItem::Eof) => {
168                self.inner.payload.take();
169                Some(None)
170            }
171            None => None,
172        })
173    }
174}
175
176impl Encoder for ClientCodec {
177    type Item = Message<(RequestHeadType, BodySize)>;
178    type Error = io::Error;
179
180    fn encode(
181        &mut self,
182        item: Self::Item,
183        dst: &mut BytesMut,
184    ) -> Result<(), Self::Error> {
185        match item {
186            Message::Item((mut head, length)) => {
187                let inner = &mut self.inner;
188                inner.version = head.as_ref().version;
189                inner
190                    .flags
191                    .set(Flags::HEAD, head.as_ref().method == Method::HEAD);
192
193                // connection status
194                inner.ctype = match head.as_ref().connection_type() {
195                    ConnectionType::KeepAlive => {
196                        if inner.flags.contains(Flags::KEEPALIVE_ENABLED) {
197                            ConnectionType::KeepAlive
198                        } else {
199                            ConnectionType::Close
200                        }
201                    }
202                    ConnectionType::Upgrade => ConnectionType::Upgrade,
203                    ConnectionType::Close => ConnectionType::Close,
204                };
205
206                inner.encoder.encode(
207                    dst,
208                    &mut head,
209                    false,
210                    false,
211                    inner.version,
212                    length,
213                    inner.ctype,
214                    &inner.config,
215                )?;
216            }
217            Message::Chunk(Some(bytes)) => {
218                self.inner.encoder.encode_chunk(bytes.as_ref(), dst)?;
219            }
220            Message::Chunk(None) => {
221                self.inner.encoder.encode_eof(dst)?;
222            }
223        }
224        Ok(())
225    }
226}
227
228pub struct Writer<'a>(pub &'a mut BytesMut);
229
230impl<'a> io::Write for Writer<'a> {
231    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
232        self.0.extend_from_slice(buf);
233        Ok(buf.len())
234    }
235    fn flush(&mut self) -> io::Result<()> {
236        Ok(())
237    }
238}