use core::ops::Deref;
use serde::{Deserialize, Serialize};
const MAX_PAYLOAD_SIZE: usize = 64;
const HEADER_SIZE: usize = 12;
const TAG_SIZE: usize = 8;
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum PacketError {
Authentication,
InvalidFormat,
BufferOverflow,
AESCounterOverflow,
Duplicate,
Corrupted,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Command {
Toggle(Component),
On(Component),
Off(Component),
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Component {
Led(u8),
}
pub trait Encrypt {
fn encrypt(&mut self, key_stream_buf: &mut [u8; 16], a_block: &mut [u8; 16], key: [u8; 16]);
}
#[derive(Debug)]
pub struct AESCCMPacket {
pub inner: heapless::Vec<u8, { HEADER_SIZE + 4 + MAX_PAYLOAD_SIZE + TAG_SIZE }>,
}
impl AESCCMPacket {
pub fn new() -> Self {
Self {
inner: heapless::Vec::new(),
}
}
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = u8>,
{
self.inner.extend(iter);
}
fn extend_from_slice(&mut self, iter: &[u8]) {
self.inner.extend_from_slice(iter).unwrap();
}
fn push(&mut self, item: u8) {
self.inner.push(item).unwrap();
}
}
impl Default for AESCCMPacket {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct PacketData {
pub dst: MacAddr,
pub flags: u8,
pub cmd: Command,
}
impl PacketData {
pub fn new(dst: MacAddr, flags: u8, cmd: Command) -> Self {
Self { dst, flags, cmd }
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct MacAddr {
inner: [u8; 6],
}
impl MacAddr {
pub fn new(f1: u8, f2: u8, f3: u8, f4: u8, f5: u8, f6: u8) -> Self {
Self {
inner: [f1, f2, f3, f4, f5, f6],
}
}
}
impl Default for MacAddr {
fn default() -> Self {
MacAddr {
inner: [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
}
}
}
impl From<[u8; 6]> for MacAddr {
fn from(value: [u8; 6]) -> Self {
Self { inner: value }
}
}
impl IntoIterator for MacAddr {
type Item = u8;
type IntoIter = core::array::IntoIter<u8, 6>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
impl Deref for MacAddr {
type Target = [u8; 6];
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct Nonce {
counter: u64,
}
impl Nonce {
pub fn inc(&mut self) -> Result<[u8; 5], PacketError> {
const MAX_5_BYTES: u64 = 0xFF_FF_FF_FF_FF;
if self.counter >= MAX_5_BYTES {
return Err(PacketError::AESCounterOverflow);
}
self.counter += 1;
let bytes = self.counter.to_be_bytes();
let mut result = [0_u8; 5];
result.copy_from_slice(&bytes[3..8]);
Ok(result)
}
pub fn set(&mut self, new_counter: u64) {
self.counter = new_counter;
}
}
struct PacketView {
pub mac: [u8; 6],
pub flags: u8,
pub raw_nonce: [u8; 5],
pub payload_len: usize,
pub tag: [u8; 8],
}
impl PacketView {
const FLAGS_IDX: usize = 6;
const NONCE_OFFSET: usize = 7;
const MAC_OFFSET: usize = 0;
const PAYLOAD_OFFSET: usize = HEADER_SIZE;
pub fn nonce(&self) -> u64 {
u64::from_be_bytes([
0,
0,
0,
self.raw_nonce[0],
self.raw_nonce[1],
self.raw_nonce[2],
self.raw_nonce[3],
self.raw_nonce[4],
])
}
}
impl TryFrom<&[u8]> for PacketView {
type Error = PacketError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() <= HEADER_SIZE + TAG_SIZE {
return Err(PacketError::InvalidFormat);
}
let mac: [u8; 6] = bytes[Self::MAC_OFFSET..Self::MAC_OFFSET + 6]
.try_into()
.unwrap();
let raw_nonce: [u8; 5] = bytes[Self::NONCE_OFFSET..Self::NONCE_OFFSET + 5]
.try_into()
.unwrap();
let payload_len = bytes.len() - TAG_SIZE - Self::PAYLOAD_OFFSET;
let tag: [u8; 8] = bytes[bytes.len() - TAG_SIZE..].try_into().unwrap();
let flags = bytes[Self::FLAGS_IDX];
Ok(Self {
mac,
flags,
raw_nonce,
payload_len,
tag,
})
}
}
pub struct AdHeader {
inner: [u8; 12],
}
impl AdHeader {
pub fn new(dst_addr: &[u8; 6], flags: u8, nonce: &[u8; 5]) -> Self {
let mut inner = [0_u8; 12];
inner[0..6].copy_from_slice(dst_addr);
inner[6] = flags;
inner[7..].copy_from_slice(nonce);
Self { inner }
}
pub fn u16_be_len(&self) -> [u8; 2] {
(self.inner.len() as u16).to_be_bytes()
}
}
impl From<[u8; 16]> for AdHeader {
fn from(value: [u8; 16]) -> Self {
Self {
inner: value[2..13].try_into().unwrap(),
}
}
}
impl IntoIterator for AdHeader {
type Item = u8;
type IntoIter = core::array::IntoIter<u8, 12>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
impl Deref for AdHeader {
type Target = [u8; 12];
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct AESCCM<T: Encrypt> {
rx_nonce: Nonce,
tx_nonce: Nonce,
key: [u8; 16],
aes: T,
}
impl<T: Encrypt> AESCCM<T> {
pub fn new(aes: T, key: [u8; 16]) -> Self {
AESCCM {
rx_nonce: Nonce { counter: 0 },
tx_nonce: Nonce { counter: 0 },
key,
aes,
}
}
pub fn encrypt(&mut self, packet_data: PacketData) -> Result<AESCCMPacket, PacketError> {
let mut buf = [0_u8; MAX_PAYLOAD_SIZE];
let payload = postcard::to_slice(&packet_data.cmd, &mut buf)
.map_err(|_| PacketError::BufferOverflow)?;
let payload_len = payload.len();
let mut block_buf = [0_u8; 16];
let raw_nonce = self.tx_nonce.inc()?;
let mac_addr = packet_data.dst;
let b_block = Self::write_b_block(
&mut block_buf,
*packet_data.dst,
packet_data.flags,
raw_nonce,
payload_len,
);
let ad_header = AdHeader::new(&mac_addr, packet_data.flags, &raw_nonce);
let mut tag = self.gen_raw_tag(b_block, ad_header, payload);
let a_block = Self::write_a_block(&mut block_buf, *mac_addr, raw_nonce);
self.xor_tag(&mut tag, a_block);
self.xor_payload(payload, a_block)?;
let mut payload_vec = AESCCMPacket::new();
payload_vec.extend(mac_addr);
payload_vec.push(packet_data.flags);
payload_vec.extend(raw_nonce);
payload_vec.extend_from_slice(payload);
payload_vec.extend(tag);
Ok(payload_vec)
}
pub fn decrypt(&mut self, bytes: &mut [u8]) -> Result<PacketData, PacketError> {
let view = PacketView::try_from(&*bytes)?;
if view.nonce() <= self.rx_nonce.counter {
return Err(PacketError::Duplicate);
}
let mut payload =
&mut bytes[PacketView::PAYLOAD_OFFSET..PacketView::PAYLOAD_OFFSET + view.payload_len];
let mut block_buf = [0_u8; 16];
let a_block = Self::write_a_block(&mut block_buf, view.mac, view.raw_nonce);
let mut tag = view.tag;
self.xor_tag(&mut tag, a_block);
self.xor_payload(&mut payload, a_block)?;
let b_block = Self::write_b_block(
&mut block_buf,
view.mac,
view.flags,
view.raw_nonce,
view.payload_len,
);
let ad_header = AdHeader::new(&view.mac, view.flags, &view.raw_nonce);
let tag_cmp = self.gen_raw_tag(b_block, ad_header, payload);
if !Self::is_tag_match_const_time(&tag, &tag_cmp) {
return Err(PacketError::Corrupted);
}
let cmd =
postcard::from_bytes::<Command>(&payload).map_err(|_| PacketError::InvalidFormat)?;
let packet_data = PacketData::new(view.mac.into(), view.flags, cmd);
self.rx_nonce.set(view.nonce());
Ok(packet_data)
}
pub fn write_a_block<'b>(
buf: &'b mut [u8; 16],
mac: [u8; 6],
raw_nonce: [u8; 5],
) -> &'b mut [u8; 16] {
const NONCE_OFFSET: usize = 7;
const MAC_OFFSET: usize = 1;
buf.fill(0);
buf[0] = 4;
buf[MAC_OFFSET..MAC_OFFSET + 6].copy_from_slice(&mac);
buf[NONCE_OFFSET..NONCE_OFFSET + 5].copy_from_slice(&raw_nonce);
buf
}
pub fn write_b_block<'b>(
buf: &'b mut [u8; 16],
mac: [u8; 6],
flags: u8,
raw_nonce: [u8; 5],
payload_len: usize,
) -> &'b mut [u8; 16] {
buf[..6].copy_from_slice(&mac);
buf[6] = flags;
buf[7..=11].copy_from_slice(&raw_nonce);
buf[12..].copy_from_slice(&(payload_len as u32).to_be_bytes());
buf
}
pub fn gen_raw_tag(
&mut self,
b_block: &mut [u8; 16],
ad_header: AdHeader,
payload: &[u8],
) -> [u8; 8] {
let mut padded_header = [0_u8; 16];
padded_header[0..2].copy_from_slice(&ad_header.u16_be_len());
padded_header[2..14].copy_from_slice(&*ad_header);
let mut key_stream_buf = [0_u8; 16];
self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
key_stream_buf
.iter_mut()
.zip(&padded_header)
.for_each(|(b, h)| *b ^= h);
self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
let (chunks, remainder) = payload.as_chunks::<16>();
for chunk in chunks {
b_block.iter_mut().zip(chunk).for_each(|(b, p)| *b ^= p);
self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
}
key_stream_buf
.iter_mut()
.zip(remainder)
.for_each(|(b, r)| *b ^= r);
self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
b_block[..8].try_into().unwrap()
}
pub fn xor_tag(&mut self, tag: &mut [u8; 8], a_block: &mut [u8; 16]) {
let mut key_stream_buf = [0_u8; 16];
self.aes.encrypt(&mut key_stream_buf, a_block, self.key);
for i in 0..8 {
tag[i] ^= key_stream_buf[i];
}
}
pub fn xor_payload(
&mut self,
payload: &mut [u8],
mut a_block: &mut [u8; 16],
) -> Result<(), PacketError> {
let mut key_stream_buf = [0_u8; 16];
let mut counter = 0_u32;
let (chunks, remainder) = payload.as_chunks_mut::<16>();
for chunk in chunks {
counter = counter
.checked_add(1)
.ok_or(PacketError::AESCounterOverflow)?;
[a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
self.aes
.encrypt(&mut key_stream_buf, &mut a_block, self.key);
chunk
.iter_mut()
.zip(key_stream_buf)
.for_each(|(c, k)| *c ^= k);
}
counter = counter
.checked_add(1)
.ok_or(PacketError::AESCounterOverflow)?;
[a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
self.aes
.encrypt(&mut key_stream_buf, &mut a_block, self.key);
remainder
.iter_mut()
.zip(key_stream_buf)
.for_each(|(r, a)| *r ^= a);
Ok(())
}
pub fn is_tag_match_const_time(tag_a: &[u8; 8], tag_b: &[u8; 8]) -> bool {
let mut acc = 0;
for i in 0..8 {
acc |= tag_a[i] ^ tag_b[i];
}
acc == 0
}
}