aerosocket_core/
message.rs

1//! Message handling for AeroSocket
2//!
3//! This module provides high-level message types and handling for WebSocket messages,
4//! including support for text, binary, ping, pong, and close messages.
5
6use crate::error::{CloseCode, Error, ProtocolError, Result};
7use crate::frame::{Frame, FrameKind};
8use crate::protocol::Opcode;
9use bytes::{Bytes, BytesMut};
10use std::fmt;
11
12/// Represents a complete WebSocket message
13#[derive(Debug, Clone)]
14pub enum Message {
15    /// Text message
16    Text(TextMessage),
17    /// Binary message
18    Binary(BinaryMessage),
19    /// Ping message
20    Ping(PingMessage),
21    /// Pong message
22    Pong(PongMessage),
23    /// Close message
24    Close(CloseMessage),
25}
26
27impl Message {
28    /// Create a text message
29    pub fn text(text: impl Into<String>) -> Self {
30        Self::Text(TextMessage::new(text))
31    }
32
33    /// Create a binary message
34    pub fn binary(data: impl Into<Bytes>) -> Self {
35        Self::Binary(BinaryMessage::new(data))
36    }
37
38    /// Create a ping message
39    pub fn ping(data: Option<Vec<u8>>) -> Self {
40        Self::Ping(PingMessage::new(data))
41    }
42
43    /// Create a pong message
44    pub fn pong(data: Option<Vec<u8>>) -> Self {
45        Self::Pong(PongMessage::new(data))
46    }
47
48    /// Create a close message
49    pub fn close(code: Option<u16>, reason: Option<String>) -> Self {
50        Self::Close(CloseMessage::new(code, reason))
51    }
52
53    /// Get the message kind
54    pub fn kind(&self) -> MessageKind {
55        match self {
56            Message::Text(_) => MessageKind::Text,
57            Message::Binary(_) => MessageKind::Binary,
58            Message::Ping(_) => MessageKind::Ping,
59            Message::Pong(_) => MessageKind::Pong,
60            Message::Close(_) => MessageKind::Close,
61        }
62    }
63
64    /// Check if this is a control message
65    pub fn is_control(&self) -> bool {
66        matches!(
67            self,
68            Message::Ping(_) | Message::Pong(_) | Message::Close(_)
69        )
70    }
71
72    /// Check if this is a data message
73    pub fn is_data(&self) -> bool {
74        matches!(self, Message::Text(_) | Message::Binary(_))
75    }
76
77    /// Get the message payload as text
78    pub fn as_text(&self) -> Option<&str> {
79        match self {
80            Message::Text(msg) => Some(msg.as_str()),
81            _ => None,
82        }
83    }
84
85    /// Get the message payload as bytes
86    pub fn as_bytes(&self) -> &[u8] {
87        match self {
88            Message::Text(msg) => msg.as_bytes(),
89            Message::Binary(msg) => msg.as_bytes(),
90            Message::Ping(msg) => msg.as_bytes(),
91            Message::Pong(msg) => msg.as_bytes(),
92            Message::Close(msg) => msg.as_bytes(),
93        }
94    }
95
96    /// Convert message to frames
97    pub fn to_frames(&self) -> Vec<Frame> {
98        match self {
99            Message::Text(msg) => vec![msg.to_frame()],
100            Message::Binary(msg) => vec![msg.to_frame()],
101            Message::Ping(msg) => vec![msg.to_frame()],
102            Message::Pong(msg) => vec![msg.to_frame()],
103            Message::Close(msg) => vec![msg.to_frame()],
104        }
105    }
106
107    /// Convert message to a single frame
108    pub fn to_frame(&self) -> Frame {
109        match self {
110            Message::Text(msg) => msg.to_frame(),
111            Message::Binary(msg) => msg.to_frame(),
112            Message::Ping(msg) => msg.to_frame(),
113            Message::Pong(msg) => msg.to_frame(),
114            Message::Close(msg) => msg.to_frame(),
115        }
116    }
117}
118
119impl fmt::Display for Message {
120    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121        match self {
122            Message::Text(msg) => write!(f, "Text({})", msg.as_str()),
123            Message::Binary(msg) => write!(f, "Binary({} bytes)", msg.len()),
124            Message::Ping(msg) => write!(f, "Ping({} bytes)", msg.len()),
125            Message::Pong(msg) => write!(f, "Pong({} bytes)", msg.len()),
126            Message::Close(msg) => write!(f, "Close({:?})", msg),
127        }
128    }
129}
130
131/// Message kind for easier matching
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum MessageKind {
134    /// Text message
135    Text,
136    /// Binary message
137    Binary,
138    /// Ping message
139    Ping,
140    /// Pong message
141    Pong,
142    /// Close message
143    Close,
144}
145
146/// Text message
147#[derive(Debug, Clone)]
148pub struct TextMessage {
149    text: String,
150}
151
152impl TextMessage {
153    /// Create a new text message
154    pub fn new(text: impl Into<String>) -> Self {
155        Self { text: text.into() }
156    }
157
158    /// Get the text content
159    pub fn as_str(&self) -> &str {
160        &self.text
161    }
162
163    /// Get the text as bytes
164    pub fn as_bytes(&self) -> &[u8] {
165        self.text.as_bytes()
166    }
167
168    /// Get the text length
169    pub fn len(&self) -> usize {
170        self.text.len()
171    }
172
173    /// Check if the text is empty
174    pub fn is_empty(&self) -> bool {
175        self.text.is_empty()
176    }
177
178    /// Convert to frame
179    pub fn to_frame(&self) -> Frame {
180        Frame::text(self.text.clone())
181    }
182}
183
184/// Binary message
185#[derive(Debug, Clone)]
186pub struct BinaryMessage {
187    data: Bytes,
188}
189
190impl BinaryMessage {
191    /// Create a new binary message
192    pub fn new(data: impl Into<Bytes>) -> Self {
193        Self { data: data.into() }
194    }
195
196    /// Get the binary data
197    pub fn as_bytes(&self) -> &[u8] {
198        &self.data
199    }
200
201    /// Get the data length
202    pub fn len(&self) -> usize {
203        self.data.len()
204    }
205
206    /// Check if the data is empty
207    pub fn is_empty(&self) -> bool {
208        self.data.is_empty()
209    }
210
211    /// Convert to frame
212    pub fn to_frame(&self) -> Frame {
213        Frame::binary(self.data.clone())
214    }
215}
216
217/// Ping message
218#[derive(Debug, Clone)]
219pub struct PingMessage {
220    data: Bytes,
221}
222
223impl PingMessage {
224    /// Create a new ping message
225    pub fn new(data: Option<Vec<u8>>) -> Self {
226        Self {
227            data: data.map_or_else(Bytes::new, Bytes::from),
228        }
229    }
230
231    /// Get the ping data
232    pub fn as_bytes(&self) -> &[u8] {
233        &self.data
234    }
235
236    /// Get the data length
237    pub fn len(&self) -> usize {
238        self.data.len()
239    }
240
241    /// Check if the data is empty
242    pub fn is_empty(&self) -> bool {
243        self.data.is_empty()
244    }
245
246    /// Convert to frame
247    pub fn to_frame(&self) -> Frame {
248        Frame::ping(self.data.clone())
249    }
250}
251
252/// Pong message
253#[derive(Debug, Clone)]
254pub struct PongMessage {
255    data: Bytes,
256}
257
258impl PongMessage {
259    /// Create a new pong message
260    pub fn new(data: Option<Vec<u8>>) -> Self {
261        Self {
262            data: data.map_or_else(Bytes::new, Bytes::from),
263        }
264    }
265
266    /// Get the pong data
267    pub fn as_bytes(&self) -> &[u8] {
268        &self.data
269    }
270
271    /// Get the data length
272    pub fn len(&self) -> usize {
273        self.data.len()
274    }
275
276    /// Check if the data is empty
277    pub fn is_empty(&self) -> bool {
278        self.data.is_empty()
279    }
280
281    /// Convert to frame
282    pub fn to_frame(&self) -> Frame {
283        Frame::pong(self.data.clone())
284    }
285}
286
287/// Close message
288#[derive(Debug, Clone)]
289pub struct CloseMessage {
290    code: Option<u16>,
291    reason: String,
292}
293
294impl CloseMessage {
295    /// Create a new close message
296    pub fn new(code: Option<u16>, reason: Option<String>) -> Self {
297        Self {
298            code,
299            reason: reason.unwrap_or_default(),
300        }
301    }
302
303    /// Get the close code
304    pub fn code(&self) -> Option<u16> {
305        self.code
306    }
307
308    /// Get the close reason
309    pub fn reason(&self) -> &str {
310        &self.reason
311    }
312
313    /// Get the close code as CloseCode enum
314    pub fn close_code(&self) -> Option<CloseCode> {
315        self.code.map(CloseCode::from)
316    }
317
318    /// Get the message as bytes
319    pub fn as_bytes(&self) -> &[u8] {
320        self.reason.as_bytes()
321    }
322
323    /// Get the total payload length
324    pub fn len(&self) -> usize {
325        let mut len = if self.code.is_some() { 2 } else { 0 };
326        len += self.reason.len();
327        len
328    }
329
330    /// Check if the message is empty
331    pub fn is_empty(&self) -> bool {
332        self.code.is_none() && self.reason.is_empty()
333    }
334
335    /// Convert to frame
336    pub fn to_frame(&self) -> Frame {
337        Frame::close(
338            self.code,
339            if self.reason.is_empty() {
340                None
341            } else {
342                Some(&self.reason)
343            },
344        )
345    }
346}
347
348/// Message assembler for fragmented messages
349#[derive(Debug, Default)]
350pub struct MessageAssembler {
351    /// Buffer for assembling fragmented messages
352    buffer: BytesMut,
353    /// Expected opcode for the message being assembled
354    opcode: Option<Opcode>,
355    /// Whether we're currently assembling a message
356    assembling: bool,
357}
358
359impl MessageAssembler {
360    /// Create a new message assembler
361    pub fn new() -> Self {
362        Self::default()
363    }
364
365    /// Feed a frame and try to assemble a complete message
366    pub fn feed_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
367        if frame.is_control() {
368            // Control frames are never fragmented
369            return Ok(Some(self.control_frame_to_message(frame)?));
370        }
371
372        if !frame.fin {
373            // Fragmented frame
374            if !self.assembling {
375                // Start of fragmented message
376                self.assembling = true;
377                self.opcode = Some(frame.opcode);
378                self.buffer.extend_from_slice(&frame.payload);
379                Ok(None)
380            } else {
381                // Continuation of fragmented message
382                if frame.opcode != Opcode::Continuation {
383                    return Err(Error::Protocol(ProtocolError::InvalidFrame(
384                        "Expected continuation frame in fragmented message".to_string(),
385                    )));
386                }
387                self.buffer.extend_from_slice(&frame.payload);
388                Ok(None)
389            }
390        } else {
391            // Final frame
392            if self.assembling {
393                // End of fragmented message
394                self.buffer.extend_from_slice(&frame.payload);
395                let message = self.assemble_complete_message()?;
396                self.reset();
397                Ok(Some(message))
398            } else {
399                // Single unfragmented frame
400                self.opcode = Some(frame.opcode);
401                self.buffer = BytesMut::from(&frame.payload[..]);
402                let message = self.assemble_complete_message()?;
403                self.reset();
404                Ok(Some(message))
405            }
406        }
407    }
408
409    /// Convert a control frame to a message
410    fn control_frame_to_message(&self, frame: Frame) -> Result<Message> {
411        match frame.kind() {
412            FrameKind::Ping => Ok(Message::ping(Some(frame.payload.to_vec()))),
413            FrameKind::Pong => Ok(Message::pong(Some(frame.payload.to_vec()))),
414            FrameKind::Close => {
415                let (code, reason) = self.parse_close_payload(&frame.payload)?;
416                Ok(Message::close(code, reason))
417            }
418            _ => Err(Error::Protocol(ProtocolError::InvalidFrame(
419                "Unexpected control frame type".to_string(),
420            ))),
421        }
422    }
423
424    /// Parse close frame payload
425    fn parse_close_payload(&self, payload: &[u8]) -> Result<(Option<u16>, Option<String>)> {
426        if payload.len() < 2 {
427            return Ok((None, None));
428        }
429
430        let code = u16::from_be_bytes([payload[0], payload[1]]);
431        let reason = if payload.len() > 2 {
432            String::from_utf8_lossy(&payload[2..]).to_string()
433        } else {
434            String::new()
435        };
436
437        Ok((Some(code), Some(reason)))
438    }
439
440    /// Assemble a complete message from the buffer
441    fn assemble_complete_message(&self) -> Result<Message> {
442        let opcode = self.opcode.ok_or_else(|| {
443            Error::Protocol(ProtocolError::InvalidFrame("No opcode set".to_string()))
444        })?;
445
446        match opcode {
447            Opcode::Text => {
448                let text =
449                    String::from_utf8(self.buffer.to_vec()).map_err(|_| Error::InvalidUtf8)?;
450                Ok(Message::text(text))
451            }
452            Opcode::Binary => Ok(Message::binary(self.buffer.clone().freeze())),
453            _ => Err(Error::Protocol(ProtocolError::InvalidFrame(
454                "Unexpected opcode for data frame".to_string(),
455            ))),
456        }
457    }
458
459    /// Reset the assembler state
460    fn reset(&mut self) {
461        self.buffer.clear();
462        self.opcode = None;
463        self.assembling = false;
464    }
465
466    /// Check if currently assembling a message
467    pub fn is_assembling(&self) -> bool {
468        self.assembling
469    }
470
471    /// Get the number of bytes currently buffered
472    pub fn buffered_bytes(&self) -> usize {
473        self.buffer.len()
474    }
475
476    /// Clear the assembler
477    pub fn clear(&mut self) {
478        self.reset();
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use crate::protocol::Opcode;
486
487    #[test]
488    fn test_text_message() {
489        let msg = Message::text("hello");
490        assert_eq!(msg.kind(), MessageKind::Text);
491        assert_eq!(msg.as_text(), Some("hello"));
492        assert!(msg.is_data());
493        assert!(!msg.is_control());
494    }
495
496    #[test]
497    fn test_binary_message() {
498        let data = vec![1, 2, 3, 4];
499        let msg = Message::binary(data.clone());
500        assert_eq!(msg.kind(), MessageKind::Binary);
501        assert_eq!(msg.as_bytes(), &data[..]);
502        assert!(msg.is_data());
503        assert!(!msg.is_control());
504    }
505
506    #[test]
507    fn test_control_messages() {
508        let ping = Message::ping(Some(vec![1, 2, 3]));
509        let pong = Message::pong(Some(vec![4, 5, 6]));
510        let close = Message::close(Some(1000), Some("Goodbye".to_string()));
511
512        assert!(ping.is_control());
513        assert!(pong.is_control());
514        assert!(close.is_control());
515    }
516
517    #[test]
518    fn test_close_message() {
519        let msg = Message::close(Some(1000), Some("Goodbye".to_string()));
520        if let Message::Close(close_msg) = msg {
521            assert_eq!(close_msg.code(), Some(1000));
522            assert_eq!(close_msg.reason(), "Goodbye");
523        } else {
524            panic!("Expected close message");
525        }
526    }
527
528    #[test]
529    fn test_message_assembler() {
530        let mut assembler = MessageAssembler::new();
531
532        // Feed fragmented text frames
533        let frame1 = Frame::new(Opcode::Text, "Hello, ").fin(false);
534        let frame2 = Frame::new(Opcode::Continuation, "world!").fin(true);
535
536        let msg1 = assembler.feed_frame(frame1).unwrap();
537        assert!(msg1.is_none()); // Not complete yet
538        assert!(assembler.is_assembling());
539
540        let msg2 = assembler.feed_frame(frame2).unwrap();
541        assert!(msg2.is_some()); // Complete message
542        assert!(!assembler.is_assembling());
543
544        if let Some(Message::Text(text_msg)) = msg2 {
545            assert_eq!(text_msg.as_str(), "Hello, world!");
546        } else {
547            panic!("Expected text message");
548        }
549    }
550
551    #[test]
552    fn test_message_display() {
553        let text_msg = Message::text("hello");
554        let binary_msg = Message::binary(vec![1, 2, 3]);
555        let ping_msg = Message::ping(Some(vec![4, 5]));
556
557        assert_eq!(text_msg.to_string(), "Text(hello)");
558        assert_eq!(binary_msg.to_string(), "Binary(3 bytes)");
559        assert_eq!(ping_msg.to_string(), "Ping(2 bytes)");
560    }
561}