use std::{io::Cursor, sync::Mutex, time::Duration};
use crate::{
DEFAULT_RETRIES,
commands::{ErrResponse, ErrResponseV2, McuMgrCommand},
smp_errors::{DeviceError, MCUmgrErr},
transport::{ReceiveError, SendError, Transport},
};
use miette::{Diagnostic, IntoDiagnostic};
use polonius_the_crab::prelude::*;
use thiserror::Error;
struct Transceiver {
transport: Box<dyn Transport + Send>,
next_seqnum: u8,
receive_buffer: Box<[u8; u16::MAX as usize]>,
}
struct Inner {
transceiver: Transceiver,
send_buffer: Box<[u8; u16::MAX as usize]>,
retries: u8,
}
pub struct Connection {
inner: Mutex<Inner>,
}
#[derive(Error, Debug, Diagnostic)]
pub enum ExecuteError {
#[error("Sending failed")]
#[diagnostic(code(mcumgr_toolkit::connection::execute::send))]
SendFailed(#[from] SendError),
#[error("Receiving failed")]
#[diagnostic(code(mcumgr_toolkit::connection::execute::receive))]
ReceiveFailed(#[from] ReceiveError),
#[error("CBOR encoding failed")]
#[diagnostic(code(mcumgr_toolkit::connection::execute::encode))]
EncodeFailed(#[source] Box<dyn miette::Diagnostic + Send + Sync>),
#[error("CBOR decoding failed")]
#[diagnostic(code(mcumgr_toolkit::connection::execute::decode))]
DecodeFailed(#[source] Box<dyn miette::Diagnostic + Send + Sync>),
#[error("Device returned error code: {0}")]
#[diagnostic(code(mcumgr_toolkit::connection::execute::device_error))]
ErrorResponse(DeviceError),
}
impl ExecuteError {
pub fn command_not_supported(&self) -> bool {
if let Self::ErrorResponse(DeviceError::V1 { rc, .. }) = self {
*rc == MCUmgrErr::MGMT_ERR_ENOTSUP as i32
} else {
false
}
}
}
impl Transceiver {
fn transceive_command(
&mut self,
write_operation: bool,
group_id: u16,
command_id: u8,
data: &[u8],
) -> Result<&'_ [u8], ExecuteError> {
let sequence_num = self.next_seqnum;
self.next_seqnum = self.next_seqnum.wrapping_add(1);
self.transport
.send_frame(write_operation, sequence_num, group_id, command_id, data)?;
self.transport
.receive_frame(
&mut self.receive_buffer,
write_operation,
sequence_num,
group_id,
command_id,
)
.map_err(Into::into)
}
fn transceive_command_with_retries(
&mut self,
write_operation: bool,
group_id: u16,
command_id: u8,
data: &[u8],
num_retries: u8,
) -> Result<&'_ [u8], ExecuteError> {
let mut this = self;
let mut counter = 0;
polonius_loop!(|this| -> Result<&'polonius [u8], ExecuteError> {
let result = this.transceive_command(write_operation, group_id, command_id, data);
if counter >= num_retries {
polonius_return!(result)
}
counter += 1;
match result {
Ok(_) => polonius_return!(result),
Err(e) => {
let mut lowest_err: &dyn std::error::Error = &e;
while let Some(lower_err) = lowest_err.source() {
lowest_err = lower_err;
}
log::warn!("Retry transmission, error occurred: {lowest_err}");
}
}
})
}
}
impl Connection {
pub fn new<T: Transport + Send + 'static>(transport: T) -> Self {
Self {
inner: Mutex::new(Inner {
transceiver: Transceiver {
transport: Box::new(transport),
next_seqnum: rand::random(),
receive_buffer: Box::new([0; u16::MAX as usize]),
},
send_buffer: Box::new([0; u16::MAX as usize]),
retries: DEFAULT_RETRIES,
}),
}
}
pub fn max_transport_frame_size(&self) -> usize {
self.inner
.lock()
.unwrap()
.transceiver
.transport
.max_smp_frame_size()
}
pub fn set_timeout(
&self,
timeout: Duration,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.inner
.lock()
.unwrap()
.transceiver
.transport
.set_timeout(timeout)
}
pub fn set_retries(&self, retries: u8) {
self.inner.lock().unwrap().retries = retries;
}
pub fn execute_command<R: McuMgrCommand>(
&self,
request: &R,
) -> Result<R::Response, ExecuteError> {
self.execute_command_impl(request, true)
}
pub fn execute_command_without_retries<R: McuMgrCommand>(
&self,
request: &R,
) -> Result<R::Response, ExecuteError> {
self.execute_command_impl(request, false)
}
fn execute_command_impl<R: McuMgrCommand>(
&self,
request: &R,
use_retries: bool,
) -> Result<R::Response, ExecuteError> {
let mut lock_guard = self.inner.lock().unwrap();
let locked_self: &mut Inner = &mut lock_guard;
let mut cursor = Cursor::new(locked_self.send_buffer.as_mut_slice());
ciborium::into_writer(request.data(), &mut cursor)
.into_diagnostic()
.map_err(Into::into)
.map_err(ExecuteError::EncodeFailed)?;
let data_size = cursor.position() as usize;
let data = &locked_self.send_buffer[..data_size];
log::debug!("TX data: {}", hex::encode(data));
let write_operation = request.is_write_operation();
let group_id = request.group_id();
let command_id = request.command_id();
let response = locked_self.transceiver.transceive_command_with_retries(
write_operation,
group_id,
command_id,
data,
if use_retries { locked_self.retries } else { 0 },
)?;
log::debug!("RX data: {}", hex::encode(response));
let err: ErrResponse = ciborium::from_reader(Cursor::new(response))
.into_diagnostic()
.map_err(Into::into)
.map_err(ExecuteError::DecodeFailed)?;
if let Some(ErrResponseV2 { rc, group }) = err.err {
return Err(ExecuteError::ErrorResponse(DeviceError::V2 { group, rc }));
}
if let Some(rc) = err.rc {
if rc != MCUmgrErr::MGMT_ERR_EOK as i32 {
return Err(ExecuteError::ErrorResponse(DeviceError::V1 {
rc,
rsn: err.rsn,
}));
}
}
let decoded_response: R::Response = ciborium::from_reader(Cursor::new(response))
.into_diagnostic()
.map_err(Into::into)
.map_err(ExecuteError::DecodeFailed)?;
Ok(decoded_response)
}
pub fn execute_raw_command(
&self,
write_operation: bool,
group_id: u16,
command_id: u8,
data: &[u8],
use_retries: bool,
) -> Result<Box<[u8]>, ExecuteError> {
let mut lock_guard = self.inner.lock().unwrap();
let locked_self: &mut Inner = &mut lock_guard;
locked_self
.transceiver
.transceive_command_with_retries(
write_operation,
group_id,
command_id,
data,
if use_retries { locked_self.retries } else { 0 },
)
.map(|val| val.into())
}
}