use macro_bits::{bit, check_bit};
use scroll::{
ctx::{MeasureWith, TryFromCtx, TryIntoCtx},
Pread, Pwrite,
};
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CryptoHeader {
packet_number: [u8; 6],
key_id: u8,
}
impl CryptoHeader {
pub const MAX_PN: u64 = 2u64.pow(48) - 1;
pub const MAX_KEY_ID: u8 = 2u8.pow(2) - 1;
pub fn new(packet_number: u64, key_id: u8) -> Option<Self> {
Self::pn_and_key_id_valid(packet_number, key_id).then_some(Self {
packet_number: packet_number.to_le_bytes()[..6].try_into().unwrap(),
key_id,
})
}
const fn pn_and_key_id_valid(packet_number: u64, key_id: u8) -> bool {
packet_number <= Self::MAX_PN || key_id <= Self::MAX_KEY_ID
}
pub fn packet_number(&self) -> u64 {
let mut extended_packet_number = [0u8; 8];
extended_packet_number[..6].copy_from_slice(self.packet_number.as_slice());
u64::from_le_bytes(extended_packet_number)
}
pub fn key_id(&self) -> u8 {
self.key_id
}
}
impl<'a> TryFromCtx<'a> for CryptoHeader {
type Error = scroll::Error;
fn try_from_ctx(from: &'a [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> {
let mut offset = 0;
let header = from.gread::<[u8; 8]>(&mut offset)?;
let mut packet_number = [0u8; 6];
packet_number[..2].copy_from_slice(&header[..2]);
packet_number[2..].copy_from_slice(&header[4..]);
if !check_bit!(header[3], bit!(5)) {
return Err(scroll::Error::BadInput {
size: offset,
msg: "Ext IV bit not set.",
});
}
let key_id = header[3] << 6;
Ok((
Self {
packet_number,
key_id,
},
offset,
))
}
}
impl TryIntoCtx<()> for CryptoHeader {
type Error = scroll::Error;
fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
let mut offset = 0;
buf.gwrite(&self.packet_number[..2], &mut offset)?;
buf.gwrite(0u8, &mut offset)?;
buf.gwrite(bit!(5) | (self.key_id << 6), &mut offset)?;
buf.gwrite(&self.packet_number[2..], &mut offset)?;
Ok(offset)
}
}
impl MeasureWith<()> for CryptoHeader {
fn measure_with(&self, _ctx: &()) -> usize {
8
}
}
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum MicState {
NotPresent,
Short,
Long,
}
impl MicState {
pub const fn mic_length(&self) -> usize {
match self {
Self::NotPresent => 0,
Self::Short => 8,
Self::Long => 16,
}
}
}
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CryptoWrapper<P> {
pub crypto_header: CryptoHeader,
pub payload: P,
pub mic_state: MicState,
}
impl<'a, P: TryFromCtx<'a, PayloadCtx, Error = scroll::Error>, PayloadCtx: Copy>
TryFromCtx<'a, (MicState, PayloadCtx)> for CryptoWrapper<P>
{
type Error = scroll::Error;
fn try_from_ctx(
from: &'a [u8],
(mic_state, payload_ctx): (MicState, PayloadCtx),
) -> Result<(Self, usize), Self::Error> {
let mut offset = 0;
let crypto_header = from.gread(&mut offset)?;
let mic_length = mic_state.mic_length();
let payload =
from[offset..][..from.len() - offset - mic_length].pread_with(0, payload_ctx)?;
Ok((
Self {
crypto_header,
payload,
mic_state,
},
from.len(),
))
}
}
impl<P: TryIntoCtx<(), Error = scroll::Error>> TryIntoCtx<()> for CryptoWrapper<P> {
type Error = scroll::Error;
fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
let mut offset = 0;
let mic_length = self.mic_state.mic_length();
buf.gwrite(self.crypto_header, &mut offset)?;
buf.gwrite(self.payload, &mut offset)?;
buf[offset..][..mic_length].fill(0);
offset += mic_length;
Ok(offset)
}
}
impl<P: MeasureWith<()>> MeasureWith<()> for CryptoWrapper<P> {
fn measure_with(&self, ctx: &()) -> usize {
self.crypto_header.measure_with(ctx)
+ self.payload.measure_with(ctx)
+ self.mic_state.mic_length()
}
}