mctp_estack/
serial.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2/*
3 * Copyright (c) 2024 Code Construct
4 */
5
6//! A MCTP serial transport binding, DSP0253
7
8#[allow(unused)]
9use crate::fmt::{debug, error, info, trace, warn};
10use mctp::{Error, Result};
11
12use crc::Crc;
13use heapless::Vec;
14
15use embedded_io_async::{Read, Write};
16
17const MCTP_SERIAL_REVISION: u8 = 0x01;
18
19// Limited by u8 bytecount field, minus MCTP headers
20pub const MTU_MAX: usize = 0xff - 4;
21
22// Received frame after unescaping. Bytes 1-N+1 in Figure 1 (serial protocol
23// revision to frame check seq lsb)
24const RXBUF_FRAMING: usize = 4;
25const MAX_RX: usize = 0xff + RXBUF_FRAMING;
26
27const FRAMING_FLAG: u8 = 0x7e;
28const FRAMING_ESCAPE: u8 = 0x7d;
29const FLAG_ESCAPED: u8 = 0x5e;
30const ESCAPE_ESCAPED: u8 = 0x5d;
31
32// Rx byte position in DSP0253 Table 1
33// Indicates the expected position of the next read byte.
34#[derive(Debug, PartialEq)]
35enum Pos {
36    // Searching for Framing Flag
37    FrameSearch,
38    SerialRevision,
39    ByteCount,
40    Data,
41    // Byte following a 0x7d
42    DataEscaped,
43    Check,
44    FrameEnd,
45}
46
47#[derive(Debug)]
48pub struct MctpSerialHandler {
49    rxpos: Pos,
50    rxbuf: Vec<u8, MAX_RX>,
51    // Last-seen byte count field
52    rxcount: usize,
53}
54
55// https://www.rfc-editor.org/rfc/rfc1662
56// checksum is complement of the output
57const CRC_FCS: Crc<u16> = Crc::<u16>::new(&crc::CRC_16_IBM_SDLC);
58
59impl MctpSerialHandler {
60    pub fn new() -> Self {
61        Self {
62            rxpos: Pos::FrameSearch,
63            rxcount: 0,
64            rxbuf: Vec::new(),
65        }
66    }
67
68    /// Read a frame.
69    ///
70    /// This is async cancel-safe.
71    pub async fn recv_async(&mut self, input: &mut impl Read) -> Result<&[u8]> {
72        // TODO: This reads one byte a time, might need a buffering wrapper
73        // for performance. Will require more thought about cancel-safety
74
75        loop {
76            let mut b = 0u8;
77            // Read from serial
78            match input.read(core::slice::from_mut(&mut b)).await {
79                Ok(1) => (),
80                Ok(0) => {
81                    trace!("Serial EOF");
82                    return Err(Error::RxFailure);
83                }
84                Ok(2..) => unreachable!(),
85                Err(_e) => {
86                    trace!("Serial read error");
87                    // TODO or do we want a RxFailure?
88                    return Err(Error::RxFailure);
89                }
90            }
91            if let Some(_p) = self.feed_frame(b) {
92                // bleh polonius
93                // return Ok(p)
94                return Ok(&self.rxbuf[2..][..self.rxcount]);
95            }
96        }
97    }
98
99    fn feed_frame(&mut self, b: u8) -> Option<&[u8]> {
100        trace!("serial read {:02x}", b);
101
102        match self.rxpos {
103            Pos::FrameSearch => {
104                if b == FRAMING_FLAG {
105                    self.rxpos = Pos::SerialRevision
106                }
107            }
108            Pos::SerialRevision => {
109                self.rxpos = match b {
110                    MCTP_SERIAL_REVISION => Pos::ByteCount,
111                    FRAMING_FLAG => Pos::SerialRevision,
112                    _ => Pos::FrameSearch,
113                };
114                self.rxbuf.clear();
115                self.rxcount = 0;
116                self.rxbuf.push(b).unwrap();
117            }
118            Pos::ByteCount => {
119                self.rxcount = b as usize;
120                self.rxbuf.push(b).unwrap();
121                self.rxpos = Pos::Data;
122            }
123            Pos::Data => {
124                match b {
125                    // Unexpected framing, restart
126                    FRAMING_FLAG => self.rxpos = Pos::SerialRevision,
127                    FRAMING_ESCAPE => self.rxpos = Pos::DataEscaped,
128                    _ => {
129                        self.rxbuf.push(b).unwrap();
130                        if self.rxbuf.len() == self.rxcount + 2 {
131                            self.rxpos = Pos::Check;
132                        }
133                    }
134                }
135            }
136            Pos::DataEscaped => {
137                match b {
138                    FLAG_ESCAPED => {
139                        self.rxbuf.push(FRAMING_FLAG).unwrap();
140                        self.rxpos = Pos::Data;
141                    }
142                    ESCAPE_ESCAPED => {
143                        self.rxbuf.push(FRAMING_ESCAPE).unwrap();
144                        self.rxpos = Pos::Data;
145                    }
146                    // Unexpected escape, restart
147                    _ => self.rxpos = Pos::FrameSearch,
148                }
149                if self.rxbuf.len() == self.rxcount + 2 {
150                    self.rxpos = Pos::Check;
151                }
152            }
153            Pos::Check => {
154                self.rxbuf.push(b).unwrap();
155                if self.rxbuf.len() == self.rxcount + RXBUF_FRAMING {
156                    self.rxpos = Pos::FrameEnd;
157                }
158            }
159            Pos::FrameEnd => {
160                if b == FRAMING_FLAG {
161                    // Ready for next frame
162                    self.rxpos = Pos::FrameSearch;
163                    // Compare checksum
164                    let (csdata, cs) = self.rxbuf.split_at(self.rxcount + 2);
165                    let cs: [u8; 2] = cs.try_into().unwrap();
166                    let cs = u16::from_be_bytes(cs);
167                    let cs_calc = !CRC_FCS.checksum(csdata);
168                    if cs_calc == cs {
169                        // Complete frame
170                        let packet = &self.rxbuf[2..][..self.rxcount];
171                        return Some(packet);
172                    } else {
173                        warn!(
174                            "Bad checksum got {:04x} calc {:04x}",
175                            cs, cs_calc
176                        );
177                    }
178                } else {
179                    // restart
180                    self.rxpos = Pos::SerialRevision;
181                }
182            }
183        }
184        // Frame is incomplete
185        None
186    }
187
188    pub async fn send_async(
189        &mut self,
190        pkt: &[u8],
191        output: &mut impl Write,
192    ) -> Result<()> {
193        Self::frame_to_serial(pkt, output)
194            .await
195            .map_err(|_e| Error::TxFailure)
196    }
197
198    async fn frame_to_serial<W>(
199        p: &[u8],
200        output: &mut W,
201    ) -> core::result::Result<(), W::Error>
202    where
203        W: Write,
204    {
205        debug_assert!(p.len() <= u8::MAX.into());
206        debug_assert!(p.len() > 4);
207
208        let start = [FRAMING_FLAG, MCTP_SERIAL_REVISION, p.len() as u8];
209        let mut cs = CRC_FCS.digest();
210        cs.update(&start[1..]);
211        cs.update(p);
212        let cs = !cs.finalize();
213
214        output.write_all(&start).await?;
215        Self::write_escaped(p, output).await?;
216        output.write_all(&cs.to_be_bytes()).await?;
217        output.write_all(&[FRAMING_FLAG]).await?;
218        Ok(())
219    }
220
221    async fn write_escaped<W>(
222        p: &[u8],
223        output: &mut W,
224    ) -> core::result::Result<(), W::Error>
225    where
226        W: Write,
227    {
228        for c in
229            p.split_inclusive(|&b| b == FRAMING_FLAG || b == FRAMING_ESCAPE)
230        {
231            let (last, rest) = c.split_last().unwrap();
232            match *last {
233                FRAMING_FLAG => {
234                    output.write_all(rest).await?;
235                    output.write_all(&[FRAMING_ESCAPE, FLAG_ESCAPED]).await?;
236                }
237                FRAMING_ESCAPE => {
238                    output.write_all(rest).await?;
239                    output.write_all(&[FRAMING_ESCAPE, ESCAPE_ESCAPED]).await?;
240                }
241                _ => output.write_all(c).await?,
242            }
243        }
244        Ok(())
245    }
246}
247
248impl Default for MctpSerialHandler {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254#[cfg(test)]
255mod tests {
256
257    use crate::serial::*;
258    use crate::*;
259    use embedded_io_adapters::futures_03::FromFutures;
260    use proptest::prelude::*;
261
262    fn start_log() {
263        let _ = env_logger::Builder::new()
264            .filter(None, log::LevelFilter::Trace)
265            .is_test(true)
266            .try_init();
267    }
268
269    async fn do_roundtrip(payload: &[u8]) {
270        let mut esc = vec![];
271        let mut s = FromFutures::new(&mut esc);
272        MctpSerialHandler::frame_to_serial(&payload, &mut s)
273            .await
274            .unwrap();
275        debug!("{:02x?}", payload);
276        debug!("{:02x?}", esc);
277
278        let mut h = MctpSerialHandler::new();
279        let mut s = FromFutures::new(esc.as_slice());
280        let packet = h.recv_async(&mut s).await.unwrap();
281        debug_assert_eq!(payload, packet);
282    }
283
284    #[test]
285    fn roundtrip_cases() {
286        // Fixed testcases
287        start_log();
288        smol::block_on(async {
289            for payload in
290                [&[0x01, 0x5d, 0x0d, 0xf4, 0x01, 0x93, 0x7d, 0xcd, 0x36]]
291            {
292                do_roundtrip(payload).await
293            }
294        })
295    }
296
297    proptest! {
298        #[test]
299        fn roundtrip_escape(payload in proptest::collection::vec(0..255u8, 5..20)) {
300            start_log();
301
302            smol::block_on(do_roundtrip(&payload))
303
304        }
305    }
306}