use std::collections::HashMap;
use hkdf::Hkdf;
use sha2::Sha256;
use crate::{
abstractions::{
Serializable, SerializationError, SerializationInfo, KEY_SIZE, TYPE_APPEND, TYPE_PUT,
TYPE_REQUEST, TYPE_SUBSCRIBE, TYPE_WIPE,
},
codec::{
header::RequestHeader,
objects::BucketId,
ptp_packet::{PtpHeader, PtpPacket},
request::RequestPacket,
response::ResponsePacket,
},
};
pub enum HandlerOptions {
UseEncryption,
UseAuthentication,
None,
}
pub struct PacketHandler {
bucket_keys: HashMap<BucketId, [u8; KEY_SIZE]>,
session_key: Option<[u8; KEY_SIZE]>,
self_counter: u16,
other_counter: u16,
options: HandlerOptions,
}
impl PacketHandler {
pub fn new(
options: HandlerOptions,
bucket_keys: Option<HashMap<BucketId, [u8; KEY_SIZE]>>,
session_key: Option<[u8; KEY_SIZE]>,
) -> Self {
Self {
bucket_keys: bucket_keys.unwrap_or(HashMap::new()),
self_counter: 0,
other_counter: 0,
options,
session_key,
}
}
fn key(&self, nr: u8, my_or_other_counter: bool) -> [u8; KEY_SIZE] {
let mut okm = [0u8; KEY_SIZE]; let kdf = Hkdf::<Sha256>::new(
None,
self.session_key
.as_ref()
.expect("Can't generate keys without session key"),
);
let counter: [u8; 2] = if my_or_other_counter {
self.self_counter.to_be_bytes()
} else {
self.other_counter.to_be_bytes()
};
let mut info = counter.to_vec();
info.push(nr);
kdf.expand(&info, &mut okm).expect("Failed to create key");
okm
}
fn get_info_with_bucket_key_if_needed(
&self,
header: &RequestHeader,
me_or_other: bool,
) -> Result<SerializationInfo, SerializationError> {
let bucket_key = if header.has_bucket_id() {
let id = header.bucket_id.as_ref().unwrap();
let permissons = id.permissions();
if match header.packet_type() {
TYPE_PUT | TYPE_WIPE => !permissons.pub_write,
TYPE_APPEND => !permissons.pub_append && !permissons.pub_write,
TYPE_REQUEST | TYPE_SUBSCRIBE => !permissons.pub_read,
_ => false,
} {
match self.bucket_keys.get(id) {
Some(key) => Some(*key),
None => {
return Err(SerializationError::MissingInfo(format!(
"Bucket key for requested bucket with id #{:?} not present",
id
)))
}
}
} else {
None
}
} else {
None
};
Ok(match self.options {
HandlerOptions::UseEncryption => SerializationInfo::UseEncryption(
self.key(0x00, me_or_other),
self.key(0x01, me_or_other),
bucket_key,
),
HandlerOptions::UseAuthentication => {
SerializationInfo::UseAuthentication(self.key(0x00, me_or_other), bucket_key)
}
_ => SerializationInfo::None,
})
}
pub fn parse_request(&mut self, data: &[u8]) -> Result<RequestPacket, SerializationError> {
let info = match self.options {
HandlerOptions::UseEncryption => {
SerializationInfo::UseEncryption(self.key(0x00, false), self.key(0x01, false), None)
}
_ => SerializationInfo::None,
};
let header = RequestHeader::from_bytes(data, Some(info))?;
let info = self.get_info_with_bucket_key_if_needed(&header, false)?;
match self.other_counter.checked_add(1) {
Some(v) => {
self.other_counter = v;
RequestPacket::from_bytes(data, info)
}
None => Err(SerializationError::CounterOverflow),
}
}
pub fn parse_response(&mut self, data: &[u8]) -> Result<ResponsePacket, SerializationError> {
let info = match self.options {
HandlerOptions::UseEncryption => {
SerializationInfo::UseEncryption(self.key(0x00, false), self.key(0x01, false), None)
}
HandlerOptions::UseAuthentication => {
SerializationInfo::UseAuthentication(self.key(0x00, false), None)
}
_ => SerializationInfo::None,
};
match self.other_counter.checked_add(1) {
Some(v) => {
self.other_counter = v;
ResponsePacket::from_bytes(data, info)
}
None => Err(SerializationError::CounterOverflow),
}
}
pub fn serialize_request(
&mut self,
packet: RequestPacket,
with_len: bool,
) -> Result<Vec<u8>, SerializationError> {
let info = self.get_info_with_bucket_key_if_needed(packet.get_header(), true)?;
let mut packet = packet;
if let HandlerOptions::UseAuthentication = self.options {
packet.header.set_mac(true);
}
match self.self_counter.checked_add(1) {
Some(v) => {
self.self_counter = v;
packet.get_bytes(info, with_len)
}
None => Err(SerializationError::CounterOverflow),
}
}
pub fn serialize_response(
&mut self,
packet: ResponsePacket,
with_len: bool,
) -> Result<Vec<u8>, SerializationError> {
let mut packet = packet;
let info = match self.options {
HandlerOptions::UseEncryption => {
SerializationInfo::UseEncryption(self.key(0x00, true), self.key(0x01, true), None)
}
HandlerOptions::UseAuthentication => {
packet.header.set_mac(true);
SerializationInfo::UseAuthentication(self.key(0x00, true), None)
}
_ => SerializationInfo::None,
};
match self.self_counter.checked_add(1) {
Some(v) => {
self.self_counter = v;
packet.get_bytes(info, with_len)
}
None => Err(SerializationError::CounterOverflow),
}
}
}
#[cfg(test)]
mod test {
use crate::{
abstractions::{TYPE_CREATE, TYPE_ERROR},
codec::{
common::SlotRange, header::ResponseHeader, request::RequestBody, response::ResponseBody,
},
};
use super::*;
#[test]
fn can_deserialize_connect_request() {
let req = &[
0, 0x77u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
];
let mut handler = PacketHandler::new(HandlerOptions::None, None, None);
let request = handler.parse_request(req).unwrap();
match request.get_body() {
RequestBody::CONNECT {
protocol_version, ..
} => {
assert_eq!(0x77, *protocol_version);
}
_ => panic!("Not a CONNECT"),
}
assert_eq!(0, handler.self_counter);
assert_eq!(1, handler.other_counter);
}
#[test]
fn create_does_never_give_bucket_key() {
let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
let id = BucketId::new(7);
let header = RequestHeader::new(1, Some(id.clone()));
sut.bucket_keys.insert(id, [1u8; 32]);
let res = sut
.get_info_with_bucket_key_if_needed(&header, false)
.unwrap();
match res {
SerializationInfo::UseAuthentication(_, bucket_key) => {
assert_eq!(None, bucket_key);
}
_ => panic!("Wrong type"),
}
}
#[test]
fn put_with_public_write_does_not_give_bucket_id() {
let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
let mut id = BucketId::new(7);
let mut permissions = id.permissions();
permissions.pub_write = true;
id.set_permissions(permissions);
let header = RequestHeader::new(TYPE_PUT, Some(id.clone()));
sut.bucket_keys.insert(id, [1u8; 32]);
let res = sut
.get_info_with_bucket_key_if_needed(&header, false)
.unwrap();
match res {
SerializationInfo::UseAuthentication(_, bucket_key) => {
assert_eq!(None, bucket_key);
}
_ => panic!("Wrong type"),
}
}
#[test]
fn append_without_public_append_does_give_bucket_id() {
let mut sut = PacketHandler::new(HandlerOptions::UseAuthentication, None, Some([1u8; 32]));
let id = BucketId::new(7);
let header = RequestHeader::new(TYPE_APPEND, Some(id.clone()));
sut.bucket_keys.insert(id, [1u8; 32]);
let res = sut
.get_info_with_bucket_key_if_needed(&header, false)
.unwrap();
match res {
SerializationInfo::UseAuthentication(_, bucket_key) => {
assert_eq!(Some([1u8; 32]), bucket_key);
}
_ => panic!("Wrong type"),
}
}
#[test]
fn can_generate_shared_key_with_counter_1() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
sut.self_counter = 1;
let key = sut.key(0x01, true);
assert_eq!(
key,
[
110, 223, 136, 196, 67, 61, 170, 231, 138, 234, 119, 93, 152, 169, 168, 18, 199,
27, 204, 11, 103, 191, 208, 199, 202, 145, 91, 96, 88, 228, 138, 41
]
);
}
#[test]
fn can_generate_shared_key_with_counter_7() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
sut.other_counter = 7;
let key = sut.key(0x00, false);
assert_eq!(
key,
[
86, 124, 77, 37, 137, 217, 171, 207, 121, 144, 71, 67, 148, 195, 193, 134, 219,
223, 221, 216, 210, 66, 219, 166, 197, 113, 208, 166, 61, 206, 218, 1
]
);
}
#[test]
fn can_generate_encryption_info_with_bucket_key() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
let id = BucketId::new(5);
sut.bucket_keys.insert(
id.clone(),
[
1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 0, 1, 2,
],
);
let header = &RequestHeader::new(2, Some(id));
let info = sut
.get_info_with_bucket_key_if_needed(header, true)
.unwrap();
match info {
SerializationInfo::UseEncryption(a, b, _) => {
assert_eq!(
a,
[
115, 13, 138, 42, 229, 171, 252, 201, 236, 154, 27, 170, 98, 19, 64, 200,
31, 27, 219, 82, 215, 38, 186, 156, 26, 126, 19, 36, 137, 132, 170, 129
]
);
assert_eq!(
b,
[
240, 166, 168, 233, 19, 0, 183, 68, 176, 91, 91, 69, 182, 111, 141, 82,
242, 142, 215, 82, 17, 88, 104, 210, 166, 49, 26, 152, 54, 245, 171, 80
]
);
}
_ => panic!("Wrong type"),
};
}
#[test]
fn can_create_request_encrypted() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
let bucket_id = BucketId::from_bytes(
&[
29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
],
None,
)
.unwrap();
let header = RequestHeader::new(1, Some(bucket_id));
let create_req = RequestPacket::new(
header,
RequestBody::CREATE(SlotRange {
from: Some(5),
to: Some(7),
}),
None,
);
let bytes = sut.serialize_request(create_req, false).unwrap();
assert_eq!(
bytes,
vec![
53, 193, 133, 121, 169, 180, 199, 145, 54, 54, 159, 110, 145, 89, 36, 72, 33, 199,
139, 63, 198, 247, 187, 161, 49, 165, 174, 140, 57, 179, 243, 227, 172, 38, 86, 25,
183
]
);
}
#[test]
fn can_create_request_authenticated() {
let session_key = &[1u8; 32];
let mut sut =
PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
let bucket_id = BucketId::from_bytes(
&[
29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
],
None,
)
.unwrap();
let header = RequestHeader::new(1, Some(bucket_id));
let create_req = RequestPacket::new(
header,
RequestBody::CREATE(SlotRange {
from: Some(5),
to: Some(7),
}),
None,
);
let bytes = sut.serialize_request(create_req, false).unwrap();
assert_eq!(
bytes,
vec![
17, 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3, 0, 5,
0, 7, 116, 139, 222, 175, 34, 89, 8, 53, 185, 215, 120, 148, 218, 125, 29, 216
]
);
}
#[test]
fn can_deserialize_request_authenticated() {
let session_key = &[1u8; 32];
let mut sut =
PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
let data = [
17, 29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3, 0, 5, 0, 7,
116, 139, 222, 175, 34, 89, 8, 53, 185, 215, 120, 148, 218, 125, 29, 216,
];
let res = sut.parse_request(&data).unwrap();
sut.other_counter -= 1; assert!(res.verify_mac(&sut.key(0x00, false), None));
}
#[test]
fn can_parse_encrypted_request() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
let data = [
53, 193, 133, 121, 169, 180, 199, 145, 54, 54, 159, 110, 145, 89, 36, 72, 33, 199, 139,
63, 198, 247, 187, 161, 49, 165, 174, 140, 57, 179, 243, 227, 172, 38, 86, 25, 183,
];
let bucket_id = BucketId::from_bytes(
&[
29, 66, 250, 236, 114, 144, 177, 199, 69, 119, 210, 222, 85, 137, 7, 3,
],
None,
)
.unwrap();
let packet = sut.parse_request(&data).unwrap();
let header = packet.get_header();
assert_eq!(Some(bucket_id), header.bucket_id);
assert_eq!(1, header.packet_type());
match packet.body {
RequestBody::CREATE(r) => {
assert_eq!(Some(5), r.from);
assert_eq!(Some(7), r.to);
}
_ => panic!("Not a create"),
}
}
#[test]
fn can_create_response_authenticated() {
let session_key = &[1u8; 32];
let mut sut =
PacketHandler::new(HandlerOptions::UseAuthentication, None, Some(*session_key));
let mut response = ResponsePacket {
header: ResponseHeader::new(TYPE_ERROR, 1),
body: ResponseBody::ERROR(7, String::from("An error occured")),
mac: None,
};
response.header.set_mac(true);
let bytes = sut.serialize_response(response, false).unwrap();
assert_eq!(
bytes,
vec![
31, 0, 1, 7, 65, 110, 32, 101, 114, 114, 111, 114, 32, 111, 99, 99, 117, 114, 101,
100, 145, 60, 204, 2, 3, 125, 144, 113, 250, 131, 128, 188, 121, 229, 153, 131
]
);
}
#[test]
fn can_create_response_no_authentication() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::None, None, Some(*session_key));
let response = ResponsePacket {
header: ResponseHeader::new(TYPE_ERROR, 1),
body: ResponseBody::ERROR(7, String::from("An error occured")),
mac: None,
};
let bytes = sut.serialize_response(response, false).unwrap();
assert_eq!(
bytes,
vec![
15, 0, 1, 7, 65, 110, 32, 101, 114, 114, 111, 114, 32, 111, 99, 99, 117, 114, 101,
100
]
);
}
#[test]
fn can_create_encrypted_response() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
let response = ResponsePacket {
header: ResponseHeader::new(TYPE_CREATE, 99),
body: ResponseBody::CREATE,
mac: None,
};
let bytes = sut.serialize_response(response, false).unwrap();
assert_eq!(
vec![
53, 220, 164, 47, 248, 118, 29, 99, 223, 219, 187, 40, 126, 155, 203, 47, 226, 93,
56
],
bytes
);
}
#[test]
fn can_parse_encrypted_response() {
let session_key = &[1u8; 32];
let mut sut = PacketHandler::new(HandlerOptions::UseEncryption, None, Some(*session_key));
let data = &[
53, 220, 164, 47, 248, 118, 29, 99, 223, 219, 187, 40, 126, 155, 203, 47, 226, 93, 56,
];
let response = sut.parse_response(data).unwrap();
assert_eq!(response.header.packet_type(), TYPE_CREATE);
assert_eq!(response.header.counter(), 99);
}
}