1use std::collections::HashMap;
2
3use hkdf::Hkdf;
4use sha2::Sha256;
5
6use crate::{
7 abstractions::{
8 Serializable, SerializationError, SerializationInfo, KEY_SIZE, TYPE_APPEND, TYPE_PUT,
9 TYPE_REQUEST, TYPE_SUBSCRIBE, TYPE_WIPE,
10 },
11 codec::{
12 header::RequestHeader,
13 objects::BucketId,
14 ptp_packet::{PtpHeader, PtpPacket},
15 request::RequestPacket,
16 response::ResponsePacket,
17 },
18};
19
20pub enum HandlerOptions {
26 UseEncryption,
27 UseAuthentication,
28 None,
29}
30
31pub struct PacketHandler {
33 bucket_keys: HashMap<BucketId, [u8; KEY_SIZE]>,
34 session_key: Option<[u8; KEY_SIZE]>,
35 self_counter: u16,
36 other_counter: u16,
37 options: HandlerOptions,
38}
39
40impl PacketHandler {
41 pub fn new(
49 options: HandlerOptions,
50 bucket_keys: Option<HashMap<BucketId, [u8; KEY_SIZE]>>,
51 session_key: Option<[u8; KEY_SIZE]>,
52 ) -> Self {
53 Self {
54 bucket_keys: bucket_keys.unwrap_or(HashMap::new()),
55 self_counter: 0,
56 other_counter: 0,
57 options,
58 session_key,
59 }
60 }
61
62 fn key(&self, nr: u8, my_or_other_counter: bool) -> [u8; KEY_SIZE] {
69 let mut okm = [0u8; KEY_SIZE]; let kdf = Hkdf::<Sha256>::new(
71 None,
72 self.session_key
73 .as_ref()
74 .expect("Can't generate keys without session key"),
75 );
76
77 let counter: [u8; 2] = if my_or_other_counter {
78 self.self_counter.to_be_bytes()
79 } else {
80 self.other_counter.to_be_bytes()
81 };
82
83 let mut info = counter.to_vec();
84 info.push(nr);
85
86 kdf.expand(&info, &mut okm).expect("Failed to create key");
87 okm
88 }
89
90 fn get_info_with_bucket_key_if_needed(
102 &self,
103 header: &RequestHeader,
104 me_or_other: bool,
105 ) -> Result<SerializationInfo, SerializationError> {
106 let bucket_key = if header.has_bucket_id() {
107 let id = header.bucket_id.as_ref().unwrap();
108 let permissons = id.permissions();
109
110 if match header.packet_type() {
112 TYPE_PUT | TYPE_WIPE => !permissons.pub_write,
114 TYPE_APPEND => !permissons.pub_append && !permissons.pub_write,
115 TYPE_REQUEST | TYPE_SUBSCRIBE => !permissons.pub_read,
116 _ => false,
117 } {
118 match self.bucket_keys.get(id) {
119 Some(key) => Some(*key),
120 None => {
121 return Err(SerializationError::MissingInfo(format!(
122 "Bucket key for requested bucket with id #{:?} not present",
123 id
124 )))
125 }
126 }
127 } else {
128 None
129 }
130 } else {
131 None
132 };
133
134 Ok(match self.options {
135 HandlerOptions::UseEncryption => SerializationInfo::UseEncryption(
136 self.key(0x00, me_or_other),
137 self.key(0x01, me_or_other),
138 bucket_key,
139 ),
140 HandlerOptions::UseAuthentication => {
141 SerializationInfo::UseAuthentication(self.key(0x00, me_or_other), bucket_key)
142 }
143 _ => SerializationInfo::None,
144 })
145 }
146
147 pub fn parse_request(&mut self, data: &[u8]) -> Result<RequestPacket, SerializationError> {
158 let info = match self.options {
159 HandlerOptions::UseEncryption => {
160 SerializationInfo::UseEncryption(self.key(0x00, false), self.key(0x01, false), None)
161 }
162 _ => SerializationInfo::None,
163 };
164
165 let header = RequestHeader::from_bytes(data, Some(info))?;
166 let info = self.get_info_with_bucket_key_if_needed(&header, false)?;
167
168 match self.other_counter.checked_add(1) {
169 Some(v) => {
170 self.other_counter = v;
171 RequestPacket::from_bytes(data, info)
172 }
173 None => Err(SerializationError::CounterOverflow),
174 }
175 }
176
177 pub fn parse_response(&mut self, data: &[u8]) -> Result<ResponsePacket, SerializationError> {
188 let info = match self.options {
189 HandlerOptions::UseEncryption => {
190 SerializationInfo::UseEncryption(self.key(0x00, false), self.key(0x01, false), None)
191 }
192 HandlerOptions::UseAuthentication => {
193 SerializationInfo::UseAuthentication(self.key(0x00, false), None)
194 }
195 _ => SerializationInfo::None,
196 };
197
198 match self.other_counter.checked_add(1) {
199 Some(v) => {
200 self.other_counter = v;
201 ResponsePacket::from_bytes(data, info)
202 }
203 None => Err(SerializationError::CounterOverflow),
204 }
205 }
206
207 pub fn serialize_request(
219 &mut self,
220 packet: RequestPacket,
221 with_len: bool,
222 ) -> Result<Vec<u8>, SerializationError> {
223 let info = self.get_info_with_bucket_key_if_needed(packet.get_header(), true)?;
224 let mut packet = packet;
225 if let HandlerOptions::UseAuthentication = self.options {
226 packet.header.set_mac(true);
227 }
228
229 match self.self_counter.checked_add(1) {
230 Some(v) => {
231 self.self_counter = v;
232 packet.get_bytes(info, with_len)
233 }
234 None => Err(SerializationError::CounterOverflow),
235 }
236 }
237
238 pub fn serialize_response(
250 &mut self,
251 packet: ResponsePacket,
252 with_len: bool,
253 ) -> Result<Vec<u8>, SerializationError> {
254 let mut packet = packet;
255 let info = match self.options {
256 HandlerOptions::UseEncryption => {
257 SerializationInfo::UseEncryption(self.key(0x00, true), self.key(0x01, true), None)
258 }
259 HandlerOptions::UseAuthentication => {
260 packet.header.set_mac(true);
261 SerializationInfo::UseAuthentication(self.key(0x00, true), None)
262 }
263 _ => SerializationInfo::None,
264 };
265
266 match self.self_counter.checked_add(1) {
267 Some(v) => {
268 self.self_counter = v;
269 packet.get_bytes(info, with_len)
270 }
271 None => Err(SerializationError::CounterOverflow),
272 }
273 }
274}
275
276#[cfg(test)]
277mod test {
278 use crate::{
279 abstractions::{TYPE_CREATE, TYPE_ERROR},
280 codec::{
281 common::SlotRange, header::ResponseHeader, request::RequestBody, response::ResponseBody,
282 },
283 };
284
285 use super::*;
286
287 #[test]
288 fn can_deserialize_connect_request() {
289 let req = &[
290 0, 0x77u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
291 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
292 ];
293
294 let mut handler = PacketHandler::new(HandlerOptions::None, None, None);
295 let request = handler.parse_request(req).unwrap();
296 match request.get_body() {
297 RequestBody::CONNECT {
298 protocol_version, ..
299 } => {
300 assert_eq!(0x77, *protocol_version);
301 }
302 _ => panic!("Not a CONNECT"),
303 }
304
305 assert_eq!(0, handler.self_counter);
306 assert_eq!(1, handler.other_counter);
307 }
308
309 #[test]
310 fn create_does_never_give_bucket_key() {
311 let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
312 let id = BucketId::new(7);
313 let header = RequestHeader::new(1, Some(id.clone()));
314 sut.bucket_keys.insert(id, [1u8; 32]);
315 let res = sut
316 .get_info_with_bucket_key_if_needed(&header, false)
317 .unwrap();
318 match res {
319 SerializationInfo::UseAuthentication(_, bucket_key) => {
320 assert_eq!(None, bucket_key);
321 }
322 _ => panic!("Wrong type"),
323 }
324 }
325
326 #[test]
327 fn put_with_public_write_does_not_give_bucket_id() {
328 let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
329 let mut id = BucketId::new(7);
330
331 let mut permissions = id.permissions();
332 permissions.pub_write = true;
333 id.set_permissions(permissions);
334
335 let header = RequestHeader::new(TYPE_PUT, Some(id.clone()));
336 sut.bucket_keys.insert(id, [1u8; 32]);
337 let res = sut
338 .get_info_with_bucket_key_if_needed(&header, false)
339 .unwrap();
340 match res {
341 SerializationInfo::UseAuthentication(_, bucket_key) => {
342 assert_eq!(None, bucket_key);
343 }
344 _ => panic!("Wrong type"),
345 }
346 }
347
348 #[test]
349 fn append_without_public_append_does_give_bucket_id() {
350 let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
351 let id = BucketId::new(7);
352
353 let header = RequestHeader::new(TYPE_APPEND, Some(id.clone()));
354 sut.bucket_keys.insert(id, [1u8; 32]);
355 let res = sut
356 .get_info_with_bucket_key_if_needed(&header, false)
357 .unwrap();
358 match res {
359 SerializationInfo::UseAuthentication(_, bucket_key) => {
360 assert_eq!(Some([1u8; 32]), bucket_key);
361 }
362 _ => panic!("Wrong type"),
363 }
364 }
365
366 #[test]
367 fn can_generate_shared_key_with_counter_1() {
368 let session_key = &[1u8; 32];
369 let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
370 sut.self_counter = 1;
371 let key = sut.key(0x01, true);
372 assert_eq!(
373 key,
374 [
375 110, 223, 136, 196, 67, 61, 170, 231, 138, 234, 119, 93, 152, 169, 168, 18, 199,
376 27, 204, 11, 103, 191, 208, 199, 202, 145, 91, 96, 88, 228, 138, 41
377 ]
378 );
379 }
380
381 #[test]
382 fn can_generate_shared_key_with_counter_7() {
383 let session_key = &[1u8; 32];
384 let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
385 sut.other_counter = 7;
386 let key = sut.key(0x00, false);
387 assert_eq!(
388 key,
389 [
390 86, 124, 77, 37, 137, 217, 171, 207, 121, 144, 71, 67, 148, 195, 193, 134, 219,
391 223, 221, 216, 210, 66, 219, 166, 197, 113, 208, 166, 61, 206, 218, 1
392 ]
393 );
394 }
395
396 #[test]
397 fn can_generate_encryption_info_with_bucket_key() {
398 let session_key = &[1u8; 32];
399 let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
400 let id = BucketId::new(5);
401 sut.bucket_keys.insert(
402 id.clone(),
403 [
404 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
405 9, 0, 1, 2,
406 ],
407 );
408 let header = &RequestHeader::new(2, Some(id));
409 let info = sut
410 .get_info_with_bucket_key_if_needed(header, true)
411 .unwrap();
412 match info {
413 SerializationInfo::UseEncryption(a, b, _) => {
414 assert_eq!(
415 a,
416 [
417 115, 13, 138, 42, 229, 171, 252, 201, 236, 154, 27, 170, 98, 19, 64, 200,
418 31, 27, 219, 82, 215, 38, 186, 156, 26, 126, 19, 36, 137, 132, 170, 129
419 ]
420 );
421
422 assert_eq!(
423 b,
424 [
425 240, 166, 168, 233, 19, 0, 183, 68, 176, 91, 91, 69, 182, 111, 141, 82,
426 242, 142, 215, 82, 17, 88, 104, 210, 166, 49, 26, 152, 54, 245, 171, 80
427 ]
428 );
429 }
430 _ => panic!("Wrong type"),
431 };
432 }
433
434 #[test]
435 fn can_create_request_encrypted() {
436 let session_key = &[1u8; 32];
437 let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
438
439 let bucket_id = BucketId::from_bytes(
440 &[
441 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
442 ],
443 None,
444 )
445 .unwrap();
446
447 let header = RequestHeader::new(1, Some(bucket_id));
448 let create_req = RequestPacket::new(
449 header,
450 RequestBody::CREATE(SlotRange {
451 from: Some(5),
452 to: Some(7),
453 }),
454 None,
455 );
456
457 let bytes = sut.serialize_request(create_req, false).unwrap();
458 assert_eq!(
459 bytes,
460 vec![
461 53, 193, 133, 121, 169, 180, 199, 145, 54, 54, 159, 110, 145, 89, 36, 72, 33, 199,
462 139, 63, 198, 247, 187, 161, 49, 165, 174, 140, 57, 179, 243, 227, 172, 38, 86, 25,
463 183
464 ]
465 );
466 }
467
468 #[test]
469 fn can_create_request_authenticated() {
470 let session_key = &[1u8; 32];
471 let mut sut =
472 PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
473
474 let bucket_id = BucketId::from_bytes(
475 &[
476 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
477 ],
478 None,
479 )
480 .unwrap();
481
482 let header = RequestHeader::new(1, Some(bucket_id));
483 let create_req = RequestPacket::new(
484 header,
485 RequestBody::CREATE(SlotRange {
486 from: Some(5),
487 to: Some(7),
488 }),
489 None,
490 );
491
492 let bytes = sut.serialize_request(create_req, false).unwrap();
493 assert_eq!(
494 bytes,
495 vec![
496 17, 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3, 0, 5,
497 0, 7, 116, 139, 222, 175, 34, 89, 8, 53, 185, 215, 120, 148, 218, 125, 29, 216
498 ]
499 );
500 }
501
502 #[test]
503 fn can_deserialize_request_authenticated() {
504 let session_key = &[1u8; 32];
505 let mut sut =
506 PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
507
508 let data = [
509 17, 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3, 0, 5, 0, 7,
510 116, 139, 222, 175, 34, 89, 8, 53, 185, 215, 120, 148, 218, 125, 29, 216,
511 ];
512
513 let res = sut.parse_request(&data).unwrap();
514 sut.other_counter -= 1; assert!(res.verify_mac(&sut.key(0x00, false), None));
516 }
517
518 #[test]
519 fn can_parse_encrypted_request() {
520 let session_key = &[1u8; 32];
521 let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
522 let data = [
523 53, 193, 133, 121, 169, 180, 199, 145, 54, 54, 159, 110, 145, 89, 36, 72, 33, 199, 139,
524 63, 198, 247, 187, 161, 49, 165, 174, 140, 57, 179, 243, 227, 172, 38, 86, 25, 183,
525 ];
526
527 let bucket_id = BucketId::from_bytes(
528 &[
529 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
530 ],
531 None,
532 )
533 .unwrap();
534
535 let packet = sut.parse_request(&data).unwrap();
536 let header = packet.get_header();
537 assert_eq!(Some(bucket_id), header.bucket_id);
538 assert_eq!(1, header.packet_type());
539 match packet.body {
540 RequestBody::CREATE(r) => {
541 assert_eq!(Some(5), r.from);
542 assert_eq!(Some(7), r.to);
543 }
544 _ => panic!("Not a create"),
545 }
546 }
547
548 #[test]
549 fn can_create_response_authenticated() {
550 let session_key = &[1u8; 32];
551 let mut sut =
552 PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
553 let mut response = ResponsePacket {
554 header: ResponseHeader::new(TYPE_ERROR, 1),
555 body: ResponseBody::ERROR(7, String::from("An error occured")),
556 mac: None,
557 };
558 response.header.set_mac(true);
559
560 let bytes = sut.serialize_response(response, false).unwrap();
561 assert_eq!(
562 bytes,
563 vec![
564 31, 0, 1, 7, 65, 110, 32, 101, 114, 114, 111, 114, 32, 111, 99, 99, 117, 114, 101,
565 100, 145, 60, 204, 2, 3, 125, 144, 113, 250, 131, 128, 188, 121, 229, 153, 131
566 ]
567 );
568 }
569
570 #[test]
571 fn can_create_response_no_authentication() {
572 let session_key = &[1u8; 32];
573 let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
574 let response = ResponsePacket {
575 header: ResponseHeader::new(TYPE_ERROR, 1),
576 body: ResponseBody::ERROR(7, String::from("An error occured")),
577 mac: None,
578 };
579
580 let bytes = sut.serialize_response(response, false).unwrap();
581 assert_eq!(
582 bytes,
583 vec![
584 15, 0, 1, 7, 65, 110, 32, 101, 114, 114, 111, 114, 32, 111, 99, 99, 117, 114, 101,
585 100
586 ]
587 );
588 }
589
590 #[test]
591 fn can_create_encrypted_response() {
592 let session_key = &[1u8; 32];
593 let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
594
595 let response = ResponsePacket {
596 header: ResponseHeader::new(TYPE_CREATE, 99),
597 body: ResponseBody::CREATE,
598 mac: None,
599 };
600
601 let bytes = sut.serialize_response(response, false).unwrap();
602 assert_eq!(
603 vec![
604 53, 220, 164, 47, 248, 118, 29, 99, 223, 219, 187, 40, 126, 155, 203, 47, 226, 93,
605 56
606 ],
607 bytes
608 );
609 }
610
611 #[test]
612 fn can_parse_encrypted_response() {
613 let session_key = &[1u8; 32];
614 let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
615 let data = &[
616 53, 220, 164, 47, 248, 118, 29, 99, 223, 219, 187, 40, 126, 155, 203, 47, 226, 93, 56,
617 ];
618 let response = sut.parse_response(data).unwrap();
619 assert_eq!(response.header.packet_type(), TYPE_CREATE);
620 assert_eq!(response.header.counter(), 99);
621 }
622}