xsens_mti/
decoder.rs

1//! A basic MT protocol frame decoder
2
3use crate::message::{Frame, FrameError, PayloadLength};
4use core::mem;
5
6#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, thiserror::Error)]
7pub enum Error {
8    #[error("Not enough bytes in the decoder buffer to store the frame")]
9    InsufficientBufferSize,
10
11    #[error("Encountered a framing error")]
12    FrameError(#[from] FrameError),
13}
14
15#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default)]
16enum State {
17    #[default]
18    Preamble,
19    BusId,
20    MsgId,
21    Len,
22    ExtLenMsb,
23    ExtLenLsb,
24    Payload,
25    Checksum,
26}
27
28#[derive(Debug)]
29pub struct Decoder<B: AsRef<[u8]> + AsMut<[u8]>> {
30    state: State,
31    count: usize,
32    invalid_count: usize,
33    accumulated_checksum: u16,
34    raw_payload_len: u16,
35    expected_frame_size: usize,
36    bytes_read: usize,
37    buffer: B,
38}
39
40impl<B: AsRef<[u8]> + AsMut<[u8]>> Decoder<B> {
41    pub fn new(buffer: B) -> Result<Self, Error> {
42        Self::check_buffer(&buffer)?;
43        Ok(Decoder {
44            state: State::default(),
45            count: 0,
46            invalid_count: 0,
47            accumulated_checksum: 0,
48            raw_payload_len: 0,
49            expected_frame_size: 0,
50            bytes_read: 0,
51            buffer,
52        })
53    }
54
55    pub fn reset(&mut self) {
56        self.state = State::default();
57        self.accumulated_checksum = 0;
58        self.raw_payload_len = 0;
59        self.expected_frame_size = 0;
60        self.bytes_read = 0;
61    }
62
63    pub fn count(&self) -> usize {
64        self.count
65    }
66
67    pub fn invalid_count(&self) -> usize {
68        self.invalid_count
69    }
70
71    pub fn swap_buffer(&mut self, new_buffer: B) -> Result<B, Error> {
72        Self::check_buffer(&new_buffer)?;
73        self.reset();
74        Ok(mem::replace(&mut self.buffer, new_buffer))
75    }
76
77    fn check_buffer(buffer: &B) -> Result<(), Error> {
78        if buffer.as_ref().len() < Frame::<&[u8]>::HEADER_SIZE + (PayloadLength::MAX_STD as usize) {
79            Err(Error::InsufficientBufferSize)
80        } else {
81            Ok(())
82        }
83    }
84
85    pub fn decode_frameless(&mut self, byte: u8) -> Result<Option<usize>, Error> {
86        match self.decode(byte)? {
87            None => Ok(None),
88            Some(f) => {
89                let buf = f.into_inner();
90                Ok(buf.len().into())
91            }
92        }
93    }
94
95    pub fn decode(&mut self, byte: u8) -> Result<Option<Frame<&[u8]>>, Error> {
96        match self.decode_inner(byte)? {
97            None => Ok(None),
98            Some(frame_size) => match Frame::new(&self.buffer.as_ref()[..frame_size]) {
99                Ok(f) => {
100                    self.count = self.count.saturating_add(1); // inc_count()
101                    Ok(Some(f))
102                }
103                Err(e) => {
104                    self.invalid_count = self.invalid_count.saturating_add(1); // inc_invalid_count()
105                    Err(e.into())
106                }
107            },
108        }
109    }
110
111    fn decode_inner(&mut self, byte: u8) -> Result<Option<usize>, Error> {
112        match self.state {
113            State::Preamble => {
114                if byte == Frame::<&[u8]>::PREAMBLE {
115                    self.feed(byte)?;
116                    // Checksum doesn't include preamble
117                    self.accumulated_checksum = 0;
118                    self.state = State::BusId;
119                } else {
120                    self.reset();
121                }
122            }
123            State::BusId => {
124                self.feed(byte)?;
125                self.state = State::MsgId;
126            }
127            State::MsgId => {
128                self.feed(byte)?;
129                self.state = State::Len;
130            }
131            State::Len => {
132                self.feed(byte)?;
133                if byte == 0 {
134                    // Message with no payload
135                    self.raw_payload_len = 0;
136                    self.expected_frame_size =
137                        Frame::<&[u8]>::HEADER_SIZE + Frame::<&[u8]>::CHECKSUM_SIZE;
138                    self.state = State::Checksum;
139                } else if byte == Frame::<&[u8]>::STD_LEN_IS_EXT {
140                    // Message with extended payload
141                    self.state = State::ExtLenMsb;
142                } else {
143                    // Message with standard payload
144                    self.raw_payload_len = byte as u16;
145                    self.expected_frame_size = Frame::<&[u8]>::HEADER_SIZE
146                        + Frame::<&[u8]>::CHECKSUM_SIZE
147                        + (byte as usize);
148                    self.state = State::Payload;
149                }
150            }
151            State::ExtLenMsb => {
152                self.feed(byte)?;
153                self.raw_payload_len = byte as u16;
154                self.state = State::ExtLenLsb;
155            }
156            State::ExtLenLsb => {
157                self.feed(byte)?;
158                // Msb stored in self.raw_payload_len in State::ExtLenMsb
159                self.raw_payload_len = u16::from_be_bytes([self.raw_payload_len as u8, byte]);
160                if self.raw_payload_len > PayloadLength::MAX_EXT {
161                    self.reset();
162                    self.inc_invalid_count();
163                } else {
164                    self.expected_frame_size = Frame::<&[u8]>::EXT_HEADER_SIZE
165                        + Frame::<&[u8]>::CHECKSUM_SIZE
166                        + (self.raw_payload_len as usize);
167                }
168                self.state = State::Payload;
169            }
170            State::Payload => {
171                self.feed(byte)?;
172                if self.bytes_read.saturating_add(1) >= self.expected_frame_size {
173                    self.state = State::Checksum;
174                }
175            }
176            State::Checksum => {
177                self.feed(byte)?;
178                let accumulated_checksum = self.accumulated_checksum;
179                let bytes_read = self.bytes_read;
180                self.reset();
181                if accumulated_checksum.trailing_zeros() >= 8 {
182                    return Ok(Some(bytes_read));
183                } else {
184                    self.inc_invalid_count();
185                }
186            }
187        }
188        Ok(None)
189    }
190
191    #[inline]
192    fn feed(&mut self, byte: u8) -> Result<(), Error> {
193        if self.bytes_read >= self.buffer.as_ref().len() {
194            Err(Error::InsufficientBufferSize)
195        } else {
196            self.accumulated_checksum = self.accumulated_checksum.wrapping_add(byte as u16);
197            self.buffer.as_mut()[self.bytes_read] = byte;
198            self.bytes_read = self.bytes_read.saturating_add(1);
199            Ok(())
200        }
201    }
202
203    #[inline]
204    fn inc_invalid_count(&mut self) {
205        self.invalid_count = self.invalid_count.saturating_add(1);
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use pretty_assertions::assert_eq;
213
214    static STD_MSG: [u8; 8] = [0xFA, 0xFF, 0x00, 0x03, 0x0A, 0x0B, 0x0C, 0xDD];
215
216    #[test]
217    fn basic_decoding() {
218        let mut buffer = [0_u8; 512];
219        let mut dec = Decoder::new(&mut buffer[..]).unwrap();
220
221        for _ in 0..4 {
222            for (idx, byte) in STD_MSG.iter().enumerate() {
223                let maybe_frame = dec.decode(*byte).unwrap();
224                if idx < (STD_MSG.len() - 1) {
225                    assert!(maybe_frame.is_none());
226                } else {
227                    assert!(maybe_frame.is_some());
228                }
229            }
230        }
231
232        assert_eq!(dec.count, 4);
233        assert_eq!(dec.invalid_count, 0);
234    }
235
236    #[test]
237    fn owned_buffer_swap() {
238        let buffer_a = [0_u8; 512];
239        let buffer_b = [0_u8; 512];
240        let mut dec = Decoder::new(buffer_a).unwrap();
241
242        assert_eq!(STD_MSG.len(), 8);
243        for byte in &STD_MSG[..7] {
244            assert_eq!(dec.decode_frameless(*byte).unwrap(), None);
245        }
246
247        let frame_size = dec.decode_frameless(STD_MSG[7]).unwrap().unwrap();
248        assert_eq!(frame_size, 8);
249
250        let buffer_a = dec.swap_buffer(buffer_b).unwrap();
251        assert!(Frame::new(&buffer_a[..frame_size]).is_ok());
252
253        assert_eq!(dec.count, 1);
254        assert_eq!(dec.invalid_count, 0);
255    }
256}