use core::mem::MaybeUninit;
use crate::crc::{CRC_INIT, crc16};
use crate::sync::Sync;
use crate::{Cmd, ReadError, Status};
pub const MAX_PAYLOAD: usize = 64;
#[repr(C, packed)]
#[derive(Clone, Copy)]
pub struct InfoData {
pub capacity: u32,
pub erase_size: u16,
pub boot_version: u16,
pub app_version: u16,
pub mode: u16,
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct EraseData {
pub byte_count: u16,
}
#[repr(C)]
#[derive(Clone, Copy)]
pub struct VerifyData {
pub crc: u16,
}
#[repr(C)]
pub union Data {
pub raw: [u8; MAX_PAYLOAD],
pub info: InfoData,
pub erase: EraseData,
pub verify: VerifyData,
}
#[repr(C)]
pub struct Frame {
sync: Sync,
pub cmd: Cmd,
pub status: Status,
pub addr: u32,
pub len: u16,
pub data: Data,
pub crc: [u8; 2],
}
impl Default for Frame {
fn default() -> Self {
let frame: MaybeUninit<Self> = MaybeUninit::uninit();
let mut frame = unsafe { frame.assume_init() };
frame.sync = Sync::default();
frame.cmd = Cmd::Info;
frame.status = Status::Request;
frame.addr = 0;
frame.len = 0;
frame.crc = [0; 2];
frame
}
}
impl Frame {
fn as_bytes(&self, offset: usize, len: usize) -> &[u8] {
debug_assert!(offset + len <= core::mem::size_of::<Self>());
unsafe {
let ptr = (self as *const Self as *const u8).add(offset);
core::slice::from_raw_parts(ptr, len)
}
}
fn as_bytes_mut(&mut self, offset: usize, len: usize) -> &mut [u8] {
debug_assert!(offset + len <= core::mem::size_of::<Self>());
unsafe {
let ptr = (self as *mut Self as *mut u8).add(offset);
core::slice::from_raw_parts_mut(ptr, len)
}
}
pub fn send<W: embedded_io::Write>(&mut self, w: &mut W) -> Result<(), W::Error> {
self.sync = Sync::valid();
let body_len = 10 + self.len as usize;
let crc = crc16(CRC_INIT, self.as_bytes(0, body_len)).to_le_bytes();
let len = self.len as usize;
unsafe {
*self.data.raw.get_unchecked_mut(len) = crc[0];
*self.data.raw.get_unchecked_mut(len + 1) = crc[1];
}
w.write_all(self.as_bytes(0, body_len + 2))
}
pub fn read<R: embedded_io::Read>(&mut self, r: &mut R) -> Result<Status, ReadError> {
self.sync.read(r)?;
r.read_exact(self.as_bytes_mut(2, 8))
.map_err(|_| ReadError)?;
if !Cmd::is_valid(self.as_bytes(2, 1)[0]) || !Status::is_valid(self.as_bytes(3, 1)[0]) {
return Ok(Status::Unsupported);
}
let data_len = self.len as usize;
if data_len > MAX_PAYLOAD {
return Ok(Status::PayloadOverflow);
}
if data_len > 0 {
r.read_exact(unsafe { &mut self.data.raw[..data_len] })
.map_err(|_| ReadError)?;
}
r.read_exact(&mut self.crc).map_err(|_| ReadError)?;
if self.crc != crc16(CRC_INIT, self.as_bytes(0, 10 + data_len)).to_le_bytes() {
return Ok(Status::CrcMismatch);
}
Ok(Status::Ok)
}
pub async fn send_async<W: embedded_io_async::Write>(
&mut self,
w: &mut W,
) -> Result<(), W::Error> {
self.sync = Sync::valid();
let body_len = 10 + self.len as usize;
self.crc = crc16(CRC_INIT, self.as_bytes(0, body_len)).to_le_bytes();
w.write_all(self.as_bytes(0, body_len)).await?;
w.write_all(&self.crc).await
}
pub async fn read_async<R: embedded_io_async::Read>(
&mut self,
r: &mut R,
) -> Result<Status, ReadError> {
self.sync.read_async(r).await?;
r.read_exact(self.as_bytes_mut(2, 8))
.await
.map_err(|_| ReadError)?;
if !Cmd::is_valid(self.as_bytes(2, 1)[0]) || !Status::is_valid(self.as_bytes(3, 1)[0]) {
return Ok(Status::Unsupported);
}
let data_len = self.len as usize;
if data_len > MAX_PAYLOAD {
return Ok(Status::PayloadOverflow);
}
if data_len > 0 {
r.read_exact(unsafe { &mut self.data.raw[..data_len] })
.await
.map_err(|_| ReadError)?;
}
r.read_exact(&mut self.crc).await.map_err(|_| ReadError)?;
if self.crc != crc16(CRC_INIT, self.as_bytes(0, 10 + data_len)).to_le_bytes() {
return Ok(Status::CrcMismatch);
}
Ok(Status::Ok)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockReader<'a> {
data: &'a [u8],
pos: usize,
}
impl<'a> MockReader<'a> {
fn new(data: &'a [u8]) -> Self {
Self { data, pos: 0 }
}
}
impl embedded_io::ErrorType for MockReader<'_> {
type Error = core::convert::Infallible;
}
impl embedded_io::Read for MockReader<'_> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
let n = buf.len().min(self.data.len() - self.pos);
buf[..n].copy_from_slice(&self.data[self.pos..self.pos + n]);
self.pos += n;
Ok(n)
}
}
struct Sink {
buf: [u8; 512],
pos: usize,
}
impl Sink {
fn new() -> Self {
Self {
buf: [0; 512],
pos: 0,
}
}
fn written(&self) -> &[u8] {
&self.buf[..self.pos]
}
}
impl embedded_io::ErrorType for Sink {
type Error = core::convert::Infallible;
}
impl embedded_io::Write for Sink {
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
let n = buf.len().min(self.buf.len() - self.pos);
self.buf[self.pos..self.pos + n].copy_from_slice(&buf[..n]);
self.pos += n;
Ok(n)
}
fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
fn frame(cmd: Cmd, status: Status, addr: u32, data: &[u8]) -> Frame {
let mut f = Frame {
cmd,
status,
addr,
len: data.len() as u16,
..Default::default()
};
unsafe { f.data.raw[..data.len()].copy_from_slice(data) };
f
}
#[test]
fn request_round_trip() {
let mut frame = frame(Cmd::Write, Status::Request, 0x0800, &[0xDE, 0xAD]);
let mut sink = Sink::new();
frame.send(&mut sink).unwrap();
let mut frame2 = Frame::default();
frame2.read(&mut MockReader::new(sink.written())).unwrap();
assert_eq!(frame2.cmd, Cmd::Write);
assert_eq!(frame2.len, 2);
assert_eq!(frame2.addr, 0x0800);
assert_eq!(frame2.status, Status::Request);
assert_eq!(unsafe { &frame2.data.raw[..2] }, &[0xDE, 0xAD]);
}
#[test]
fn response_round_trip() {
let mut frame = frame(Cmd::Verify, Status::Ok, 0, &[0x12, 0x34]);
let mut sink = Sink::new();
frame.send(&mut sink).unwrap();
let mut frame2 = Frame::default();
frame2.read(&mut MockReader::new(sink.written())).unwrap();
assert_eq!(frame2.cmd, Cmd::Verify);
assert_eq!(frame2.status, Status::Ok);
assert_eq!(unsafe { &frame2.data.raw[..2] }, &[0x12, 0x34]);
}
#[test]
fn request_no_data() {
let mut frame = frame(Cmd::Erase, Status::Request, 0, &[]);
let mut sink = Sink::new();
frame.send(&mut sink).unwrap();
assert_eq!(sink.written().len(), 12);
let mut frame2 = Frame::default();
frame2.read(&mut MockReader::new(sink.written())).unwrap();
assert_eq!(frame2.cmd, Cmd::Erase);
assert_eq!(frame2.len, 0);
}
#[test]
fn large_addr_round_trip() {
let mut frame = frame(Cmd::Write, Status::Request, 0x0001_0800, &[0xAB]);
let mut sink = Sink::new();
frame.send(&mut sink).unwrap();
let mut frame2 = Frame::default();
frame2.read(&mut MockReader::new(sink.written())).unwrap();
assert_eq!(frame2.addr, 0x0001_0800);
}
#[test]
fn cmd_addr_carry_over() {
let mut frame = frame(Cmd::Write, Status::Request, 0x0400, &[0xAB, 0xCD]);
let mut sink = Sink::new();
frame.send(&mut sink).unwrap();
let mut dev = Frame::default();
dev.read(&mut MockReader::new(sink.written())).unwrap();
dev.status = Status::Ok;
dev.len = 0;
let mut resp_sink = Sink::new();
dev.send(&mut resp_sink).unwrap();
let mut host = Frame::default();
host.read(&mut MockReader::new(resp_sink.written()))
.unwrap();
assert_eq!(host.cmd, Cmd::Write);
assert_eq!(host.addr, 0x0400);
assert_eq!(host.status, Status::Ok);
}
#[test]
fn read_bad_cmd() {
let mut frame = frame(Cmd::Info, Status::Request, 0, &[]);
let mut sink = Sink::new();
frame.send(&mut sink).unwrap();
sink.buf[2] ^= 0xFF;
let mut frame2 = Frame::default();
assert_eq!(
frame2.read(&mut MockReader::new(sink.written())),
Ok(Status::Unsupported)
);
}
#[test]
fn read_after_garbage() {
let mut frame = frame(Cmd::Verify, Status::Request, 0, &[]);
let mut sink = Sink::new();
frame.send(&mut sink).unwrap();
let frame_len = sink.pos;
let mut input = [0u8; 4 + 512];
input[..4].copy_from_slice(&[0xFF, 0x00, 0xAA, 0x42]);
input[4..4 + frame_len].copy_from_slice(&sink.buf[..frame_len]);
let mut frame2 = Frame::default();
assert_eq!(
frame2.read(&mut MockReader::new(&input[..4 + frame_len])),
Ok(Status::Ok)
);
assert_eq!(frame2.cmd, Cmd::Verify);
}
#[test]
fn read_overflow() {
let mut f = frame(Cmd::Write, Status::Request, 0, &[]);
let mut sink = Sink::new();
f.send(&mut sink).unwrap();
let overflow_len = (MAX_PAYLOAD as u16 + 1).to_le_bytes();
sink.buf[8] = overflow_len[0];
sink.buf[9] = overflow_len[1];
let mut frame2 = Frame::default();
assert_eq!(
frame2.read(&mut MockReader::new(sink.written())),
Ok(Status::PayloadOverflow)
);
}
}