Skip to main content

mcumgr_toolkit/transport/
serial.rs

1use std::time::Duration;
2
3use base64::prelude::*;
4use miette::IntoDiagnostic;
5use ringbuf::{
6    LocalRb,
7    storage::Heap,
8    traits::{Consumer, Observer, Producer},
9};
10use serialport::SerialPort;
11
12use super::{ReceiveError, SMP_HEADER_SIZE, SMP_TRANSFER_BUFFER_SIZE, SendError, Transport};
13
14/// A transport layer implementation for serial ports.
15pub struct SerialTransport<T> {
16    transfer_buffer: Box<[u8]>,
17    body_buffer: Box<[u8]>,
18    serial: T,
19    crc_algo: crc::Crc<u16>,
20    read_buffer: LocalRb<Heap<u8>>,
21}
22
23fn fill_buffer_with_data<'a, I: Iterator<Item = u8>>(
24    buffer: &'a mut [u8],
25    data_iter: &mut I,
26) -> &'a [u8] {
27    for (pos, val) in buffer.iter_mut().enumerate() {
28        if let Some(next) = data_iter.next() {
29            *val = next;
30        } else {
31            return &buffer[..pos];
32        }
33    }
34
35    buffer
36}
37
38/// See Zephyr's [`MCUMGR_SERIAL_MAX_FRAME`](https://github.com/zephyrproject-rtos/zephyr/blob/v4.2.1/include/zephyr/mgmt/mcumgr/transport/serial.h#L18).
39const SERIAL_TRANSPORT_ZEPHYR_MTU: usize = 127;
40
41impl<T> SerialTransport<T>
42where
43    T: std::io::Write + std::io::Read,
44{
45    /// Create a new [`SerialTransport`].
46    ///
47    /// # Arguments
48    ///
49    /// * `serial` - A serial port object, like [`serialport::SerialPort`].
50    ///
51    pub fn new(serial: T) -> Self {
52        let mtu = SERIAL_TRANSPORT_ZEPHYR_MTU;
53        Self {
54            serial,
55            transfer_buffer: vec![0u8; mtu].into_boxed_slice(),
56            body_buffer: vec![0u8; ((mtu - 3) / 4) * 3].into_boxed_slice(),
57            crc_algo: crc::Crc::<u16>::new(&crc::CRC_16_XMODEM),
58            read_buffer: LocalRb::new(4096),
59        }
60    }
61
62    /// Take a raw data stream, split it into SMP transport frames and transmit them.
63    ///
64    /// # Arguments
65    ///
66    /// * `data_iter` - An iterator that produces the binary data of the message to send.
67    ///
68    fn send_chunked<I: Iterator<Item = u8>>(&mut self, mut data_iter: I) -> Result<(), SendError> {
69        self.transfer_buffer[0] = 6;
70        self.transfer_buffer[1] = 9;
71
72        loop {
73            let body = fill_buffer_with_data(&mut self.body_buffer, &mut data_iter);
74
75            if body.is_empty() {
76                break Ok(());
77            }
78
79            let base64_len = BASE64_STANDARD
80                .encode_slice(body, &mut self.transfer_buffer[2..])
81                .expect("Transfer buffer overflow; this is a bug. Please report.");
82
83            self.transfer_buffer[base64_len + 2] = 0x0a;
84
85            self.serial
86                .write_all(&self.transfer_buffer[..base64_len + 3])?;
87
88            log::debug!(
89                "Sent Chunk ({}, {} bytes raw, {} bytes encoded)",
90                if self.transfer_buffer[0] == 6 {
91                    "initial"
92                } else {
93                    "partial"
94                },
95                body.len(),
96                base64_len,
97            );
98
99            self.transfer_buffer[0] = 4;
100            self.transfer_buffer[1] = 20;
101        }
102    }
103
104    /// Receive an SMP transport frame and decode it.
105    ///
106    /// # Arguments
107    ///
108    /// * `first` - whether this is the first first frame of the message.
109    ///
110    /// # Return
111    ///
112    /// The received data
113    ///
114    fn recv_chunk(&mut self, first: bool) -> Result<&[u8], ReceiveError> {
115        let expected_header_0 = if first { 6 } else { 4 };
116        let expected_header_1 = if first { 9 } else { 20 };
117
118        loop {
119            while self.read_buffer.occupied_len() < 2 {
120                let num_read = self
121                    .read_buffer
122                    .read_from(&mut self.serial, None)
123                    .unwrap()?;
124
125                if num_read == 0 {
126                    return Err(ReceiveError::TransportError(std::io::Error::new(
127                        std::io::ErrorKind::UnexpectedEof,
128                        "Serial port unexpectedly returned end-of-file",
129                    )));
130                }
131            }
132
133            let current = self.read_buffer.try_pop().unwrap();
134            let next = self.read_buffer.try_peek().unwrap();
135            if current == expected_header_0 && *next == expected_header_1 {
136                self.read_buffer.try_pop().unwrap();
137                break;
138            }
139        }
140
141        let mut base64_data = None;
142        for (pos, elem) in self.transfer_buffer.iter_mut().enumerate() {
143            let data = loop {
144                if let Some(e) = self.read_buffer.try_pop() {
145                    break e;
146                } else {
147                    let num_read = self
148                        .read_buffer
149                        .read_from(&mut self.serial, None)
150                        .unwrap()?;
151
152                    if num_read == 0 {
153                        return Err(ReceiveError::TransportError(std::io::Error::new(
154                            std::io::ErrorKind::UnexpectedEof,
155                            "Serial port unexpectedly returned end-of-file",
156                        )));
157                    }
158                }
159            };
160
161            if data == 0x0a {
162                base64_data = Some(&self.transfer_buffer[..pos]);
163                break;
164            }
165
166            *elem = data;
167        }
168
169        if let Some(0x0a) = self.read_buffer.try_peek() {
170            base64_data = Some(&self.transfer_buffer);
171        }
172
173        if let Some(base64_data) = base64_data {
174            let len = BASE64_STANDARD.decode_slice(base64_data, &mut self.body_buffer)?;
175
176            log::debug!(
177                "Received Chunk ({}, {} bytes raw, {} bytes decoded)",
178                if first { "initial" } else { "partial" },
179                base64_data.len(),
180                len
181            );
182            Ok(&self.body_buffer[..len])
183        } else {
184            Err(ReceiveError::FrameTooBig)
185        }
186    }
187}
188
189impl<T> Transport for SerialTransport<T>
190where
191    T: std::io::Write + std::io::Read + ConfigurableTimeout,
192{
193    fn send_raw_frame(
194        &mut self,
195        header: [u8; SMP_HEADER_SIZE],
196        data: &[u8],
197    ) -> Result<(), SendError> {
198        log::debug!("Sending SMP Frame ({} bytes)", data.len());
199
200        let checksum = {
201            let mut digest = self.crc_algo.digest();
202            digest.update(&header);
203            digest.update(data);
204            digest.finalize().to_be_bytes()
205        };
206
207        let size = u16::try_from(header.len() + data.len() + checksum.len())
208            .map_err(|_| SendError::DataTooBig)?
209            .to_be_bytes();
210
211        self.send_chunked(
212            size.into_iter()
213                .chain(header)
214                .chain(data.iter().copied())
215                .chain(checksum),
216        )
217    }
218
219    fn recv_raw_frame<'a>(
220        &mut self,
221        buffer: &'a mut [u8; SMP_TRANSFER_BUFFER_SIZE],
222    ) -> Result<&'a [u8], ReceiveError> {
223        let first_chunk = self.recv_chunk(true)?;
224
225        let (len, first_data) =
226            if let Some((len_data, first_data)) = first_chunk.split_first_chunk::<2>() {
227                (u16::from_be_bytes(*len_data), first_data)
228            } else {
229                return Err(ReceiveError::UnexpectedResponse);
230            };
231
232        let result_buffer = buffer
233            .split_at_mut_checked(len.into())
234            .ok_or(ReceiveError::FrameTooBig)?
235            .0;
236
237        let (first_result_buffer, mut leftover_result_buffer) = result_buffer
238            .split_at_mut_checked(first_data.len())
239            .ok_or(ReceiveError::UnexpectedResponse)?;
240
241        first_result_buffer.copy_from_slice(first_data);
242
243        while !leftover_result_buffer.is_empty() {
244            let next_chunk = self.recv_chunk(false)?;
245
246            let current_result_buffer;
247            (current_result_buffer, leftover_result_buffer) = leftover_result_buffer
248                .split_at_mut_checked(next_chunk.len())
249                .ok_or(ReceiveError::UnexpectedResponse)?;
250
251            current_result_buffer.copy_from_slice(next_chunk);
252        }
253
254        let (data, checksum_data) = result_buffer
255            .split_last_chunk::<2>()
256            .ok_or(ReceiveError::UnexpectedResponse)?;
257
258        let expected_checksum = u16::from_be_bytes(*checksum_data);
259
260        let actual_checksum = self.crc_algo.checksum(data);
261
262        if expected_checksum != actual_checksum {
263            return Err(ReceiveError::UnexpectedResponse);
264        }
265
266        log::debug!("Received SMP Frame ({} bytes)", data.len());
267
268        Ok(data)
269    }
270
271    fn set_timeout(&mut self, timeout: std::time::Duration) -> Result<(), miette::Report> {
272        ConfigurableTimeout::set_timeout(&mut self.serial, timeout)
273    }
274}
275
276/// Specifies that the serial transport has a configurable timeout
277pub trait ConfigurableTimeout {
278    /// Changes the communication timeout.
279    ///
280    /// When the device does not respond within the set duration,
281    /// an error will be returned.
282    fn set_timeout(&mut self, duration: Duration) -> Result<(), miette::Report>;
283}
284
285impl<T: AsMut<dyn SerialPort> + ?Sized> ConfigurableTimeout for T {
286    fn set_timeout(&mut self, timeout: Duration) -> Result<(), miette::Report> {
287        SerialPort::set_timeout(self.as_mut(), timeout).into_diagnostic()
288    }
289}