plabble_codec/codec/header/
request_header.rs

1use chacha20::{
2    cipher::{KeyIvInit, StreamCipher},
3    ChaCha20,
4};
5
6use crate::{
7    abstractions::{Serializable, SerializationError, SerializationInfo, ID_SIZE},
8    codec::{
9        common::assert_len,
10        objects::BucketId,
11        ptp_packet::{PtpHeader, PtpHeaderBase},
12    },
13};
14
15/// The header of a request packet
16///
17/// # Fields
18///
19/// * `type_and_flags` - the type of the packet and the flags
20/// * `bucket_id` - the id of the bucket the packet is for (optional)
21pub struct RequestHeader {
22    type_and_flags: u8,
23    pub bucket_id: Option<BucketId>,
24}
25
26impl PtpHeaderBase for RequestHeader {
27    fn get_type_and_flags(&self) -> u8 {
28        self.type_and_flags
29    }
30
31    fn set_type_and_flags(&mut self, type_and_flags: u8) {
32        self.type_and_flags = type_and_flags;
33    }
34}
35
36impl PtpHeader for RequestHeader {}
37
38impl RequestHeader {
39    /// Create new request header
40    ///
41    /// # Arguments
42    ///
43    /// * `packet_type` - the type of the packet
44    /// * `bucket_id` - the id of the bucket the packet is for (optional)
45    pub fn new(packet_type: u8, bucket_id: Option<BucketId>) -> Self {
46        Self {
47            type_and_flags: packet_type & 0b0000_1111,
48            bucket_id,
49        }
50    }
51
52    /// Indicates if this packet type needs a bucket id
53    pub fn has_bucket_id(&self) -> bool {
54        !matches!(self.packet_type(), 0)
55    }
56}
57
58impl Serializable for RequestHeader {
59    fn size(&self) -> usize {
60        1 + if self.bucket_id.is_some() { ID_SIZE } else { 0 }
61    }
62
63    fn get_bytes(&self) -> Vec<u8> {
64        let mut buff = Vec::new();
65        buff.push(self.type_and_flags);
66        if let Some(id) = &self.bucket_id {
67            buff.append(&mut id.get_bytes());
68        }
69
70        buff
71    }
72
73    fn from_bytes(data: &[u8], info: Option<SerializationInfo>) -> Result<Self, SerializationError>
74    where
75        Self: Sized,
76    {
77        // Because no packet is less than 17 bytes (connect has no id, but is greater. all other types have an id)
78        assert_len(data, 1 + ID_SIZE)?;
79        let mut data = data[..(1 + ID_SIZE)].to_vec();
80
81        // If encryption is used, decrypt it
82        if let Some(SerializationInfo::UseEncryption(key0, _, _)) = info {
83            let mut cipher = ChaCha20::new(&key0.into(), &[0u8; 12].into());
84            cipher.apply_keystream(&mut data);
85        };
86
87        let mut header = Self {
88            type_and_flags: data[0],
89            bucket_id: None,
90        };
91
92        if header.has_bucket_id() {
93            header.bucket_id = Some(BucketId::from_bytes(&data[1..(1 + ID_SIZE)], None)?);
94        }
95
96        Ok(header)
97    }
98}
99
100#[cfg(test)]
101mod test {
102    use crate::codec::objects::BucketPermissions;
103
104    use super::*;
105
106    #[test]
107    fn can_detect_mac() {
108        let h = RequestHeader {
109            type_and_flags: 0b0101_0110, // bits are right-to-left, so bit 5 is left of underscore
110            bucket_id: None,
111        };
112
113        assert_eq!(h.has_mac(), true);
114    }
115
116    #[test]
117    fn can_detect_type() {
118        for i in 0..16 {
119            let h = RequestHeader {
120                type_and_flags: i + (128 + 64 + 32 + 16),
121                bucket_id: None,
122            };
123
124            // println!("{:#b}", &h.type_and_flags);
125            assert_eq!(h.packet_type(), i);
126        }
127    }
128
129    #[test]
130    fn can_serialize_without_bucket_id() {
131        let mut header = RequestHeader::new(0, None);
132        header.set_flags((true, false, false));
133        header.set_mac(true);
134        let serialized = header.get_bytes();
135        assert_eq!(serialized.len(), 1);
136        assert_eq!("00110000", &format!("{:08b}", serialized[0]));
137    }
138
139    #[test]
140    fn can_serialize_with_bucket_id() {
141        let mut id = BucketId::from_bytes(&[0u8; 16], None).unwrap();
142        id.set_lifetime(123);
143        id.set_permissions(BucketPermissions {
144            pub_read: true,
145            pub_write: false,
146            pub_append: false,
147            priv_write: true,
148            priv_append: false,
149            delete_bucket: true,
150        });
151        let mut header = RequestHeader::new(7, Some(id));
152        header.set_flags((true, true, false));
153        let data = header.get_bytes();
154        assert_eq!(
155            vec![
156                0b0110_0111,
157                0,
158                0,
159                0,
160                0,
161                0,
162                0,
163                0,
164                0,
165                0,
166                0,
167                0,
168                0,
169                0,
170                0,
171                123,
172                0b1010_0100
173            ],
174            data
175        );
176    }
177
178    #[test]
179    fn can_deserialize_from_longer_slice() {
180        let bytes = &[
181            1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 123, 123, 123,
182        ];
183        let header = RequestHeader::from_bytes(bytes, None).unwrap();
184        assert_eq!(
185            header.bucket_id.as_ref().unwrap().get_bytes(),
186            &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0]
187        );
188        assert_eq!(header.flags(), (false, false, false));
189        assert!(!header.has_mac());
190    }
191
192    #[test]
193    fn can_deserialize_with_id() {
194        let bytes = &[
195            0b0011_0001,
196            1,
197            2,
198            3,
199            4,
200            5,
201            6,
202            7,
203            8,
204            9,
205            10,
206            11,
207            12,
208            13,
209            14,
210            15,
211            0b0101_1000,
212        ];
213        let header = RequestHeader::from_bytes(bytes, None).unwrap();
214        assert_eq!(header.size(), 17);
215        assert_eq!(1, header.packet_type());
216        assert!(header.has_mac());
217        assert_eq!(header.flags(), (true, false, false));
218
219        let id = header.bucket_id.unwrap();
220        assert_eq!(id.lifetime(), 15);
221        assert_eq!(
222            id.permissions(),
223            BucketPermissions {
224                pub_read: false,
225                pub_write: true,
226                pub_append: true,
227                priv_write: false,
228                priv_append: true,
229                delete_bucket: false
230            }
231        );
232    }
233}