mqtt_async_embedded/
client.rs

1//! # The Asynchronous MQTT Client
2//!
3//! This module contains the primary `MqttClient` struct, which manages the state,
4//! connection, and communication with an MQTT broker.
5
6use crate::error::{ConnectReasonCode, MqttError, ProtocolError};
7use crate::packet::{
8    self, ConnAck, Connect, DecodePacket, Disconnect, EncodePacket, MqttPacket, PingReq, Publish,
9    QoS, SubAck, Subscribe,
10};
11use crate::transport::{self, MqttTransport};
12use embassy_time::{Duration, Instant};
13use heapless::{String, Vec};
14
15/// Represents the MQTT protocol version used by the client.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[cfg_attr(feature = "defmt", derive(defmt::Format))]
18pub enum MqttVersion {
19    V3,
20    V5,
21}
22
23/// Configuration options for the `MqttClient`.
24pub struct MqttOptions<'a> {
25    client_id: &'a str,
26    broker_addr: &'a str,
27    broker_port: u16,
28    version: MqttVersion,
29    keep_alive: Duration,
30}
31
32impl<'a> MqttOptions<'a> {
33    pub fn new(client_id: &'a str, broker_addr: &'a str, broker_port: u16) -> Self {
34        Self { client_id, broker_addr, broker_port, version: MqttVersion::V3, keep_alive: Duration::from_secs(60) }
35    }
36    #[cfg(feature = "v5")]
37    pub fn with_version(mut self, version: MqttVersion) -> Self { self.version = version; self }
38    pub fn with_keep_alive(mut self, keep_alive: Duration) -> Self { self.keep_alive = keep_alive; self }
39}
40
41/// Represents the current connection state of the client.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[cfg_attr(feature = "defmt", derive(defmt::Format))]
44enum ConnectionState {
45    Disconnected,
46    Connecting,
47    Connected,
48}
49
50/// The asynchronous MQTT client.
51pub struct MqttClient<'a, T, const MAX_TOPICS: usize, const BUF_SIZE: usize>
52where
53    T: MqttTransport,
54{
55    transport: T,
56    options: MqttOptions<'a>,
57    tx_buffer: [u8; BUF_SIZE],
58    rx_buffer: [u8; BUF_SIZE],
59    state: ConnectionState,
60    last_tx_time: Instant,
61    next_packet_id: u16,
62}
63
64impl<'a, T, const MAX_TOPICS: usize, const BUF_SIZE: usize>
65MqttClient<'a, T, MAX_TOPICS, BUF_SIZE>
66where
67    T: MqttTransport,
68{
69    pub fn new(transport: T, options: MqttOptions<'a>) -> Self {
70        Self { transport, options, tx_buffer: [0; BUF_SIZE], rx_buffer: [0; BUF_SIZE], state: ConnectionState::Disconnected, last_tx_time: Instant::now(), next_packet_id: 1 }
71    }
72
73    /// Attempts to connect to the MQTT broker.
74    pub async fn connect(&mut self) -> Result<(), MqttError<T::Error>>
75    where
76        T::Error: transport::TransportError,
77    {
78        self.state = ConnectionState::Connecting;
79        let connect_packet = Connect::new(
80            self.options.client_id,
81            self.options.keep_alive.as_secs() as u16,
82            true,
83        );
84        let len = connect_packet
85            .encode(&mut self.tx_buffer, self.options.version)
86            .map_err(MqttError::cast_transport_error)?;
87        self.transport.send(&self.tx_buffer[..len]).await?;
88        let n = self.transport.recv(&mut self.rx_buffer).await?;
89        let packet = packet::decode::<T::Error>(&self.rx_buffer[..n], self.options.version)?
90            .ok_or(MqttError::Protocol(ProtocolError::InvalidResponse))?;
91        if let MqttPacket::ConnAck(connack) = packet {
92            if connack.reason_code == 0 {
93                self.state = ConnectionState::Connected;
94                self.last_tx_time = Instant::now();
95                Ok(())
96            } else {
97                self.state = ConnectionState::Disconnected;
98                Err(MqttError::ConnectionRefused(connack.reason_code.into()))
99            }
100        } else {
101            self.state = ConnectionState::Disconnected;
102            Err(MqttError::Protocol(ProtocolError::InvalidResponse))
103        }
104    }
105
106    /// Publishes a message to a topic.
107    pub async fn publish<'p>(
108        &mut self,
109        _topic: &'p str,
110        _payload: &'p [u8],
111        _qos: QoS,
112    ) -> Result<(), MqttError<T::Error>>
113    where
114        T::Error: transport::TransportError,
115    {
116        Ok(())
117    }
118
119    /// Sends a pre-constructed packet over the transport.
120    async fn _send_packet<P>(&mut self, packet: P) -> Result<(), MqttError<T::Error>>
121    where
122        P: EncodePacket,
123        T::Error: transport::TransportError,
124    {
125        if self.state != ConnectionState::Connected {
126            return Err(MqttError::NotConnected);
127        }
128        let len = packet
129            .encode(&mut self.tx_buffer, self.options.version)
130            .map_err(MqttError::cast_transport_error)?;
131        self.transport.send(&self.tx_buffer[..len]).await?;
132        self.last_tx_time = Instant::now();
133        Ok(())
134    }
135
136    /// Polls the connection for incoming packets and handles keep-alives.
137    ///
138    /// The returned `MqttEvent` contains references to the client's internal receive
139    /// buffer. These references are only valid until the next call to `poll`.
140    pub async fn poll<'p>(&'p mut self) -> Result<Option<MqttEvent<'p>>, MqttError<T::Error>>
141    where
142        T::Error: transport::TransportError,
143    {
144        if self.state != ConnectionState::Connected { return Err(MqttError::NotConnected); }
145
146        if self.last_tx_time.elapsed() >= self.options.keep_alive {
147            self._send_packet(PingReq).await?;
148        }
149
150        let n = self.transport.recv(&mut self.rx_buffer).await?;
151        if n > 0 {
152            if let Some(packet) = packet::decode::<T::Error>(&self.rx_buffer[..n], self.options.version)? {
153                if let MqttPacket::Publish(p) = packet {
154                    // The event is valid for the lifetime 'p of the poll borrow
155                    return Ok(Some(MqttEvent::Publish(p)));
156                }
157            }
158        }
159        Ok(None)
160    }
161
162    fn get_next_packet_id(&mut self) -> u16 {
163        self.next_packet_id = self.next_packet_id.wrapping_add(1);
164        if self.next_packet_id == 0 { self.next_packet_id = 1; }
165        self.next_packet_id
166    }
167}
168
169/// Represents an event received from the MQTT broker.
170/// The lifetime `'p` indicates that the event borrows data from the client's
171/// buffer and is only valid for the duration of the `poll` call.
172#[derive(Debug)]
173pub enum MqttEvent<'p> {
174    Publish(Publish<'p>),
175}
176