1use std::{fmt, io};
2
3use actori_codec::{Decoder, Encoder};
4use bitflags::bitflags;
5use bytes::BytesMut;
6use http::{Method, Version};
7
8use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
9use super::{decoder, encoder};
10use super::{Message, MessageType};
11use crate::body::BodySize;
12use crate::config::ServiceConfig;
13use crate::error::ParseError;
14use crate::message::ConnectionType;
15use crate::request::Request;
16use crate::response::Response;
17
18bitflags! {
19 struct Flags: u8 {
20 const HEAD = 0b0000_0001;
21 const KEEPALIVE_ENABLED = 0b0000_0010;
22 const STREAM = 0b0000_0100;
23 }
24}
25
26pub struct Codec {
28 config: ServiceConfig,
29 decoder: decoder::MessageDecoder<Request>,
30 payload: Option<PayloadDecoder>,
31 version: Version,
32 ctype: ConnectionType,
33
34 flags: Flags,
36 encoder: encoder::MessageEncoder<Response<()>>,
37}
38
39impl Default for Codec {
40 fn default() -> Self {
41 Codec::new(ServiceConfig::default())
42 }
43}
44
45impl fmt::Debug for Codec {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 write!(f, "h1::Codec({:?})", self.flags)
48 }
49}
50
51impl Codec {
52 pub fn new(config: ServiceConfig) -> Self {
56 let flags = if config.keep_alive_enabled() {
57 Flags::KEEPALIVE_ENABLED
58 } else {
59 Flags::empty()
60 };
61 Codec {
62 config,
63 flags,
64 decoder: decoder::MessageDecoder::default(),
65 payload: None,
66 version: Version::HTTP_11,
67 ctype: ConnectionType::Close,
68 encoder: encoder::MessageEncoder::default(),
69 }
70 }
71
72 #[inline]
73 pub fn upgrade(&self) -> bool {
75 self.ctype == ConnectionType::Upgrade
76 }
77
78 #[inline]
79 pub fn keepalive(&self) -> bool {
81 self.ctype == ConnectionType::KeepAlive
82 }
83
84 #[inline]
85 pub fn keepalive_enabled(&self) -> bool {
87 self.flags.contains(Flags::KEEPALIVE_ENABLED)
88 }
89
90 #[inline]
91 pub fn message_type(&self) -> MessageType {
93 if self.flags.contains(Flags::STREAM) {
94 MessageType::Stream
95 } else if self.payload.is_none() {
96 MessageType::None
97 } else {
98 MessageType::Payload
99 }
100 }
101
102 #[inline]
103 pub fn config(&self) -> &ServiceConfig {
104 &self.config
105 }
106}
107
108impl Decoder for Codec {
109 type Item = Message<Request>;
110 type Error = ParseError;
111
112 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
113 if self.payload.is_some() {
114 Ok(match self.payload.as_mut().unwrap().decode(src)? {
115 Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
116 Some(PayloadItem::Eof) => {
117 self.payload.take();
118 Some(Message::Chunk(None))
119 }
120 None => None,
121 })
122 } else if let Some((req, payload)) = self.decoder.decode(src)? {
123 let head = req.head();
124 self.flags.set(Flags::HEAD, head.method == Method::HEAD);
125 self.version = head.version;
126 self.ctype = head.connection_type();
127 if self.ctype == ConnectionType::KeepAlive
128 && !self.flags.contains(Flags::KEEPALIVE_ENABLED)
129 {
130 self.ctype = ConnectionType::Close
131 }
132 match payload {
133 PayloadType::None => self.payload = None,
134 PayloadType::Payload(pl) => self.payload = Some(pl),
135 PayloadType::Stream(pl) => {
136 self.payload = Some(pl);
137 self.flags.insert(Flags::STREAM);
138 }
139 }
140 Ok(Some(Message::Item(req)))
141 } else {
142 Ok(None)
143 }
144 }
145}
146
147impl Encoder for Codec {
148 type Item = Message<(Response<()>, BodySize)>;
149 type Error = io::Error;
150
151 fn encode(
152 &mut self,
153 item: Self::Item,
154 dst: &mut BytesMut,
155 ) -> Result<(), Self::Error> {
156 match item {
157 Message::Item((mut res, length)) => {
158 res.head_mut().version = self.version;
160
161 self.ctype = if let Some(ct) = res.head().ctype() {
163 if ct == ConnectionType::KeepAlive {
164 self.ctype
165 } else {
166 ct
167 }
168 } else {
169 self.ctype
170 };
171
172 self.encoder.encode(
174 dst,
175 &mut res,
176 self.flags.contains(Flags::HEAD),
177 self.flags.contains(Flags::STREAM),
178 self.version,
179 length,
180 self.ctype,
181 &self.config,
182 )?;
183 }
185 Message::Chunk(Some(bytes)) => {
186 self.encoder.encode_chunk(bytes.as_ref(), dst)?;
187 }
188 Message::Chunk(None) => {
189 self.encoder.encode_eof(dst)?;
190 }
191 }
192 Ok(())
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use bytes::BytesMut;
199 use http::Method;
200
201 use super::*;
202 use crate::httpmessage::HttpMessage;
203
204 #[test]
205 fn test_http_request_chunked_payload_and_next_message() {
206 let mut codec = Codec::default();
207
208 let mut buf = BytesMut::from(
209 "GET /test HTTP/1.1\r\n\
210 transfer-encoding: chunked\r\n\r\n",
211 );
212 let item = codec.decode(&mut buf).unwrap().unwrap();
213 let req = item.message();
214
215 assert_eq!(req.method(), Method::GET);
216 assert!(req.chunked().unwrap());
217
218 buf.extend(
219 b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n\
220 POST /test2 HTTP/1.1\r\n\
221 transfer-encoding: chunked\r\n\r\n"
222 .iter(),
223 );
224
225 let msg = codec.decode(&mut buf).unwrap().unwrap();
226 assert_eq!(msg.chunk().as_ref(), b"data");
227
228 let msg = codec.decode(&mut buf).unwrap().unwrap();
229 assert_eq!(msg.chunk().as_ref(), b"line");
230
231 let msg = codec.decode(&mut buf).unwrap().unwrap();
232 assert!(msg.eof());
233
234 let item = codec.decode(&mut buf).unwrap().unwrap();
236 let req = item.message();
237 assert_eq!(*req.method(), Method::POST);
238 assert!(req.chunked().unwrap());
239 }
240}