plabble_codec/codec/
ptp_packet.rs

1use crate::{
2    abstractions::{Serializable, SerializationError, SerializationInfo, KEY_SIZE, MAC_SIZE},
3    codec::common::{assert_len, dyn_int},
4};
5
6use chacha20::{
7    cipher::{KeyIvInit, StreamCipher},
8    ChaCha20,
9};
10use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit};
11use poly1305::Poly1305;
12use sha2::{Digest, Sha256};
13
14pub trait PtpBody {
15    /// Get packet type byte
16    fn packet_type(&self) -> u8;
17}
18
19pub trait PtpHeaderBase {
20    fn get_type_and_flags(&self) -> u8;
21    fn set_type_and_flags(&mut self, type_and_flags: u8);
22}
23
24pub trait PtpHeader: PtpHeaderBase {
25    /// Returns true if the packet has a MAC (according to the flags)
26    fn has_mac(&self) -> bool {
27        self.get_type_and_flags() & (1 << 4) != 0
28    }
29
30    /// Sets the MAC flag
31    fn set_mac(&mut self, has_mac: bool) {
32        self.set_type_and_flags(if has_mac {
33            self.get_type_and_flags() | 1 << 4
34        } else {
35            self.get_type_and_flags() & !(1 << 4)
36        });
37    }
38
39    /// Returns the packet type
40    fn packet_type(&self) -> u8 {
41        self.get_type_and_flags() & 0b0000_1111 // reset flags to 0
42    }
43
44    /// Sets the packet type
45    fn set_packet_type(&mut self, packet_type: u8) {
46        self.set_type_and_flags(packet_type & 0b0000_1111);
47    }
48
49    /// Returns the 3 extra flags which differ per packet type
50    fn flags(&self) -> (bool, bool, bool) {
51        let type_and_flags = self.get_type_and_flags();
52        (
53            type_and_flags & (1 << 5) != 0,
54            type_and_flags & (1 << 6) != 0,
55            type_and_flags & (1 << 7) != 0,
56        )
57    }
58
59    /// Sets the 3 extra flags (which differ per packet type)
60    fn set_flags(&mut self, flags: (bool, bool, bool)) {
61        let mut set_mask = 0b0000_0000;
62        let mut reset_mask = 0b1111_1111;
63        if flags.0 {
64            set_mask += 32;
65        } else {
66            reset_mask -= 32;
67        }
68        if flags.1 {
69            set_mask += 64;
70        } else {
71            reset_mask -= 64;
72        }
73        if flags.2 {
74            set_mask += 128;
75        } else {
76            reset_mask -= 128;
77        }
78
79        let mut type_and_flags = self.get_type_and_flags();
80        type_and_flags |= set_mask;
81        type_and_flags &= reset_mask;
82        self.set_type_and_flags(type_and_flags);
83    }
84}
85
86pub trait PtpPacket<HT, BT>
87where
88    HT: Serializable + PtpHeader,
89    BT: Serializable + PtpBody,
90{
91    /// Returns the header of the packet
92    fn get_header(&self) -> &HT;
93
94    /// Returns the body of the packet
95    fn get_body(&self) -> &BT;
96
97    /// Returns the MAC of the packet (if present)
98    fn get_mac(&self) -> Option<&[u8; MAC_SIZE]>;
99
100    /// Create a new packet with the given header, body and MAC
101    fn new(header: HT, body: BT, mac: Option<[u8; MAC_SIZE]>) -> Self;
102
103    /// Deserialize packet from bytes
104    ///
105    /// # Arguments
106    ///
107    /// * `data` - The bytes to deserialize
108    /// * `info` - The info to use for deserialization
109    ///
110    /// # Errors
111    ///
112    /// * `SerializationError::DecryptionFailed` - If the packet is encrypted and the decryption fails
113    /// * `SerializationError::MissingInfo` - If the packet is encrypted or authenticated and the info is missing
114    /// * `SerializationError::AuthenticationFailed` - If the packet is authenticated and the authentication fails
115    ///
116    /// # Returns
117    ///
118    /// The deserialized packet or an error
119    fn from_bytes(data: &[u8], info: SerializationInfo) -> Result<Self, SerializationError>
120    where
121        Self: Sized,
122    {
123        // Deserialize header. If encryption is used, it will be decrypted with key0
124        let header = HT::from_bytes(data, Some(info))?;
125
126        let body_size = data.len() - header.size() - if header.has_mac() { MAC_SIZE } else { 0 };
127        let mut body_bytes = data[header.size()..(body_size + header.size())].to_vec();
128        assert_eq!(body_size, body_bytes.len());
129
130        let mut expected_mac: Option<Vec<u8>> = None;
131
132        match info {
133            SerializationInfo::UseEncryption(_, key1, bucket_key) => {
134                let cipher = ChaCha20Poly1305::new(&key1.into());
135
136                let mut auth_data = Vec::new();
137                auth_data.append(&mut header.get_bytes());
138                if let Some(key) = bucket_key {
139                    auth_data.extend_from_slice(&key);
140                }
141
142                match cipher.decrypt_in_place(&[0u8; 12].into(), &auth_data, &mut body_bytes) {
143                    Ok(_) => (),
144                    Err(_) => return Err(SerializationError::DecryptionFailed),
145                }
146            }
147            SerializationInfo::UseAuthentication(key, bucket_key) => {
148                let mut auth_data = Vec::new();
149                auth_data.append(&mut header.get_bytes());
150                auth_data.extend_from_slice(&Sha256::digest(&body_bytes));
151                if let Some(bucket_key) = bucket_key {
152                    auth_data.extend_from_slice(&bucket_key);
153                }
154
155                let poly = Poly1305::new(&key.into());
156                expected_mac = Some(poly.compute_unpadded(&auth_data).to_vec());
157            }
158            _ => (),
159        }
160
161        let body = BT::from_bytes(
162            &body_bytes,
163            Some(SerializationInfo::PacketType(header.packet_type())),
164        )?;
165
166        let mac = if header.has_mac() {
167            if expected_mac.is_none() {
168                return Err(SerializationError::MissingInfo(String::from(
169                    "Missing UseAuthentication info to verify the MAC",
170                )));
171            }
172
173            assert_len(data, header.size() + body_size + MAC_SIZE)?;
174            let mut mac = [0u8; MAC_SIZE];
175            let slice = &data[(header.size() + body_size)..];
176            assert_eq!(MAC_SIZE, slice.len());
177            mac.copy_from_slice(slice);
178
179            if expected_mac.unwrap() != mac {
180                return Err(SerializationError::AuthenticationFailed);
181            }
182
183            Some(mac)
184        } else {
185            None
186        };
187
188        Ok(Self::new(header, body, mac))
189    }
190
191    /// Serialize packet to bytes
192    ///
193    /// # Arguments
194    ///
195    /// * `info` - The info to use for serialization
196    /// * `with_len` - Whether to prepend the length of the packet (plabble dyn_int bytes)
197    ///
198    /// # Errors
199    ///
200    /// * `SerializationError::MissingInfo` - No serialization info is provided
201    fn get_bytes(
202        &self,
203        info: SerializationInfo,
204        with_len: bool,
205    ) -> Result<Vec<u8>, SerializationError> {
206        let mut buff = Vec::new();
207        let mut header_bytes = self.get_header().get_bytes();
208        let mut body_bytes = self.get_body().get_bytes();
209        let mut mac: Option<Vec<u8>> = None;
210
211        match info {
212            SerializationInfo::UseEncryption(key0, key1, bucket_key) => {
213                let nonce = [0u8; 12];
214                let mut auth_data = header_bytes.to_vec();
215                if let Some(bucket_key) = bucket_key {
216                    auth_data.extend_from_slice(&bucket_key);
217                }
218
219                // Encrypt header
220                let mut cipher = ChaCha20::new(&key0.into(), &nonce.into());
221                cipher.apply_keystream(&mut header_bytes);
222
223                // Encrypt body
224                let cipher = ChaCha20Poly1305::new(&key1.into());
225                cipher
226                    .encrypt_in_place(&nonce.into(), &auth_data, &mut body_bytes)
227                    .expect("Encryption failed");
228            }
229            SerializationInfo::UseAuthentication(key, bucket_key) => {
230                let mut auth_data = header_bytes.to_vec();
231                auth_data.extend_from_slice(&Sha256::digest(&body_bytes));
232                if let Some(bucket_key) = bucket_key {
233                    auth_data.extend_from_slice(&bucket_key);
234                }
235
236                let poly = Poly1305::new(&key.into());
237                mac = Some(poly.compute_unpadded(&auth_data).to_vec());
238            }
239            SerializationInfo::None => (),
240            other => {
241                return Err(SerializationError::MissingInfo(format!(
242                    "Needs UseEncryption, UseAuthentication or None method. But {:?} is provided",
243                    other
244                )))
245            }
246        };
247
248        if with_len {
249            let len =
250                header_bytes.len() + body_bytes.len() + if mac.is_some() { MAC_SIZE } else { 0 };
251
252            buff.append(&mut dyn_int::encode(len as u128));
253        }
254
255        buff.append(&mut header_bytes);
256        buff.append(&mut body_bytes);
257
258        if let Some(mut mac) = mac {
259            buff.append(&mut mac);
260        }
261
262        Ok(buff)
263    }
264
265    /// Verifies the MAC of the packet
266    /// This method is needed so a non-encrypted packet can be checked after deserialization if we do not want to "peek" the header, which is less efficient
267    ///
268    /// # Arguments
269    ///
270    /// * `key` - The key to verify the MAC with. Must be generated with HKDF
271    /// * `bucket_key` - The bucket key to add to the auth_data if authentication is needed. Optional
272    fn verify_mac(&self, key: &[u8; KEY_SIZE], bucket_key: Option<[u8; KEY_SIZE]>) -> bool {
273        if !self.get_header().has_mac() {
274            eprintln!("[WARN]: Verifying MAC for packet with no MAC present!");
275            return false;
276        }
277
278        if let Some(self_mac) = self.get_mac() {
279            let mut auth_data = self.get_header().get_bytes().to_vec();
280            auth_data.extend_from_slice(&Sha256::digest(self.get_body().get_bytes()));
281            if let Some(bucket_key) = bucket_key {
282                auth_data.extend_from_slice(&bucket_key);
283            }
284
285            let poly = Poly1305::new(key.into());
286            let mac = poly.compute_unpadded(&auth_data).to_vec();
287            mac == self_mac
288        } else {
289            false
290        }
291    }
292}