1use self::string_collect::StringCollector;
2use super::frame::{CloseFrame, Frame};
3use crate::protocol::error::ProtocolError;
4use crate::protocol::frame::Utf8Bytes;
5use rama_core::bytes::Bytes;
6use rama_utils::str::utf8;
7use std::{fmt, result::Result as StdResult, str};
8
9mod string_collect {
10 use rama_core::error::OpaqueError;
11
12 use super::*;
13
14 #[derive(Debug)]
15 pub(super) struct StringCollector {
16 data: String,
17 incomplete: Option<utf8::Incomplete>,
18 }
19
20 impl StringCollector {
21 pub(super) fn new() -> Self {
22 Self {
23 data: String::new(),
24 incomplete: None,
25 }
26 }
27
28 pub(super) fn len(&self) -> usize {
29 self.data
30 .len()
31 .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
32 }
33
34 pub(super) fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<(), ProtocolError> {
35 let mut input: &[u8] = tail.as_ref();
36
37 if let Some(mut incomplete) = self.incomplete.take() {
38 if let Some((result, rest)) = incomplete.try_complete(input) {
39 input = rest;
40 match result {
41 Ok(text) => self.data.push_str(text),
42 Err(result_bytes) => {
43 return Err(ProtocolError::Utf8(OpaqueError::from_display(
44 String::from_utf8_lossy(result_bytes).to_string(),
45 )));
46 }
47 }
48 } else {
49 input = &[];
50 self.incomplete = Some(incomplete);
51 }
52 }
53
54 if !input.is_empty() {
55 match utf8::decode(input) {
56 Ok(text) => {
57 self.data.push_str(text);
58 Ok(())
59 }
60 Err(utf8::DecodeError::Incomplete {
61 valid_prefix,
62 incomplete_suffix,
63 }) => {
64 self.data.push_str(valid_prefix);
65 self.incomplete = Some(incomplete_suffix);
66 Ok(())
67 }
68 Err(utf8::DecodeError::Invalid {
69 valid_prefix,
70 invalid_sequence,
71 ..
72 }) => {
73 self.data.push_str(valid_prefix);
74 Err(ProtocolError::Utf8(OpaqueError::from_display(
75 String::from_utf8_lossy(invalid_sequence).to_string(),
76 )))
77 }
78 }
79 } else {
80 Ok(())
81 }
82 }
83
84 pub(super) fn into_string(self) -> Result<String, ProtocolError> {
85 if let Some(incomplete) = self.incomplete {
86 Err(ProtocolError::Utf8(OpaqueError::from_display(format!(
87 "incomplete string: {incomplete:?}",
88 ))))
89 } else {
90 Ok(self.data)
91 }
92 }
93 }
94}
95
96#[derive(Debug)]
98pub(super) struct IncompleteMessage {
99 collector: IncompleteMessageCollector,
100}
101
102#[derive(Debug)]
103enum IncompleteMessageCollector {
104 Text(StringCollector),
105 Binary(Vec<u8>),
106}
107
108impl IncompleteMessage {
109 pub(super) fn new(message_type: IncompleteMessageType) -> Self {
111 Self {
112 collector: match message_type {
113 IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
114 IncompleteMessageType::Text => {
115 IncompleteMessageCollector::Text(StringCollector::new())
116 }
117 },
118 }
119 }
120
121 pub(super) fn len(&self) -> usize {
123 match self.collector {
124 IncompleteMessageCollector::Text(ref t) => t.len(),
125 IncompleteMessageCollector::Binary(ref b) => b.len(),
126 }
127 }
128
129 pub(super) fn extend<T: AsRef<[u8]>>(
131 &mut self,
132 tail: T,
133 size_limit: Option<usize>,
134 ) -> Result<(), ProtocolError> {
135 let max_size = size_limit.unwrap_or_else(usize::max_value);
138 let my_size = self.len();
139 let portion_size = tail.as_ref().len();
140 if my_size > max_size || portion_size > max_size - my_size {
142 return Err(ProtocolError::MessageTooLong {
143 size: my_size + portion_size,
144 max_size,
145 });
146 }
147
148 match self.collector {
149 IncompleteMessageCollector::Binary(ref mut v) => {
150 v.extend(tail.as_ref());
151 Ok(())
152 }
153 IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
154 }
155 }
156
157 pub(super) fn complete(self) -> Result<Message, ProtocolError> {
159 match self.collector {
160 IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v.into())),
161 IncompleteMessageCollector::Text(t) => {
162 let text = t.into_string()?;
163 Ok(Message::text(text))
164 }
165 }
166 }
167}
168
169#[derive(Debug, Clone, Copy)]
171pub(super) enum IncompleteMessageType {
172 Text,
173 Binary,
174}
175
176#[derive(Debug, Eq, PartialEq, Clone)]
178pub enum Message {
179 Text(Utf8Bytes),
181 Binary(Bytes),
183 Ping(Bytes),
187 Pong(Bytes),
191 Close(Option<CloseFrame>),
193 Frame(Frame),
195}
196
197impl Message {
198 pub fn text<S>(string: S) -> Self
200 where
201 S: Into<Utf8Bytes>,
202 {
203 Self::Text(string.into())
204 }
205
206 pub fn binary<B>(bin: B) -> Self
208 where
209 B: Into<Bytes>,
210 {
211 Self::Binary(bin.into())
212 }
213
214 pub fn is_text(&self) -> bool {
216 matches!(*self, Self::Text(_))
217 }
218
219 pub fn is_binary(&self) -> bool {
221 matches!(*self, Self::Binary(_))
222 }
223
224 pub fn is_ping(&self) -> bool {
226 matches!(*self, Self::Ping(_))
227 }
228
229 pub fn is_pong(&self) -> bool {
231 matches!(*self, Self::Pong(_))
232 }
233
234 pub fn is_close(&self) -> bool {
236 matches!(*self, Self::Close(_))
237 }
238
239 pub fn len(&self) -> usize {
241 match *self {
242 Self::Text(ref string) => string.len(),
243 Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => data.len(),
244 Self::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
245 Self::Frame(ref frame) => frame.len(),
246 }
247 }
248
249 pub fn is_empty(&self) -> bool {
252 self.len() == 0
253 }
254
255 pub fn into_data(self) -> Bytes {
257 match self {
258 Self::Text(utf8) => utf8.into(),
259 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
260 Self::Close(None) => <_>::default(),
261 Self::Close(Some(frame)) => frame.reason.into(),
262 Self::Frame(frame) => frame.into_payload(),
263 }
264 }
265
266 pub fn into_text(self) -> Result<Utf8Bytes, ProtocolError> {
268 match self {
269 Self::Text(txt) => Ok(txt),
270 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(data.try_into()?),
271 Self::Close(None) => Ok(<_>::default()),
272 Self::Close(Some(frame)) => Ok(frame.reason),
273 Self::Frame(frame) => Ok(frame.into_text()?),
274 }
275 }
276
277 pub fn to_text(&self) -> Result<&str, ProtocolError> {
280 match *self {
281 Self::Text(ref string) => Ok(string.as_str()),
282 Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
283 Ok(str::from_utf8(data)?)
284 }
285 Self::Close(None) => Ok(""),
286 Self::Close(Some(ref frame)) => Ok(&frame.reason),
287 Self::Frame(ref frame) => Ok(frame.to_text()?),
288 }
289 }
290}
291
292impl From<String> for Message {
293 #[inline]
294 fn from(string: String) -> Self {
295 Self::text(string)
296 }
297}
298
299impl<'s> From<&'s str> for Message {
300 #[inline]
301 fn from(string: &'s str) -> Self {
302 Self::text(string)
303 }
304}
305
306impl<'b> From<&'b [u8]> for Message {
307 #[inline]
308 fn from(data: &'b [u8]) -> Self {
309 Self::binary(Bytes::copy_from_slice(data))
310 }
311}
312
313impl From<Bytes> for Message {
314 fn from(data: Bytes) -> Self {
315 Self::binary(data)
316 }
317}
318
319impl From<Vec<u8>> for Message {
320 #[inline]
321 fn from(data: Vec<u8>) -> Self {
322 Self::binary(data)
323 }
324}
325
326impl From<Message> for Bytes {
327 #[inline]
328 fn from(message: Message) -> Self {
329 message.into_data()
330 }
331}
332
333impl fmt::Display for Message {
334 fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
335 match self {
336 Self::Text(utf8_bytes) => write!(f, "Message::Text({utf8_bytes})"),
337 Self::Binary(bytes) => write!(f, "Message::Binary({bytes:x})"),
338 Self::Ping(bytes) => write!(f, "Message::Ping({bytes:x})"),
339 Self::Pong(bytes) => write!(f, "Message::Pong({bytes:x})"),
340 Self::Close(_) => write!(f, "Message::Close<length={}>", self.len()),
341 Self::Frame(_) => write!(f, "Message::Frame<length={}>", self.len()),
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn display() {
352 let t = Message::text("test".to_owned());
353 assert_eq!(t.to_string(), "Message::Text(test)".to_owned());
354
355 let bin = Message::binary(vec![0, 1, 3, 4, 241]);
356 assert_eq!(bin.to_string(), "Message::Binary(00010304f1)".to_owned());
357 }
358
359 #[test]
360 fn binary_convert() {
361 let bin = [6u8, 7, 8, 9, 10, 241];
362 let msg = Message::from(&bin[..]);
363 assert!(msg.is_binary());
364 assert!(msg.into_text().is_err());
365 }
366
367 #[test]
368 fn binary_convert_bytes() {
369 let bin = Bytes::from_iter([6u8, 7, 8, 9, 10, 241]);
370 let msg = Message::from(bin);
371 assert!(msg.is_binary());
372 assert!(msg.into_text().is_err());
373 }
374
375 #[test]
376 fn binary_convert_vec() {
377 let bin = vec![6u8, 7, 8, 9, 10, 241];
378 let msg = Message::from(bin);
379 assert!(msg.is_binary());
380 assert!(msg.into_text().is_err());
381 }
382
383 #[test]
384 fn binary_convert_into_bytes() {
385 let bin = vec![6u8, 7, 8, 9, 10, 241];
386 let bin_copy = bin.clone();
387 let msg = Message::from(bin);
388 let serialized: Bytes = msg.into();
389 assert_eq!(bin_copy, serialized);
390 }
391
392 #[test]
393 fn text_convert() {
394 let s = "kiwotsukete";
395 let msg = Message::from(s);
396 assert!(msg.is_text());
397 }
398}