use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
use embassy_sync::signal::Signal;
use embassy_time::{with_timeout, Duration, Timer};
use embedded_hal::digital::OutputPin;
use embedded_hal_async::spi::SpiBus;
use crate::error::{Error, Result};
const SPI_HEADER_MAGIC: u16 = 0x55AA;
const MAX_SPI_XFER: usize = 2048;
#[repr(C, packed)]
#[derive(Debug, Clone, Copy)]
pub struct SpiHeader {
magic: u16,
len: u16,
flags: u8,
msg_type: u8,
reserved: u16,
}
impl SpiHeader {
pub fn new_at_cmd(payload_len: u16) -> Self {
Self {
magic: SPI_HEADER_MAGIC,
len: payload_len,
flags: 0,
msg_type: 0, reserved: 0,
}
}
pub fn to_bytes(&self) -> [u8; 8] {
[
(self.magic & 0xFF) as u8,
(self.magic >> 8) as u8,
(self.len & 0xFF) as u8,
(self.len >> 8) as u8,
self.flags,
self.msg_type,
(self.reserved & 0xFF) as u8,
(self.reserved >> 8) as u8,
]
}
pub fn from_bytes(bytes: &[u8; 8]) -> Self {
Self {
magic: u16::from_le_bytes([bytes[0], bytes[1]]),
len: u16::from_le_bytes([bytes[2], bytes[3]]),
flags: bytes[4],
msg_type: bytes[5],
reserved: u16::from_le_bytes([bytes[6], bytes[7]]),
}
}
pub fn is_valid(&self) -> bool {
let magic = self.magic;
let len = self.len;
magic == SPI_HEADER_MAGIC && len <= MAX_SPI_XFER as u16
}
pub fn rx_stall(&self) -> bool {
let flags = self.flags; (flags & 0x04) != 0
}
}
pub struct SpiTransportRdy<SPI, CS>
where
SPI: SpiBus,
CS: OutputPin,
{
spi: SPI,
cs: CS,
txn_ready_signal: &'static Signal<CriticalSectionRawMutex, ()>,
hdr_ack_signal: &'static Signal<CriticalSectionRawMutex, ()>,
}
impl<SPI, CS> SpiTransportRdy<SPI, CS>
where
SPI: SpiBus,
CS: OutputPin,
{
pub fn new(
spi: SPI,
cs: CS,
txn_ready_signal: &'static Signal<CriticalSectionRawMutex, ()>,
hdr_ack_signal: &'static Signal<CriticalSectionRawMutex, ()>,
) -> Self {
Self {
spi,
cs,
txn_ready_signal,
hdr_ack_signal,
}
}
pub async fn write(&mut self, data: &[u8]) -> Result<usize> {
if data.len() > MAX_SPI_XFER {
return Err(Error::BufferTooSmall);
}
let padded_len = (data.len() + 3) & !3;
let padding = padded_len - data.len();
let header = SpiHeader::new_at_cmd(padded_len as u16);
let header_bytes = header.to_bytes();
let mut frame = [0u8; MAX_SPI_XFER + 8];
frame[..8].copy_from_slice(&header_bytes);
frame[8..8 + data.len()].copy_from_slice(data);
for i in 0..padding {
frame[8 + data.len() + i] = 0x88;
}
let total_len = 8 + padded_len;
#[cfg(feature = "defmt")]
defmt::debug!(
"SPI TX: len={}, padded={}, total={}",
data.len(),
padded_len,
total_len
);
let _ = with_timeout(Duration::from_millis(100), self.txn_ready_signal.wait()).await;
self.cs.set_high().map_err(|_| Error::Spi)?;
Timer::after(Duration::from_micros(10)).await;
self.spi
.write(&frame[..total_len])
.await
.map_err(|_| Error::Spi)?;
let _ = with_timeout(Duration::from_millis(100), self.hdr_ack_signal.wait()).await;
Timer::after(Duration::from_micros(10)).await;
self.cs.set_low().map_err(|_| Error::Spi)?;
Ok(data.len())
}
pub async fn read(&mut self, buffer: &mut [u8]) -> Result<usize> {
if buffer.len() > MAX_SPI_XFER {
return Err(Error::BufferTooSmall);
}
with_timeout(Duration::from_millis(2000), self.txn_ready_signal.wait())
.await
.map_err(|_| Error::Timeout)?;
self.cs.set_high().map_err(|_| Error::Spi)?;
Timer::after(Duration::from_micros(10)).await;
let mut header_bytes = [0u8; 8];
self.spi
.read(&mut header_bytes)
.await
.map_err(|_| Error::Spi)?;
let header = SpiHeader::from_bytes(&header_bytes);
let _magic = header.magic;
let payload_len = header.len as usize;
#[cfg(feature = "defmt")]
defmt::trace!("SPI RX: magic={:04x}, len={}", magic, payload_len);
if !header.is_valid() {
self.cs.set_low().map_err(|_| Error::Spi)?;
#[cfg(feature = "defmt")]
defmt::warn!("Invalid SPI header magic: {:04x}", magic);
return Err(Error::InvalidResponse);
}
if payload_len == 0 {
let _ = with_timeout(Duration::from_millis(100), self.hdr_ack_signal.wait()).await;
self.cs.set_low().map_err(|_| Error::Spi)?;
return Ok(0);
}
if payload_len > buffer.len() {
self.cs.set_low().map_err(|_| Error::Spi)?;
return Err(Error::BufferTooSmall);
}
self.spi
.read(&mut buffer[..payload_len])
.await
.map_err(|_| Error::Spi)?;
let _ = with_timeout(Duration::from_millis(100), self.hdr_ack_signal.wait()).await;
Timer::after(Duration::from_micros(10)).await;
self.cs.set_low().map_err(|_| Error::Spi)?;
Ok(payload_len)
}
}