plabble-codec 0.1.0

Plabble Transport Protocol codec
Documentation
use chacha20::{
    cipher::{KeyIvInit, StreamCipher},
    ChaCha20,
};

use crate::{
    abstractions::{Serializable, SerializationError, SerializationInfo, ID_SIZE},
    codec::{
        common::assert_len,
        objects::BucketId,
        ptp_packet::{PtpHeader, PtpHeaderBase},
    },
};

/// The header of a request packet
///
/// # Fields
///
/// * `type_and_flags` - the type of the packet and the flags
/// * `bucket_id` - the id of the bucket the packet is for (optional)
pub struct RequestHeader {
    type_and_flags: u8,
    pub bucket_id: Option<BucketId>,
}

impl PtpHeaderBase for RequestHeader {
    fn get_type_and_flags(&self) -> u8 {
        self.type_and_flags
    }

    fn set_type_and_flags(&mut self, type_and_flags: u8) {
        self.type_and_flags = type_and_flags;
    }
}

impl PtpHeader for RequestHeader {}

impl RequestHeader {
    /// Create new request header
    ///
    /// # Arguments
    ///
    /// * `packet_type` - the type of the packet
    /// * `bucket_id` - the id of the bucket the packet is for (optional)
    pub fn new(packet_type: u8, bucket_id: Option<BucketId>) -> Self {
        Self {
            type_and_flags: packet_type & 0b0000_1111,
            bucket_id,
        }
    }

    /// Indicates if this packet type needs a bucket id
    pub fn has_bucket_id(&self) -> bool {
        !matches!(self.packet_type(), 0)
    }
}

impl Serializable for RequestHeader {
    fn size(&self) -> usize {
        1 + if self.bucket_id.is_some() { ID_SIZE } else { 0 }
    }

    fn get_bytes(&self) -> Vec<u8> {
        let mut buff = Vec::new();
        buff.push(self.type_and_flags);
        if let Some(id) = &self.bucket_id {
            buff.append(&mut id.get_bytes());
        }

        buff
    }

    fn from_bytes(data: &[u8], info: Option<SerializationInfo>) -> Result<Self, SerializationError>
    where
        Self: Sized,
    {
        // Because no packet is less than 17 bytes (connect has no id, but is greater. all other types have an id)
        assert_len(data, 1 + ID_SIZE)?;
        let mut data = data[..(1 + ID_SIZE)].to_vec();

        // If encryption is used, decrypt it
        if let Some(SerializationInfo::UseEncryption(key0, _, _)) = info {
            let mut cipher = ChaCha20::new(&key0.into(), &[0u8; 12].into());
            cipher.apply_keystream(&mut data);
        };

        let mut header = Self {
            type_and_flags: data[0],
            bucket_id: None,
        };

        if header.has_bucket_id() {
            header.bucket_id = Some(BucketId::from_bytes(&data[1..(1 + ID_SIZE)], None)?);
        }

        Ok(header)
    }
}

#[cfg(test)]
mod test {
    use crate::codec::objects::BucketPermissions;

    use super::*;

    #[test]
    fn can_detect_mac() {
        let h = RequestHeader {
            type_and_flags: 0b0101_0110, // bits are right-to-left, so bit 5 is left of underscore
            bucket_id: None,
        };

        assert_eq!(h.has_mac(), true);
    }

    #[test]
    fn can_detect_type() {
        for i in 0..16 {
            let h = RequestHeader {
                type_and_flags: i + (128 + 64 + 32 + 16),
                bucket_id: None,
            };

            // println!("{:#b}", &h.type_and_flags);
            assert_eq!(h.packet_type(), i);
        }
    }

    #[test]
    fn can_serialize_without_bucket_id() {
        let mut header = RequestHeader::new(0, None);
        header.set_flags((true, false, false));
        header.set_mac(true);
        let serialized = header.get_bytes();
        assert_eq!(serialized.len(), 1);
        assert_eq!("00110000", &format!("{:08b}", serialized[0]));
    }

    #[test]
    fn can_serialize_with_bucket_id() {
        let mut id = BucketId::from_bytes(&[0u8; 16], None).unwrap();
        id.set_lifetime(123);
        id.set_permissions(BucketPermissions {
            pub_read: true,
            pub_write: false,
            pub_append: false,
            priv_write: true,
            priv_append: false,
            delete_bucket: true,
        });
        let mut header = RequestHeader::new(7, Some(id));
        header.set_flags((true, true, false));
        let data = header.get_bytes();
        assert_eq!(
            vec![
                0b0110_0111,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                123,
                0b1010_0100
            ],
            data
        );
    }

    #[test]
    fn can_deserialize_from_longer_slice() {
        let bytes = &[
            1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 123, 123, 123,
        ];
        let header = RequestHeader::from_bytes(bytes, None).unwrap();
        assert_eq!(
            header.bucket_id.as_ref().unwrap().get_bytes(),
            &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0]
        );
        assert_eq!(header.flags(), (false, false, false));
        assert!(!header.has_mac());
    }

    #[test]
    fn can_deserialize_with_id() {
        let bytes = &[
            0b0011_0001,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            9,
            10,
            11,
            12,
            13,
            14,
            15,
            0b0101_1000,
        ];
        let header = RequestHeader::from_bytes(bytes, None).unwrap();
        assert_eq!(header.size(), 17);
        assert_eq!(1, header.packet_type());
        assert!(header.has_mac());
        assert_eq!(header.flags(), (true, false, false));

        let id = header.bucket_id.unwrap();
        assert_eq!(id.lifetime(), 15);
        assert_eq!(
            id.permissions(),
            BucketPermissions {
                pub_read: false,
                pub_write: true,
                pub_append: true,
                priv_write: false,
                priv_append: true,
                delete_bucket: false
            }
        );
    }
}