plabble-codec 0.1.0

Plabble Transport Protocol codec
Documentation
use crate::{
    abstractions::{Serializable, SerializationError, SerializationInfo, KEY_SIZE, MAC_SIZE},
    codec::common::{assert_len, dyn_int},
};

use chacha20::{
    cipher::{KeyIvInit, StreamCipher},
    ChaCha20,
};
use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit};
use poly1305::Poly1305;
use sha2::{Digest, Sha256};

pub trait PtpBody {
    /// Get packet type byte
    fn packet_type(&self) -> u8;
}

pub trait PtpHeaderBase {
    fn get_type_and_flags(&self) -> u8;
    fn set_type_and_flags(&mut self, type_and_flags: u8);
}

pub trait PtpHeader: PtpHeaderBase {
    /// Returns true if the packet has a MAC (according to the flags)
    fn has_mac(&self) -> bool {
        self.get_type_and_flags() & (1 << 4) != 0
    }

    /// Sets the MAC flag
    fn set_mac(&mut self, has_mac: bool) {
        self.set_type_and_flags(if has_mac {
            self.get_type_and_flags() | 1 << 4
        } else {
            self.get_type_and_flags() & !(1 << 4)
        });
    }

    /// Returns the packet type
    fn packet_type(&self) -> u8 {
        self.get_type_and_flags() & 0b0000_1111 // reset flags to 0
    }

    /// Sets the packet type
    fn set_packet_type(&mut self, packet_type: u8) {
        self.set_type_and_flags(packet_type & 0b0000_1111);
    }

    /// Returns the 3 extra flags which differ per packet type
    fn flags(&self) -> (bool, bool, bool) {
        let type_and_flags = self.get_type_and_flags();
        (
            type_and_flags & (1 << 5) != 0,
            type_and_flags & (1 << 6) != 0,
            type_and_flags & (1 << 7) != 0,
        )
    }

    /// Sets the 3 extra flags (which differ per packet type)
    fn set_flags(&mut self, flags: (bool, bool, bool)) {
        let mut set_mask = 0b0000_0000;
        let mut reset_mask = 0b1111_1111;
        if flags.0 {
            set_mask += 32;
        } else {
            reset_mask -= 32;
        }
        if flags.1 {
            set_mask += 64;
        } else {
            reset_mask -= 64;
        }
        if flags.2 {
            set_mask += 128;
        } else {
            reset_mask -= 128;
        }

        let mut type_and_flags = self.get_type_and_flags();
        type_and_flags |= set_mask;
        type_and_flags &= reset_mask;
        self.set_type_and_flags(type_and_flags);
    }
}

