use anyhow::{anyhow, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use super::frame::Frame;
#[derive(Debug, PartialEq, Clone, Copy, Eq, Hash)]
pub enum MessageType {
Auth = 0,
Message = 1,
Subscribe = 2,
Response = 3,
}
impl TryFrom<u8> for MessageType {
type Error = anyhow::Error;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(MessageType::Auth),
1 => Ok(MessageType::Message),
2 => Ok(MessageType::Subscribe),
3 => Ok(MessageType::Response),
_ => Err(anyhow!("Invalid message type: {}", value)),
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct Message {
pub message_type: MessageType,
pub topics: Vec<String>,
pub data: Vec<u8>,
}
impl Message {
pub fn new(message_type: MessageType, topics: Vec<String>, data: Vec<u8>) -> Self {
Message {
message_type,
topics,
data,
}
}
pub fn from_packet(packet: Bytes) -> Result<Self> {
Self::deserialize(packet)
}
pub fn serialize(&self) -> Result<Frame> {
let mut buffer = BytesMut::with_capacity(
2 + self.topics.iter().map(|t| 1 + t.len()).sum::<usize>() + 4 + self.data.len(),
);
buffer.put_u8(self.message_type as u8);
buffer.put_u8(self.topics.len() as u8);
for topic in self.topics.iter() {
buffer.put_u8(topic.len() as u8);
buffer.extend_from_slice(topic.as_bytes());
}
buffer.put_u32(self.data.len() as u32);
buffer.extend_from_slice(&self.data);
Frame::new(buffer.freeze())
}
pub fn deserialize(mut buffer: Bytes) -> Result<Self> {
if buffer.remaining() < 1 {
return Err(anyhow!("Incomplete message"));
}
let message_type = MessageType::try_from(buffer.get_u8())?;
let topics_count = buffer.get_u8() as usize;
let mut topics = Vec::with_capacity(topics_count);
for _ in 0..topics_count {
let topic_length = buffer.get_u8() as usize;
if buffer.remaining() < topic_length {
return Err(anyhow!("Incomplete message"));
}
let topic = buffer.split_to(topic_length);
topics.push(String::from_utf8(topic.to_vec())?);
}
if buffer.remaining() < 4 {
return Err(anyhow!("Incomplete message"));
}
let data_length = buffer.get_u32() as usize;
if buffer.remaining() < data_length {
return Err(anyhow!("Incomplete message"));
}
let data = buffer.split_to(data_length).to_vec();
Ok(Message {
message_type,
topics,
data,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_new_message() {
let topics = vec!["topic1".to_string(), "topic2".to_string()];
let data = vec![1, 2, 3, 4];
let message = Message::new(MessageType::Message, topics.clone(), data.clone());
assert_eq!(message.message_type, MessageType::Message);
assert_eq!(message.topics, topics);
assert_eq!(message.data, data);
}
#[test]
fn test_serialize_deserialize() {
let topics = vec!["topic1".to_string(), "topic2".to_string()];
let data = vec![1, 2, 3, 4];
let message = Message::new(MessageType::Subscribe, topics, data);
let serialized = message.serialize().unwrap();
let deserialized = Message::deserialize(serialized.payload).unwrap();
assert_eq!(message, deserialized);
}
#[test]
fn test_serialize_empty_message() {
let message = Message::new(MessageType::Auth, vec![], vec![]);
let serialized = message.serialize().unwrap();
assert_eq!(serialized.payload, Bytes::from_static(&[0, 0, 0, 0, 0, 0]));
}
#[test]
fn test_deserialize_empty_message() {
let bytes = Bytes::from_static(&[0, 0, 0, 0, 0, 0]);
let deserialized = Message::deserialize(bytes).unwrap();
assert_eq!(deserialized.message_type, MessageType::Auth);
assert!(deserialized.topics.is_empty());
assert!(deserialized.data.is_empty());
}
#[test]
fn test_serialize_large_message() {
let topics = vec!["topic1".to_string(); 255];
let data = vec![1; 1000];
let message = Message::new(MessageType::Message, topics, data);
let serialized = message.serialize().unwrap();
let deserialized = Message::deserialize(serialized.payload).unwrap();
assert_eq!(message, deserialized);
}
#[test]
fn test_deserialize_invalid_utf8() {
let invalid_utf8 = vec![0xFF, 0xFE, 0xFD];
let mut buffer = BytesMut::new();
buffer.put_u8(MessageType::Message as u8);
buffer.put_u8(1); buffer.put_u8(invalid_utf8.len() as u8);
buffer.extend_from_slice(&invalid_utf8);
buffer.put_u32(0);
let result = Message::deserialize(buffer.freeze());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("invalid utf-8"));
}
#[test]
fn test_deserialize_incomplete_message() {
let incomplete_data = vec![1, 1, 3, b'f', b'o', b'o']; let result = Message::deserialize(Bytes::from(incomplete_data));
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Incomplete message"));
}
#[test]
fn test_invalid_message_type() {
let invalid_type = vec![255, 0, 0, 0, 0, 0]; let result = Message::deserialize(Bytes::from(invalid_type));
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid message type"));
}
}