blitz_ws/protocol/
message.rs

1//! WebSocket Message handler
2
3use bytes::Bytes;
4
5use crate::{
6    error::{CapacityError, Error, Result},
7    protocol::{
8        frame::{CloseFrame, Frame, Utf8Bytes},
9        message::string_lib::StringCollector,
10    },
11};
12
13mod string_lib {
14    use crate::error::{Error, Result};
15    use utf8::DecodeError;
16
17    #[derive(Debug)]
18    pub struct StringCollector {
19        data: String,
20        incomplete: Option<utf8::Incomplete>,
21    }
22
23    impl StringCollector {
24        pub fn new() -> Self {
25            StringCollector { data: String::new(), incomplete: None }
26        }
27
28        pub fn len(&self) -> usize {
29            self.data
30                .len()
31                .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
32        }
33
34        pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
35            let mut input: &[u8] = tail.as_ref();
36
37            if let Some(mut incomplete) = self.incomplete.take() {
38                if let Some((result, remaining)) = incomplete.try_complete(input) {
39                    input = remaining;
40
41                    match result {
42                        Ok(s) => self.data.push_str(s),
43                        Err(result_bytes) => {
44                            return Err(Error::Utf8(String::from_utf8_lossy(result_bytes).into()))
45                        }
46                    }
47                } else {
48                    input = &[];
49                    self.incomplete = Some(incomplete);
50                }
51            }
52
53            if !input.is_empty() {
54                match utf8::decode(input) {
55                    Ok(s) => {
56                        self.data.push_str(s);
57                        Ok(())
58                    }
59                    Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
60                        self.data.push_str(valid_prefix);
61                        self.incomplete = Some(incomplete_suffix);
62
63                        Ok(())
64                    }
65                    Err(DecodeError::Invalid { valid_prefix, invalid_sequence, .. }) => {
66                        self.data.push_str(valid_prefix);
67
68                        Err(Error::Utf8(String::from_utf8_lossy(invalid_sequence).into()))
69                    }
70                }
71            } else {
72                Ok(())
73            }
74        }
75
76        pub fn into_string(self) -> Result<String> {
77            if let Some(incomplete) = self.incomplete {
78                Err(Error::Utf8(format!("Incomplete string: {:?}", incomplete)))
79            } else {
80                Ok(self.data)
81            }
82        }
83    }
84}
85
86/// A struct representing the incomplete message.
87#[derive(Debug)]
88pub struct IncompleteMessage {
89    collector: IncompleteMessageCollector,
90}
91
92#[derive(Debug)]
93enum IncompleteMessageCollector {
94    Text(StringCollector),
95    Binary(Vec<u8>),
96}
97
98/// The type of incomplete message.
99#[allow(missing_copy_implementations)]
100#[derive(Debug)]
101pub enum IncompleteMessageType {
102    /// Text type
103    Text,
104    /// Binary type
105    Binary,
106}
107
108impl IncompleteMessage {
109    /// Create new.
110    pub fn new(msg_type: IncompleteMessageType) -> Self {
111        IncompleteMessage {
112            collector: match msg_type {
113                IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
114                IncompleteMessageType::Text => {
115                    IncompleteMessageCollector::Text(StringCollector::new())
116                }
117            },
118        }
119    }
120
121    /// Get the current filled size of the buffer.
122    pub fn len(&self) -> usize {
123        match self.collector {
124            IncompleteMessageCollector::Binary(ref b) => b.len(),
125            IncompleteMessageCollector::Text(ref t) => t.len(),
126        }
127    }
128
129    /// Checks if the incomplete message is empty
130    pub fn is_empty(&self) -> bool {
131        self.len() == 0
132    }
133
134    /// Add more data to an existing message.
135    pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T, limit: Option<usize>) -> Result<()> {
136        let max = limit.unwrap_or(usize::MAX);
137        let size = self.len();
138        let portion = tail.as_ref().len();
139
140        if size > max || portion > max - size {
141            return Err(Error::Capacity(CapacityError::MessageTooLarge {
142                size: size + portion,
143                max,
144            }));
145        }
146
147        match self.collector {
148            IncompleteMessageCollector::Binary(ref mut b) => {
149                b.extend(tail.as_ref());
150                Ok(())
151            }
152            IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
153        }
154    }
155
156    /// Convert an incomplete message into a complete one.
157    pub fn complete(self) -> Result<Message> {
158        match self.collector {
159            IncompleteMessageCollector::Binary(b) => Ok(Message::Binary(b.into())),
160            IncompleteMessageCollector::Text(t) => {
161                let text = t.into_string()?;
162                Ok(Message::Text(text.into()))
163            }
164        }
165    }
166}
167
168/// A WebSocket message
169#[derive(Debug, Clone, PartialEq, Eq)]
170pub enum Message {
171    /// A text message
172    Text(Utf8Bytes),
173    /// A binary message
174    Binary(Bytes),
175    /// A ping (control) message
176    Ping(Bytes),
177    /// A pong (control) message
178    Pong(Bytes),
179    /// A close (control) message
180    Close(Option<CloseFrame>),
181    /// Raw frame
182    Frame(Frame),
183}
184
185impl Message {
186    /// Create a new text WebSocket message from a stringable.
187    pub fn new_text<S>(string: S) -> Message
188    where
189        S: Into<Utf8Bytes>,
190    {
191        Message::Text(string.into())
192    }
193
194    /// Create a new binary WebSocket message by converting to `Bytes`.
195    pub fn new_binary<B>(binary: B) -> Message
196    where
197        B: Into<Bytes>,
198    {
199        Message::Binary(binary.into())
200    }
201
202    /// Indicates if the Message is of control protocol (`Ping`, `Pong`, `Close`)
203    pub fn is_control(&self) -> bool {
204        matches!(self, Message::Ping(_) | Message::Pong(_) | Message::Close(_))
205    }
206
207    /// Indicates if the Message is of data protocol (`Text`, `Binary`)
208    pub fn is_data(&self) -> bool {
209        matches!(self, Message::Text(_) | Message::Binary(_))
210    }
211
212    /// Indicates if the Message is of `Text` protocol
213    pub fn is_text(&self) -> bool {
214        matches!(self, Message::Text(_))
215    }
216
217    /// Indicates if the Message is of `Binary` protocol
218    pub fn is_binary(&self) -> bool {
219        matches!(self, Message::Binary(_))
220    }
221
222    /// Get the length of the WebSocket message.
223    pub fn len(&self) -> usize {
224        match *self {
225            Message::Text(ref s) => s.len(),
226            Message::Binary(ref b) | Message::Ping(ref b) | Message::Pong(ref b) => b.len(),
227            Message::Close(ref frame) => frame.as_ref().map(|d| d.reason.len()).unwrap_or(0),
228            Message::Frame(ref frame) => frame.len(),
229        }
230    }
231
232    /// Returns true if the WebSocket message has no content.
233    /// For example, if the other side of the connection sent an empty string.
234    pub fn is_empty(&self) -> bool {
235        self.len() == 0
236    }
237
238    /// Parses the message data
239    pub fn into_data(self) -> Bytes {
240        match self {
241            Self::Text(s) => s.into(),
242            Self::Binary(b) | Self::Ping(b) | Self::Pong(b) => b,
243            Self::Close(None) => <_>::default(),
244            Self::Close(Some(frame)) => frame.reason.into(),
245            Self::Frame(frame) => frame.into_payload(),
246        }
247    }
248}
249
250impl From<String> for Message {
251    #[inline]
252    fn from(value: String) -> Self {
253        Message::new_text(value)
254    }
255}
256
257impl<'s> From<&'s str> for Message {
258    #[inline]
259    fn from(value: &'s str) -> Self {
260        Message::new_text(value)
261    }
262}
263
264impl<'b> From<&'b [u8]> for Message {
265    #[inline]
266    fn from(value: &'b [u8]) -> Self {
267        Message::new_binary(Bytes::copy_from_slice(value))
268    }
269}
270
271impl From<Bytes> for Message {
272    fn from(value: Bytes) -> Self {
273        Message::new_binary(value)
274    }
275}
276
277impl From<Vec<u8>> for Message {
278    #[inline]
279    fn from(value: Vec<u8>) -> Self {
280        Message::new_binary(value)
281    }
282}
283
284impl From<Message> for Bytes {
285    #[inline]
286    fn from(value: Message) -> Self {
287        value.into_data()
288    }
289}
290
291impl std::fmt::Display for Message {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        match self {
294            Message::Text(s) => write!(f, "Text({})", s),
295            Message::Binary(b) => write!(f, "Binary({} bytes)", b.len()),
296            Message::Ping(_) => write!(f, "Ping"),
297            Message::Pong(_) => write!(f, "Pong"),
298            Message::Close(Some(frame)) => write!(f, "Close({}, {})", frame.code, frame.reason),
299            Message::Close(None) => write!(f, "Close"),
300            _ => Ok(()),
301        }
302    }
303}