pub trait PtpPacket<HT, BT>
where
    HT: Serializable + PtpHeader,
    BT: Serializable + PtpBody,
{
    /// Returns the header of the packet
    fn get_header(&self) -> &HT;

    /// Returns the body of the packet
    fn get_body(&self) -> &BT;

    /// Returns the MAC of the packet (if present)
    fn get_mac(&self) -> Option<&[u8; MAC_SIZE]>;

    /// Create a new packet with the given header, body and MAC
    fn new(header: HT, body: BT, mac: Option<[u8; MAC_SIZE]>) -> Self;

    /// Deserialize packet from bytes
    ///
    /// # Arguments
    ///
    /// * `data` - The bytes to deserialize
    /// * `info` - The info to use for deserialization
    ///
    /// # Errors
    ///
    /// * `SerializationError::DecryptionFailed` - If the packet is encrypted and the decryption fails
    /// * `SerializationError::MissingInfo` - If the packet is encrypted or authenticated and the info is missing
    /// * `SerializationError::AuthenticationFailed` - If the packet is authenticated and the authentication fails
    ///
    /// # Returns
    ///
    /// The deserialized packet or an error
    fn from_bytes(data: &[u8], info: SerializationInfo) -> Result<Self, SerializationError>
    where
        Self: Sized,
    {
        // Deserialize header. If encryption is used, it will be decrypted with key0
        let header = HT::from_bytes(data, Some(info))?;

        let body_size = data.len() - header.size() - if header.has_mac() { MAC_SIZE } else { 0 };
        let mut body_bytes = data[header.size()..(body_size + header.size())].to_vec();
        assert_eq!(body_size, body_bytes.len());

        let mut expected_mac: Option<Vec<u8>> = None;

        match info {
            SerializationInfo::UseEncryption(_, key1, bucket_key) => {
                let cipher = ChaCha20Poly1305::new(&key1.into());

                let mut auth_data = Vec::new();
                auth_data.append(&mut header.get_bytes());
                if let Some(key) = bucket_key {
                    auth_data.extend_from_slice(&key);
                }

                match cipher.decrypt_in_place(&[0u8; 12].into(), &auth_data, &mut body_bytes) {
                    Ok(_) => (),
                    Err(_) => return Err(SerializationError::DecryptionFailed),
                }
            }
            SerializationInfo::UseAuthentication(key, bucket_key) => {
                let mut auth_data = Vec::new();
                auth_data.append(&mut header.get_bytes());
                auth_data.extend_from_slice(&Sha256::digest(&body_bytes));
                if let Some(bucket_key) = bucket_key {
                    auth_data.extend_from_slice(&bucket_key);
                }

                let poly = Poly1305::new(&key.into());
                expected_mac = Some(poly.compute_unpadded(&auth_data).to_vec());
            }
            _ => (),
        }

        let body = BT::from_bytes(
            &body_bytes,
            Some(SerializationInfo::PacketType(header.packet_type())),
        )?;

        let mac = if header.has_mac() {
            if expected_mac.is_none() {
                return Err(SerializationError::MissingInfo(String::from(
                    "Missing UseAuthentication info to verify the MAC",
                )));
            }

            assert_len(data, header.size() + body_size + MAC_SIZE)?;
            let mut mac = [0u8; MAC_SIZE];
            let slice = &data[(header.size() + body_size)..];
            assert_eq!(MAC_SIZE, slice.len());
            mac.copy_from_slice(slice);

            if expected_mac.unwrap() != mac {
                return Err(SerializationError::AuthenticationFailed);
            }

            Some(mac)
        } else {
            None
        };

        Ok(Self::new(header, body, mac))
    }

    /// Serialize packet to bytes
    ///
    /// # Arguments
    ///
    /// * `info` - The info to use for serialization
    /// * `with_len` - Whether to prepend the length of the packet (plabble dyn_int bytes)
    ///
    /// # Errors
    ///
    /// * `SerializationError::MissingInfo` - No serialization info is provided
    fn get_bytes(
        &self,
        info: SerializationInfo,
        with_len: bool,
    ) -> Result<Vec<u8>, SerializationError> {
        let mut buff = Vec::new();
        let mut header_bytes = self.get_header().get_bytes();
        let mut body_bytes = self.get_body().get_bytes();
        let mut mac: Option<Vec<u8>> = None;

        match info {
            SerializationInfo::UseEncryption(key0, key1, bucket_key) => {
                let nonce = [0u8; 12];
                let mut auth_data = header_bytes.to_vec();
                if let Some(bucket_key) = bucket_key {
                    auth_data.extend_from_slice(&bucket_key);
                }

                // Encrypt header
                let mut cipher = ChaCha20::new(&key0.into(), &nonce.into());
                cipher.apply_keystream(&mut header_bytes);

                // Encrypt body
                let cipher = ChaCha20Poly1305::new(&key1.into());
                cipher
                    .encrypt_in_place(&nonce.into(), &auth_data, &mut body_bytes)
                    .expect("Encryption failed");
            }
            SerializationInfo::UseAuthentication(key, bucket_key) => {
                let mut auth_data = header_bytes.to_vec();
                auth_data.extend_from_slice(&Sha256::digest(&body_bytes));
                if let Some(bucket_key) = bucket_key {
                    auth_data.extend_from_slice(&bucket_key);
                }

                let poly = Poly1305::new(&key.into());
                mac = Some(poly.compute_unpadded(&auth_data).to_vec());
            }
            SerializationInfo::None => (),
            other => {
                return Err(SerializationError::MissingInfo(format!(
                    "Needs UseEncryption, UseAuthentication or None method. But {:?} is provided",
                    other
                )))
            }
        };

        if with_len {
            let len =
                header_bytes.len() + body_bytes.len() + if mac.is_some() { MAC_SIZE } else { 0 };

            buff.append(&mut dyn_int::encode(len as u128));
        }

        buff.append(&mut header_bytes);
        buff.append(&mut body_bytes);

        if let Some(mut mac) = mac {
            buff.append(&mut mac);
        }

        Ok(buff)
    }

    /// Verifies the MAC of the packet
    /// This method is needed so a non-encrypted packet can be checked after deserialization if we do not want to "peek" the header, which is less efficient
    ///
    /// # Arguments
    ///
    /// * `key` - The key to verify the MAC with. Must be generated with HKDF
    /// * `bucket_key` - The bucket key to add to the auth_data if authentication is needed. Optional
    fn verify_mac(&self, key: &[u8; KEY_SIZE], bucket_key: Option<[u8; KEY_SIZE]>) -> bool {
        if !self.get_header().has_mac() {
            eprintln!("[WARN]: Verifying MAC for packet with no MAC present!");
            return false;
        }

        if let Some(self_mac) = self.get_mac() {
            let mut auth_data = self.get_header().get_bytes().to_vec();
            auth_data.extend_from_slice(&Sha256::digest(self.get_body().get_bytes()));
            if let Some(bucket_key) = bucket_key {
                auth_data.extend_from_slice(&bucket_key);
            }

            let poly = Poly1305::new(key.into());
            let mac = poly.compute_unpadded(&auth_data).to_vec();
            mac == self_mac
        } else {
            false
        }
    }
}