1use core::ops::Deref;
8
9pub const HEADER_SIZE: usize = 12;
11pub const TAG_SIZE: usize = 16;
12const FLAGS_IDX: usize = 6;
13const NONCE_OFFSET: usize = 7;
14const MAC_OFFSET: usize = 0;
15const PAYLOAD_OFFSET: usize = HEADER_SIZE;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum Error {
20 Authentication,
22 InvalidFormat,
24 BufferOverflow,
26 AESCounterOverflow,
28 Duplicate,
30 Corrupted,
32 PostcardError,
34}
35
36pub trait Encrypt {
54 fn encrypt(&mut self, key_stream_buf: &mut [u8; 16], block: &mut [u8; 16], key: [u8; 16]);
56}
57
58use crate::Payload;
59#[derive(Debug)]
61pub struct Frame<T: Payload> {
62 pub inner: T::FrameBuf,
63 len: usize,
64}
65impl<T: Payload> Default for Frame<T> {
66 fn default() -> Self {
67 Self {
68 inner: T::new_buf(),
69 len: 0,
70 }
71 }
72}
73impl<T: Payload> Frame<T> {
74 fn new(mac: [u8; 6], flags: u8, raw_nonce: [u8; 5]) -> Result<Self, Error> {
75 let mut frame = Self::default();
76 frame.extend_from_slice(&mac)?;
77 frame.push(flags)?;
78 frame.extend_from_slice(&raw_nonce)?;
79 Ok(frame)
80 }
81
82 fn payload_mut_slice(&mut self) -> &mut [u8] {
83 &mut self.inner.as_mut()[HEADER_SIZE..]
84 }
85 fn finalize(&mut self, payload_len: usize, tag: [u8; 16]) -> Result<(), Error> {
86 self.len += payload_len;
87 self.extend_from_slice(&tag)
88 }
89 pub fn bytes(&self) -> &[u8] {
90 &self.inner.as_ref()[..self.len]
91 }
92
93 pub fn bytes_mut(&mut self) -> &mut [u8] {
94 &mut self.inner.as_mut()[..self.len]
95 }
96
97 fn push(&mut self, byte: u8) -> Result<(), Error> {
98 if self.len >= self.inner.as_ref().len() {
99 return Err(Error::BufferOverflow);
100 }
101 self.inner.as_mut()[self.len] = byte;
102 self.len += 1;
103 Ok(())
104 }
105
106 fn extend_from_slice(&mut self, iter: &[u8]) -> Result<(), Error> {
107 if iter.len() + self.len > self.inner.as_ref().len() {
108 return Err(Error::BufferOverflow);
109 }
110 self.inner.as_mut()[self.len..self.len + iter.len()].copy_from_slice(iter);
111 self.len += iter.len();
112 Ok(())
113 }
114}
115
116#[derive(Debug, PartialEq, Eq, Copy, Clone)]
118pub struct PacketData<T>
119where
120 T: Payload,
121{
122 pub dst: MacAddr,
124 pub flags: u8,
126 pub payload: T,
128}
129
130impl<T> PacketData<T>
131where
132 T: Payload,
133{
134 pub fn new(dst: MacAddr, mut flags: u8, payload: T) -> Self {
137 flags &= 0b_00_111111;
138 Self {
139 dst,
140 flags,
141 payload,
142 }
143 }
144}
145
146#[derive(Debug, Copy, Clone, PartialEq, Eq)]
148pub struct MacAddr {
149 inner: [u8; 6],
150}
151
152impl MacAddr {
153 pub fn new(f1: u8, f2: u8, f3: u8, f4: u8, f5: u8, f6: u8) -> Self {
155 Self {
156 inner: [f1, f2, f3, f4, f5, f6],
157 }
158 }
159}
160
161impl Default for MacAddr {
162 fn default() -> Self {
164 MacAddr {
165 inner: [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
166 }
167 }
168}
169
170impl From<[u8; 6]> for MacAddr {
171 fn from(value: [u8; 6]) -> Self {
172 Self { inner: value }
173 }
174}
175
176impl IntoIterator for MacAddr {
177 type Item = u8;
178 type IntoIter = core::array::IntoIter<u8, 6>;
179 fn into_iter(self) -> Self::IntoIter {
180 self.inner.into_iter()
181 }
182}
183
184impl Deref for MacAddr {
185 type Target = [u8; 6];
186
187 fn deref(&self) -> &Self::Target {
188 &self.inner
189 }
190}
191
192struct Nonce {
194 counter: u64,
195}
196
197impl Nonce {
198 fn inc(&mut self) -> Result<[u8; 5], Error> {
205 const MAX_5_BYTES: u64 = 0xFF_FF_FF_FF_FF;
206 if self.counter >= MAX_5_BYTES {
207 return Err(Error::AESCounterOverflow);
208 }
209 self.counter += 1;
210
211 let bytes = self.counter.to_be_bytes();
212 let mut result = [0_u8; 5];
213 result.copy_from_slice(&bytes[3..8]);
214
215 Ok(result)
216 }
217
218 fn set(&mut self, new_counter: u64) {
220 self.counter = new_counter;
221 }
222}
223
224pub struct PacketView<'a> {
256 bytes: &'a [u8],
257}
258
259impl<'a> PacketView<'a> {
260 pub fn new(bytes: &'a [u8]) -> Result<Self, Error> {
268 Self::try_from(bytes)
269 }
270
271 pub fn mac(&self) -> [u8; 6] {
273 self.bytes[MAC_OFFSET..MAC_OFFSET + 6].try_into().unwrap()
274 }
275
276 pub fn flags(&self) -> u8 {
278 self.bytes[FLAGS_IDX]
279 }
280
281 pub fn raw_nonce(&self) -> [u8; 5] {
283 self.bytes[NONCE_OFFSET..NONCE_OFFSET + 5]
284 .try_into()
285 .unwrap()
286 }
287
288 pub fn nonce(&self) -> u64 {
290 let raw_nonce = self.raw_nonce();
291 u64::from_be_bytes([
292 0,
293 0,
294 0,
295 raw_nonce[0],
296 raw_nonce[1],
297 raw_nonce[2],
298 raw_nonce[3],
299 raw_nonce[4],
300 ])
301 }
302}
303
304impl<'a> TryFrom<&'a [u8]> for PacketView<'a> {
305 type Error = Error;
306
307 fn try_from(bytes: &'a [u8]) -> Result<Self, Self::Error> {
309 if bytes.len() <= HEADER_SIZE + TAG_SIZE {
310 return Err(Error::InvalidFormat);
311 }
312 Ok(Self { bytes })
313 }
314}
315
316struct Parts {
318 pub mac: [u8; 6],
319 pub flags: u8,
320 pub raw_nonce: [u8; 5],
321 pub payload_len: usize,
322 pub tag: [u8; TAG_SIZE],
323}
324
325impl Parts {
326 fn nonce(&self) -> u64 {
328 u64::from_be_bytes([
329 0,
330 0,
331 0,
332 self.raw_nonce[0],
333 self.raw_nonce[1],
334 self.raw_nonce[2],
335 self.raw_nonce[3],
336 self.raw_nonce[4],
337 ])
338 }
339}
340
341impl TryFrom<&[u8]> for Parts {
342 type Error = Error;
343
344 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
346 if bytes.len() <= HEADER_SIZE + TAG_SIZE {
347 return Err(Error::InvalidFormat);
348 }
349 let mac: [u8; 6] = bytes[MAC_OFFSET..MAC_OFFSET + 6].try_into().unwrap();
350 let raw_nonce: [u8; 5] = bytes[NONCE_OFFSET..NONCE_OFFSET + 5].try_into().unwrap();
351 let payload_len = bytes.len() - TAG_SIZE - PAYLOAD_OFFSET;
352
353 let tag: [u8; TAG_SIZE] = bytes[bytes.len() - TAG_SIZE..].try_into().unwrap();
354
355 let flags = bytes[FLAGS_IDX];
356 Ok(Self {
357 mac,
358 flags,
359 raw_nonce,
360 payload_len,
361 tag,
362 })
363 }
364}
365
366pub struct AdHeader {
368 inner: [u8; 12],
369}
370
371impl AdHeader {
372 pub fn new(dst_addr: &[u8; 6], flags: u8, nonce: &[u8; 5]) -> Self {
374 let mut inner = [0_u8; 12];
375 inner[0..6].copy_from_slice(dst_addr);
376 inner[6] = flags;
377 inner[7..].copy_from_slice(nonce);
378 Self { inner }
379 }
380
381 fn u16_be_len(&self) -> [u8; 2] {
383 (self.inner.len() as u16).to_be_bytes()
384 }
385}
386
387impl IntoIterator for AdHeader {
388 type Item = u8;
389 type IntoIter = core::array::IntoIter<u8, 12>;
390 fn into_iter(self) -> Self::IntoIter {
391 self.inner.into_iter()
392 }
393}
394
395impl Deref for AdHeader {
396 type Target = [u8; 12];
397 fn deref(&self) -> &Self::Target {
398 &self.inner
399 }
400}
401
402pub struct AESCCM<E>
404where
405 E: Encrypt,
406{
407 rx_nonce: Nonce,
408 tx_nonce: Nonce,
409 key: [u8; 16],
410 aes: E,
411}
412impl<E> AESCCM<E>
413where
414 E: Encrypt,
415{
416 pub fn new(aes: E, key: [u8; 16]) -> Self {
418 AESCCM {
419 rx_nonce: Nonce { counter: 0 },
420 tx_nonce: Nonce { counter: 0 },
421 key,
422 aes,
423 }
424 }
425
426 pub fn encrypt<T>(&mut self, packet_data: &PacketData<T>) -> Result<Frame<T>, Error>
445 where
446 T: Payload,
447 {
448 let mac = *packet_data.dst;
449 let raw_nonce = self.tx_nonce.inc()?;
450 let mut frame = Frame::new(mac, packet_data.flags, raw_nonce)?;
451
452 let mut payload = postcard::to_slice(&packet_data.payload, frame.payload_mut_slice())
453 .map_err(|_| Error::PostcardError)?;
454
455 let payload_len = payload.len();
456
457 let mut block_buf = [0_u8; 16];
458
459 let b_block = Self::write_b_block(&mut block_buf, mac, raw_nonce, payload_len);
460
461 let ad_header = AdHeader::new(&mac, packet_data.flags, &raw_nonce);
462
463 let mut tag = self.gen_raw_tag(b_block, ad_header, payload);
464
465 let a_block = Self::write_a_block(&mut block_buf, mac, raw_nonce);
466
467 self.xor_tag(&mut tag, a_block);
468
469 self.xor_payload(&mut payload, a_block)?;
470
471 frame.finalize(payload_len, tag)?;
472
473 Ok(frame)
474 }
475
476 pub fn decrypt<T>(&mut self, bytes: &mut [u8]) -> Result<PacketData<T>, Error>
485 where
486 T: Payload,
487 {
488 let parts = Parts::try_from(&*bytes)?;
489 if parts.nonce() <= self.rx_nonce.counter {
490 return Err(Error::Duplicate);
491 }
492
493 let mut payload = &mut bytes[PAYLOAD_OFFSET..PAYLOAD_OFFSET + parts.payload_len];
494
495 let mut block_buf = [0_u8; 16];
496 let a_block = Self::write_a_block(&mut block_buf, parts.mac, parts.raw_nonce);
497 let mut tag = parts.tag;
498
499 self.xor_tag(&mut tag, a_block);
500
501 self.xor_payload(&mut payload, a_block)?;
502
503 let b_block = Self::write_b_block(
504 &mut block_buf,
505 parts.mac,
506 parts.raw_nonce,
507 parts.payload_len,
508 );
509 let ad_header = AdHeader::new(&parts.mac, parts.flags, &parts.raw_nonce);
510
511 let tag_cmp = self.gen_raw_tag(b_block, ad_header, payload);
512 if !Self::is_tag_match_const_time(&tag, &tag_cmp) {
513 return Err(Error::Corrupted);
514 }
515
516 let serialized_payload =
517 postcard::from_bytes::<T>(&payload).map_err(|_| Error::InvalidFormat)?;
518 let packet_data = PacketData::new(parts.mac.into(), parts.flags, serialized_payload);
519 self.rx_nonce.set(parts.nonce());
520 Ok(packet_data)
521 }
522
523 fn write_a_block<'b>(
525 buf: &'b mut [u8; 16],
526 mac: [u8; 6],
527 raw_nonce: [u8; 5],
528 ) -> &'b mut [u8; 16] {
529 const A_NONCE_OFFSET: usize = 7;
530 const A_MAC_OFFSET: usize = 1;
531 buf.fill(0);
532 buf[0] = 4;
533 buf[A_MAC_OFFSET..A_MAC_OFFSET + 6].copy_from_slice(&mac);
534 buf[A_NONCE_OFFSET..A_NONCE_OFFSET + 5].copy_from_slice(&raw_nonce);
535 buf
536 }
537
538 fn write_b_block<'b>(
540 buf: &'b mut [u8; 16],
541 mac: [u8; 6],
542 raw_nonce: [u8; 5],
543 payload_len: usize,
544 ) -> &'b mut [u8; 16] {
545 const B0_FLAGS: u8 = 0b0_1_111_011;
546 buf[..6].copy_from_slice(&mac);
547 buf[6] = B0_FLAGS;
548 buf[7..=11].copy_from_slice(&raw_nonce);
549 buf[12..].copy_from_slice(&(payload_len as u32).to_be_bytes());
550 buf
551 }
552
553 fn gen_raw_tag(
555 &mut self,
556 b_block: &mut [u8; 16],
557 ad_header: AdHeader,
558 payload: &[u8],
559 ) -> [u8; TAG_SIZE] {
560 let mut padded_header = [0_u8; 16];
561 padded_header[0..2].copy_from_slice(&ad_header.u16_be_len());
562 padded_header[2..14].copy_from_slice(&*ad_header);
563
564 let mut key_stream_buf = [0_u8; 16];
565 self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
566 key_stream_buf
567 .iter_mut()
568 .zip(&padded_header)
569 .for_each(|(b, h)| *b ^= h);
570 self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
571 let (chunks, remainder) = payload.as_chunks::<16>();
572 for chunk in chunks {
573 b_block.iter_mut().zip(chunk).for_each(|(b, p)| *b ^= p);
574 self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
575 }
576 key_stream_buf
577 .iter_mut()
578 .zip(remainder)
579 .for_each(|(b, r)| *b ^= r);
580 self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
581
582 b_block[..TAG_SIZE].try_into().unwrap()
583 }
584
585 fn xor_tag(&mut self, tag: &mut [u8; TAG_SIZE], a_block: &mut [u8; 16]) {
587 let mut key_stream_buf = [0_u8; 16];
588 self.aes.encrypt(&mut key_stream_buf, a_block, self.key);
589 for i in 0..TAG_SIZE {
590 tag[i] ^= key_stream_buf[i];
591 }
592 }
593
594 fn xor_payload(&mut self, payload: &mut [u8], mut a_block: &mut [u8; 16]) -> Result<(), Error> {
600 let mut key_stream_buf = [0_u8; 16];
601 let mut counter = 0_u32;
602 let (chunks, remainder) = payload.as_chunks_mut::<16>();
603 for chunk in chunks {
604 counter = counter.checked_add(1).ok_or(Error::AESCounterOverflow)?;
605 [a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
606
607 self.aes
608 .encrypt(&mut key_stream_buf, &mut a_block, self.key);
609 chunk
610 .iter_mut()
611 .zip(key_stream_buf)
612 .for_each(|(c, k)| *c ^= k);
613 }
614 counter = counter.checked_add(1).ok_or(Error::AESCounterOverflow)?;
615 [a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
616 self.aes
617 .encrypt(&mut key_stream_buf, &mut a_block, self.key);
618 remainder
619 .iter_mut()
620 .zip(key_stream_buf)
621 .for_each(|(r, a)| *r ^= a);
622 Ok(())
623 }
624
625 fn is_tag_match_const_time(tag_a: &[u8; TAG_SIZE], tag_b: &[u8; TAG_SIZE]) -> bool {
627 let mut acc = 0;
628
629 for i in 0..TAG_SIZE {
630 acc |= tag_a[i] ^ tag_b[i];
631 }
632 acc == 0
633 }
634}