Skip to main content

mcumgr_toolkit/transport/
serial.rs

1use std::time::Duration;
2
3use base64::prelude::*;
4use ringbuf::{
5    LocalRb,
6    storage::Heap,
7    traits::{Consumer, Observer, Producer},
8};
9use serialport::SerialPort;
10
11use super::{ReceiveError, SMP_HEADER_SIZE, SMP_TRANSFER_BUFFER_SIZE, SendError, Transport};
12
13/// A transport layer implementation for serial ports.
14pub struct SerialTransport<T> {
15    transfer_buffer: Box<[u8]>,
16    body_buffer: Box<[u8]>,
17    serial: T,
18    crc_algo: crc::Crc<u16>,
19    read_buffer: LocalRb<Heap<u8>>,
20}
21
22fn fill_buffer_with_data<'a, I: Iterator<Item = u8>>(
23    buffer: &'a mut [u8],
24    data_iter: &mut I,
25) -> &'a [u8] {
26    for (pos, val) in buffer.iter_mut().enumerate() {
27        if let Some(next) = data_iter.next() {
28            *val = next;
29        } else {
30            return &buffer[..pos];
31        }
32    }
33
34    buffer
35}
36
37/// See Zephyr's [`MCUMGR_SERIAL_MAX_FRAME`](https://github.com/zephyrproject-rtos/zephyr/blob/v4.2.1/include/zephyr/mgmt/mcumgr/transport/serial.h#L18).
38const SERIAL_TRANSPORT_ZEPHYR_MTU: usize = 127;
39
40impl<T> SerialTransport<T>
41where
42    T: std::io::Write + std::io::Read,
43{
44    /// Create a new [`SerialTransport`].
45    ///
46    /// # Arguments
47    ///
48    /// * `serial` - A serial port object, like [`serialport::SerialPort`].
49    ///
50    pub fn new(serial: T) -> Self {
51        let mtu = SERIAL_TRANSPORT_ZEPHYR_MTU;
52        Self {
53            serial,
54            transfer_buffer: vec![0u8; mtu].into_boxed_slice(),
55            body_buffer: vec![0u8; ((mtu - 3) / 4) * 3].into_boxed_slice(),
56            crc_algo: crc::Crc::<u16>::new(&crc::CRC_16_XMODEM),
57            read_buffer: LocalRb::new(4096),
58        }
59    }
60
61    /// Take a raw data stream, split it into SMP transport frames and transmit them.
62    ///
63    /// # Arguments
64    ///
65    /// * `data_iter` - An iterator that produces the binary data of the message to send.
66    ///
67    fn send_chunked<I: Iterator<Item = u8>>(&mut self, mut data_iter: I) -> Result<(), SendError> {
68        self.transfer_buffer[0] = 6;
69        self.transfer_buffer[1] = 9;
70
71        loop {
72            let body = fill_buffer_with_data(&mut self.body_buffer, &mut data_iter);
73
74            if body.is_empty() {
75                break Ok(());
76            }
77
78            let base64_len = BASE64_STANDARD
79                .encode_slice(body, &mut self.transfer_buffer[2..])
80                .expect("Transfer buffer overflow; this is a bug. Please report.");
81
82            self.transfer_buffer[base64_len + 2] = 0x0a;
83
84            self.serial
85                .write_all(&self.transfer_buffer[..base64_len + 3])?;
86
87            log::debug!(
88                "Sent Chunk ({}, {} bytes raw, {} bytes encoded)",
89                if self.transfer_buffer[0] == 6 {
90                    "initial"
91                } else {
92                    "partial"
93                },
94                body.len(),
95                base64_len,
96            );
97
98            self.transfer_buffer[0] = 4;
99            self.transfer_buffer[1] = 20;
100        }
101    }
102
103    /// Receive an SMP transport frame and decode it.
104    ///
105    /// # Arguments
106    ///
107    /// * `first` - whether this is the first first frame of the message.
108    ///
109    /// # Return
110    ///
111    /// The received data
112    ///
113    fn recv_chunk(&mut self, first: bool) -> Result<&[u8], ReceiveError> {
114        let expected_header_0 = if first { 6 } else { 4 };
115        let expected_header_1 = if first { 9 } else { 20 };
116
117        loop {
118            while self.read_buffer.occupied_len() < 2 {
119                let num_read = self
120                    .read_buffer
121                    .read_from(&mut self.serial, None)
122                    .unwrap()?;
123
124                if num_read == 0 {
125                    return Err(ReceiveError::TransportError(std::io::Error::new(
126                        std::io::ErrorKind::UnexpectedEof,
127                        "Serial port unexpectedly returned end-of-file",
128                    )));
129                }
130            }
131
132            let current = self.read_buffer.try_pop().unwrap();
133            let next = self.read_buffer.try_peek().unwrap();
134            if current == expected_header_0 && *next == expected_header_1 {
135                self.read_buffer.try_pop().unwrap();
136                break;
137            }
138        }
139
140        let mut base64_data = None;
141        for (pos, elem) in self.transfer_buffer.iter_mut().enumerate() {
142            let data = loop {
143                if let Some(e) = self.read_buffer.try_pop() {
144                    break e;
145                } else {
146                    let num_read = self
147                        .read_buffer
148                        .read_from(&mut self.serial, None)
149                        .unwrap()?;
150
151                    if num_read == 0 {
152                        return Err(ReceiveError::TransportError(std::io::Error::new(
153                            std::io::ErrorKind::UnexpectedEof,
154                            "Serial port unexpectedly returned end-of-file",
155                        )));
156                    }
157                }
158            };
159
160            if data == 0x0a {
161                base64_data = Some(&self.transfer_buffer[..pos]);
162                break;
163            }
164
165            *elem = data;
166        }
167
168        if let Some(0x0a) = self.read_buffer.try_peek() {
169            base64_data = Some(&self.transfer_buffer);
170        }
171
172        if let Some(base64_data) = base64_data {
173            let len = BASE64_STANDARD.decode_slice(base64_data, &mut self.body_buffer)?;
174
175            log::debug!(
176                "Received Chunk ({}, {} bytes raw, {} bytes decoded)",
177                if first { "initial" } else { "partial" },
178                base64_data.len(),
179                len
180            );
181            Ok(&self.body_buffer[..len])
182        } else {
183            Err(ReceiveError::FrameTooBig)
184        }
185    }
186}
187
188impl<T> Transport for SerialTransport<T>
189where
190    T: std::io::Write + std::io::Read + ConfigurableTimeout,
191{
192    fn send_raw_frame(
193        &mut self,
194        header: [u8; SMP_HEADER_SIZE],
195        data: &[u8],
196    ) -> Result<(), SendError> {
197        log::debug!("Sending SMP Frame ({} bytes)", data.len());
198
199        let checksum = {
200            let mut digest = self.crc_algo.digest();
201            digest.update(&header);
202            digest.update(data);
203            digest.finalize().to_be_bytes()
204        };
205
206        let size = u16::try_from(header.len() + data.len() + checksum.len())
207            .map_err(|_| SendError::DataTooBig)?
208            .to_be_bytes();
209
210        self.send_chunked(
211            size.into_iter()
212                .chain(header)
213                .chain(data.iter().copied())
214                .chain(checksum),
215        )
216    }
217
218    fn recv_raw_frame<'a>(
219        &mut self,
220        buffer: &'a mut [u8; SMP_TRANSFER_BUFFER_SIZE],
221    ) -> Result<&'a [u8], ReceiveError> {
222        let first_chunk = self.recv_chunk(true)?;
223
224        let (len, first_data) =
225            if let Some((len_data, first_data)) = first_chunk.split_first_chunk::<2>() {
226                (u16::from_be_bytes(*len_data), first_data)
227            } else {
228                return Err(ReceiveError::UnexpectedResponse);
229            };
230
231        let result_buffer = buffer
232            .split_at_mut_checked(len.into())
233            .ok_or(ReceiveError::FrameTooBig)?
234            .0;
235
236        let (first_result_buffer, mut leftover_result_buffer) = result_buffer
237            .split_at_mut_checked(first_data.len())
238            .ok_or(ReceiveError::UnexpectedResponse)?;
239
240        first_result_buffer.copy_from_slice(first_data);
241
242        while !leftover_result_buffer.is_empty() {
243            let next_chunk = self.recv_chunk(false)?;
244
245            let current_result_buffer;
246            (current_result_buffer, leftover_result_buffer) = leftover_result_buffer
247                .split_at_mut_checked(next_chunk.len())
248                .ok_or(ReceiveError::UnexpectedResponse)?;
249
250            current_result_buffer.copy_from_slice(next_chunk);
251        }
252
253        let (data, checksum_data) = result_buffer
254            .split_last_chunk::<2>()
255            .ok_or(ReceiveError::UnexpectedResponse)?;
256
257        let expected_checksum = u16::from_be_bytes(*checksum_data);
258
259        let actual_checksum = self.crc_algo.checksum(data);
260
261        if expected_checksum != actual_checksum {
262            return Err(ReceiveError::UnexpectedResponse);
263        }
264
265        log::debug!("Received SMP Frame ({} bytes)", data.len());
266
267        Ok(data)
268    }
269
270    fn set_timeout(
271        &mut self,
272        timeout: std::time::Duration,
273    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
274        ConfigurableTimeout::set_timeout(&mut self.serial, timeout)
275    }
276}
277
278/// Specifies that the serial transport has a configurable timeout
279pub trait ConfigurableTimeout {
280    /// Changes the communication timeout.
281    ///
282    /// When the device does not respond within the set duration,
283    /// an error will be returned.
284    fn set_timeout(
285        &mut self,
286        duration: Duration,
287    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
288}
289
290impl<T: AsMut<dyn SerialPort> + ?Sized> ConfigurableTimeout for T {
291    fn set_timeout(
292        &mut self,
293        timeout: Duration,
294    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
295        SerialPort::set_timeout(self.as_mut(), timeout).map_err(Into::into)
296    }
297}