use crate::{
abstractions::{Serializable, SerializationError, SerializationInfo, KEY_SIZE, MAC_SIZE},
codec::common::{assert_len, dyn_int},
};
use chacha20::{
cipher::{KeyIvInit, StreamCipher},
ChaCha20,
};
use chacha20poly1305::{AeadInPlace, ChaCha20Poly1305, KeyInit};
use poly1305::Poly1305;
use sha2::{Digest, Sha256};
pub trait PtpBody {
fn packet_type(&self) -> u8;
}
pub trait PtpHeaderBase {
fn get_type_and_flags(&self) -> u8;
fn set_type_and_flags(&mut self, type_and_flags: u8);
}
pub trait PtpHeader: PtpHeaderBase {
fn has_mac(&self) -> bool {
self.get_type_and_flags() & (1 << 4) != 0
}
fn set_mac(&mut self, has_mac: bool) {
self.set_type_and_flags(if has_mac {
self.get_type_and_flags() | 1 << 4
} else {
self.get_type_and_flags() & !(1 << 4)
});
}
fn packet_type(&self) -> u8 {
self.get_type_and_flags() & 0b0000_1111 }
fn set_packet_type(&mut self, packet_type: u8) {
self.set_type_and_flags(packet_type & 0b0000_1111);
}
fn flags(&self) -> (bool, bool, bool) {
let type_and_flags = self.get_type_and_flags();
(
type_and_flags & (1 << 5) != 0,
type_and_flags & (1 << 6) != 0,
type_and_flags & (1 << 7) != 0,
)
}
fn set_flags(&mut self, flags: (bool, bool, bool)) {
let mut set_mask = 0b0000_0000;
let mut reset_mask = 0b1111_1111;
if flags.0 {
set_mask += 32;
} else {
reset_mask -= 32;
}
if flags.1 {
set_mask += 64;
} else {
reset_mask -= 64;
}
if flags.2 {
set_mask += 128;
} else {
reset_mask -= 128;
}
let mut type_and_flags = self.get_type_and_flags();
type_and_flags |= set_mask;
type_and_flags &= reset_mask;
self.set_type_and_flags(type_and_flags);
}
}
pub trait PtpPacket<HT, BT>
where
HT: Serializable + PtpHeader,
BT: Serializable + PtpBody,
{
fn get_header(&self) -> &HT;
fn get_body(&self) -> &BT;
fn get_mac(&self) -> Option<&[u8; MAC_SIZE]>;
fn new(header: HT, body: BT, mac: Option<[u8; MAC_SIZE]>) -> Self;
fn from_bytes(data: &[u8], info: SerializationInfo) -> Result<Self, SerializationError>
where
Self: Sized,
{
let header = HT::from_bytes(data, Some(info))?;
let body_size = data.len() - header.size() - if header.has_mac() { MAC_SIZE } else { 0 };
let mut body_bytes = data[header.size()..(body_size + header.size())].to_vec();
assert_eq!(body_size, body_bytes.len());
let mut expected_mac: Option<Vec<u8>> = None;
match info {
SerializationInfo::UseEncryption(_, key1, bucket_key) => {
let cipher = ChaCha20Poly1305::new(&key1.into());
let mut auth_data = Vec::new();
auth_data.append(&mut header.get_bytes());
if let Some(key) = bucket_key {
auth_data.extend_from_slice(&key);
}
match cipher.decrypt_in_place(&[0u8; 12].into(), &auth_data, &mut body_bytes) {
Ok(_) => (),
Err(_) => return Err(SerializationError::DecryptionFailed),
}
}
SerializationInfo::UseAuthentication(key, bucket_key) => {
let mut auth_data = Vec::new();
auth_data.append(&mut header.get_bytes());
auth_data.extend_from_slice(&Sha256::digest(&body_bytes));
if let Some(bucket_key) = bucket_key {
auth_data.extend_from_slice(&bucket_key);
}
let poly = Poly1305::new(&key.into());
expected_mac = Some(poly.compute_unpadded(&auth_data).to_vec());
}
_ => (),
}
let body = BT::from_bytes(
&body_bytes,
Some(SerializationInfo::PacketType(header.packet_type())),
)?;
let mac = if header.has_mac() {
if expected_mac.is_none() {
return Err(SerializationError::MissingInfo(String::from(
"Missing UseAuthentication info to verify the MAC",
)));
}
assert_len(data, header.size() + body_size + MAC_SIZE)?;
let mut mac = [0u8; MAC_SIZE];
let slice = &data[(header.size() + body_size)..];
assert_eq!(MAC_SIZE, slice.len());
mac.copy_from_slice(slice);
if expected_mac.unwrap() != mac {
return Err(SerializationError::AuthenticationFailed);
}
Some(mac)
} else {
None
};
Ok(Self::new(header, body, mac))
}
fn get_bytes(
&self,
info: SerializationInfo,
with_len: bool,
) -> Result<Vec<u8>, SerializationError> {
let mut buff = Vec::new();
let mut header_bytes = self.get_header().get_bytes();
let mut body_bytes = self.get_body().get_bytes();
let mut mac: Option<Vec<u8>> = None;
match info {
SerializationInfo::UseEncryption(key0, key1, bucket_key) => {
let nonce = [0u8; 12];
let mut auth_data = header_bytes.to_vec();
if let Some(bucket_key) = bucket_key {
auth_data.extend_from_slice(&bucket_key);
}
let mut cipher = ChaCha20::new(&key0.into(), &nonce.into());
cipher.apply_keystream(&mut header_bytes);
let cipher = ChaCha20Poly1305::new(&key1.into());
cipher
.encrypt_in_place(&nonce.into(), &auth_data, &mut body_bytes)
.expect("Encryption failed");
}
SerializationInfo::UseAuthentication(key, bucket_key) => {
let mut auth_data = header_bytes.to_vec();
auth_data.extend_from_slice(&Sha256::digest(&body_bytes));
if let Some(bucket_key) = bucket_key {
auth_data.extend_from_slice(&bucket_key);
}
let poly = Poly1305::new(&key.into());
mac = Some(poly.compute_unpadded(&auth_data).to_vec());
}
SerializationInfo::None => (),
other => {
return Err(SerializationError::MissingInfo(format!(
"Needs UseEncryption, UseAuthentication or None method. But {:?} is provided",
other
)))
}
};
if with_len {
let len =
header_bytes.len() + body_bytes.len() + if mac.is_some() { MAC_SIZE } else { 0 };
buff.append(&mut dyn_int::encode(len as u128));
}
buff.append(&mut header_bytes);
buff.append(&mut body_bytes);
if let Some(mut mac) = mac {
buff.append(&mut mac);
}
Ok(buff)
}
fn verify_mac(&self, key: &[u8; KEY_SIZE], bucket_key: Option<[u8; KEY_SIZE]>) -> bool {
if !self.get_header().has_mac() {
eprintln!("[WARN]: Verifying MAC for packet with no MAC present!");
return false;
}
if let Some(self_mac) = self.get_mac() {
let mut auth_data = self.get_header().get_bytes().to_vec();
auth_data.extend_from_slice(&Sha256::digest(self.get_body().get_bytes()));
if let Some(bucket_key) = bucket_key {
auth_data.extend_from_slice(&bucket_key);
}
let poly = Poly1305::new(key.into());
let mac = poly.compute_unpadded(&auth_data).to_vec();
mac == self_mac
} else {
false
}
}
}