1use crate::error::{CloseCode, Error, ProtocolError, Result};
7use crate::frame::{Frame, FrameKind};
8use crate::protocol::Opcode;
9use bytes::{Bytes, BytesMut};
10use std::fmt;
11
12#[derive(Debug, Clone)]
14pub enum Message {
15 Text(TextMessage),
17 Binary(BinaryMessage),
19 Ping(PingMessage),
21 Pong(PongMessage),
23 Close(CloseMessage),
25}
26
27impl Message {
28 pub fn text(text: impl Into<String>) -> Self {
30 Self::Text(TextMessage::new(text))
31 }
32
33 pub fn binary(data: impl Into<Bytes>) -> Self {
35 Self::Binary(BinaryMessage::new(data))
36 }
37
38 pub fn ping(data: Option<Vec<u8>>) -> Self {
40 Self::Ping(PingMessage::new(data))
41 }
42
43 pub fn pong(data: Option<Vec<u8>>) -> Self {
45 Self::Pong(PongMessage::new(data))
46 }
47
48 pub fn close(code: Option<u16>, reason: Option<String>) -> Self {
50 Self::Close(CloseMessage::new(code, reason))
51 }
52
53 pub fn kind(&self) -> MessageKind {
55 match self {
56 Message::Text(_) => MessageKind::Text,
57 Message::Binary(_) => MessageKind::Binary,
58 Message::Ping(_) => MessageKind::Ping,
59 Message::Pong(_) => MessageKind::Pong,
60 Message::Close(_) => MessageKind::Close,
61 }
62 }
63
64 pub fn is_control(&self) -> bool {
66 matches!(
67 self,
68 Message::Ping(_) | Message::Pong(_) | Message::Close(_)
69 )
70 }
71
72 pub fn is_data(&self) -> bool {
74 matches!(self, Message::Text(_) | Message::Binary(_))
75 }
76
77 pub fn as_text(&self) -> Option<&str> {
79 match self {
80 Message::Text(msg) => Some(msg.as_str()),
81 _ => None,
82 }
83 }
84
85 pub fn as_bytes(&self) -> &[u8] {
87 match self {
88 Message::Text(msg) => msg.as_bytes(),
89 Message::Binary(msg) => msg.as_bytes(),
90 Message::Ping(msg) => msg.as_bytes(),
91 Message::Pong(msg) => msg.as_bytes(),
92 Message::Close(msg) => msg.as_bytes(),
93 }
94 }
95
96 pub fn to_frames(&self) -> Vec<Frame> {
98 match self {
99 Message::Text(msg) => vec![msg.to_frame()],
100 Message::Binary(msg) => vec![msg.to_frame()],
101 Message::Ping(msg) => vec![msg.to_frame()],
102 Message::Pong(msg) => vec![msg.to_frame()],
103 Message::Close(msg) => vec![msg.to_frame()],
104 }
105 }
106
107 pub fn to_frame(&self) -> Frame {
109 match self {
110 Message::Text(msg) => msg.to_frame(),
111 Message::Binary(msg) => msg.to_frame(),
112 Message::Ping(msg) => msg.to_frame(),
113 Message::Pong(msg) => msg.to_frame(),
114 Message::Close(msg) => msg.to_frame(),
115 }
116 }
117}
118
119impl fmt::Display for Message {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 match self {
122 Message::Text(msg) => write!(f, "Text({})", msg.as_str()),
123 Message::Binary(msg) => write!(f, "Binary({} bytes)", msg.len()),
124 Message::Ping(msg) => write!(f, "Ping({} bytes)", msg.len()),
125 Message::Pong(msg) => write!(f, "Pong({} bytes)", msg.len()),
126 Message::Close(msg) => write!(f, "Close({:?})", msg),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum MessageKind {
134 Text,
136 Binary,
138 Ping,
140 Pong,
142 Close,
144}
145
146#[derive(Debug, Clone)]
148pub struct TextMessage {
149 text: String,
150}
151
152impl TextMessage {
153 pub fn new(text: impl Into<String>) -> Self {
155 Self { text: text.into() }
156 }
157
158 pub fn as_str(&self) -> &str {
160 &self.text
161 }
162
163 pub fn as_bytes(&self) -> &[u8] {
165 self.text.as_bytes()
166 }
167
168 pub fn len(&self) -> usize {
170 self.text.len()
171 }
172
173 pub fn is_empty(&self) -> bool {
175 self.text.is_empty()
176 }
177
178 pub fn to_frame(&self) -> Frame {
180 Frame::text(self.text.clone())
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct BinaryMessage {
187 data: Bytes,
188}
189
190impl BinaryMessage {
191 pub fn new(data: impl Into<Bytes>) -> Self {
193 Self { data: data.into() }
194 }
195
196 pub fn as_bytes(&self) -> &[u8] {
198 &self.data
199 }
200
201 pub fn len(&self) -> usize {
203 self.data.len()
204 }
205
206 pub fn is_empty(&self) -> bool {
208 self.data.is_empty()
209 }
210
211 pub fn to_frame(&self) -> Frame {
213 Frame::binary(self.data.clone())
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct PingMessage {
220 data: Bytes,
221}
222
223impl PingMessage {
224 pub fn new(data: Option<Vec<u8>>) -> Self {
226 Self {
227 data: data.map_or_else(Bytes::new, Bytes::from),
228 }
229 }
230
231 pub fn as_bytes(&self) -> &[u8] {
233 &self.data
234 }
235
236 pub fn len(&self) -> usize {
238 self.data.len()
239 }
240
241 pub fn is_empty(&self) -> bool {
243 self.data.is_empty()
244 }
245
246 pub fn to_frame(&self) -> Frame {
248 Frame::ping(self.data.clone())
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct PongMessage {
255 data: Bytes,
256}
257
258impl PongMessage {
259 pub fn new(data: Option<Vec<u8>>) -> Self {
261 Self {
262 data: data.map_or_else(Bytes::new, Bytes::from),
263 }
264 }
265
266 pub fn as_bytes(&self) -> &[u8] {
268 &self.data
269 }
270
271 pub fn len(&self) -> usize {
273 self.data.len()
274 }
275
276 pub fn is_empty(&self) -> bool {
278 self.data.is_empty()
279 }
280
281 pub fn to_frame(&self) -> Frame {
283 Frame::pong(self.data.clone())
284 }
285}
286
287#[derive(Debug, Clone)]
289pub struct CloseMessage {
290 code: Option<u16>,
291 reason: String,
292}
293
294impl CloseMessage {
295 pub fn new(code: Option<u16>, reason: Option<String>) -> Self {
297 Self {
298 code,
299 reason: reason.unwrap_or_default(),
300 }
301 }
302
303 pub fn code(&self) -> Option<u16> {
305 self.code
306 }
307
308 pub fn reason(&self) -> &str {
310 &self.reason
311 }
312
313 pub fn close_code(&self) -> Option<CloseCode> {
315 self.code.map(CloseCode::from)
316 }
317
318 pub fn as_bytes(&self) -> &[u8] {
320 self.reason.as_bytes()
321 }
322
323 pub fn len(&self) -> usize {
325 let mut len = if self.code.is_some() { 2 } else { 0 };
326 len += self.reason.len();
327 len
328 }
329
330 pub fn is_empty(&self) -> bool {
332 self.code.is_none() && self.reason.is_empty()
333 }
334
335 pub fn to_frame(&self) -> Frame {
337 Frame::close(
338 self.code,
339 if self.reason.is_empty() {
340 None
341 } else {
342 Some(&self.reason)
343 },
344 )
345 }
346}
347
348#[derive(Debug, Default)]
350pub struct MessageAssembler {
351 buffer: BytesMut,
353 opcode: Option<Opcode>,
355 assembling: bool,
357}
358
359impl MessageAssembler {
360 pub fn new() -> Self {
362 Self::default()
363 }
364
365 pub fn feed_frame(&mut self, frame: Frame) -> Result<Option<Message>> {
367 if frame.is_control() {
368 return Ok(Some(self.control_frame_to_message(frame)?));
370 }
371
372 if !frame.fin {
373 if !self.assembling {
375 self.assembling = true;
377 self.opcode = Some(frame.opcode);
378 self.buffer.extend_from_slice(&frame.payload);
379 Ok(None)
380 } else {
381 if frame.opcode != Opcode::Continuation {
383 return Err(Error::Protocol(ProtocolError::InvalidFrame(
384 "Expected continuation frame in fragmented message".to_string(),
385 )));
386 }
387 self.buffer.extend_from_slice(&frame.payload);
388 Ok(None)
389 }
390 } else {
391 if self.assembling {
393 self.buffer.extend_from_slice(&frame.payload);
395 let message = self.assemble_complete_message()?;
396 self.reset();
397 Ok(Some(message))
398 } else {
399 self.opcode = Some(frame.opcode);
401 self.buffer = BytesMut::from(&frame.payload[..]);
402 let message = self.assemble_complete_message()?;
403 self.reset();
404 Ok(Some(message))
405 }
406 }
407 }
408
409 fn control_frame_to_message(&self, frame: Frame) -> Result<Message> {
411 match frame.kind() {
412 FrameKind::Ping => Ok(Message::ping(Some(frame.payload.to_vec()))),
413 FrameKind::Pong => Ok(Message::pong(Some(frame.payload.to_vec()))),
414 FrameKind::Close => {
415 let (code, reason) = self.parse_close_payload(&frame.payload)?;
416 Ok(Message::close(code, reason))
417 }
418 _ => Err(Error::Protocol(ProtocolError::InvalidFrame(
419 "Unexpected control frame type".to_string(),
420 ))),
421 }
422 }
423
424 fn parse_close_payload(&self, payload: &[u8]) -> Result<(Option<u16>, Option<String>)> {
426 if payload.len() < 2 {
427 return Ok((None, None));
428 }
429
430 let code = u16::from_be_bytes([payload[0], payload[1]]);
431 let reason = if payload.len() > 2 {
432 String::from_utf8_lossy(&payload[2..]).to_string()
433 } else {
434 String::new()
435 };
436
437 Ok((Some(code), Some(reason)))
438 }
439
440 fn assemble_complete_message(&self) -> Result<Message> {
442 let opcode = self.opcode.ok_or_else(|| {
443 Error::Protocol(ProtocolError::InvalidFrame("No opcode set".to_string()))
444 })?;
445
446 match opcode {
447 Opcode::Text => {
448 let text =
449 String::from_utf8(self.buffer.to_vec()).map_err(|_| Error::InvalidUtf8)?;
450 Ok(Message::text(text))
451 }
452 Opcode::Binary => Ok(Message::binary(self.buffer.clone().freeze())),
453 _ => Err(Error::Protocol(ProtocolError::InvalidFrame(
454 "Unexpected opcode for data frame".to_string(),
455 ))),
456 }
457 }
458
459 fn reset(&mut self) {
461 self.buffer.clear();
462 self.opcode = None;
463 self.assembling = false;
464 }
465
466 pub fn is_assembling(&self) -> bool {
468 self.assembling
469 }
470
471 pub fn buffered_bytes(&self) -> usize {
473 self.buffer.len()
474 }
475
476 pub fn clear(&mut self) {
478 self.reset();
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::protocol::Opcode;
486
487 #[test]
488 fn test_text_message() {
489 let msg = Message::text("hello");
490 assert_eq!(msg.kind(), MessageKind::Text);
491 assert_eq!(msg.as_text(), Some("hello"));
492 assert!(msg.is_data());
493 assert!(!msg.is_control());
494 }
495
496 #[test]
497 fn test_binary_message() {
498 let data = vec![1, 2, 3, 4];
499 let msg = Message::binary(data.clone());
500 assert_eq!(msg.kind(), MessageKind::Binary);
501 assert_eq!(msg.as_bytes(), &data[..]);
502 assert!(msg.is_data());
503 assert!(!msg.is_control());
504 }
505
506 #[test]
507 fn test_control_messages() {
508 let ping = Message::ping(Some(vec![1, 2, 3]));
509 let pong = Message::pong(Some(vec![4, 5, 6]));
510 let close = Message::close(Some(1000), Some("Goodbye".to_string()));
511
512 assert!(ping.is_control());
513 assert!(pong.is_control());
514 assert!(close.is_control());
515 }
516
517 #[test]
518 fn test_close_message() {
519 let msg = Message::close(Some(1000), Some("Goodbye".to_string()));
520 if let Message::Close(close_msg) = msg {
521 assert_eq!(close_msg.code(), Some(1000));
522 assert_eq!(close_msg.reason(), "Goodbye");
523 } else {
524 panic!("Expected close message");
525 }
526 }
527
528 #[test]
529 fn test_message_assembler() {
530 let mut assembler = MessageAssembler::new();
531
532 let frame1 = Frame::new(Opcode::Text, "Hello, ").fin(false);
534 let frame2 = Frame::new(Opcode::Continuation, "world!").fin(true);
535
536 let msg1 = assembler.feed_frame(frame1).unwrap();
537 assert!(msg1.is_none()); assert!(assembler.is_assembling());
539
540 let msg2 = assembler.feed_frame(frame2).unwrap();
541 assert!(msg2.is_some()); assert!(!assembler.is_assembling());
543
544 if let Some(Message::Text(text_msg)) = msg2 {
545 assert_eq!(text_msg.as_str(), "Hello, world!");
546 } else {
547 panic!("Expected text message");
548 }
549 }
550
551 #[test]
552 fn test_message_display() {
553 let text_msg = Message::text("hello");
554 let binary_msg = Message::binary(vec![1, 2, 3]);
555 let ping_msg = Message::ping(Some(vec![4, 5]));
556
557 assert_eq!(text_msg.to_string(), "Text(hello)");
558 assert_eq!(binary_msg.to_string(), "Binary(3 bytes)");
559 assert_eq!(ping_msg.to_string(), "Ping(2 bytes)");
560 }
561}