mqtt_async_embedded/
transport.rs

1//! # MQTT Transport Abstraction
2//!
3//! This module defines the `MqttTransport` trait, which abstracts the underlying
4//! communication channel (like TCP, UART, etc.), allowing the MQTT client to be
5//! hardware and network-stack agnostic.
6//!
7//! With the Rust 2024 Edition, this trait uses native `async fn`, removing the
8//! need for the `#[async_trait]` macro.
9
10
11/// A placeholder error type used in contexts where the actual transport error is not known,
12/// such as in the `EncodePacket` trait.
13#[derive(Debug, Copy, Clone)]
14#[cfg_attr(feature = "defmt", derive(defmt::Format))]
15pub struct ErrorPlaceHolder;
16
17/// A trait representing a transport for MQTT packets.
18///
19/// This trait abstracts over any reliable, ordered, stream-based communication channel.
20// `async fn` in traits is now stable in Rust 2024, so `#[async_trait]` is not needed.
21pub trait MqttTransport {
22    /// The error type returned by the transport.
23    type Error: core::fmt::Debug;
24
25    /// Sends a buffer of data over the transport.
26    async fn send(&mut self, buf: &[u8]) -> Result<(), Self::Error>;
27
28    /// Receives data from the transport into a buffer.
29    ///
30    /// Returns the number of bytes read.
31    async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
32}
33
34// Allow the placeholder to be treated as a transport error for generic contexts.
35impl TransportError for ErrorPlaceHolder {}
36
37/// A marker trait for transport-related errors.
38pub trait TransportError: core::fmt::Debug {}
39
40/// An example TCP transport implementation using `embassy-net`.
41#[cfg(feature = "transport-smoltcp")]
42pub struct TcpTransport<'a> {
43    socket: embassy_net::tcp::TcpSocket<'a>,
44    timeout: Duration,
45}
46
47#[cfg(feature = "transport-smoltcp")]
48impl<'a> TcpTransport<'a> {
49    /// Creates a new `TcpTransport` with the given socket and timeout.
50    pub fn new(socket: embassy_net::tcp::TcpSocket<'a>, timeout: Duration) -> Self {
51        Self { socket, timeout }
52    }
53
54    /// A helper function to perform a read with a timeout.
55    async fn read_with_timeout<'b>(
56        &'b mut self,
57        buf: &'b mut [u8],
58    ) -> Result<Result<usize, MqttError<embassy_net::tcp::Error>>, MqttError<embassy_net::tcp::Error>>
59    {
60        // Use `select` to race the read operation against a timer.
61        let read_fut = self.socket.read(buf).map(Ok);
62        let timer = Timer::after(self.timeout).map(|_| Err(MqttError::Timeout));
63
64        match futures::future::select(read_fut, timer).await {
65            futures::future::Either::Left((Ok(Ok(n)), _)) => {
66                if n == 0 {
67                    // If the peer closes the connection, read returns 0.
68                    Err(MqttError::Protocol(super::error::ProtocolError::InvalidResponse))
69                } else {
70                    Ok(Ok(n))
71                }
72            }
73            futures::future::Either::Left((Ok(Err(e)), _)) => Ok(Err(MqttError::Transport(e))),
74            futures::future::Either::Right((Err(e), _)) => Err(e),
75            _ => unreachable!(),
76        }
77    }
78}
79
80#[cfg(feature = "transport-smoltcp")]
81impl<'a> MqttTransport for TcpTransport<'a> {
82    type Error = MqttError<embassy_net::tcp::Error>;
83
84    async fn send(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
85        self.socket.write_all(buf).await.map_err(MqttError::Transport)
86    }
87
88    async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
89        match self.read_with_timeout(buf).await {
90            Ok(Ok(n)) => Ok(n),
91            Ok(Err(e)) => Err(e),
92            Err(e) => Err(e),
93        }
94    }
95}
96