plabble_codec/protocol/
packet_handler.rs

1use std::collections::HashMap;
2
3use hkdf::Hkdf;
4use sha2::Sha256;
5
6use crate::{
7    abstractions::{
8        Serializable, SerializationError, SerializationInfo, KEY_SIZE, TYPE_APPEND, TYPE_PUT,
9        TYPE_REQUEST, TYPE_SUBSCRIBE, TYPE_WIPE,
10    },
11    codec::{
12        header::RequestHeader,
13        objects::BucketId,
14        ptp_packet::{PtpHeader, PtpPacket},
15        request::RequestPacket,
16        response::ResponsePacket,
17    },
18};
19
20/// Options for the packet handler
21///
22/// * `UseEncryption` - Encrypts the packet with CHACHA20-POLY1305
23/// * `UseAuthentication` - Authenticates the packet with POLY1305
24/// * `None` - No encryption or authentication
25pub enum HandlerOptions {
26    UseEncryption,
27    UseAuthentication,
28    None,
29}
30
31/// Tool that handles the serialization, encryption and authentication of packets
32pub struct PacketHandler {
33    bucket_keys: HashMap<BucketId, [u8; KEY_SIZE]>,
34    session_key: Option<[u8; KEY_SIZE]>,
35    self_counter: u16,
36    other_counter: u16,
37    options: HandlerOptions,
38}
39
40impl PacketHandler {
41    /// Creates a new PacketHandler
42    ///
43    /// # Arguments
44    ///
45    /// * `options` - Options for the handler. See `HandlerOptions` for more information
46    /// * `bucket_keys` - Keys for the buckets. If `None`, the handler will use an empty map
47    /// * `session_key` - The session key. If `None`, the handler will not be able to encrypt or authenticate packets
48    pub fn new(
49        options: HandlerOptions,
50        bucket_keys: Option<HashMap<BucketId, [u8; KEY_SIZE]>>,
51        session_key: Option<[u8; KEY_SIZE]>,
52    ) -> Self {
53        Self {
54            bucket_keys: bucket_keys.unwrap_or(HashMap::new()),
55            self_counter: 0,
56            other_counter: 0,
57            options,
58            session_key,
59        }
60    }
61
62    /// Generate key for a request
63    ///
64    /// # Arguments
65    ///
66    /// * `nr` - The byte to add to the key info. See [MAC and encryption keys](https://plabble.github.io/transport/#MAC%20and%20Encryption%20keys)
67    /// * `my_or_other_counter` - If true, the counter of the current peer will be used. If false, the counter of the other peer will be used
68    fn key(&self, nr: u8, my_or_other_counter: bool) -> [u8; KEY_SIZE] {
69        let mut okm = [0u8; KEY_SIZE]; // output key
70        let kdf = Hkdf::<Sha256>::new(
71            None,
72            self.session_key
73                .as_ref()
74                .expect("Can't generate keys without session key"),
75        );
76
77        let counter: [u8; 2] = if my_or_other_counter {
78            self.self_counter.to_be_bytes()
79        } else {
80            self.other_counter.to_be_bytes()
81        };
82
83        let mut info = counter.to_vec();
84        info.push(nr);
85
86        kdf.expand(&info, &mut okm).expect("Failed to create key");
87        okm
88    }
89
90    /// Generate serialization info for a request
91    ///
92    /// # Arguments
93    ///
94    /// * `header` - The header of the request
95    /// * `me_or_other` - Indicates if the serialization info is for the current peer or the other peer
96    ///
97    /// # Returns
98    ///
99    /// * `Ok(SerializationInfo)` - The serialization info
100    /// * `Err(SerializationError)` - The error that occured
101    fn get_info_with_bucket_key_if_needed(
102        &self,
103        header: &RequestHeader,
104        me_or_other: bool,
105    ) -> Result<SerializationInfo, SerializationError> {
106        let bucket_key = if header.has_bucket_id() {
107            let id = header.bucket_id.as_ref().unwrap();
108            let permissons = id.permissions();
109
110            // Check if permissions require bucket key
111            if match header.packet_type() {
112                // Write
113                TYPE_PUT | TYPE_WIPE => !permissons.pub_write,
114                TYPE_APPEND => !permissons.pub_append && !permissons.pub_write,
115                TYPE_REQUEST | TYPE_SUBSCRIBE => !permissons.pub_read,
116                _ => false,
117            } {
118                match self.bucket_keys.get(id) {
119                    Some(key) => Some(*key),
120                    None => {
121                        return Err(SerializationError::MissingInfo(format!(
122                            "Bucket key for requested bucket with id #{:?} not present",
123                            id
124                        )))
125                    }
126                }
127            } else {
128                None
129            }
130        } else {
131            None
132        };
133
134        Ok(match self.options {
135            HandlerOptions::UseEncryption => SerializationInfo::UseEncryption(
136                self.key(0x00, me_or_other),
137                self.key(0x01, me_or_other),
138                bucket_key,
139            ),
140            HandlerOptions::UseAuthentication => {
141                SerializationInfo::UseAuthentication(self.key(0x00, me_or_other), bucket_key)
142            }
143            _ => SerializationInfo::None,
144        })
145    }
146
147    /// Parse a serialized and maybe encrypted request packet
148    ///
149    /// # Arguments
150    ///
151    /// * `data` - The serialized packet
152    ///
153    /// # Returns
154    ///
155    /// * `Ok(RequestPacket)` - The parsed packet
156    /// * `Err(SerializationError)` - The error that occured
157    pub fn parse_request(&mut self, data: &[u8]) -> Result<RequestPacket, SerializationError> {
158        let info = match self.options {
159            HandlerOptions::UseEncryption => {
160                SerializationInfo::UseEncryption(self.key(0x00, false), self.key(0x01, false), None)
161            }
162            _ => SerializationInfo::None,
163        };
164
165        let header = RequestHeader::from_bytes(data, Some(info))?;
166        let info = self.get_info_with_bucket_key_if_needed(&header, false)?;
167
168        match self.other_counter.checked_add(1) {
169            Some(v) => {
170                self.other_counter = v;
171                RequestPacket::from_bytes(data, info)
172            }
173            None => Err(SerializationError::CounterOverflow),
174        }
175    }
176
177    /// Parse a serialized and maybe encrypted response packet
178    ///
179    /// # Arguments
180    ///
181    /// * `data` - The serialized packet
182    ///
183    /// # Returns
184    ///
185    /// * `Ok(ResponsePacket)` - The parsed response packet
186    /// * `Err(SerializationError)` - The error that occured
187    pub fn parse_response(&mut self, data: &[u8]) -> Result<ResponsePacket, SerializationError> {
188        let info = match self.options {
189            HandlerOptions::UseEncryption => {
190                SerializationInfo::UseEncryption(self.key(0x00, false), self.key(0x01, false), None)
191            }
192            HandlerOptions::UseAuthentication => {
193                SerializationInfo::UseAuthentication(self.key(0x00, false), None)
194            }
195            _ => SerializationInfo::None,
196        };
197
198        match self.other_counter.checked_add(1) {
199            Some(v) => {
200                self.other_counter = v;
201                ResponsePacket::from_bytes(data, info)
202            }
203            None => Err(SerializationError::CounterOverflow),
204        }
205    }
206
207    /// Serialize a request packet
208    ///
209    /// # Arguments
210    ///
211    /// * `packet` - The packet to serialize
212    /// * `with_len` - Indicates if the length of the packet should be included
213    ///
214    /// # Returns
215    ///
216    /// * `Ok(Vec<u8>)` - The serialized packet
217    /// * `Err(SerializationError)` - The error that occured
218    pub fn serialize_request(
219        &mut self,
220        packet: RequestPacket,
221        with_len: bool,
222    ) -> Result<Vec<u8>, SerializationError> {
223        let info = self.get_info_with_bucket_key_if_needed(packet.get_header(), true)?;
224        let mut packet = packet;
225        if let HandlerOptions::UseAuthentication = self.options {
226            packet.header.set_mac(true);
227        }
228
229        match self.self_counter.checked_add(1) {
230            Some(v) => {
231                self.self_counter = v;
232                packet.get_bytes(info, with_len)
233            }
234            None => Err(SerializationError::CounterOverflow),
235        }
236    }
237
238    /// Serialize a response packet
239    ///
240    /// # Arguments
241    ///
242    /// * `packet` - The packet to serialize
243    /// * `with_len` - Indicates if the length of the packet should be included
244    ///
245    /// # Returns
246    ///
247    /// * `Ok(Vec<u8>)` - The serialized packet
248    /// * `Err(SerializationError)` - The error that occured
249    pub fn serialize_response(
250        &mut self,
251        packet: ResponsePacket,
252        with_len: bool,
253    ) -> Result<Vec<u8>, SerializationError> {
254        let mut packet = packet;
255        let info = match self.options {
256            HandlerOptions::UseEncryption => {
257                SerializationInfo::UseEncryption(self.key(0x00, true), self.key(0x01, true), None)
258            }
259            HandlerOptions::UseAuthentication => {
260                packet.header.set_mac(true);
261                SerializationInfo::UseAuthentication(self.key(0x00, true), None)
262            }
263            _ => SerializationInfo::None,
264        };
265
266        match self.self_counter.checked_add(1) {
267            Some(v) => {
268                self.self_counter = v;
269                packet.get_bytes(info, with_len)
270            }
271            None => Err(SerializationError::CounterOverflow),
272        }
273    }
274}
275
276#[cfg(test)]
277mod test {
278    use crate::{
279        abstractions::{TYPE_CREATE, TYPE_ERROR},
280        codec::{
281            common::SlotRange, header::ResponseHeader, request::RequestBody, response::ResponseBody,
282        },
283    };
284
285    use super::*;
286
287    #[test]
288    fn can_deserialize_connect_request() {
289        let req = &[
290            0, 0x77u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
291            22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
292        ];
293
294        let mut handler = PacketHandler::new(HandlerOptions::None, None, None);
295        let request = handler.parse_request(req).unwrap();
296        match request.get_body() {
297            RequestBody::CONNECT {
298                protocol_version, ..
299            } => {
300                assert_eq!(0x77, *protocol_version);
301            }
302            _ => panic!("Not a CONNECT"),
303        }
304
305        assert_eq!(0, handler.self_counter);
306        assert_eq!(1, handler.other_counter);
307    }
308
309    #[test]
310    fn create_does_never_give_bucket_key() {
311        let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
312        let id = BucketId::new(7);
313        let header = RequestHeader::new(1, Some(id.clone()));
314        sut.bucket_keys.insert(id, [1u8; 32]);
315        let res = sut
316            .get_info_with_bucket_key_if_needed(&header, false)
317            .unwrap();
318        match res {
319            SerializationInfo::UseAuthentication(_, bucket_key) => {
320                assert_eq!(None, bucket_key);
321            }
322            _ => panic!("Wrong type"),
323        }
324    }
325
326    #[test]
327    fn put_with_public_write_does_not_give_bucket_id() {
328        let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
329        let mut id = BucketId::new(7);
330
331        let mut permissions = id.permissions();
332        permissions.pub_write = true;
333        id.set_permissions(permissions);
334
335        let header = RequestHeader::new(TYPE_PUT, Some(id.clone()));
336        sut.bucket_keys.insert(id, [1u8; 32]);
337        let res = sut
338            .get_info_with_bucket_key_if_needed(&header, false)
339            .unwrap();
340        match res {
341            SerializationInfo::UseAuthentication(_, bucket_key) => {
342                assert_eq!(None, bucket_key);
343            }
344            _ => panic!("Wrong type"),
345        }
346    }
347
348    #[test]
349    fn append_without_public_append_does_give_bucket_id() {
350        let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
351        let id = BucketId::new(7);
352
353        let header = RequestHeader::new(TYPE_APPEND, Some(id.clone()));
354        sut.bucket_keys.insert(id, [1u8; 32]);
355        let res = sut
356            .get_info_with_bucket_key_if_needed(&header, false)
357            .unwrap();
358        match res {
359            SerializationInfo::UseAuthentication(_, bucket_key) => {
360                assert_eq!(Some([1u8; 32]), bucket_key);
361            }
362            _ => panic!("Wrong type"),
363        }
364    }
365
366    #[test]
367    fn can_generate_shared_key_with_counter_1() {
368        let session_key = &[1u8; 32];
369        let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
370        sut.self_counter = 1;
371        let key = sut.key(0x01, true);
372        assert_eq!(
373            key,
374            [
375                110, 223, 136, 196, 67, 61, 170, 231, 138, 234, 119, 93, 152, 169, 168, 18, 199,
376                27, 204, 11, 103, 191, 208, 199, 202, 145, 91, 96, 88, 228, 138, 41
377            ]
378        );
379    }
380
381    #[test]
382    fn can_generate_shared_key_with_counter_7() {
383        let session_key = &[1u8; 32];
384        let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
385        sut.other_counter = 7;
386        let key = sut.key(0x00, false);
387        assert_eq!(
388            key,
389            [
390                86, 124, 77, 37, 137, 217, 171, 207, 121, 144, 71, 67, 148, 195, 193, 134, 219,
391                223, 221, 216, 210, 66, 219, 166, 197, 113, 208, 166, 61, 206, 218, 1
392            ]
393        );
394    }
395
396    #[test]
397    fn can_generate_encryption_info_with_bucket_key() {
398        let session_key = &[1u8; 32];
399        let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
400        let id = BucketId::new(5);
401        sut.bucket_keys.insert(
402            id.clone(),
403            [
404                1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
405                9, 0, 1, 2,
406            ],
407        );
408        let header = &RequestHeader::new(2, Some(id));
409        let info = sut
410            .get_info_with_bucket_key_if_needed(header, true)
411            .unwrap();
412        match info {
413            SerializationInfo::UseEncryption(a, b, _) => {
414                assert_eq!(
415                    a,
416                    [
417                        115, 13, 138, 42, 229, 171, 252, 201, 236, 154, 27, 170, 98, 19, 64, 200,
418                        31, 27, 219, 82, 215, 38, 186, 156, 26, 126, 19, 36, 137, 132, 170, 129
419                    ]
420                );
421
422                assert_eq!(
423                    b,
424                    [
425                        240, 166, 168, 233, 19, 0, 183, 68, 176, 91, 91, 69, 182, 111, 141, 82,
426                        242, 142, 215, 82, 17, 88, 104, 210, 166, 49, 26, 152, 54, 245, 171, 80
427                    ]
428                );
429            }
430            _ => panic!("Wrong type"),
431        };
432    }
433
434    #[test]
435    fn can_create_request_encrypted() {
436        let session_key = &[1u8; 32];
437        let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
438
439        let bucket_id = BucketId::from_bytes(
440            &[
441                29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
442            ],
443            None,
444        )
445        .unwrap();
446
447        let header = RequestHeader::new(1, Some(bucket_id));
448        let create_req = RequestPacket::new(
449            header,
450            RequestBody::CREATE(SlotRange {
451                from: Some(5),
452                to: Some(7),
453            }),
454            None,
455        );
456
457        let bytes = sut.serialize_request(create_req, false).unwrap();
458        assert_eq!(
459            bytes,
460            vec![
461                53, 193, 133, 121, 169, 180, 199, 145, 54, 54, 159, 110, 145, 89, 36, 72, 33, 199,
462                139, 63, 198, 247, 187, 161, 49, 165, 174, 140, 57, 179, 243, 227, 172, 38, 86, 25,
463                183
464            ]
465        );
466    }
467
468    #[test]
469    fn can_create_request_authenticated() {
470        let session_key = &[1u8; 32];
471        let mut sut =
472            PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
473
474        let bucket_id = BucketId::from_bytes(
475            &[
476                29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
477            ],
478            None,
479        )
480        .unwrap();
481
482        let header = RequestHeader::new(1, Some(bucket_id));
483        let create_req = RequestPacket::new(
484            header,
485            RequestBody::CREATE(SlotRange {
486                from: Some(5),
487                to: Some(7),
488            }),
489            None,
490        );
491
492        let bytes = sut.serialize_request(create_req, false).unwrap();
493        assert_eq!(
494            bytes,
495            vec![
496                17, 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3, 0, 5,
497                0, 7, 116, 139, 222, 175, 34, 89, 8, 53, 185, 215, 120, 148, 218, 125, 29, 216
498            ]
499        );
500    }
501
502    #[test]
503    fn can_deserialize_request_authenticated() {
504        let session_key = &[1u8; 32];
505        let mut sut =
506            PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
507
508        let data = [
509            17, 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3, 0, 5, 0, 7,
510            116, 139, 222, 175, 34, 89, 8, 53, 185, 215, 120, 148, 218, 125, 29, 216,
511        ];
512
513        let res = sut.parse_request(&data).unwrap();
514        sut.other_counter -= 1; //yeah, that makes it work
515        assert!(res.verify_mac(&sut.key(0x00, false), None));
516    }
517
518    #[test]
519    fn can_parse_encrypted_request() {
520        let session_key = &[1u8; 32];
521        let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
522        let data = [
523            53, 193, 133, 121, 169, 180, 199, 145, 54, 54, 159, 110, 145, 89, 36, 72, 33, 199, 139,
524            63, 198, 247, 187, 161, 49, 165, 174, 140, 57, 179, 243, 227, 172, 38, 86, 25, 183,
525        ];
526
527        let bucket_id = BucketId::from_bytes(
528            &[
529                29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
530            ],
531            None,
532        )
533        .unwrap();
534
535        let packet = sut.parse_request(&data).unwrap();
536        let header = packet.get_header();
537        assert_eq!(Some(bucket_id), header.bucket_id);
538        assert_eq!(1, header.packet_type());
539        match packet.body {
540            RequestBody::CREATE(r) => {
541                assert_eq!(Some(5), r.from);
542                assert_eq!(Some(7), r.to);
543            }
544            _ => panic!("Not a create"),
545        }
546    }
547
548    #[test]
549    fn can_create_response_authenticated() {
550        let session_key = &[1u8; 32];
551        let mut sut =
552            PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
553        let mut response = ResponsePacket {
554            header: ResponseHeader::new(TYPE_ERROR, 1),
555            body: ResponseBody::ERROR(7, String::from("An error occured")),
556            mac: None,
557        };
558        response.header.set_mac(true);
559
560        let bytes = sut.serialize_response(response, false).unwrap();
561        assert_eq!(
562            bytes,
563            vec![
564                31, 0, 1, 7, 65, 110, 32, 101, 114, 114, 111, 114, 32, 111, 99, 99, 117, 114, 101,
565                100, 145, 60, 204, 2, 3, 125, 144, 113, 250, 131, 128, 188, 121, 229, 153, 131
566            ]
567        );
568    }
569
570    #[test]
571    fn can_create_response_no_authentication() {
572        let session_key = &[1u8; 32];
573        let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
574        let response = ResponsePacket {
575            header: ResponseHeader::new(TYPE_ERROR, 1),
576            body: ResponseBody::ERROR(7, String::from("An error occured")),
577            mac: None,
578        };
579
580        let bytes = sut.serialize_response(response, false).unwrap();
581        assert_eq!(
582            bytes,
583            vec![
584                15, 0, 1, 7, 65, 110, 32, 101, 114, 114, 111, 114, 32, 111, 99, 99, 117, 114, 101,
585                100
586            ]
587        );
588    }
589
590    #[test]
591    fn can_create_encrypted_response() {
592        let session_key = &[1u8; 32];
593        let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
594
595        let response = ResponsePacket {
596            header: ResponseHeader::new(TYPE_CREATE, 99),
597            body: ResponseBody::CREATE,
598            mac: None,
599        };
600
601        let bytes = sut.serialize_response(response, false).unwrap();
602        assert_eq!(
603            vec![
604                53, 220, 164, 47, 248, 118, 29, 99, 223, 219, 187, 40, 126, 155, 203, 47, 226, 93,
605                56
606            ],
607            bytes
608        );
609    }
610
611    #[test]
612    fn can_parse_encrypted_response() {
613        let session_key = &[1u8; 32];
614        let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
615        let data = &[
616            53, 220, 164, 47, 248, 118, 29, 99, 223, 219, 187, 40, 126, 155, 203, 47, 226, 93, 56,
617        ];
618        let response = sut.parse_response(data).unwrap();
619        assert_eq!(response.header.packet_type(), TYPE_CREATE);
620        assert_eq!(response.header.counter(), 99);
621    }
622}