requiem_http/ws/
codec.rs

1use requiem_codec::{Decoder, Encoder};
2use bytes::{Bytes, BytesMut};
3
4use super::frame::Parser;
5use super::proto::{CloseReason, OpCode};
6use super::ProtocolError;
7
8/// `WebSocket` Message
9#[derive(Debug, PartialEq)]
10pub enum Message {
11    /// Text message
12    Text(String),
13    /// Binary message
14    Binary(Bytes),
15    /// Continuation
16    Continuation(Item),
17    /// Ping message
18    Ping(Bytes),
19    /// Pong message
20    Pong(Bytes),
21    /// Close message with optional reason
22    Close(Option<CloseReason>),
23    /// No-op. Useful for actix-net services
24    Nop,
25}
26
27/// `WebSocket` frame
28#[derive(Debug, PartialEq)]
29pub enum Frame {
30    /// Text frame, codec does not verify utf8 encoding
31    Text(Bytes),
32    /// Binary frame
33    Binary(Bytes),
34    /// Continuation
35    Continuation(Item),
36    /// Ping message
37    Ping(Bytes),
38    /// Pong message
39    Pong(Bytes),
40    /// Close message with optional reason
41    Close(Option<CloseReason>),
42}
43
44/// `WebSocket` continuation item
45#[derive(Debug, PartialEq)]
46pub enum Item {
47    FirstText(Bytes),
48    FirstBinary(Bytes),
49    Continue(Bytes),
50    Last(Bytes),
51}
52
53#[derive(Debug, Copy, Clone)]
54/// WebSockets protocol codec
55pub struct Codec {
56    flags: Flags,
57    max_size: usize,
58}
59
60bitflags::bitflags! {
61    struct Flags: u8 {
62        const SERVER         = 0b0000_0001;
63        const CONTINUATION   = 0b0000_0010;
64        const W_CONTINUATION = 0b0000_0100;
65    }
66}
67
68impl Codec {
69    /// Create new websocket frames decoder
70    pub fn new() -> Codec {
71        Codec {
72            max_size: 65_536,
73            flags: Flags::SERVER,
74        }
75    }
76
77    /// Set max frame size
78    ///
79    /// By default max size is set to 64kb
80    pub fn max_size(mut self, size: usize) -> Self {
81        self.max_size = size;
82        self
83    }
84
85    /// Set decoder to client mode.
86    ///
87    /// By default decoder works in server mode.
88    pub fn client_mode(mut self) -> Self {
89        self.flags.remove(Flags::SERVER);
90        self
91    }
92}
93
94impl Encoder for Codec {
95    type Item = Message;
96    type Error = ProtocolError;
97
98    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
99        match item {
100            Message::Text(txt) => Parser::write_message(
101                dst,
102                txt,
103                OpCode::Text,
104                true,
105                !self.flags.contains(Flags::SERVER),
106            ),
107            Message::Binary(bin) => Parser::write_message(
108                dst,
109                bin,
110                OpCode::Binary,
111                true,
112                !self.flags.contains(Flags::SERVER),
113            ),
114            Message::Ping(txt) => Parser::write_message(
115                dst,
116                txt,
117                OpCode::Ping,
118                true,
119                !self.flags.contains(Flags::SERVER),
120            ),
121            Message::Pong(txt) => Parser::write_message(
122                dst,
123                txt,
124                OpCode::Pong,
125                true,
126                !self.flags.contains(Flags::SERVER),
127            ),
128            Message::Close(reason) => {
129                Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER))
130            }
131            Message::Continuation(cont) => match cont {
132                Item::FirstText(data) => {
133                    if self.flags.contains(Flags::W_CONTINUATION) {
134                        return Err(ProtocolError::ContinuationStarted);
135                    } else {
136                        self.flags.insert(Flags::W_CONTINUATION);
137                        Parser::write_message(
138                            dst,
139                            &data[..],
140                            OpCode::Binary,
141                            false,
142                            !self.flags.contains(Flags::SERVER),
143                        )
144                    }
145                }
146                Item::FirstBinary(data) => {
147                    if self.flags.contains(Flags::W_CONTINUATION) {
148                        return Err(ProtocolError::ContinuationStarted);
149                    } else {
150                        self.flags.insert(Flags::W_CONTINUATION);
151                        Parser::write_message(
152                            dst,
153                            &data[..],
154                            OpCode::Text,
155                            false,
156                            !self.flags.contains(Flags::SERVER),
157                        )
158                    }
159                }
160                Item::Continue(data) => {
161                    if self.flags.contains(Flags::W_CONTINUATION) {
162                        Parser::write_message(
163                            dst,
164                            &data[..],
165                            OpCode::Continue,
166                            false,
167                            !self.flags.contains(Flags::SERVER),
168                        )
169                    } else {
170                        return Err(ProtocolError::ContinuationNotStarted);
171                    }
172                }
173                Item::Last(data) => {
174                    if self.flags.contains(Flags::W_CONTINUATION) {
175                        self.flags.remove(Flags::W_CONTINUATION);
176                        Parser::write_message(
177                            dst,
178                            &data[..],
179                            OpCode::Continue,
180                            true,
181                            !self.flags.contains(Flags::SERVER),
182                        )
183                    } else {
184                        return Err(ProtocolError::ContinuationNotStarted);
185                    }
186                }
187            },
188            Message::Nop => (),
189        }
190        Ok(())
191    }
192}
193
194impl Decoder for Codec {
195    type Item = Frame;
196    type Error = ProtocolError;
197
198    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
199        match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
200            Ok(Some((finished, opcode, payload))) => {
201                // continuation is not supported
202                if !finished {
203                    return match opcode {
204                        OpCode::Continue => {
205                            if self.flags.contains(Flags::CONTINUATION) {
206                                Ok(Some(Frame::Continuation(Item::Continue(
207                                    payload
208                                        .map(|pl| pl.freeze())
209                                        .unwrap_or_else(Bytes::new),
210                                ))))
211                            } else {
212                                Err(ProtocolError::ContinuationNotStarted)
213                            }
214                        }
215                        OpCode::Binary => {
216                            if !self.flags.contains(Flags::CONTINUATION) {
217                                self.flags.insert(Flags::CONTINUATION);
218                                Ok(Some(Frame::Continuation(Item::FirstBinary(
219                                    payload
220                                        .map(|pl| pl.freeze())
221                                        .unwrap_or_else(Bytes::new),
222                                ))))
223                            } else {
224                                Err(ProtocolError::ContinuationStarted)
225                            }
226                        }
227                        OpCode::Text => {
228                            if !self.flags.contains(Flags::CONTINUATION) {
229                                self.flags.insert(Flags::CONTINUATION);
230                                Ok(Some(Frame::Continuation(Item::FirstText(
231                                    payload
232                                        .map(|pl| pl.freeze())
233                                        .unwrap_or_else(Bytes::new),
234                                ))))
235                            } else {
236                                Err(ProtocolError::ContinuationStarted)
237                            }
238                        }
239                        _ => {
240                            error!("Unfinished fragment {:?}", opcode);
241                            Err(ProtocolError::ContinuationFragment(opcode))
242                        }
243                    };
244                }
245
246                match opcode {
247                    OpCode::Continue => {
248                        if self.flags.contains(Flags::CONTINUATION) {
249                            self.flags.remove(Flags::CONTINUATION);
250                            Ok(Some(Frame::Continuation(Item::Last(
251                                payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
252                            ))))
253                        } else {
254                            Err(ProtocolError::ContinuationNotStarted)
255                        }
256                    }
257                    OpCode::Bad => Err(ProtocolError::BadOpCode),
258                    OpCode::Close => {
259                        if let Some(ref pl) = payload {
260                            let close_reason = Parser::parse_close_payload(pl);
261                            Ok(Some(Frame::Close(close_reason)))
262                        } else {
263                            Ok(Some(Frame::Close(None)))
264                        }
265                    }
266                    OpCode::Ping => Ok(Some(Frame::Ping(
267                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
268                    ))),
269                    OpCode::Pong => Ok(Some(Frame::Pong(
270                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
271                    ))),
272                    OpCode::Binary => Ok(Some(Frame::Binary(
273                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
274                    ))),
275                    OpCode::Text => Ok(Some(Frame::Text(
276                        payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
277                    ))),
278                }
279            }
280            Ok(None) => Ok(None),
281            Err(e) => Err(e),
282        }
283    }
284}