use crate::byte_queue::ByteQueue;
use core::convert::TryFrom;
use crc::{Crc, Digest, CRC_32_ISO_HDLC};
#[derive(Debug)]
pub struct MsgQueue<'a, const LEN: usize> {
byte_queue: ByteQueue,
prefix: &'a [u8],
rx_buf: [u8; LEN],
rx_buf_len: usize,
has_received_full_msg: bool,
}
use core::fmt;
#[derive(Debug, PartialEq)]
pub enum MqError {
MqFull,
MqEmpty,
MqCrcErr,
MqMsgTooBig,
MqWrongProtocolVersion,
}
impl fmt::Display for MqError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MqError::MqFull => write!(f, "Message queue is full"),
MqError::MqEmpty => write!(f, "Message queue is empty"),
MqError::MqCrcErr => write!(f, "CRC check failed"),
MqError::MqMsgTooBig => write!(f, "Message is too big"),
MqError::MqWrongProtocolVersion => {
write!(f, "Message protocol version is incompatible")
}
}
}
}
impl core::error::Error for MqError {}
const PROTOCOL_VERSION: u8 = 1;
const MSG_SIZE_FIELD_SIZE: usize = core::mem::size_of::<u32>();
const MSG_CRC_FIELD_SIZE: usize = core::mem::size_of::<u32>();
const MSG_PROTOCOL_FIELD_SIZE: usize = core::mem::size_of::<u8>();
const CRC32: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
impl<'a, const LEN: usize> MsgQueue<'a, LEN> {
pub fn new(byte_queue: ByteQueue, prefix: &'a [u8], rx_buf: [u8; LEN]) -> Self {
Self {
byte_queue,
prefix,
rx_buf,
rx_buf_len: 0,
has_received_full_msg: false,
}
}
fn len_p(&self) -> usize {
self.prefix.len()
}
fn len_pv(&self) -> usize {
self.len_p() + MSG_PROTOCOL_FIELD_SIZE
}
fn len_pvl(&self) -> usize {
self.len_pv() + MSG_SIZE_FIELD_SIZE
}
fn len_pvlc(&self) -> usize {
self.len_pvl() + MSG_CRC_FIELD_SIZE
}
fn len_pvlcd(&self, msg_len: usize) -> usize {
self.len_pvlc() + msg_len
}
fn len_pvlcdc(&self, msg_len: usize) -> usize {
self.len_pvlcd(msg_len) + MSG_CRC_FIELD_SIZE
}
fn read_bytes(&mut self) {
let read_bytes_len = self
.byte_queue
.consume_at_most(&mut self.rx_buf[self.rx_buf_len..]);
self.rx_buf_len += read_bytes_len;
}
fn skip_in_rx_buf(&mut self, skip: usize) {
assert!(
skip <= self.rx_buf_len,
"skip rx_buffer value exceeds current rx_buffer length."
);
self.rx_buf.copy_within(skip..self.rx_buf_len, 0);
self.rx_buf_len -= skip;
}
fn rm_old_msg(&mut self) {
if self.has_received_full_msg {
let msg_len = self.try_extract_msg_len().unwrap(); if msg_len <= self.rx_buf_len {
self.skip_in_rx_buf(self.len_pvlcdc(msg_len));
}
self.has_received_full_msg = false;
}
}
fn invalidate_current_msg(&mut self) {
self.skip_in_rx_buf(1);
}
fn try_advance_to_prefix(&mut self) -> Result<(), MqError> {
let mut pos = None;
for (idx, window) in self.rx_buf[..self.rx_buf_len]
.windows(self.prefix.len())
.enumerate()
{
if self.prefix == window {
pos = Some(idx);
break;
}
}
if let Some(idx) = pos {
self.skip_in_rx_buf(idx);
return Ok(());
}
if self.rx_buf_len >= self.prefix.len() {
self.skip_in_rx_buf(self.rx_buf_len - self.prefix.len());
}
Err(MqError::MqEmpty)
}
fn try_extract_msg_len(&self) -> Result<usize, MqError> {
if self.rx_buf_len < self.len_pvl() {
return Err(MqError::MqEmpty);
}
let start = self.len_pv();
let end = start + MSG_SIZE_FIELD_SIZE;
let slice = &self.rx_buf[start..end];
let mut array = [0u8; MSG_SIZE_FIELD_SIZE];
array.copy_from_slice(slice);
Ok(u32::from_le_bytes(array) as usize)
}
fn verify_msg_packet_len(&mut self, msg_len: usize) -> Result<(), MqError> {
if self.rx_buf.len() < self.len_pvlcdc(msg_len) {
self.invalidate_current_msg();
return Err(MqError::MqMsgTooBig);
}
Ok(())
}
fn verify_protocol_version(&mut self) -> Result<(), MqError> {
if self.rx_buf_len < self.len_pv() {
return Err(MqError::MqEmpty);
}
if self.rx_buf[self.len_p()] != PROTOCOL_VERSION {
self.invalidate_current_msg();
return Err(MqError::MqWrongProtocolVersion);
}
Ok(())
}
fn verify_crc(&mut self, crc_start: usize, calculated_crc: u32) -> Result<(), MqError> {
if self.rx_buf_len < crc_start + MSG_CRC_FIELD_SIZE {
return Err(MqError::MqEmpty); }
let crc_end = crc_start + MSG_CRC_FIELD_SIZE;
let mut crc_array = [0u8; MSG_CRC_FIELD_SIZE];
crc_array.copy_from_slice(&self.rx_buf[crc_start..crc_end]);
let received_crc = u32::from_le_bytes(crc_array);
if received_crc != calculated_crc {
self.invalidate_current_msg();
return Err(MqError::MqCrcErr);
}
Ok(())
}
fn verify_full_msg(&mut self) -> Result<usize, MqError> {
self.verify_protocol_version()?;
let msg_len = self.try_extract_msg_len()?;
self.verify_crc(
self.len_pvl(),
CRC32.checksum(&self.rx_buf[..self.len_pvl()]),
)?;
self.verify_msg_packet_len(msg_len)?;
self.verify_crc(
self.len_pvlcd(msg_len),
CRC32.checksum(&self.rx_buf[..self.len_pvlcd(msg_len)]),
)?;
Ok(msg_len)
}
fn find_next_msg(&mut self) -> Result<(usize, usize), MqError> {
self.rm_old_msg();
self.read_bytes();
self.try_advance_to_prefix()?;
let msg_len = self.verify_full_msg()?;
self.has_received_full_msg = true;
Ok((self.len_pvlc(), self.len_pvlc() + msg_len))
}
pub fn read_or_fail(&mut self) -> Result<&[u8], MqError> {
let (start, end) = self.find_next_msg()?;
Ok(&self.rx_buf[start..end])
}
pub fn read_blocking(&mut self) -> Result<&[u8], MqError> {
loop {
match self.find_next_msg() {
Ok((start, end)) => return Ok(&self.rx_buf[start..end]),
Err(MqError::MqFull | MqError::MqEmpty) => continue,
Err(err) => return Err(err),
}
}
}
fn wacc(&mut self, digest: &mut Digest<u32>, data: &[u8]) {
self.byte_queue.write_or_fail(data).unwrap();
digest.update(data);
}
fn write_msg(&mut self, msg_data: &[u8]) -> Result<(), MqError> {
let mut header_crc = CRC32.digest();
self.wacc(&mut header_crc, self.prefix);
self.wacc(&mut header_crc, &PROTOCOL_VERSION.to_le_bytes());
let msg_len_u32 = u32::try_from(msg_data.len()).map_err(|_| MqError::MqMsgTooBig)?;
self.wacc(&mut header_crc, &msg_len_u32.to_le_bytes());
let mut total_crc = header_crc.clone();
let header_crc_bytes = header_crc.finalize().to_le_bytes();
self.wacc(&mut total_crc, &header_crc_bytes);
self.wacc(&mut total_crc, msg_data);
let total_crc_bytes = total_crc.finalize().to_le_bytes();
self.byte_queue.write_or_fail(&total_crc_bytes).unwrap();
Ok(())
}
pub fn write_or_fail(&mut self, msg_data: &[u8]) -> Result<(), MqError> {
if self.byte_queue.capacity() < self.len_pvlcdc(msg_data.len()) {
return Err(MqError::MqMsgTooBig);
}
if self.byte_queue.space() < self.len_pvlcdc(msg_data.len()) {
return Err(MqError::MqFull);
}
self.write_msg(msg_data)?;
Ok(())
}
pub fn write_blocking(&mut self, msg_data: &[u8]) -> Result<(), MqError> {
if self.byte_queue.capacity() < self.len_pvlcdc(msg_data.len()) {
return Err(MqError::MqMsgTooBig);
}
while self.byte_queue.space() < self.len_pvlcdc(msg_data.len()) {}
self.write_msg(msg_data)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::byte_queue::ByteQueue;
use crate::msg_queue::{
MqError, MsgQueue, MSG_CRC_FIELD_SIZE, MSG_PROTOCOL_FIELD_SIZE, MSG_SIZE_FIELD_SIZE,
};
const DEFAULT_PREFIX: &'static [u8] = b"DEFAULT_PREFIX: ";
#[test]
fn test_skip_in_rx_buf() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let s = b"abcde";
for skip in 0..=s.len() {
msg_queue.rx_buf[..s.len()].copy_from_slice(s); msg_queue.rx_buf_len = s.len();
msg_queue.skip_in_rx_buf(skip);
assert_eq!(&msg_queue.rx_buf[..msg_queue.rx_buf_len], &s[skip..]);
assert_eq!(msg_queue.rx_buf_len, s.len() - skip);
}
}
#[test]
fn test_invalid_msg_size() {
let mut bq_buf = [0u32; 10]; let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 7], )
};
let data = b"abcd";
let msg_size = DEFAULT_PREFIX.len() + MSG_PROTOCOL_FIELD_SIZE + MSG_SIZE_FIELD_SIZE + MSG_CRC_FIELD_SIZE + data.len() + MSG_CRC_FIELD_SIZE; assert!(msg_queue.byte_queue.capacity() < msg_size);
assert_eq!(msg_queue.write_or_fail(data), Err(MqError::MqMsgTooBig));
let data = b"ab";
let msg_size = DEFAULT_PREFIX.len() + MSG_PROTOCOL_FIELD_SIZE + MSG_SIZE_FIELD_SIZE + MSG_CRC_FIELD_SIZE + data.len() + MSG_CRC_FIELD_SIZE; assert!(msg_queue.byte_queue.capacity() == msg_size); assert_eq!(msg_queue.write_or_fail(&[1, 2]), Ok(()));
}
#[test]
fn test_read_empty_queue() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let result = msg_queue.read_or_fail();
assert_eq!(result, Err(MqError::MqEmpty));
}
#[test]
fn test_write_and_read_msg() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let msg = b"Hello, World!";
let result = msg_queue.write_or_fail(msg);
assert!(result.is_ok());
let read_msg = msg_queue.read_or_fail().unwrap();
assert_eq!(read_msg, msg);
}
#[test]
fn test_crc_error() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let msg = b"xxxxyyyy";
msg_queue.write_or_fail(msg).unwrap();
bq_buf[2 + msg_queue.len_pvl() / 4..].fill(0);
assert_eq!(msg_queue.read_or_fail(), Err(MqError::MqCrcErr));
msg_queue.write_or_fail(msg).unwrap();
assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
msg_queue.write_or_fail(msg).unwrap();
bq_buf[2 + msg_queue.len_pvlcd(msg.len()) / 4..].fill(0);
assert_eq!(msg_queue.read_or_fail(), Err(MqError::MqCrcErr));
msg_queue.write_blocking(msg).unwrap();
assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
}
#[test]
fn test_saturate_queue() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let data = b"abcd";
let msg_size = DEFAULT_PREFIX.len()
+ MSG_PROTOCOL_FIELD_SIZE
+ MSG_SIZE_FIELD_SIZE
+ MSG_CRC_FIELD_SIZE
+ data.len()
+ MSG_CRC_FIELD_SIZE;
let repeat = (bq_buf.len() * 4 - 2 * core::mem::size_of::<u32>() - 1) / msg_size;
assert_eq!(repeat, 7);
for _ in 0..repeat {
let result = msg_queue.write_or_fail(data);
assert_eq!(result, Ok(()));
}
assert_eq!(
msg_queue.byte_queue.space(),
(bq_buf.len() * 4 - 2 * core::mem::size_of::<u32>() - 1 - repeat * msg_size)
);
let result = msg_queue.write_or_fail(data);
assert_eq!(result, Err(MqError::MqFull));
}
#[test]
fn test_read_after_invalid_msg() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let msg = b"valid msg";
msg_queue.write_or_fail(msg).unwrap();
msg_queue.read_bytes();
msg_queue.invalidate_current_msg();
let result = msg_queue.read_or_fail();
assert_eq!(result, Err(MqError::MqEmpty));
}
#[test]
fn test_read_write_after_invalid_msg() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let msg = b"valid msg";
msg_queue.write_or_fail(msg).unwrap();
msg_queue.write_or_fail(msg).unwrap();
msg_queue.read_bytes();
msg_queue.invalidate_current_msg();
let result = msg_queue.read_or_fail().unwrap();
assert_eq!(result, msg);
}
#[test]
fn test_blocking_read_msg() {
let mut bq_buf = [0u32; 64];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let msg = b"Blocking Msg";
msg_queue.write_blocking(msg).unwrap();
let read_msg = msg_queue.read_blocking().unwrap();
assert_eq!(read_msg, msg);
}
#[test]
fn test_read_part_of_next_msg() {
let mut bq_buf = [0u32; 128];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64], )
};
let msg = b"valid msg";
let garbage = [0xff; 128];
for garbage_len in 64 - 20..64 {
msg_queue
.byte_queue
.write_or_fail(&garbage[..garbage_len])
.unwrap();
msg_queue.write_or_fail(msg).unwrap();
assert_eq!(msg_queue.read_or_fail(), Err(MqError::MqEmpty));
assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
}
}
#[test]
fn test_incompatible_protocol_version() {
let mut bq_buf = [0u32; 128];
let mut msg_queue = unsafe {
MsgQueue::new(
ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
DEFAULT_PREFIX,
[0u8; 64 * 4],
)
};
let msg = b"xxxxyyyy";
msg_queue.write_blocking(msg).unwrap();
let u8_slice: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(bq_buf.as_mut_ptr() as *mut u8, 128 * 4) };
u8_slice[8 + msg_queue.len_p()] = 2;
assert_eq!(
msg_queue.read_or_fail(),
Err(MqError::MqWrongProtocolVersion)
);
msg_queue.write_blocking(msg).unwrap();
assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
}
}