use chacha20::{
cipher::{KeyIvInit, StreamCipher},
ChaCha20,
};
use crate::{
abstractions::{Serializable, SerializationError, SerializationInfo},
codec::{
common::assert_len,
ptp_packet::{PtpHeader, PtpHeaderBase},
},
};
pub struct ResponseHeader {
type_and_flags: u8,
counter: u16,
}
impl PtpHeaderBase for ResponseHeader {
fn get_type_and_flags(&self) -> u8 {
self.type_and_flags
}
fn set_type_and_flags(&mut self, type_and_flags: u8) {
self.type_and_flags = type_and_flags;
}
}
impl PtpHeader for ResponseHeader {}
impl ResponseHeader {
pub fn new(packet_type: u8, counter: u16) -> Self {
Self {
type_and_flags: packet_type & 0b0000_1111,
counter,
}
}
pub fn counter(&self) -> u16 {
self.counter
}
}
impl Serializable for ResponseHeader {
fn size(&self) -> usize {
1 + 2
}
fn get_bytes(&self) -> Vec<u8> {
let mut buff = Vec::with_capacity(3);
buff.push(self.type_and_flags);
buff.extend_from_slice(&self.counter.to_be_bytes());
buff
}
fn from_bytes(data: &[u8], info: Option<SerializationInfo>) -> Result<Self, SerializationError>
where
Self: Sized,
{
assert_len(data, 3)?;
let mut data = data[..3].to_vec();
if let Some(SerializationInfo::UseEncryption(key0, _, _)) = info {
let mut cipher = ChaCha20::new(&key0.into(), &[0u8; 12].into());
cipher.apply_keystream(&mut data);
};
let type_and_flags = data[0];
let mut counter = [0u8; 2];
counter.copy_from_slice(&data[1..3]);
let counter = u16::from_be_bytes(counter);
Ok(Self {
type_and_flags,
counter,
})
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn can_detect_mac_and_status() {
let r = ResponseHeader {
type_and_flags: 0b1001_0111,
counter: 12,
};
assert_eq!(r.has_mac(), true);
assert_eq!(r.packet_type(), 7);
}
#[test]
fn can_deserialize() {
let data = &[0b0001_0111, 0b0000_0001, 0b0000_0011];
let header = ResponseHeader::from_bytes(data, None).unwrap();
assert_eq!(header.counter, 259);
assert_eq!(header.packet_type(), 7);
assert_eq!(header.has_mac(), true);
}
#[test]
fn can_deserialize_from_longer_slice() {
let data = &[
0b0001_0111,
0b0000_0001,
0b0000_0011,
1,
2,
3,
4,
5,
6,
6,
7,
8,
];
let header = ResponseHeader::from_bytes(data, None).unwrap();
assert_eq!(header.counter, 259);
assert_eq!(header.packet_type(), 7);
assert_eq!(header.has_mac(), true);
}
#[test]
fn can_serialize() {
let mut header = ResponseHeader::new(15, 258);
header.set_mac(true);
let bytes = header.get_bytes();
assert_eq!(vec![15 + 16, 1, 2], bytes);
}
}