Skip to main content

mcumgr_toolkit/
connection.rs

1use std::{io::Cursor, sync::Mutex, time::Duration};
2
3use crate::{
4    DEFAULT_RETRIES,
5    commands::{ErrResponse, ErrResponseV2, McuMgrCommand},
6    smp_errors::{DeviceError, MCUmgrErr},
7    transport::{ReceiveError, SendError, Transport},
8};
9
10use miette::{Diagnostic, IntoDiagnostic};
11use polonius_the_crab::prelude::*;
12use thiserror::Error;
13
14struct Transceiver {
15    transport: Box<dyn Transport + Send>,
16    next_seqnum: u8,
17    receive_buffer: Box<[u8; u16::MAX as usize]>,
18}
19
20struct Inner {
21    transceiver: Transceiver,
22    send_buffer: Box<[u8; u16::MAX as usize]>,
23    retries: u8,
24}
25
26/// An SMP protocol layer connection to a device.
27///
28/// In most cases this struct will not be used directly by the user,
29/// but instead it is used indirectly through [`MCUmgrClient`](crate::MCUmgrClient).
30pub struct Connection {
31    inner: Mutex<Inner>,
32}
33
34/// Errors that can happen on SMP protocol level
35#[derive(Error, Debug, Diagnostic)]
36pub enum ExecuteError {
37    /// An error happened on SMP transport level while sending a request
38    #[error("Sending failed")]
39    #[diagnostic(code(mcumgr_toolkit::connection::execute::send))]
40    SendFailed(#[from] SendError),
41    /// An error happened on SMP transport level while receiving a response
42    #[error("Receiving failed")]
43    #[diagnostic(code(mcumgr_toolkit::connection::execute::receive))]
44    ReceiveFailed(#[from] ReceiveError),
45    /// An error happened while CBOR encoding the request payload
46    #[error("CBOR encoding failed")]
47    #[diagnostic(code(mcumgr_toolkit::connection::execute::encode))]
48    EncodeFailed(#[source] Box<dyn miette::Diagnostic + Send + Sync>),
49    /// An error happened while CBOR decoding the response payload
50    #[error("CBOR decoding failed")]
51    #[diagnostic(code(mcumgr_toolkit::connection::execute::decode))]
52    DecodeFailed(#[source] Box<dyn miette::Diagnostic + Send + Sync>),
53    /// The device returned an SMP error
54    #[error("Device returned error code: {0}")]
55    #[diagnostic(code(mcumgr_toolkit::connection::execute::device_error))]
56    ErrorResponse(DeviceError),
57}
58
59impl ExecuteError {
60    /// Checks if the device reported the command as unsupported
61    pub fn command_not_supported(&self) -> bool {
62        if let Self::ErrorResponse(DeviceError::V1 { rc, .. }) = self {
63            *rc == MCUmgrErr::MGMT_ERR_ENOTSUP as i32
64        } else {
65            false
66        }
67    }
68}
69
70impl Transceiver {
71    fn transceive_command(
72        &mut self,
73        write_operation: bool,
74        group_id: u16,
75        command_id: u8,
76        data: &[u8],
77    ) -> Result<&'_ [u8], ExecuteError> {
78        let sequence_num = self.next_seqnum;
79        self.next_seqnum = self.next_seqnum.wrapping_add(1);
80
81        self.transport
82            .send_frame(write_operation, sequence_num, group_id, command_id, data)?;
83
84        self.transport
85            .receive_frame(
86                &mut self.receive_buffer,
87                write_operation,
88                sequence_num,
89                group_id,
90                command_id,
91            )
92            .map_err(Into::into)
93    }
94
95    fn transceive_command_with_retries(
96        &mut self,
97        write_operation: bool,
98        group_id: u16,
99        command_id: u8,
100        data: &[u8],
101        num_retries: u8,
102    ) -> Result<&'_ [u8], ExecuteError> {
103        let mut this = self;
104
105        let mut counter = 0;
106
107        polonius_loop!(|this| -> Result<&'polonius [u8], ExecuteError> {
108            let result = this.transceive_command(write_operation, group_id, command_id, data);
109
110            if counter >= num_retries {
111                polonius_return!(result)
112            }
113            counter += 1;
114
115            match result {
116                Ok(_) => polonius_return!(result),
117                Err(e) => {
118                    let mut lowest_err: &dyn std::error::Error = &e;
119                    while let Some(lower_err) = lowest_err.source() {
120                        lowest_err = lower_err;
121                    }
122                    log::warn!("Retry transmission, error occurred: {lowest_err}");
123                }
124            }
125        })
126    }
127}
128
129impl Connection {
130    /// Creates a new SMP connection
131    pub fn new<T: Transport + Send + 'static>(transport: T) -> Self {
132        Self {
133            inner: Mutex::new(Inner {
134                transceiver: Transceiver {
135                    transport: Box::new(transport),
136                    next_seqnum: rand::random(),
137                    receive_buffer: Box::new([0; u16::MAX as usize]),
138                },
139                send_buffer: Box::new([0; u16::MAX as usize]),
140                retries: DEFAULT_RETRIES,
141            }),
142        }
143    }
144
145    /// Returns the maximum SMP frame size the underlying transport can
146    /// deliver reliably.
147    pub fn max_transport_frame_size(&self) -> usize {
148        self.inner
149            .lock()
150            .unwrap()
151            .transceiver
152            .transport
153            .max_smp_frame_size()
154    }
155
156    /// Changes the communication timeout.
157    ///
158    /// When the device does not respond to packets within the set
159    /// duration, an error will be raised.
160    pub fn set_timeout(
161        &self,
162        timeout: Duration,
163    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
164        self.inner
165            .lock()
166            .unwrap()
167            .transceiver
168            .transport
169            .set_timeout(timeout)
170    }
171
172    /// Changes the retry amount.
173    ///
174    /// When the device encounters a transport error, it will retry
175    /// this many times until giving up.
176    pub fn set_retries(&self, retries: u8) {
177        self.inner.lock().unwrap().retries = retries;
178    }
179
180    /// Executes a given CBOR based SMP command.
181    pub fn execute_command<R: McuMgrCommand>(
182        &self,
183        request: &R,
184    ) -> Result<R::Response, ExecuteError> {
185        self.execute_command_impl(request, true)
186    }
187
188    /// Executes a given CBOR based SMP command.
189    ///
190    /// Does not use retries.
191    pub fn execute_command_without_retries<R: McuMgrCommand>(
192        &self,
193        request: &R,
194    ) -> Result<R::Response, ExecuteError> {
195        self.execute_command_impl(request, false)
196    }
197
198    fn execute_command_impl<R: McuMgrCommand>(
199        &self,
200        request: &R,
201        use_retries: bool,
202    ) -> Result<R::Response, ExecuteError> {
203        let mut lock_guard = self.inner.lock().unwrap();
204        let locked_self: &mut Inner = &mut lock_guard;
205
206        let mut cursor = Cursor::new(locked_self.send_buffer.as_mut_slice());
207        ciborium::into_writer(request.data(), &mut cursor)
208            .into_diagnostic()
209            .map_err(Into::into)
210            .map_err(ExecuteError::EncodeFailed)?;
211        let data_size = cursor.position() as usize;
212        let data = &locked_self.send_buffer[..data_size];
213
214        log::debug!("TX data: {}", hex::encode(data));
215
216        let write_operation = request.is_write_operation();
217        let group_id = request.group_id();
218        let command_id = request.command_id();
219
220        let response = locked_self.transceiver.transceive_command_with_retries(
221            write_operation,
222            group_id,
223            command_id,
224            data,
225            if use_retries { locked_self.retries } else { 0 },
226        )?;
227
228        log::debug!("RX data: {}", hex::encode(response));
229
230        let err: ErrResponse = ciborium::from_reader(Cursor::new(response))
231            .into_diagnostic()
232            .map_err(Into::into)
233            .map_err(ExecuteError::DecodeFailed)?;
234
235        if let Some(ErrResponseV2 { rc, group }) = err.err {
236            return Err(ExecuteError::ErrorResponse(DeviceError::V2 { group, rc }));
237        }
238
239        if let Some(rc) = err.rc {
240            if rc != MCUmgrErr::MGMT_ERR_EOK as i32 {
241                return Err(ExecuteError::ErrorResponse(DeviceError::V1 {
242                    rc,
243                    rsn: err.rsn,
244                }));
245            }
246        }
247
248        let decoded_response: R::Response = ciborium::from_reader(Cursor::new(response))
249            .into_diagnostic()
250            .map_err(Into::into)
251            .map_err(ExecuteError::DecodeFailed)?;
252
253        Ok(decoded_response)
254    }
255
256    /// Executes a raw SMP command.
257    ///
258    /// Same as [`Connection::execute_command`], but the payload can be anything and must not
259    /// necessarily be CBOR encoded.
260    ///
261    /// Errors are also not decoded but instead will be returned as raw CBOR data.
262    ///
263    /// Read Zephyr's [SMP Protocol Specification](https://docs.zephyrproject.org/latest/services/device_mgmt/smp_protocol.html)
264    /// for more information.
265    pub fn execute_raw_command(
266        &self,
267        write_operation: bool,
268        group_id: u16,
269        command_id: u8,
270        data: &[u8],
271        use_retries: bool,
272    ) -> Result<Box<[u8]>, ExecuteError> {
273        let mut lock_guard = self.inner.lock().unwrap();
274        let locked_self: &mut Inner = &mut lock_guard;
275
276        locked_self
277            .transceiver
278            .transceive_command_with_retries(
279                write_operation,
280                group_id,
281                command_id,
282                data,
283                if use_retries { locked_self.retries } else { 0 },
284            )
285            .map(|val| val.into())
286    }
287}