http_ws/
codec.rs

1use bytes::{Bytes, BytesMut};
2use tracing::error;
3
4use super::{
5    error::ProtocolError,
6    frame::Parser,
7    proto::{CloseReason, OpCode},
8};
9
10/// A WebSocket message.
11#[derive(Debug, Eq, PartialEq)]
12pub enum Message {
13    /// Text message.
14    Text(Bytes),
15    /// Binary message.
16    Binary(Bytes),
17    /// Continuation.
18    Continuation(Item),
19    /// Ping message.
20    Ping(Bytes),
21    /// Pong message.
22    Pong(Bytes),
23    /// Close message with optional reason.
24    Close(Option<CloseReason>),
25    /// No-op. Useful for low-level services.
26    Nop,
27}
28
29/// A WebSocket continuation item.
30#[derive(Debug, Eq, PartialEq)]
31pub enum Item {
32    FirstText(Bytes),
33    FirstBinary(Bytes),
34    Continue(Bytes),
35    Last(Bytes),
36}
37
38/// WebSocket protocol codec.
39#[derive(Debug, Copy, Clone)]
40pub struct Codec {
41    flags: Flags,
42    capacity: usize,
43    max_size: usize,
44}
45
46#[derive(Debug, Copy, Clone)]
47struct Flags(u8);
48
49impl Flags {
50    const SERVER: u8 = 0b0001;
51    const CONTINUATION: u8 = 0b0010;
52    const CLOSED: u8 = 0b0100;
53
54    #[inline(always)]
55    fn remove(&mut self, other: u8) {
56        self.0 &= !other;
57    }
58
59    #[inline(always)]
60    fn insert(&mut self, other: u8) {
61        self.0 |= other;
62    }
63
64    #[inline(always)]
65    const fn contains(&self, other: u8) -> bool {
66        (self.0 & other) == other
67    }
68}
69
70impl Default for Codec {
71    fn default() -> Self {
72        unimplemented!("please use Codec::new")
73    }
74}
75
76impl Codec {
77    /// Create new WebSocket frames decoder.
78    pub const fn new() -> Codec {
79        Codec {
80            max_size: 65_536,
81            capacity: 128,
82            flags: Flags(Flags::SERVER),
83        }
84    }
85
86    /// Set max frame size.
87    ///
88    /// By default max size is set to 64kB.
89    pub fn set_max_size(mut self, size: usize) -> Self {
90        self.max_size = size;
91        self
92    }
93
94    pub const fn max_size(&self) -> usize {
95        self.max_size
96    }
97
98    /// Set capacity for concurrent buffered outgoing message.
99    ///
100    /// By default capacity is set to 128.
101    pub fn set_capacity(mut self, size: usize) -> Self {
102        self.capacity = size;
103        self
104    }
105
106    pub const fn capacity(&self) -> usize {
107        self.capacity
108    }
109
110    /// Set decoder to client mode.
111    ///
112    /// By default decoder works in server mode.
113    pub fn client_mode(mut self) -> Self {
114        self.flags.remove(Flags::SERVER);
115        self.flags.remove(Flags::CONTINUATION);
116        self
117    }
118
119    #[doc(hidden)]
120    pub fn duplicate(mut self) -> Self {
121        self.flags.remove(Flags::CONTINUATION);
122        self
123    }
124}
125
126impl Codec {
127    pub fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), ProtocolError> {
128        if self.flags.contains(Flags::CLOSED) {
129            return Err(ProtocolError::Closed);
130        }
131
132        let mask = !self.flags.contains(Flags::SERVER);
133        match item {
134            Message::Text(bytes) => Parser::write_message(dst, bytes, OpCode::Text, true, mask),
135            Message::Binary(bytes) => Parser::write_message(dst, bytes, OpCode::Binary, true, mask),
136            Message::Ping(bytes) => Parser::write_message(dst, bytes, OpCode::Ping, true, mask),
137            Message::Pong(bytes) => Parser::write_message(dst, bytes, OpCode::Pong, true, mask),
138            Message::Close(reason) => {
139                Parser::write_close(dst, reason, mask);
140                self.flags.insert(Flags::CLOSED);
141            }
142            Message::Continuation(cont) => match cont {
143                Item::Continue(_) | Item::Last(_) if !self.flags.contains(Flags::CONTINUATION) => {
144                    return Err(ProtocolError::ContinuationNotStarted)
145                }
146                Item::FirstText(ref data) => {
147                    self.try_start_continue()?;
148                    Parser::write_message(dst, data, OpCode::Text, false, mask);
149                }
150                Item::FirstBinary(ref data) => {
151                    self.try_start_continue()?;
152                    Parser::write_message(dst, data, OpCode::Binary, false, mask);
153                }
154                Item::Continue(ref data) => Parser::write_message(dst, data, OpCode::Continue, false, mask),
155                Item::Last(ref data) => {
156                    self.flags.remove(Flags::CONTINUATION);
157                    Parser::write_message(dst, data, OpCode::Continue, true, mask);
158                }
159            },
160            Message::Nop => {}
161        }
162
163        Ok(())
164    }
165
166    pub fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, ProtocolError> {
167        match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size)? {
168            Some((finished, opcode, payload)) => match opcode {
169                OpCode::Continue if !self.flags.contains(Flags::CONTINUATION) => {
170                    Err(ProtocolError::ContinuationNotStarted)
171                }
172                OpCode::Continue => {
173                    if finished {
174                        self.flags.remove(Flags::CONTINUATION);
175                    }
176                    Ok(Some(Message::Continuation(Item::Continue(
177                        payload.unwrap_or_else(Bytes::new),
178                    ))))
179                }
180                OpCode::Binary if !finished => {
181                    self.try_start_continue()?;
182                    Ok(Some(Message::Continuation(Item::FirstBinary(
183                        payload.unwrap_or_else(Bytes::new),
184                    ))))
185                }
186                OpCode::Text if !finished => {
187                    self.try_start_continue()?;
188                    Ok(Some(Message::Continuation(Item::FirstText(
189                        payload.unwrap_or_else(Bytes::new),
190                    ))))
191                }
192                OpCode::Close if !finished => {
193                    error!("Unfinished fragment {:?}", opcode);
194                    Err(ProtocolError::ContinuationFragment(opcode))
195                }
196                OpCode::Binary => Ok(Some(Message::Binary(payload.unwrap_or_else(Bytes::new)))),
197                OpCode::Text => Ok(Some(Message::Text(payload.unwrap_or_else(Bytes::new)))),
198                OpCode::Close => Ok(Some(Message::Close(
199                    payload.as_deref().and_then(Parser::parse_close_payload),
200                ))),
201                OpCode::Ping => Ok(Some(Message::Ping(payload.unwrap_or_else(Bytes::new)))),
202                OpCode::Pong => Ok(Some(Message::Pong(payload.unwrap_or_else(Bytes::new)))),
203                OpCode::Bad => Err(ProtocolError::BadOpCode),
204            },
205            None => Ok(None),
206        }
207    }
208
209    fn try_start_continue(&mut self) -> Result<(), ProtocolError> {
210        if !self.flags.contains(Flags::CONTINUATION) {
211            self.flags.insert(Flags::CONTINUATION);
212            Ok(())
213        } else {
214            Err(ProtocolError::ContinuationStarted)
215        }
216    }
217}
218
219#[cfg(test)]
220mod test {
221    use super::*;
222
223    #[test]
224    fn flag() {
225        let mut flags = Flags(Flags::SERVER);
226
227        assert!(flags.contains(Flags::SERVER));
228        assert!(!flags.contains(Flags::CONTINUATION));
229
230        flags.remove(Flags::SERVER);
231        assert!(!flags.contains(Flags::SERVER));
232        assert!(!flags.contains(Flags::CONTINUATION));
233
234        flags.insert(Flags::CONTINUATION);
235        assert!(flags.contains(Flags::CONTINUATION));
236        assert!(!flags.contains(Flags::SERVER));
237    }
238}