plabble_codec/codec/header/
response_header.rs

1use chacha20::{
2    cipher::{KeyIvInit, StreamCipher},
3    ChaCha20,
4};
5
6use crate::{
7    abstractions::{Serializable, SerializationError, SerializationInfo},
8    codec::{
9        common::assert_len,
10        ptp_packet::{PtpHeader, PtpHeaderBase},
11    },
12};
13
14/// The header of a response packet
15///
16/// # Fields
17///
18/// * `type_and_flags` - the type of the packet and the flags
19/// * `counter` - the counter of the request (client counter) the response is for
20pub struct ResponseHeader {
21    type_and_flags: u8,
22    counter: u16,
23}
24
25impl PtpHeaderBase for ResponseHeader {
26    fn get_type_and_flags(&self) -> u8 {
27        self.type_and_flags
28    }
29
30    fn set_type_and_flags(&mut self, type_and_flags: u8) {
31        self.type_and_flags = type_and_flags;
32    }
33}
34
35impl PtpHeader for ResponseHeader {}
36
37impl ResponseHeader {
38    /// Create new response header
39    ///
40    /// # Arguments
41    ///
42    /// * `packet_type` - the type of the packet
43    /// * `counter` - the counter of the request (client counter) the response is for
44    pub fn new(packet_type: u8, counter: u16) -> Self {
45        Self {
46            type_and_flags: packet_type & 0b0000_1111,
47            counter,
48        }
49    }
50
51    /// Get the response counter
52    ///
53    /// # Returns
54    ///
55    /// The counter of the request (client counter) the response is for
56    pub fn counter(&self) -> u16 {
57        self.counter
58    }
59}
60
61impl Serializable for ResponseHeader {
62    fn size(&self) -> usize {
63        1 + 2
64    }
65
66    fn get_bytes(&self) -> Vec<u8> {
67        let mut buff = Vec::with_capacity(3);
68        buff.push(self.type_and_flags);
69        buff.extend_from_slice(&self.counter.to_be_bytes());
70        buff
71    }
72
73    fn from_bytes(data: &[u8], info: Option<SerializationInfo>) -> Result<Self, SerializationError>
74    where
75        Self: Sized,
76    {
77        assert_len(data, 3)?;
78        let mut data = data[..3].to_vec();
79
80        // If encryption is used, decrypt it
81        if let Some(SerializationInfo::UseEncryption(key0, _, _)) = info {
82            let mut cipher = ChaCha20::new(&key0.into(), &[0u8; 12].into());
83            cipher.apply_keystream(&mut data);
84        };
85
86        let type_and_flags = data[0];
87        let mut counter = [0u8; 2];
88        counter.copy_from_slice(&data[1..3]);
89        let counter = u16::from_be_bytes(counter);
90        Ok(Self {
91            type_and_flags,
92            counter,
93        })
94    }
95}
96
97#[cfg(test)]
98mod test {
99    use super::*;
100
101    #[test]
102    fn can_detect_mac_and_status() {
103        let r = ResponseHeader {
104            type_and_flags: 0b1001_0111,
105            counter: 12,
106        };
107
108        assert_eq!(r.has_mac(), true);
109        assert_eq!(r.packet_type(), 7);
110    }
111
112    #[test]
113    fn can_deserialize() {
114        let data = &[0b0001_0111, 0b0000_0001, 0b0000_0011];
115        let header = ResponseHeader::from_bytes(data, None).unwrap();
116        assert_eq!(header.counter, 259);
117        assert_eq!(header.packet_type(), 7);
118        assert_eq!(header.has_mac(), true);
119    }
120
121    #[test]
122    fn can_deserialize_from_longer_slice() {
123        let data = &[
124            0b0001_0111,
125            0b0000_0001,
126            0b0000_0011,
127            1,
128            2,
129            3,
130            4,
131            5,
132            6,
133            6,
134            7,
135            8,
136        ];
137        let header = ResponseHeader::from_bytes(data, None).unwrap();
138        assert_eq!(header.counter, 259);
139        assert_eq!(header.packet_type(), 7);
140        assert_eq!(header.has_mac(), true);
141    }
142
143    #[test]
144    fn can_serialize() {
145        let mut header = ResponseHeader::new(15, 258);
146        header.set_mac(true);
147
148        let bytes = header.get_bytes();
149        assert_eq!(vec![15 + 16, 1, 2], bytes);
150    }
151}