use core::borrow::Borrow;
use crate::util::{Buffer, OutOfMemory, CRC_X25};
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
struct Padding(u8);
impl Padding {
const fn new() -> Self {
Padding(0)
}
fn bump(&mut self) {
self.0 = self.0.wrapping_sub(1);
}
const fn get(&self) -> u8 {
self.0 & 0x3
}
}
#[derive(Debug, Clone, Copy)]
enum EncoderState {
Init(u8),
LookingForEscape(u8),
HandlingEscape(u8),
End(i8),
}
pub struct Encoder<I>
where
I: Iterator<Item = u8>,
{
state: EncoderState,
crc: crc::Digest<'static, u16>,
padding: Padding,
iter: I,
}
impl<I> Encoder<I>
where
I: Iterator<Item = u8>,
{
pub fn new(iter: I) -> Self {
let mut crc = CRC_X25.digest();
crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]);
Encoder {
state: EncoderState::Init(0),
crc,
padding: Padding::new(),
iter,
}
}
fn read_from_iter(&mut self) -> Option<u8> {
let ret = self.iter.next();
if ret.is_some() {
self.padding.bump();
}
ret
}
fn next_from_state(&mut self, state: EncoderState) -> (Option<u8>, EncoderState) {
self.state = state;
let out = self.next();
(out, self.state)
}
}
impl<I> Iterator for Encoder<I>
where
I: Iterator<Item = u8>,
{
type Item = u8;
fn next(&mut self) -> Option<u8> {
use EncoderState::*;
let (out, state) = match self.state {
Init(n) if n < 4 => (Some(0x1b), Init(n + 1)),
Init(n) if n < 8 => (Some(0x01), Init(n + 1)),
Init(n) => {
assert_eq!(n, 8);
self.next_from_state(LookingForEscape(0))
}
LookingForEscape(n) if n < 4 => {
match self.read_from_iter() {
Some(b) => {
self.crc.update(&[b]);
(Some(b), LookingForEscape((n + 1) * u8::from(b == 0x1b)))
}
None => {
let padding = self.padding.get();
for _ in 0..padding {
self.crc.update(&[0x00]);
}
self.crc.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, padding]);
self.next_from_state(End(-(padding as i8)))
}
}
}
LookingForEscape(n) => {
assert_eq!(n, 4);
self.crc.update(&[0x1b; 4]);
self.next_from_state(HandlingEscape(0))
}
HandlingEscape(n) if n < 4 => (Some(0x1b), HandlingEscape(n + 1)),
HandlingEscape(n) => {
assert_eq!(n, 4);
self.next_from_state(LookingForEscape(0))
}
End(n) => {
let out = match n {
n if n < 0 => 0x00,
n if n < 4 => 0x1b,
4 => 0x1a,
5 => self.padding.get(),
n if n < 8 => {
let crc_bytes = self.crc.clone().finalize().to_le_bytes();
crc_bytes[(n - 6) as usize]
}
8 => {
return None;
}
_ => unreachable!(),
};
(Some(out), End(n + 1))
}
};
self.state = state;
out
}
}
#[cfg_attr(
feature = "alloc",
doc = r##"
### Using alloc::Vec
```
# use sml_rs::transport::encode;
# let bytes = [0x12, 0x34, 0x56, 0x78];
# let expected = [0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01, 0x12, 0x34, 0x56, 0x78, 0x1b, 0x1b, 0x1b, 0x1b, 0x1a, 0x00, 0xb8, 0x7b];
let encoded = encode::<Vec<u8>>(&bytes);
assert!(encoded.is_ok());
assert_eq!(encoded.unwrap().as_slice(), &expected);
```
"##
)]
pub fn encode<B: Buffer>(
iter: impl IntoIterator<Item = impl Borrow<u8>>,
) -> Result<B, OutOfMemory> {
let mut res: B = Default::default();
res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01])?;
let mut num_1b = 0;
for b in iter.into_iter() {
let b = *b.borrow();
if b == 0x1b {
num_1b += 1;
} else {
num_1b = 0;
}
res.push(b)?;
if num_1b == 4 {
res.extend_from_slice(&[0x1b; 4])?;
num_1b = 0;
}
}
let num_padding_bytes = (4 - (res.len() % 4)) % 4;
res.extend_from_slice(&[0x0; 3][..num_padding_bytes])?;
res.extend_from_slice(&[0x1b, 0x1b, 0x1b, 0x1b, 0x1a, num_padding_bytes as u8])?;
let crc = CRC_X25.checksum(&res[..]);
res.extend_from_slice(&crc.to_le_bytes())?;
Ok(res)
}
pub fn encode_streaming(
iter: impl IntoIterator<Item = impl Borrow<u8>>,
) -> Encoder<impl Iterator<Item = u8>> {
Encoder::new(iter.into_iter().map(|x| *x.borrow()))
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum DecodeErr {
DiscardedBytes(usize),
InvalidEsc([u8; 4]),
OutOfMemory,
InvalidMessage {
checksum_mismatch: (u16, u16),
end_esc_misaligned: bool,
num_padding_bytes: u8,
},
}
#[derive(Debug)]
enum DecodeState {
LookingForMessageStart {
num_discarded_bytes: u16,
num_init_seq_bytes: u8,
},
ParsingNormal,
ParsingEscChars(u8),
ParsingEscPayload(u8),
Done,
}
pub struct Decoder<B: Buffer> {
buf: B,
raw_msg_len: usize,
crc: crc::Digest<'static, u16>,
crc_idx: usize,
state: DecodeState,
}
impl<B: Buffer> Default for Decoder<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Buffer> Decoder<B> {
#[must_use]
pub fn new() -> Self {
Self::from_buf(Default::default())
}
pub fn from_buf(buf: B) -> Self {
Decoder {
buf,
raw_msg_len: 0,
crc: CRC_X25.digest(),
crc_idx: 0,
state: DecodeState::LookingForMessageStart {
num_discarded_bytes: 0,
num_init_seq_bytes: 0,
},
}
}
pub fn push_byte(&mut self, b: u8) -> Result<Option<&[u8]>, DecodeErr> {
self._push_byte(b)
.map(|b| if b { Some(self.borrow_buf()) } else { None })
}
pub fn finalize(&mut self) -> Option<DecodeErr> {
use DecodeState::*;
let res = match self.state {
LookingForMessageStart {
num_discarded_bytes: 0,
num_init_seq_bytes: 0,
} => None,
LookingForMessageStart {
num_discarded_bytes,
num_init_seq_bytes,
} => Some(DecodeErr::DiscardedBytes(
num_discarded_bytes as usize + num_init_seq_bytes as usize,
)),
Done => None,
_ => Some(DecodeErr::DiscardedBytes(self.raw_msg_len)),
};
self.reset();
res
}
fn _push_byte(&mut self, b: u8) -> Result<bool, DecodeErr> {
use DecodeState::*;
self.raw_msg_len += 1;
match self.state {
LookingForMessageStart {
ref mut num_discarded_bytes,
ref mut num_init_seq_bytes,
} => {
if (b == 0x1b && *num_init_seq_bytes < 4) || (b == 0x01 && *num_init_seq_bytes >= 4)
{
*num_init_seq_bytes += 1;
} else {
*num_discarded_bytes += 1 + u16::from(*num_init_seq_bytes);
*num_init_seq_bytes = 0;
}
if *num_init_seq_bytes == 8 {
let num_discarded_bytes = *num_discarded_bytes;
self.state = ParsingNormal;
self.raw_msg_len = 8;
assert_eq!(self.buf.len(), 0);
assert_eq!(self.crc_idx, 0);
self.crc = CRC_X25.digest();
self.crc
.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]);
if num_discarded_bytes > 0 {
return Err(DecodeErr::DiscardedBytes(num_discarded_bytes as usize));
}
}
}
ParsingNormal => {
if b == 0x1b {
self.state = ParsingEscChars(1);
} else {
self.push(b)?;
}
}
ParsingEscChars(n) => {
if b != 0x1b {
for _ in 0..n {
self.push(0x1b)?;
}
self.push(b)?;
self.state = ParsingNormal;
} else if n == 3 {
self.crc.update(&self.buf[self.crc_idx..self.buf.len()]);
self.crc.update(&[0x1b, 0x1b, 0x1b, 0x1b]);
self.crc_idx = self.buf.len();
self.state = ParsingEscPayload(0);
} else {
self.state = ParsingEscChars(n + 1);
}
}
ParsingEscPayload(n) => {
self.push(b)?;
if n < 3 {
self.state = ParsingEscPayload(n + 1);
} else {
let payload = &self.buf[self.buf.len() - 4..self.buf.len()];
if payload == [0x1b, 0x1b, 0x1b, 0x1b] {
self.state = ParsingNormal;
} else if payload == [0x01, 0x01, 0x01, 0x01] {
let ignored_bytes = self.raw_msg_len - 8;
self.raw_msg_len = 8;
self.buf.clear();
self.crc = CRC_X25.digest();
self.crc
.update(&[0x1b, 0x1b, 0x1b, 0x1b, 0x01, 0x01, 0x01, 0x01]);
self.crc_idx = 0;
self.state = ParsingNormal;
return Err(DecodeErr::DiscardedBytes(ignored_bytes));
} else if payload[0] == 0x1a {
let num_padding_bytes = payload[1];
let read_crc = u16::from_le_bytes([payload[2], payload[3]]);
self.crc
.update(&self.buf[self.crc_idx..(self.buf.len() - 2)]);
let calculated_crc = {
let mut crc = CRC_X25.digest();
core::mem::swap(&mut crc, &mut self.crc);
crc.finalize()
};
let misaligned = self.buf.len() % 4 != 0;
let padding_too_large = num_padding_bytes > 3
|| (num_padding_bytes as usize + 4) > self.buf.len();
if read_crc != calculated_crc || misaligned || padding_too_large {
self.reset();
return Err(DecodeErr::InvalidMessage {
checksum_mismatch: (read_crc, calculated_crc),
end_esc_misaligned: misaligned,
num_padding_bytes,
});
}
self.buf
.truncate(self.buf.len() - num_padding_bytes as usize - 4);
self.set_done();
return Ok(true);
} else {
let bytes_until_alignment = (4 - (self.buf.len() % 4)) % 4;
if bytes_until_alignment > 0
&& payload[..bytes_until_alignment].iter().all(|x| *x == 0x1b)
&& payload[bytes_until_alignment] == 0x1a
{
self.state = ParsingEscPayload(4 - bytes_until_alignment as u8);
return Ok(false);
}
let esc_bytes: [u8; 4] = payload.try_into().unwrap();
self.reset();
return Err(DecodeErr::InvalidEsc(esc_bytes));
}
}
}
Done => {
self.reset();
return self._push_byte(b);
}
}
Ok(false)
}
fn borrow_buf(&self) -> &[u8] {
if !matches!(self.state, DecodeState::Done) {
panic!("Reading from the internal buffer is only allowed when a complete message is present (DecodeState::Done). Found state {:?}.", self.state);
}
&self.buf[..self.buf.len()]
}
fn set_done(&mut self) {
self.state = DecodeState::Done;
}
fn reset(&mut self) {
self.state = DecodeState::LookingForMessageStart {
num_discarded_bytes: 0,
num_init_seq_bytes: 0,
};
self.buf.clear();
self.crc_idx = 0;
self.raw_msg_len = 0;
}
fn push(&mut self, b: u8) -> Result<(), DecodeErr> {
if self.buf.push(b).is_err() {
self.reset();
return Err(DecodeErr::OutOfMemory);
}
Ok(())
}
}
#[cfg(feature = "alloc")]
#[must_use]
pub fn decode(iter: impl IntoIterator<Item = impl Borrow<u8>>) -> Vec<Result<Vec<u8>, DecodeErr>> {
let mut decoder: Decoder<Vec<u8>> = Decoder::new();
let mut res = Vec::new();
for b in iter.into_iter() {
match decoder.push_byte(*b.borrow()) {
Ok(None) => {}
Ok(Some(buf)) => res.push(Ok(buf.to_vec())),
Err(e) => res.push(Err(e)),
}
}
if let Some(e) = decoder.finalize() {
res.push(Err(e));
}
res
}
pub struct DecodeIterator<B: Buffer, I: Iterator<Item = u8>> {
decoder: Decoder<B>,
bytes: I,
done: bool,
}
impl<B: Buffer, I: Iterator<Item = u8>> DecodeIterator<B, I> {
fn new(bytes: I) -> Self {
DecodeIterator {
decoder: Decoder::new(),
bytes,
done: false,
}
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> Option<Result<&[u8], DecodeErr>> {
if self.done {
return None;
}
loop {
match self.bytes.next() {
Some(b) => {
match self.decoder._push_byte(b) {
Ok(true) => return Some(Ok(self.decoder.borrow_buf())),
Err(e) => {
return Some(Err(e));
}
Ok(false) => {
}
}
}
None => {
self.done = true;
return self.decoder.finalize().map(Err);
}
}
}
}
}
pub fn decode_streaming<B: Buffer>(
iter: impl IntoIterator<Item = impl Borrow<u8>>,
) -> DecodeIterator<B, impl Iterator<Item = u8>> {
DecodeIterator::new(iter.into_iter().map(|x| *x.borrow()))
}
#[cfg(test)]
mod tests {
use super::*;
use hex_literal::hex;
macro_rules! assert_eq_hex {
($left:expr, $right:expr $(,)?) => {{
match (&$left, &$right) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
panic!(
"assertion failed: `(left == right)`\n left: `{:02x?}`,\n right: `{:02x?}`",
&*left_val, &*right_val
)
}
}
}
}};
}
fn test_encoding<const N: usize>(bytes: &[u8], exp_encoded_bytes: &[u8; N]) {
compare_encoded_bytes(
exp_encoded_bytes,
&encode::<crate::util::ArrayBuf<N>>(bytes).expect("ran out of memory"),
);
compare_encoded_bytes(
exp_encoded_bytes,
&encode_streaming(bytes).collect::<crate::util::ArrayBuf<N>>(),
);
#[cfg(feature = "alloc")]
assert_eq_hex!(alloc::vec![Ok(bytes.to_vec())], decode(exp_encoded_bytes));
}
fn compare_encoded_bytes(expected: &[u8], actual: &[u8]) {
assert_eq_hex!(expected, actual);
}
#[test]
fn basic() {
test_encoding(
&hex!("12345678"),
&hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b"),
);
}
#[test]
fn empty() {
test_encoding(&hex!(""), &hex!("1b1b1b1b 01010101 1b1b1b1b 1a00c6e5"));
}
#[test]
fn padding() {
test_encoding(
&hex!("123456"),
&hex!("1b1b1b1b 01010101 12345600 1b1b1b1b 1a0191a5"),
);
}
#[test]
fn escape_in_user_data() {
test_encoding(
&hex!("121b1b1b1b"),
&hex!("1b1b1b1b 01010101 12 1b1b1b1b 1b1b1b1b 000000 1b1b1b1b 1a03be25"),
);
}
#[test]
fn almost_escape_in_user_data() {
test_encoding(
&hex!("121b1b1bFF"),
&hex!("1b1b1b1b 01010101 12 1b1b1bFF 000000 1b1b1b1b 1a0324d9"),
);
}
#[test]
fn ending_with_1b_no_padding() {
test_encoding(
&hex!("12345678 12341b1b"),
&hex!("1b1b1b1b 01010101 12345678 12341b1b 1b1b1b1b 1a001ac5"),
);
}
}
#[cfg(test)]
mod decode_tests {
use super::*;
use crate::util::ArrayBuf;
use hex_literal::hex;
use DecodeErr::*;
fn test_parse_input<B: Buffer>(bytes: &[u8], exp: &[Result<&[u8], DecodeErr>]) {
let mut exp_iter = exp.iter();
let mut streaming_decoder = DecodeIterator::<B, _>::new(bytes.iter().cloned());
while let Some(res) = streaming_decoder.next() {
match exp_iter.next() {
Some(exp) => {
assert_eq!(res, *exp);
}
None => {
panic!("Additional decoded item: {:?}", res);
}
}
}
assert_eq!(exp_iter.next(), None);
let mut decoder = Decoder::<B>::new();
let mut streaming_decoder = DecodeIterator::<B, _>::new(bytes.iter().cloned());
for b in bytes {
let res = decoder.push_byte(*b);
if let Ok(None) = res {
continue;
}
let res2 = streaming_decoder.next();
match (res, res2) {
(Ok(Some(a)), Some(Ok(b))) => assert_eq!(a, b),
(Err(a), Some(Err(b))) => assert_eq!(a, b),
(a, b) => panic!(
"Mismatch between decoder and streaming_decoder: {:?} vs. {:?}",
a, b
),
}
}
match (decoder.finalize(), streaming_decoder.next()) {
(None, None) => {},
(Some(a), Some(Err(b))) => assert_eq!(a, b),
(a, b) => panic!("Mismatch between decoder and streaming_decoder on the final element: {:?} vs. {:?}", a, b),
}
}
#[test]
fn basic() {
let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
let exp = &[Ok(hex!("12345678").as_slice())];
test_parse_input::<ArrayBuf<8>>(&bytes, exp);
}
#[test]
fn out_of_memory() {
let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
let exp = &[Err(DecodeErr::OutOfMemory)];
test_parse_input::<ArrayBuf<7>>(&bytes, exp);
}
#[test]
fn invalid_crc() {
let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b8FF");
let exp = &[Err(InvalidMessage {
checksum_mismatch: (0xFFb8, 0x7bb8),
end_esc_misaligned: false,
num_padding_bytes: 0,
})];
test_parse_input::<ArrayBuf<8>>(&bytes, exp);
}
#[test]
fn msg_end_misaligned() {
let bytes = hex!("1b1b1b1b 01010101 12345678 FF 1b1b1b1b 1a0013b6");
let exp = &[Err(InvalidMessage {
checksum_mismatch: (0xb613, 0xb613),
end_esc_misaligned: true,
num_padding_bytes: 0,
})];
test_parse_input::<ArrayBuf<16>>(&bytes, exp);
}
#[test]
fn padding_too_large() {
let bytes = hex!("1b1b1b1b 01010101 12345678 12345678 1b1b1b1b 1a04f950");
let exp = &[Err(InvalidMessage {
checksum_mismatch: (0x50f9, 0x50f9),
end_esc_misaligned: false,
num_padding_bytes: 4,
})];
test_parse_input::<ArrayBuf<16>>(&bytes, exp);
}
#[test]
fn empty_msg_with_padding() {
let bytes = hex!("1b1b1b1b 01010101 1b1b1b1b 1a014FF4");
let exp = &[Err(InvalidMessage {
checksum_mismatch: (0xf44f, 0xf44f),
end_esc_misaligned: false,
num_padding_bytes: 1,
})];
test_parse_input::<ArrayBuf<16>>(&bytes, exp);
}
#[test]
fn additional_bytes() {
let bytes = hex!("000102 1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b 1234");
let exp = &[
Err(DiscardedBytes(3)),
Ok(hex!("12345678").as_slice()),
Err(DiscardedBytes(2)),
];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn incomplete_message() {
let bytes = hex!("1b1b1b1b 01010101 123456");
let exp = &[Err(DiscardedBytes(11))];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn invalid_esc_sequence() {
let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1c000000 12345678 1b1b1b1b 1a03be25");
let exp = &[
Err(InvalidEsc([0x1c, 0x0, 0x0, 0x0])),
Err(DiscardedBytes(12)),
];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn incomplete_esc_sequence() {
let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b00 12345678 1b1b1b1b 1a030A07");
let exp = &[Ok(hex!("12345678 1b1b1b00 12").as_slice())];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn double_msg_start() {
let bytes =
hex!("1b1b1b1b 01010101 09 87654321 1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
let exp = &[Err(DiscardedBytes(13)), Ok(hex!("12345678").as_slice())];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn padding() {
let bytes = hex!("1b1b1b1b 01010101 12345600 1b1b1b1b 1a0191a5");
let exp_bytes = hex!("123456");
let exp = &[Ok(exp_bytes.as_slice())];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn escape_in_user_data() {
let bytes = hex!("1b1b1b1b 01010101 12 1b1b1b1b 1b1b1b1b 000000 1b1b1b1b 1a03be25");
let exp = &[Ok(hex!("121b1b1b1b").as_slice())];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn ending_with_1b_no_padding_1() {
let bytes = hex!("1b1b1b1b 01010101 12345678 1234561b 1b1b1b1b 1a00361a");
let exp_bytes = hex!("12345678 1234561b");
let exp = &[Ok(exp_bytes.as_slice())];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn ending_with_1b_no_padding_2() {
let bytes = hex!("1b1b1b1b 01010101 12345678 12341b1b 1b1b1b1b 1a001ac5");
let exp_bytes = hex!("12345678 12341b1b");
let exp = &[Ok(exp_bytes.as_slice())];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[test]
fn ending_with_1b_no_padding_3() {
let bytes = hex!("1b1b1b1b 01010101 12345678 121b1b1b 1b1b1b1b 1a000ba4");
let exp_bytes = hex!("12345678 121b1b1b");
let exp = &[Ok(exp_bytes.as_slice())];
test_parse_input::<ArrayBuf<128>>(&bytes, exp);
}
#[cfg(feature = "alloc")]
#[test]
fn alloc_basic() {
let bytes = hex!("1b1b1b1b 01010101 12345678 1b1b1b1b 1a00b87b");
let exp = &[Ok(hex!("12345678").as_slice())];
test_parse_input::<Vec<u8>>(&bytes, exp);
}
}