rama_ws/protocol/
message.rs

1use self::string_collect::StringCollector;
2use super::frame::{CloseFrame, Frame};
3use crate::protocol::error::ProtocolError;
4use crate::protocol::frame::Utf8Bytes;
5use rama_core::bytes::Bytes;
6use rama_utils::str::utf8;
7use std::{fmt, result::Result as StdResult, str};
8
9mod string_collect {
10    use rama_core::error::OpaqueError;
11
12    use super::*;
13
14    #[derive(Debug)]
15    pub(super) struct StringCollector {
16        data: String,
17        incomplete: Option<utf8::Incomplete>,
18    }
19
20    impl StringCollector {
21        pub(super) fn new() -> Self {
22            Self {
23                data: String::new(),
24                incomplete: None,
25            }
26        }
27
28        pub(super) fn len(&self) -> usize {
29            self.data
30                .len()
31                .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
32        }
33
34        pub(super) fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<(), ProtocolError> {
35            let mut input: &[u8] = tail.as_ref();
36
37            if let Some(mut incomplete) = self.incomplete.take() {
38                if let Some((result, rest)) = incomplete.try_complete(input) {
39                    input = rest;
40                    match result {
41                        Ok(text) => self.data.push_str(text),
42                        Err(result_bytes) => {
43                            return Err(ProtocolError::Utf8(OpaqueError::from_display(
44                                String::from_utf8_lossy(result_bytes).to_string(),
45                            )));
46                        }
47                    }
48                } else {
49                    input = &[];
50                    self.incomplete = Some(incomplete);
51                }
52            }
53
54            if !input.is_empty() {
55                match utf8::decode(input) {
56                    Ok(text) => {
57                        self.data.push_str(text);
58                        Ok(())
59                    }
60                    Err(utf8::DecodeError::Incomplete {
61                        valid_prefix,
62                        incomplete_suffix,
63                    }) => {
64                        self.data.push_str(valid_prefix);
65                        self.incomplete = Some(incomplete_suffix);
66                        Ok(())
67                    }
68                    Err(utf8::DecodeError::Invalid {
69                        valid_prefix,
70                        invalid_sequence,
71                        ..
72                    }) => {
73                        self.data.push_str(valid_prefix);
74                        Err(ProtocolError::Utf8(OpaqueError::from_display(
75                            String::from_utf8_lossy(invalid_sequence).to_string(),
76                        )))
77                    }
78                }
79            } else {
80                Ok(())
81            }
82        }
83
84        pub(super) fn into_string(self) -> Result<String, ProtocolError> {
85            if let Some(incomplete) = self.incomplete {
86                Err(ProtocolError::Utf8(OpaqueError::from_display(format!(
87                    "incomplete string: {incomplete:?}",
88                ))))
89            } else {
90                Ok(self.data)
91            }
92        }
93    }
94}
95
96/// A struct representing the incomplete message.
97#[derive(Debug)]
98pub(super) struct IncompleteMessage {
99    collector: IncompleteMessageCollector,
100}
101
102#[derive(Debug)]
103enum IncompleteMessageCollector {
104    Text(StringCollector),
105    Binary(Vec<u8>),
106}
107
108impl IncompleteMessage {
109    /// Create new.
110    pub(super) fn new(message_type: IncompleteMessageType) -> Self {
111        Self {
112            collector: match message_type {
113                IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
114                IncompleteMessageType::Text => {
115                    IncompleteMessageCollector::Text(StringCollector::new())
116                }
117            },
118        }
119    }
120
121    /// Get the current filled size of the buffer.
122    pub(super) fn len(&self) -> usize {
123        match self.collector {
124            IncompleteMessageCollector::Text(ref t) => t.len(),
125            IncompleteMessageCollector::Binary(ref b) => b.len(),
126        }
127    }
128
129    /// Add more data to an existing message.
130    pub(super) fn extend<T: AsRef<[u8]>>(
131        &mut self,
132        tail: T,
133        size_limit: Option<usize>,
134    ) -> Result<(), ProtocolError> {
135        // Always have a max size. This ensures an error in case of concatenating two buffers
136        // of more than `usize::max_value()` bytes in total.
137        let max_size = size_limit.unwrap_or_else(usize::max_value);
138        let my_size = self.len();
139        let portion_size = tail.as_ref().len();
140        // Be careful about integer overflows here.
141        if my_size > max_size || portion_size > max_size - my_size {
142            return Err(ProtocolError::MessageTooLong {
143                size: my_size + portion_size,
144                max_size,
145            });
146        }
147
148        match self.collector {
149            IncompleteMessageCollector::Binary(ref mut v) => {
150                v.extend(tail.as_ref());
151                Ok(())
152            }
153            IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
154        }
155    }
156
157    /// Convert an incomplete message into a complete one.
158    pub(super) fn complete(self) -> Result<Message, ProtocolError> {
159        match self.collector {
160            IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v.into())),
161            IncompleteMessageCollector::Text(t) => {
162                let text = t.into_string()?;
163                Ok(Message::text(text))
164            }
165        }
166    }
167}
168
169/// The type of incomplete message.
170#[derive(Debug, Clone, Copy)]
171pub(super) enum IncompleteMessageType {
172    Text,
173    Binary,
174}
175
176/// An enum representing the various forms of a WebSocket message.
177#[derive(Debug, Eq, PartialEq, Clone)]
178pub enum Message {
179    /// A text WebSocket message
180    Text(Utf8Bytes),
181    /// A binary WebSocket message
182    Binary(Bytes),
183    /// A ping message with the specified payload
184    ///
185    /// The payload here must have a length less than 125 bytes
186    Ping(Bytes),
187    /// A pong message with the specified payload
188    ///
189    /// The payload here must have a length less than 125 bytes
190    Pong(Bytes),
191    /// A close message with the optional close frame.
192    Close(Option<CloseFrame>),
193    /// Raw frame. Note, that you're not going to get this value while reading the message.
194    Frame(Frame),
195}
196
197impl Message {
198    /// Create a new text WebSocket message from a stringable.
199    pub fn text<S>(string: S) -> Self
200    where
201        S: Into<Utf8Bytes>,
202    {
203        Self::Text(string.into())
204    }
205
206    /// Create a new binary WebSocket message by converting to `Bytes`.
207    pub fn binary<B>(bin: B) -> Self
208    where
209        B: Into<Bytes>,
210    {
211        Self::Binary(bin.into())
212    }
213
214    /// Indicates whether a message is a text message.
215    pub fn is_text(&self) -> bool {
216        matches!(*self, Self::Text(_))
217    }
218
219    /// Indicates whether a message is a binary message.
220    pub fn is_binary(&self) -> bool {
221        matches!(*self, Self::Binary(_))
222    }
223
224    /// Indicates whether a message is a ping message.
225    pub fn is_ping(&self) -> bool {
226        matches!(*self, Self::Ping(_))
227    }
228
229    /// Indicates whether a message is a pong message.
230    pub fn is_pong(&self) -> bool {
231        matches!(*self, Self::Pong(_))
232    }
233
234    /// Indicates whether a message is a close message.
235    pub fn is_close(&self) -> bool {
236        matches!(*self, Self::Close(_))
237    }
238
239    /// Get the length of the WebSocket message.
240    pub fn len(&self) -> usize {
241        match *self {
242            Self::Text(ref string) => string.len(),
243            Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => data.len(),
244            Self::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
245            Self::Frame(ref frame) => frame.len(),
246        }
247    }
248
249    /// Returns true if the WebSocket message has no content.
250    /// For example, if the other side of the connection sent an empty string.
251    pub fn is_empty(&self) -> bool {
252        self.len() == 0
253    }
254
255    /// Consume the WebSocket and return it as binary data.
256    pub fn into_data(self) -> Bytes {
257        match self {
258            Self::Text(utf8) => utf8.into(),
259            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
260            Self::Close(None) => <_>::default(),
261            Self::Close(Some(frame)) => frame.reason.into(),
262            Self::Frame(frame) => frame.into_payload(),
263        }
264    }
265
266    /// Attempt to consume the WebSocket message and convert it to a String.
267    pub fn into_text(self) -> Result<Utf8Bytes, ProtocolError> {
268        match self {
269            Self::Text(txt) => Ok(txt),
270            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(data.try_into()?),
271            Self::Close(None) => Ok(<_>::default()),
272            Self::Close(Some(frame)) => Ok(frame.reason),
273            Self::Frame(frame) => Ok(frame.into_text()?),
274        }
275    }
276
277    /// Attempt to get a &str from the WebSocket message,
278    /// this will try to convert binary data to utf8.
279    pub fn to_text(&self) -> Result<&str, ProtocolError> {
280        match *self {
281            Self::Text(ref string) => Ok(string.as_str()),
282            Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
283                Ok(str::from_utf8(data)?)
284            }
285            Self::Close(None) => Ok(""),
286            Self::Close(Some(ref frame)) => Ok(&frame.reason),
287            Self::Frame(ref frame) => Ok(frame.to_text()?),
288        }
289    }
290}
291
292impl From<String> for Message {
293    #[inline]
294    fn from(string: String) -> Self {
295        Self::text(string)
296    }
297}
298
299impl<'s> From<&'s str> for Message {
300    #[inline]
301    fn from(string: &'s str) -> Self {
302        Self::text(string)
303    }
304}
305
306impl<'b> From<&'b [u8]> for Message {
307    #[inline]
308    fn from(data: &'b [u8]) -> Self {
309        Self::binary(Bytes::copy_from_slice(data))
310    }
311}
312
313impl From<Bytes> for Message {
314    fn from(data: Bytes) -> Self {
315        Self::binary(data)
316    }
317}
318
319impl From<Vec<u8>> for Message {
320    #[inline]
321    fn from(data: Vec<u8>) -> Self {
322        Self::binary(data)
323    }
324}
325
326impl From<Message> for Bytes {
327    #[inline]
328    fn from(message: Message) -> Self {
329        message.into_data()
330    }
331}
332
333impl fmt::Display for Message {
334    fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
335        match self {
336            Self::Text(utf8_bytes) => write!(f, "Message::Text({utf8_bytes})"),
337            Self::Binary(bytes) => write!(f, "Message::Binary({bytes:x})"),
338            Self::Ping(bytes) => write!(f, "Message::Ping({bytes:x})"),
339            Self::Pong(bytes) => write!(f, "Message::Pong({bytes:x})"),
340            Self::Close(_) => write!(f, "Message::Close<length={}>", self.len()),
341            Self::Frame(_) => write!(f, "Message::Frame<length={}>", self.len()),
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn display() {
352        let t = Message::text("test".to_owned());
353        assert_eq!(t.to_string(), "Message::Text(test)".to_owned());
354
355        let bin = Message::binary(vec![0, 1, 3, 4, 241]);
356        assert_eq!(bin.to_string(), "Message::Binary(00010304f1)".to_owned());
357    }
358
359    #[test]
360    fn binary_convert() {
361        let bin = [6u8, 7, 8, 9, 10, 241];
362        let msg = Message::from(&bin[..]);
363        assert!(msg.is_binary());
364        assert!(msg.into_text().is_err());
365    }
366
367    #[test]
368    fn binary_convert_bytes() {
369        let bin = Bytes::from_iter([6u8, 7, 8, 9, 10, 241]);
370        let msg = Message::from(bin);
371        assert!(msg.is_binary());
372        assert!(msg.into_text().is_err());
373    }
374
375    #[test]
376    fn binary_convert_vec() {
377        let bin = vec![6u8, 7, 8, 9, 10, 241];
378        let msg = Message::from(bin);
379        assert!(msg.is_binary());
380        assert!(msg.into_text().is_err());
381    }
382
383    #[test]
384    fn binary_convert_into_bytes() {
385        let bin = vec![6u8, 7, 8, 9, 10, 241];
386        let bin_copy = bin.clone();
387        let msg = Message::from(bin);
388        let serialized: Bytes = msg.into();
389        assert_eq!(bin_copy, serialized);
390    }
391
392    #[test]
393    fn text_convert() {
394        let s = "kiwotsukete";
395        let msg = Message::from(s);
396        assert!(msg.is_text());
397    }
398}