1use core::ops::Deref;
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9
10const HEADER_SIZE: usize = 12;
12const MAX_PAYLOAD_SIZE: usize = 64;
13const TAG_SIZE: usize = 16;
14const FLAGS_IDX: usize = 6;
15const NONCE_OFFSET: usize = 7;
16const MAC_OFFSET: usize = 0;
17const PAYLOAD_OFFSET: usize = HEADER_SIZE;
18
19#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
21pub enum Error {
22 Authentication,
24 InvalidFormat,
26 BufferOverflow,
28 AESCounterOverflow,
30 Duplicate,
32 Corrupted,
34 ReservedBytesOverride,
36 Postcard(postcard::Error),
38}
39
40pub trait Encrypt {
58 fn encrypt(&mut self, key_stream_buf: &mut [u8; 16], block: &mut [u8; 16], key: [u8; 16]);
60}
61
62#[derive(Debug)]
64pub struct Frame {
65 inner: [u8; HEADER_SIZE + 4 + MAX_PAYLOAD_SIZE + TAG_SIZE],
66 len: usize,
67}
68impl Default for Frame {
69 fn default() -> Self {
70 Self {
71 inner: [0_u8; HEADER_SIZE + 4 + MAX_PAYLOAD_SIZE + TAG_SIZE],
72 len: 0,
73 }
74 }
75}
76impl Frame {
77 fn new(mac: [u8; 6], flags: u8, raw_nonce: [u8; 5]) -> Result<Self, Error> {
78 let mut frame = Self::default();
79 frame.extend_from_slice(&mac)?;
80 frame.push(flags)?;
81 frame.extend_from_slice(&raw_nonce)?;
82 Ok(frame)
83 }
84
85 fn payload_mut_slice(&mut self) -> &mut [u8] {
86 &mut self.inner[HEADER_SIZE..]
87 }
88 fn finalize(&mut self, payload_len: usize, tag: [u8; 16]) -> Result<(), Error> {
89 self.len += payload_len;
90 self.extend_from_slice(&tag)
91 }
92 pub fn bytes(&self) -> &[u8] {
93 &self.inner[..self.len]
94 }
95
96 pub fn bytes_mut(&mut self) -> &mut [u8] {
97 &mut self.inner[..self.len]
98 }
99
100 fn push(&mut self, byte: u8) -> Result<(), Error> {
101 if self.len >= self.inner.len() {
102 return Err(Error::BufferOverflow);
103 }
104 self.inner[self.len] = byte;
105 self.len += 1;
106 Ok(())
107 }
108
109 fn extend_from_slice(&mut self, iter: &[u8]) -> Result<(), Error> {
110 if iter.len() + self.len > self.inner.len() {
111 return Err(Error::BufferOverflow);
112 }
113 self.inner[self.len..self.len + iter.len()].copy_from_slice(iter);
114 self.len += iter.len();
115 Ok(())
116 }
117}
118
119#[derive(Debug, PartialEq, Eq, Copy, Clone)]
121pub struct PacketData<T>
122where
123 T: Serialize + DeserializeOwned,
124{
125 pub dst: MacAddr,
127 pub flags: u8,
129 pub payload: T,
131}
132
133impl<T> PacketData<T>
134where
135 T: Serialize + DeserializeOwned,
136{
137 pub fn new(dst: MacAddr, mut flags: u8, payload: T) -> Self {
140 flags &= 0b_00_111111;
141 Self {
142 dst,
143 flags,
144 payload,
145 }
146 }
147}
148
149#[derive(Debug, Copy, Clone, PartialEq, Eq)]
151pub struct MacAddr {
152 inner: [u8; 6],
153}
154
155impl MacAddr {
156 pub fn new(f1: u8, f2: u8, f3: u8, f4: u8, f5: u8, f6: u8) -> Self {
158 Self {
159 inner: [f1, f2, f3, f4, f5, f6],
160 }
161 }
162}
163
164impl Default for MacAddr {
165 fn default() -> Self {
167 MacAddr {
168 inner: [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
169 }
170 }
171}
172
173impl From<[u8; 6]> for MacAddr {
174 fn from(value: [u8; 6]) -> Self {
175 Self { inner: value }
176 }
177}
178
179impl IntoIterator for MacAddr {
180 type Item = u8;
181 type IntoIter = core::array::IntoIter<u8, 6>;
182 fn into_iter(self) -> Self::IntoIter {
183 self.inner.into_iter()
184 }
185}
186
187impl Deref for MacAddr {
188 type Target = [u8; 6];
189
190 fn deref(&self) -> &Self::Target {
191 &self.inner
192 }
193}
194
195struct Nonce {
197 counter: u64,
198}
199
200impl Nonce {
201 fn inc(&mut self) -> Result<[u8; 5], Error> {
208 const MAX_5_BYTES: u64 = 0xFF_FF_FF_FF_FF;
209 if self.counter >= MAX_5_BYTES {
210 return Err(Error::AESCounterOverflow);
211 }
212 self.counter += 1;
213
214 let bytes = self.counter.to_be_bytes();
215 let mut result = [0_u8; 5];
216 result.copy_from_slice(&bytes[3..8]);
217
218 Ok(result)
219 }
220
221 fn set(&mut self, new_counter: u64) {
223 self.counter = new_counter;
224 }
225}
226
227pub struct PacketView<'a> {
259 bytes: &'a [u8],
260}
261
262impl<'a> PacketView<'a> {
263 pub fn new(bytes: &'a [u8]) -> Result<Self, Error> {
271 Self::try_from(bytes)
272 }
273
274 pub fn mac(&self) -> [u8; 6] {
276 self.bytes[MAC_OFFSET..MAC_OFFSET + 6].try_into().unwrap()
277 }
278
279 pub fn flags(&self) -> u8 {
281 self.bytes[FLAGS_IDX]
282 }
283
284 pub fn raw_nonce(&self) -> [u8; 5] {
286 self.bytes[NONCE_OFFSET..NONCE_OFFSET + 5]
287 .try_into()
288 .unwrap()
289 }
290
291 pub fn nonce(&self) -> u64 {
293 let raw_nonce = self.raw_nonce();
294 u64::from_be_bytes([
295 0,
296 0,
297 0,
298 raw_nonce[0],
299 raw_nonce[1],
300 raw_nonce[2],
301 raw_nonce[3],
302 raw_nonce[4],
303 ])
304 }
305}
306
307impl<'a> TryFrom<&'a [u8]> for PacketView<'a> {
308 type Error = Error;
309
310 fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
312 if bytes.len() <= HEADER_SIZE + TAG_SIZE {
313 return Err(Error::InvalidFormat);
314 }
315 Ok(Self { bytes })
316 }
317}
318
319struct Parts {
321 pub mac: [u8; 6],
322 pub flags: u8,
323 pub raw_nonce: [u8; 5],
324 pub payload_len: usize,
325 pub tag: [u8; TAG_SIZE],
326}
327
328impl Parts {
329 fn nonce(&self) -> u64 {
331 u64::from_be_bytes([
332 0,
333 0,
334 0,
335 self.raw_nonce[0],
336 self.raw_nonce[1],
337 self.raw_nonce[2],
338 self.raw_nonce[3],
339 self.raw_nonce[4],
340 ])
341 }
342}
343
344impl TryFrom<&[u8]> for Parts {
345 type Error = Error;
346
347 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
349 if bytes.len() <= HEADER_SIZE + TAG_SIZE {
350 return Err(Error::InvalidFormat);
351 }
352 let mac: [u8; 6] = bytes[MAC_OFFSET..MAC_OFFSET + 6].try_into().unwrap();
353 let raw_nonce: [u8; 5] = bytes[NONCE_OFFSET..NONCE_OFFSET + 5].try_into().unwrap();
354 let payload_len = bytes.len() - TAG_SIZE - PAYLOAD_OFFSET;
355
356 let tag: [u8; TAG_SIZE] = bytes[bytes.len() - TAG_SIZE..].try_into().unwrap();
357
358 let flags = bytes[FLAGS_IDX];
359 Ok(Self {
360 mac,
361 flags,
362 raw_nonce,
363 payload_len,
364 tag,
365 })
366 }
367}
368
369pub struct AdHeader {
371 inner: [u8; 12],
372}
373
374impl AdHeader {
375 pub fn new(dst_addr: &[u8; 6], flags: u8, nonce: &[u8; 5]) -> Self {
377 let mut inner = [0_u8; 12];
378 inner[0..6].copy_from_slice(dst_addr);
379 inner[6] = flags;
380 inner[7..].copy_from_slice(nonce);
381 Self { inner }
382 }
383
384 fn u16_be_len(&self) -> [u8; 2] {
386 (self.inner.len() as u16).to_be_bytes()
387 }
388}
389
390impl From<[u8; 16]> for AdHeader {
391 fn from(value: [u8; 16]) -> Self {
392 Self {
393 inner: value[2..14].try_into().unwrap(),
394 }
395 }
396}
397
398impl IntoIterator for AdHeader {
399 type Item = u8;
400 type IntoIter = core::array::IntoIter<u8, 12>;
401 fn into_iter(self) -> Self::IntoIter {
402 self.inner.into_iter()
403 }
404}
405
406impl Deref for AdHeader {
407 type Target = [u8; 12];
408 fn deref(&self) -> &Self::Target {
409 &self.inner
410 }
411}
412
413pub struct AESCCM<E>
415where
416 E: Encrypt,
417{
418 rx_nonce: Nonce,
419 tx_nonce: Nonce,
420 key: [u8; 16],
421 aes: E,
422}
423impl<E> AESCCM<E>
424where
425 E: Encrypt,
426{
427 pub fn new(aes: E, key: [u8; 16]) -> Self {
429 AESCCM {
430 rx_nonce: Nonce { counter: 0 },
431 tx_nonce: Nonce { counter: 0 },
432 key,
433 aes,
434 }
435 }
436
437 pub fn encrypt<T>(&mut self, packet_data: &PacketData<T>) -> Result<Frame, Error>
456 where
457 T: Serialize + DeserializeOwned,
458 {
459 let mac = *packet_data.dst;
460 let raw_nonce = self.tx_nonce.inc()?;
461 let mut frame = Frame::new(mac, packet_data.flags, raw_nonce)?;
462
463 let mut payload = postcard::to_slice(&packet_data.payload, frame.payload_mut_slice())
464 .map_err(|e| Error::Postcard(e))?;
465
466 let payload_len = payload.len();
467
468 let mut block_buf = [0_u8; 16];
469
470 let b_block = Self::write_b_block(&mut block_buf, mac, raw_nonce, payload_len);
471
472 let ad_header = AdHeader::new(&mac, packet_data.flags, &raw_nonce);
473
474 let mut tag = self.gen_raw_tag(b_block, ad_header, payload);
475
476 let a_block = Self::write_a_block(&mut block_buf, mac, raw_nonce);
477
478 self.xor_tag(&mut tag, a_block);
479
480 self.xor_payload(&mut payload, a_block)?;
481
482 frame.finalize(payload_len, tag)?;
483
484 Ok(frame)
485 }
486
487 pub fn decrypt<T>(&mut self, bytes: &mut [u8]) -> Result<PacketData<T>, Error>
496 where
497 T: Serialize + DeserializeOwned,
498 {
499 let parts = Parts::try_from(&*bytes)?;
500 if parts.nonce() <= self.rx_nonce.counter {
501 return Err(Error::Duplicate);
502 }
503
504 let mut payload = &mut bytes[PAYLOAD_OFFSET..PAYLOAD_OFFSET + parts.payload_len];
505
506 let mut block_buf = [0_u8; 16];
507 let a_block = Self::write_a_block(&mut block_buf, parts.mac, parts.raw_nonce);
508 let mut tag = parts.tag;
509
510 self.xor_tag(&mut tag, a_block);
511
512 self.xor_payload(&mut payload, a_block)?;
513
514 let b_block = Self::write_b_block(
515 &mut block_buf,
516 parts.mac,
517 parts.raw_nonce,
518 parts.payload_len,
519 );
520 let ad_header = AdHeader::new(&parts.mac, parts.flags, &parts.raw_nonce);
521
522 let tag_cmp = self.gen_raw_tag(b_block, ad_header, payload);
523 if !Self::is_tag_match_const_time(&tag, &tag_cmp) {
524 return Err(Error::Corrupted);
525 }
526
527 let serialized_payload =
528 postcard::from_bytes::<T>(&payload).map_err(|_| Error::InvalidFormat)?;
529 let packet_data = PacketData::new(parts.mac.into(), parts.flags, serialized_payload);
530 self.rx_nonce.set(parts.nonce());
531 Ok(packet_data)
532 }
533
534 fn write_a_block<'b>(
536 buf: &'b mut [u8; 16],
537 mac: [u8; 6],
538 raw_nonce: [u8; 5],
539 ) -> &'b mut [u8; 16] {
540 const A_NONCE_OFFSET: usize = 7;
541 const A_MAC_OFFSET: usize = 1;
542 buf.fill(0);
543 buf[0] = 4;
544 buf[A_MAC_OFFSET..A_MAC_OFFSET + 6].copy_from_slice(&mac);
545 buf[A_NONCE_OFFSET..A_NONCE_OFFSET + 5].copy_from_slice(&raw_nonce);
546 buf
547 }
548
549 fn write_b_block<'b>(
551 buf: &'b mut [u8; 16],
552 mac: [u8; 6],
553 raw_nonce: [u8; 5],
554 payload_len: usize,
555 ) -> &'b mut [u8; 16] {
556 const B0_FLAGS: u8 = 0b0_1_111_011;
557 buf[..6].copy_from_slice(&mac);
558 buf[6] = B0_FLAGS;
559 buf[7..=11].copy_from_slice(&raw_nonce);
560 buf[12..].copy_from_slice(&(payload_len as u32).to_be_bytes());
561 buf
562 }
563
564 fn gen_raw_tag(
566 &mut self,
567 b_block: &mut [u8; 16],
568 ad_header: AdHeader,
569 payload: &[u8],
570 ) -> [u8; TAG_SIZE] {
571 let mut padded_header = [0_u8; 16];
572 padded_header[0..2].copy_from_slice(&ad_header.u16_be_len());
573 padded_header[2..14].copy_from_slice(&*ad_header);
574
575 let mut key_stream_buf = [0_u8; 16];
576 self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
577 key_stream_buf
578 .iter_mut()
579 .zip(&padded_header)
580 .for_each(|(b, h)| *b ^= h);
581 self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
582 let (chunks, remainder) = payload.as_chunks::<16>();
583 for chunk in chunks {
584 b_block.iter_mut().zip(chunk).for_each(|(b, p)| *b ^= p);
585 self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
586 }
587 key_stream_buf
588 .iter_mut()
589 .zip(remainder)
590 .for_each(|(b, r)| *b ^= r);
591 self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
592
593 b_block[..TAG_SIZE].try_into().unwrap()
594 }
595
596 fn xor_tag(&mut self, tag: &mut [u8; TAG_SIZE], a_block: &mut [u8; 16]) {
598 let mut key_stream_buf = [0_u8; 16];
599 self.aes.encrypt(&mut key_stream_buf, a_block, self.key);
600 for i in 0..TAG_SIZE {
601 tag[i] ^= key_stream_buf[i];
602 }
603 }
604
605 fn xor_payload(&mut self, payload: &mut [u8], mut a_block: &mut [u8; 16]) -> Result<(), Error> {
611 let mut key_stream_buf = [0_u8; 16];
612 let mut counter = 0_u32;
613 let (chunks, remainder) = payload.as_chunks_mut::<16>();
614 for chunk in chunks {
615 counter = counter.checked_add(1).ok_or(Error::AESCounterOverflow)?;
616 [a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
617
618 self.aes
619 .encrypt(&mut key_stream_buf, &mut a_block, self.key);
620 chunk
621 .iter_mut()
622 .zip(key_stream_buf)
623 .for_each(|(c, k)| *c ^= k);
624 }
625 counter = counter.checked_add(1).ok_or(Error::AESCounterOverflow)?;
626 [a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
627 self.aes
628 .encrypt(&mut key_stream_buf, &mut a_block, self.key);
629 remainder
630 .iter_mut()
631 .zip(key_stream_buf)
632 .for_each(|(r, a)| *r ^= a);
633 Ok(())
634 }
635
636 fn is_tag_match_const_time(tag_a: &[u8; TAG_SIZE], tag_b: &[u8; TAG_SIZE]) -> bool {
638 let mut acc = 0;
639
640 for i in 0..TAG_SIZE {
641 acc |= tag_a[i] ^ tag_b[i];
642 }
643 acc == 0
644 }
645}