1use core::ops::Deref;
9use serde::{Deserialize, Serialize};
10
11const MAX_PAYLOAD_SIZE: usize = 64;
13const HEADER_SIZE: usize = 12;
14const TAG_SIZE: usize = 8;
15
16#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PacketError {
19 Authentication,
21 InvalidFormat,
23 BufferOverflow,
25 AESCounterOverflow,
27 Duplicate,
29 Corrupted,
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
35pub enum Command {
36 Toggle(Component),
38 On(Component),
40 Off(Component),
42}
43
44#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
46pub enum Component {
47 Led(u8),
49}
50
51pub trait Encrypt {
73 fn encrypt(&mut self, key_stream_buf: &mut [u8; 16], a_block: &mut [u8; 16], key: [u8; 16]);
75}
76
77#[derive(Debug)]
82pub struct AESCCMPacket {
83 pub inner: heapless::Vec<u8, { HEADER_SIZE + 4 + MAX_PAYLOAD_SIZE + TAG_SIZE }>,
85}
86
87impl AESCCMPacket {
88 pub fn new() -> Self {
90 Self {
91 inner: heapless::Vec::new(),
92 }
93 }
94
95 fn extend<I>(&mut self, iter: I)
96 where
97 I: IntoIterator<Item = u8>,
98 {
99 self.inner.extend(iter);
100 }
101
102 fn extend_from_slice(&mut self, iter: &[u8]) {
103 self.inner.extend_from_slice(iter).unwrap();
104 }
105
106 fn push(&mut self, item: u8) {
107 self.inner.push(item).unwrap();
108 }
109}
110
111impl Default for AESCCMPacket {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117#[derive(Debug)]
119pub struct PacketData {
120 pub dst: MacAddr,
122 pub flags: u8,
124 pub cmd: Command,
126}
127
128impl PacketData {
129 pub fn new(dst: MacAddr, flags: u8, cmd: Command) -> Self {
131 Self { dst, flags, cmd }
132 }
133}
134
135#[derive(Debug, Copy, Clone, PartialEq, Eq)]
137pub struct MacAddr {
138 inner: [u8; 6],
139}
140
141impl MacAddr {
142 pub fn new(f1: u8, f2: u8, f3: u8, f4: u8, f5: u8, f6: u8) -> Self {
144 Self {
145 inner: [f1, f2, f3, f4, f5, f6],
146 }
147 }
148}
149
150impl Default for MacAddr {
151 fn default() -> Self {
153 MacAddr {
154 inner: [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
155 }
156 }
157}
158
159impl From<[u8; 6]> for MacAddr {
160 fn from(value: [u8; 6]) -> Self {
161 Self { inner: value }
162 }
163}
164
165impl IntoIterator for MacAddr {
166 type Item = u8;
167 type IntoIter = core::array::IntoIter<u8, 6>;
168 fn into_iter(self) -> Self::IntoIter {
169 self.inner.into_iter()
170 }
171}
172
173impl Deref for MacAddr {
174 type Target = [u8; 6];
175
176 fn deref(&self) -> &Self::Target {
177 &self.inner
178 }
179}
180
181pub struct Nonce {
183 counter: u64,
184}
185
186impl Nonce {
187 pub fn inc(&mut self) -> Result<[u8; 5], PacketError> {
194 const MAX_5_BYTES: u64 = 0xFF_FF_FF_FF_FF;
195 if self.counter >= MAX_5_BYTES {
196 return Err(PacketError::AESCounterOverflow);
197 }
198 self.counter += 1;
199
200 let bytes = self.counter.to_be_bytes();
201 let mut result = [0_u8; 5];
202 result.copy_from_slice(&bytes[3..8]);
203
204 Ok(result)
205 }
206
207 pub fn set(&mut self, new_counter: u64) {
209 self.counter = new_counter;
210 }
211}
212
213struct PacketView {
215 pub mac: [u8; 6],
216 pub flags: u8,
217 pub raw_nonce: [u8; 5],
218 pub payload_len: usize,
219 pub tag: [u8; 8],
220}
221
222impl PacketView {
223 const FLAGS_IDX: usize = 6;
224 const NONCE_OFFSET: usize = 7;
225 const MAC_OFFSET: usize = 0;
226 const PAYLOAD_OFFSET: usize = HEADER_SIZE;
227
228 pub fn nonce(&self) -> u64 {
230 u64::from_be_bytes([
231 0,
232 0,
233 0,
234 self.raw_nonce[0],
235 self.raw_nonce[1],
236 self.raw_nonce[2],
237 self.raw_nonce[3],
238 self.raw_nonce[4],
239 ])
240 }
241}
242
243impl TryFrom<&[u8]> for PacketView {
244 type Error = PacketError;
245
246 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
248 if bytes.len() <= HEADER_SIZE + TAG_SIZE {
249 return Err(PacketError::InvalidFormat);
250 }
251 let mac: [u8; 6] = bytes[Self::MAC_OFFSET..Self::MAC_OFFSET + 6]
252 .try_into()
253 .unwrap();
254 let raw_nonce: [u8; 5] = bytes[Self::NONCE_OFFSET..Self::NONCE_OFFSET + 5]
255 .try_into()
256 .unwrap();
257 let payload_len = bytes.len() - TAG_SIZE - Self::PAYLOAD_OFFSET;
258
259 let tag: [u8; 8] = bytes[bytes.len() - TAG_SIZE..].try_into().unwrap();
260
261 let flags = bytes[Self::FLAGS_IDX];
262 Ok(Self {
263 mac,
264 flags,
265 raw_nonce,
266 payload_len,
267 tag,
268 })
269 }
270}
271
272pub struct AdHeader {
274 inner: [u8; 12],
275}
276
277impl AdHeader {
278 pub fn new(dst_addr: &[u8; 6], flags: u8, nonce: &[u8; 5]) -> Self {
280 let mut inner = [0_u8; 12];
281 inner[0..6].copy_from_slice(dst_addr);
282 inner[6] = flags;
283 inner[7..].copy_from_slice(nonce);
284 Self { inner }
285 }
286
287 pub fn u16_be_len(&self) -> [u8; 2] {
289 (self.inner.len() as u16).to_be_bytes()
290 }
291}
292
293impl From<[u8; 16]> for AdHeader {
294 fn from(value: [u8; 16]) -> Self {
295 Self {
296 inner: value[2..13].try_into().unwrap(),
297 }
298 }
299}
300
301impl IntoIterator for AdHeader {
302 type Item = u8;
303 type IntoIter = core::array::IntoIter<u8, 12>;
304 fn into_iter(self) -> Self::IntoIter {
305 self.inner.into_iter()
306 }
307}
308
309impl Deref for AdHeader {
310 type Target = [u8; 12];
311 fn deref(&self) -> &Self::Target {
312 &self.inner
313 }
314}
315
316pub struct AESCCM<T: Encrypt> {
318 rx_nonce: Nonce,
319 tx_nonce: Nonce,
320 key: [u8; 16],
321 aes: T,
322}
323
324impl<T: Encrypt> AESCCM<T> {
325 pub fn new(aes: T, key: [u8; 16]) -> Self {
327 AESCCM {
328 rx_nonce: Nonce { counter: 0 },
329 tx_nonce: Nonce { counter: 0 },
330 key,
331 aes,
332 }
333 }
334
335 pub fn encrypt(&mut self, packet_data: PacketData) -> Result<AESCCMPacket, PacketError> {
354 let mut buf = [0_u8; MAX_PAYLOAD_SIZE];
355 let payload = postcard::to_slice(&packet_data.cmd, &mut buf)
356 .map_err(|_| PacketError::BufferOverflow)?;
357 let payload_len = payload.len();
358
359 let mut block_buf = [0_u8; 16];
360
361 let raw_nonce = self.tx_nonce.inc()?;
362 let mac_addr = packet_data.dst;
363
364 let b_block = Self::write_b_block(
365 &mut block_buf,
366 *packet_data.dst,
367 packet_data.flags,
368 raw_nonce,
369 payload_len,
370 );
371
372 let ad_header = AdHeader::new(&mac_addr, packet_data.flags, &raw_nonce);
373
374 let mut tag = self.gen_raw_tag(b_block, ad_header, payload);
375
376 let a_block = Self::write_a_block(&mut block_buf, *mac_addr, raw_nonce);
377
378 self.xor_tag(&mut tag, a_block);
379
380 self.xor_payload(payload, a_block)?;
381
382 let mut payload_vec = AESCCMPacket::new();
383 payload_vec.extend(mac_addr);
384 payload_vec.push(packet_data.flags);
385 payload_vec.extend(raw_nonce);
386 payload_vec.extend_from_slice(payload);
387 payload_vec.extend(tag);
388 Ok(payload_vec)
389 }
390
391 pub fn decrypt(&mut self, bytes: &mut [u8]) -> Result<PacketData, PacketError> {
400 let view = PacketView::try_from(&*bytes)?;
401 if view.nonce() <= self.rx_nonce.counter {
402 return Err(PacketError::Duplicate);
403 }
404
405 let mut payload =
406 &mut bytes[PacketView::PAYLOAD_OFFSET..PacketView::PAYLOAD_OFFSET + view.payload_len];
407
408 let mut block_buf = [0_u8; 16];
409 let a_block = Self::write_a_block(&mut block_buf, view.mac, view.raw_nonce);
410 let mut tag = view.tag;
411
412 self.xor_tag(&mut tag, a_block);
413
414 self.xor_payload(&mut payload, a_block)?;
415
416 let b_block = Self::write_b_block(
417 &mut block_buf,
418 view.mac,
419 view.flags,
420 view.raw_nonce,
421 view.payload_len,
422 );
423 let ad_header = AdHeader::new(&view.mac, view.flags, &view.raw_nonce);
424
425 let tag_cmp = self.gen_raw_tag(b_block, ad_header, payload);
426 if !Self::is_tag_match_const_time(&tag, &tag_cmp) {
427 return Err(PacketError::Corrupted);
428 }
429
430 let cmd =
431 postcard::from_bytes::<Command>(&payload).map_err(|_| PacketError::InvalidFormat)?;
432 let packet_data = PacketData::new(view.mac.into(), view.flags, cmd);
433 self.rx_nonce.set(view.nonce());
434 Ok(packet_data)
435 }
436
437 pub fn write_a_block<'b>(
439 buf: &'b mut [u8; 16],
440 mac: [u8; 6],
441 raw_nonce: [u8; 5],
442 ) -> &'b mut [u8; 16] {
443 const NONCE_OFFSET: usize = 7;
444 const MAC_OFFSET: usize = 1;
445 buf.fill(0);
446 buf[0] = 4;
447 buf[MAC_OFFSET..MAC_OFFSET + 6].copy_from_slice(&mac);
448 buf[NONCE_OFFSET..NONCE_OFFSET + 5].copy_from_slice(&raw_nonce);
449 buf
450 }
451
452 pub fn write_b_block<'b>(
454 buf: &'b mut [u8; 16],
455 mac: [u8; 6],
456 flags: u8,
457 raw_nonce: [u8; 5],
458 payload_len: usize,
459 ) -> &'b mut [u8; 16] {
460 buf[..6].copy_from_slice(&mac);
461 buf[6] = flags;
462 buf[7..=11].copy_from_slice(&raw_nonce);
463 buf[12..].copy_from_slice(&(payload_len as u32).to_be_bytes());
464 buf
465 }
466
467 pub fn gen_raw_tag(
469 &mut self,
470 b_block: &mut [u8; 16],
471 ad_header: AdHeader,
472 payload: &[u8],
473 ) -> [u8; 8] {
474 let mut padded_header = [0_u8; 16];
475 padded_header[0..2].copy_from_slice(&ad_header.u16_be_len());
476 padded_header[2..14].copy_from_slice(&*ad_header);
477
478 let mut key_stream_buf = [0_u8; 16];
479 self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
480 key_stream_buf
481 .iter_mut()
482 .zip(&padded_header)
483 .for_each(|(b, h)| *b ^= h);
484 self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
485 let (chunks, remainder) = payload.as_chunks::<16>();
486 for chunk in chunks {
487 b_block.iter_mut().zip(chunk).for_each(|(b, p)| *b ^= p);
488 self.aes.encrypt(&mut key_stream_buf, b_block, self.key);
489 }
490 key_stream_buf
491 .iter_mut()
492 .zip(remainder)
493 .for_each(|(b, r)| *b ^= r);
494 self.aes.encrypt(b_block, &mut key_stream_buf, self.key);
495
496 b_block[..8].try_into().unwrap()
497 }
498
499 pub fn xor_tag(&mut self, tag: &mut [u8; 8], a_block: &mut [u8; 16]) {
501 let mut key_stream_buf = [0_u8; 16];
502 self.aes.encrypt(&mut key_stream_buf, a_block, self.key);
503 for i in 0..8 {
504 tag[i] ^= key_stream_buf[i];
505 }
506 }
507
508 pub fn xor_payload(
514 &mut self,
515 payload: &mut [u8],
516 mut a_block: &mut [u8; 16],
517 ) -> Result<(), PacketError> {
518 let mut key_stream_buf = [0_u8; 16];
519 let mut counter = 0_u32;
520 let (chunks, remainder) = payload.as_chunks_mut::<16>();
521 for chunk in chunks {
522 counter = counter
523 .checked_add(1)
524 .ok_or(PacketError::AESCounterOverflow)?;
525 [a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
526
527 self.aes
528 .encrypt(&mut key_stream_buf, &mut a_block, self.key);
529 chunk
530 .iter_mut()
531 .zip(key_stream_buf)
532 .for_each(|(c, k)| *c ^= k);
533 }
534 counter = counter
535 .checked_add(1)
536 .ok_or(PacketError::AESCounterOverflow)?;
537 [a_block[12], a_block[13], a_block[14], a_block[15]] = counter.to_be_bytes();
538 self.aes
539 .encrypt(&mut key_stream_buf, &mut a_block, self.key);
540 remainder
541 .iter_mut()
542 .zip(key_stream_buf)
543 .for_each(|(r, a)| *r ^= a);
544 Ok(())
545 }
546
547 pub fn is_tag_match_const_time(tag_a: &[u8; 8], tag_b: &[u8; 8]) -> bool {
549 let mut acc = 0;
550
551 for i in 0..8 {
552 acc |= tag_a[i] ^ tag_b[i];
553 }
554 acc == 0
555 }
556}