use crate::error::{ConnectReasonCode, MqttError, ProtocolError};
use crate::packet::{
self, ConnAck, Connect, DecodePacket, Disconnect, EncodePacket, MqttPacket, PingReq, Publish,
QoS, SubAck, Subscribe,
};
use crate::transport::{self, MqttTransport};
use embassy_time::{Duration, Instant};
use heapless::{String, Vec};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum MqttVersion {
V3,
V5,
}
pub struct MqttOptions<'a> {
client_id: &'a str,
broker_addr: &'a str,
broker_port: u16,
version: MqttVersion,
keep_alive: Duration,
}
impl<'a> MqttOptions<'a> {
pub fn new(client_id: &'a str, broker_addr: &'a str, broker_port: u16) -> Self {
Self { client_id, broker_addr, broker_port, version: MqttVersion::V3, keep_alive: Duration::from_secs(60) }
}
#[cfg(feature = "v5")]
pub fn with_version(mut self, version: MqttVersion) -> Self { self.version = version; self }
pub fn with_keep_alive(mut self, keep_alive: Duration) -> Self { self.keep_alive = keep_alive; self }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
enum ConnectionState {
Disconnected,
Connecting,
Connected,
}
pub struct MqttClient<'a, T, const MAX_TOPICS: usize, const BUF_SIZE: usize>
where
T: MqttTransport,
{
transport: T,
options: MqttOptions<'a>,
tx_buffer: [u8; BUF_SIZE],
rx_buffer: [u8; BUF_SIZE],
state: ConnectionState,
last_tx_time: Instant,
next_packet_id: u16,
}
impl<'a, T, const MAX_TOPICS: usize, const BUF_SIZE: usize>
MqttClient<'a, T, MAX_TOPICS, BUF_SIZE>
where
T: MqttTransport,
{
pub fn new(transport: T, options: MqttOptions<'a>) -> Self {
Self { transport, options, tx_buffer: [0; BUF_SIZE], rx_buffer: [0; BUF_SIZE], state: ConnectionState::Disconnected, last_tx_time: Instant::now(), next_packet_id: 1 }
}
pub async fn connect(&mut self) -> Result<(), MqttError<T::Error>>
where
T::Error: transport::TransportError,
{
self.state = ConnectionState::Connecting;
let connect_packet = Connect::new(
self.options.client_id,
self.options.keep_alive.as_secs() as u16,
true,
);
let len = connect_packet
.encode(&mut self.tx_buffer, self.options.version)
.map_err(MqttError::cast_transport_error)?;
self.transport.send(&self.tx_buffer[..len]).await?;
let n = self.transport.recv(&mut self.rx_buffer).await?;
let packet = packet::decode::<T::Error>(&self.rx_buffer[..n], self.options.version)?
.ok_or(MqttError::Protocol(ProtocolError::InvalidResponse))?;
if let MqttPacket::ConnAck(connack) = packet {
if connack.reason_code == 0 {
self.state = ConnectionState::Connected;
self.last_tx_time = Instant::now();
Ok(())
} else {
self.state = ConnectionState::Disconnected;
Err(MqttError::ConnectionRefused(connack.reason_code.into()))
}
} else {
self.state = ConnectionState::Disconnected;
Err(MqttError::Protocol(ProtocolError::InvalidResponse))
}
}
pub async fn publish<'p>(
&mut self,
_topic: &'p str,
_payload: &'p [u8],
_qos: QoS,
) -> Result<(), MqttError<T::Error>>
where
T::Error: transport::TransportError,
{
Ok(())
}
async fn _send_packet<P>(&mut self, packet: P) -> Result<(), MqttError<T::Error>>
where
P: EncodePacket,
T::Error: transport::TransportError,
{
if self.state != ConnectionState::Connected {
return Err(MqttError::NotConnected);
}
let len = packet
.encode(&mut self.tx_buffer, self.options.version)
.map_err(MqttError::cast_transport_error)?;
self.transport.send(&self.tx_buffer[..len]).await?;
self.last_tx_time = Instant::now();
Ok(())
}
pub async fn poll<'p>(&'p mut self) -> Result<Option<MqttEvent<'p>>, MqttError<T::Error>>
where
T::Error: transport::TransportError,
{
if self.state != ConnectionState::Connected { return Err(MqttError::NotConnected); }
if self.last_tx_time.elapsed() >= self.options.keep_alive {
self._send_packet(PingReq).await?;
}
let n = self.transport.recv(&mut self.rx_buffer).await?;
if n > 0 {
if let Some(packet) = packet::decode::<T::Error>(&self.rx_buffer[..n], self.options.version)? {
if let MqttPacket::Publish(p) = packet {
return Ok(Some(MqttEvent::Publish(p)));
}
}
}
Ok(None)
}
fn get_next_packet_id(&mut self) -> u16 {
self.next_packet_id = self.next_packet_id.wrapping_add(1);
if self.next_packet_id == 0 { self.next_packet_id = 1; }
self.next_packet_id
}
}
#[derive(Debug)]
pub enum MqttEvent<'p> {
Publish(Publish<'p>),
}