Skip to main content

http_ws/
codec.rs

1use bytes::{Bytes, BytesMut};
2use tracing::error;
3
4use super::{
5    error::ProtocolError,
6    frame::Parser,
7    proto::{CloseReason, OpCode},
8};
9
10/// A WebSocket message.
11#[derive(Debug, Eq, PartialEq)]
12pub enum Message {
13    /// Text message.
14    Text(Bytes),
15    /// Binary message.
16    Binary(Bytes),
17    /// Continuation.
18    Continuation(Item),
19    /// Ping message.
20    Ping(Bytes),
21    /// Pong message.
22    Pong(Bytes),
23    /// Close message with optional reason.
24    Close(Option<CloseReason>),
25    /// No-op. Useful for low-level services.
26    Nop,
27}
28
29/// A WebSocket continuation item.
30#[derive(Debug, Eq, PartialEq)]
31pub enum Item {
32    FirstText(Bytes),
33    FirstBinary(Bytes),
34    Continue(Bytes),
35    Last(Bytes),
36}
37
38/// WebSocket protocol codec.
39#[derive(Debug, Copy, Clone)]
40pub struct Codec {
41    flags: Flags,
42    capacity: usize,
43    max_size: usize,
44}
45
46#[derive(Debug, Copy, Clone)]
47struct Flags(u8);
48
49impl Flags {
50    const SERVER: u8 = 0b0001;
51    const CONTINUATION: u8 = 0b0010;
52    const SEND_CLOSED: u8 = 0b0100;
53    const RECV_CLOSED: u8 = 0b1000;
54
55    #[inline(always)]
56    fn remove(&mut self, other: u8) {
57        self.0 &= !other;
58    }
59
60    #[inline(always)]
61    fn insert(&mut self, other: u8) {
62        self.0 |= other;
63    }
64
65    #[inline(always)]
66    const fn contains(&self, other: u8) -> bool {
67        (self.0 & other) == other
68    }
69}
70
71impl Codec {
72    /// Create new WebSocket frames decoder.
73    pub const fn new() -> Codec {
74        Codec {
75            max_size: 65_536,
76            capacity: 128,
77            flags: Flags(Flags::SERVER),
78        }
79    }
80
81    /// Set max frame size.
82    ///
83    /// By default max size is set to 64kB.
84    pub fn set_max_size(mut self, size: usize) -> Self {
85        self.max_size = size;
86        self
87    }
88
89    pub const fn max_size(&self) -> usize {
90        self.max_size
91    }
92
93    /// Set capacity for concurrent buffered outgoing message.
94    ///
95    /// By default capacity is set to 128.
96    pub fn set_capacity(mut self, size: usize) -> Self {
97        self.capacity = size;
98        self
99    }
100
101    pub const fn capacity(&self) -> usize {
102        self.capacity
103    }
104
105    /// Set decoder to client mode.
106    ///
107    /// By default decoder works in server mode.
108    pub fn client_mode(mut self) -> Self {
109        self.flags.remove(Flags::SERVER);
110        self.flags.remove(Flags::CONTINUATION);
111        self
112    }
113
114    #[doc(hidden)]
115    pub fn duplicate(mut self) -> Self {
116        self.flags.remove(Flags::CONTINUATION);
117        self
118    }
119
120    pub(super) fn send_closed(&self) -> bool {
121        self.flags.contains(Flags::SEND_CLOSED)
122    }
123
124    fn set_send_closed(&mut self) {
125        self.flags.insert(Flags::SEND_CLOSED);
126    }
127
128    fn recv_closed(&self) -> bool {
129        self.flags.contains(Flags::RECV_CLOSED)
130    }
131
132    fn set_recv_closed(&mut self) {
133        self.flags.insert(Flags::RECV_CLOSED);
134    }
135}
136
137impl Codec {
138    pub fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), ProtocolError> {
139        if self.send_closed() {
140            return Err(ProtocolError::SendClosed);
141        }
142
143        let mask = !self.flags.contains(Flags::SERVER);
144
145        match item {
146            Message::Text(bytes) => Parser::write_message(dst, bytes, OpCode::Text, true, mask),
147            Message::Binary(bytes) => Parser::write_message(dst, bytes, OpCode::Binary, true, mask),
148            Message::Ping(bytes) => Parser::write_message(dst, bytes, OpCode::Ping, true, mask),
149            Message::Pong(bytes) => Parser::write_message(dst, bytes, OpCode::Pong, true, mask),
150            Message::Close(reason) => {
151                Parser::write_close(dst, reason, mask);
152                self.set_send_closed();
153            }
154            Message::Continuation(cont) => match cont {
155                Item::Continue(_) | Item::Last(_) if !self.flags.contains(Flags::CONTINUATION) => {
156                    return Err(ProtocolError::ContinuationNotStarted)
157                }
158                Item::FirstText(ref data) => {
159                    self.try_start_continue()?;
160                    Parser::write_message(dst, data, OpCode::Text, false, mask);
161                }
162                Item::FirstBinary(ref data) => {
163                    self.try_start_continue()?;
164                    Parser::write_message(dst, data, OpCode::Binary, false, mask);
165                }
166                Item::Continue(ref data) => Parser::write_message(dst, data, OpCode::Continue, false, mask),
167                Item::Last(ref data) => {
168                    self.flags.remove(Flags::CONTINUATION);
169                    Parser::write_message(dst, data, OpCode::Continue, true, mask);
170                }
171            },
172            Message::Nop => {}
173        }
174
175        Ok(())
176    }
177
178    pub fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, ProtocolError> {
179        if self.recv_closed() {
180            return Err(ProtocolError::RecvClosed);
181        }
182
183        match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size)? {
184            Some((finished, opcode, payload)) => match opcode {
185                OpCode::Continue if !self.flags.contains(Flags::CONTINUATION) => {
186                    Err(ProtocolError::ContinuationNotStarted)
187                }
188                OpCode::Continue => {
189                    if finished {
190                        self.flags.remove(Flags::CONTINUATION);
191                    }
192                    Ok(Some(Message::Continuation(Item::Continue(
193                        payload.unwrap_or_else(Bytes::new),
194                    ))))
195                }
196                OpCode::Binary if !finished => {
197                    self.try_start_continue()?;
198                    Ok(Some(Message::Continuation(Item::FirstBinary(
199                        payload.unwrap_or_else(Bytes::new),
200                    ))))
201                }
202                OpCode::Text if !finished => {
203                    self.try_start_continue()?;
204                    Ok(Some(Message::Continuation(Item::FirstText(
205                        payload.unwrap_or_else(Bytes::new),
206                    ))))
207                }
208                OpCode::Close if !finished => {
209                    error!("Unfinished fragment {:?}", opcode);
210                    Err(ProtocolError::ContinuationFragment(opcode))
211                }
212                OpCode::Binary => Ok(Some(Message::Binary(payload.unwrap_or_else(Bytes::new)))),
213                OpCode::Text => Ok(Some(Message::Text(payload.unwrap_or_else(Bytes::new)))),
214                OpCode::Close => {
215                    self.set_recv_closed();
216                    Ok(Some(Message::Close(
217                        payload.as_deref().and_then(Parser::parse_close_payload),
218                    )))
219                }
220                OpCode::Ping => Ok(Some(Message::Ping(payload.unwrap_or_else(Bytes::new)))),
221                OpCode::Pong => Ok(Some(Message::Pong(payload.unwrap_or_else(Bytes::new)))),
222                OpCode::Bad => Err(ProtocolError::BadOpCode),
223            },
224            None => Ok(None),
225        }
226    }
227
228    fn try_start_continue(&mut self) -> Result<(), ProtocolError> {
229        if !self.flags.contains(Flags::CONTINUATION) {
230            self.flags.insert(Flags::CONTINUATION);
231            Ok(())
232        } else {
233            Err(ProtocolError::ContinuationStarted)
234        }
235    }
236}
237
238#[cfg(test)]
239mod test {
240    use super::*;
241
242    #[test]
243    fn flag() {
244        let mut flags = Flags(Flags::SERVER);
245
246        assert!(flags.contains(Flags::SERVER));
247        assert!(!flags.contains(Flags::SEND_CLOSED));
248
249        flags.remove(Flags::SERVER);
250        assert!(!flags.contains(Flags::SERVER));
251        assert!(!flags.contains(Flags::SEND_CLOSED));
252
253        flags.insert(Flags::SEND_CLOSED);
254        assert!(flags.contains(Flags::SEND_CLOSED));
255        assert!(!flags.contains(Flags::SERVER));
256
257        flags.insert(Flags::RECV_CLOSED);
258        assert!(flags.contains(Flags::SEND_CLOSED));
259        assert!(flags.contains(Flags::RECV_CLOSED));
260        assert!(!flags.contains(Flags::SERVER));
261    }
262}