rsub 0.1.0

A high-performance message broker with QUIC transport and pub/sub messaging patterns
Documentation
use anyhow::{anyhow, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};

use super::frame::Frame;

#[derive(Debug, PartialEq, Clone, Copy, Eq, Hash)]
pub enum MessageType {
    Auth = 0,
    Message = 1,
    Subscribe = 2,
    Response = 3,
}

impl TryFrom<u8> for MessageType {
    type Error = anyhow::Error;

    fn try_from(value: u8) -> Result<Self, Self::Error> {
        match value {
            0 => Ok(MessageType::Auth),
            1 => Ok(MessageType::Message),
            2 => Ok(MessageType::Subscribe),
            3 => Ok(MessageType::Response),
            _ => Err(anyhow!("Invalid message type: {}", value)),
        }
    }
}

#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub struct Message {
    pub message_type: MessageType,
    pub topics: Vec<String>,
    pub data: Vec<u8>,
}

impl Message {
    pub fn new(message_type: MessageType, topics: Vec<String>, data: Vec<u8>) -> Self {
        Message {
            message_type,
            topics,
            data,
        }
    }
    pub fn from_packet(packet: Bytes) -> Result<Self> {
        Self::deserialize(packet)
    }

    pub fn serialize(&self) -> Result<Frame> {
        // Message format:
        // +-------------------+-------------------+-------------------+-------------------+
        // | Message Type (1B) | Topics count (1B) | Topic 1 length(1B)| Topic 1 string    |
        // +-------------------+-------------------+-------------------+-------------------+
        // | Topic 2 length(1B)| Topic 2 string    | ...               | Data length (4B)  |
        // +-------------------+-------------------+-------------------+-------------------+
        // | Data              |
        // +-------------------+
        //
        // Detailed description:
        // 1. Message Type (1 byte):
        //    - Represents the type of message (Auth=0, Message=1, Request=2, Response=3)
        // 2. Topics count (1 byte):
        //    - Represents the number of topics (0-255)
        // 3. For each topic:
        //    a. Topic length (1 byte):
        //       - Represents the length of the topic string (0-255 bytes)
        //    b. Topic string (variable length):
        //       - UTF-8 encoded string of the topic
        // 4. Data length (4 bytes):
        //    - Represents the length of the data in bytes (0-4,294,967,295 bytes)
        // 5. Data (variable length):
        //    - Arbitrary binary data

        let mut buffer = BytesMut::with_capacity(
            2 + self.topics.iter().map(|t| 1 + t.len()).sum::<usize>() + 4 + self.data.len(),
        );
        buffer.put_u8(self.message_type as u8);
        buffer.put_u8(self.topics.len() as u8);
        for topic in self.topics.iter() {
            buffer.put_u8(topic.len() as u8);
            buffer.extend_from_slice(topic.as_bytes());
        }
        buffer.put_u32(self.data.len() as u32);
        buffer.extend_from_slice(&self.data);
        Frame::new(buffer.freeze())
    }

    pub fn deserialize(mut buffer: Bytes) -> Result<Self> {
        if buffer.remaining() < 1 {
            return Err(anyhow!("Incomplete message"));
        }
        let message_type = MessageType::try_from(buffer.get_u8())?;

        let topics_count = buffer.get_u8() as usize;
        let mut topics = Vec::with_capacity(topics_count);
        for _ in 0..topics_count {
            let topic_length = buffer.get_u8() as usize;
            if buffer.remaining() < topic_length {
                return Err(anyhow!("Incomplete message"));
            }
            let topic = buffer.split_to(topic_length);
            topics.push(String::from_utf8(topic.to_vec())?);
        }
        if buffer.remaining() < 4 {
            return Err(anyhow!("Incomplete message"));
        }
        let data_length = buffer.get_u32() as usize;
        if buffer.remaining() < data_length {
            return Err(anyhow!("Incomplete message"));
        }
        let data = buffer.split_to(data_length).to_vec();
        Ok(Message {
            message_type,
            topics,
            data,
        })
    }
}

#[cfg(test)]
mod tests {

    use super::*;
    use bytes::Bytes;

    #[test]
    fn test_new_message() {
        let topics = vec!["topic1".to_string(), "topic2".to_string()];
        let data = vec![1, 2, 3, 4];
        let message = Message::new(MessageType::Message, topics.clone(), data.clone());

        assert_eq!(message.message_type, MessageType::Message);
        assert_eq!(message.topics, topics);
        assert_eq!(message.data, data);
    }

    #[test]
    fn test_serialize_deserialize() {
        let topics = vec!["topic1".to_string(), "topic2".to_string()];
        let data = vec![1, 2, 3, 4];
        let message = Message::new(MessageType::Subscribe, topics, data);

        let serialized = message.serialize().unwrap();
        let deserialized = Message::deserialize(serialized.payload).unwrap();

        assert_eq!(message, deserialized);
    }

    #[test]
    fn test_serialize_empty_message() {
        let message = Message::new(MessageType::Auth, vec![], vec![]);
        let serialized = message.serialize().unwrap();

        assert_eq!(serialized.payload, Bytes::from_static(&[0, 0, 0, 0, 0, 0]));
    }

    #[test]
    fn test_deserialize_empty_message() {
        let bytes = Bytes::from_static(&[0, 0, 0, 0, 0, 0]);
        let deserialized = Message::deserialize(bytes).unwrap();

        assert_eq!(deserialized.message_type, MessageType::Auth);
        assert!(deserialized.topics.is_empty());
        assert!(deserialized.data.is_empty());
    }

    #[test]
    fn test_serialize_large_message() {
        let topics = vec!["topic1".to_string(); 255];
        let data = vec![1; 1000];
        let message = Message::new(MessageType::Message, topics, data);

        let serialized = message.serialize().unwrap();
        let deserialized = Message::deserialize(serialized.payload).unwrap();

        assert_eq!(message, deserialized);
    }

    #[test]
    fn test_deserialize_invalid_utf8() {
        let invalid_utf8 = vec![0xFF, 0xFE, 0xFD];
        let mut buffer = BytesMut::new();
        buffer.put_u8(MessageType::Message as u8);
        buffer.put_u8(1); // One topic
        buffer.put_u8(invalid_utf8.len() as u8);
        buffer.extend_from_slice(&invalid_utf8);
        buffer.put_u32(0); // Empty data

        let result = Message::deserialize(buffer.freeze());

        assert!(result.is_err());
        assert!(result.unwrap_err().to_string().contains("invalid utf-8"));
    }

    #[test]
    fn test_deserialize_incomplete_message() {
        let incomplete_data = vec![1, 1, 3, b'f', b'o', b'o']; // Missing data length and data
        let result = Message::deserialize(Bytes::from(incomplete_data));

        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .to_string()
            .contains("Incomplete message"));
    }

    #[test]
    fn test_invalid_message_type() {
        let invalid_type = vec![255, 0, 0, 0, 0, 0]; // Invalid message type
        let result = Message::deserialize(Bytes::from(invalid_type));

        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .to_string()
            .contains("Invalid message type"));
    }
}