#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ErrorPlaceHolder;
pub trait MqttTransport {
type Error: core::fmt::Debug;
async fn send(&mut self, buf: &[u8]) -> Result<(), Self::Error>;
async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error>;
}
impl TransportError for ErrorPlaceHolder {}
pub trait TransportError: core::fmt::Debug {}
#[cfg(feature = "transport-smoltcp")]
pub struct TcpTransport<'a> {
socket: embassy_net::tcp::TcpSocket<'a>,
timeout: Duration,
}
#[cfg(feature = "transport-smoltcp")]
impl<'a> TcpTransport<'a> {
pub fn new(socket: embassy_net::tcp::TcpSocket<'a>, timeout: Duration) -> Self {
Self { socket, timeout }
}
async fn read_with_timeout<'b>(
&'b mut self,
buf: &'b mut [u8],
) -> Result<Result<usize, MqttError<embassy_net::tcp::Error>>, MqttError<embassy_net::tcp::Error>>
{
let read_fut = self.socket.read(buf).map(Ok);
let timer = Timer::after(self.timeout).map(|_| Err(MqttError::Timeout));
match futures::future::select(read_fut, timer).await {
futures::future::Either::Left((Ok(Ok(n)), _)) => {
if n == 0 {
Err(MqttError::Protocol(super::error::ProtocolError::InvalidResponse))
} else {
Ok(Ok(n))
}
}
futures::future::Either::Left((Ok(Err(e)), _)) => Ok(Err(MqttError::Transport(e))),
futures::future::Either::Right((Err(e), _)) => Err(e),
_ => unreachable!(),
}
}
}
#[cfg(feature = "transport-smoltcp")]
impl<'a> MqttTransport for TcpTransport<'a> {
type Error = MqttError<embassy_net::tcp::Error>;
async fn send(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
self.socket.write_all(buf).await.map_err(MqttError::Transport)
}
async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
match self.read_with_timeout(buf).await {
Ok(Ok(n)) => Ok(n),
Ok(Err(e)) => Err(e),
Err(e) => Err(e),
}
}
}