1use crate::error::{ConnectReasonCode, MqttError, ProtocolError};
7use crate::packet::{
8 self, ConnAck, Connect, DecodePacket, Disconnect, EncodePacket, MqttPacket, PingReq, Publish,
9 QoS, SubAck, Subscribe,
10};
11use crate::transport::{self, MqttTransport};
12use embassy_time::{Duration, Instant};
13use heapless::{String, Vec};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[cfg_attr(feature = "defmt", derive(defmt::Format))]
18pub enum MqttVersion {
19 V3,
20 V5,
21}
22
23pub struct MqttOptions<'a> {
25 client_id: &'a str,
26 broker_addr: &'a str,
27 broker_port: u16,
28 version: MqttVersion,
29 keep_alive: Duration,
30}
31
32impl<'a> MqttOptions<'a> {
33 pub fn new(client_id: &'a str, broker_addr: &'a str, broker_port: u16) -> Self {
34 Self { client_id, broker_addr, broker_port, version: MqttVersion::V3, keep_alive: Duration::from_secs(60) }
35 }
36 #[cfg(feature = "v5")]
37 pub fn with_version(mut self, version: MqttVersion) -> Self { self.version = version; self }
38 pub fn with_keep_alive(mut self, keep_alive: Duration) -> Self { self.keep_alive = keep_alive; self }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[cfg_attr(feature = "defmt", derive(defmt::Format))]
44enum ConnectionState {
45 Disconnected,
46 Connecting,
47 Connected,
48}
49
50pub struct MqttClient<'a, T, const MAX_TOPICS: usize, const BUF_SIZE: usize>
52where
53 T: MqttTransport,
54{
55 transport: T,
56 options: MqttOptions<'a>,
57 tx_buffer: [u8; BUF_SIZE],
58 rx_buffer: [u8; BUF_SIZE],
59 state: ConnectionState,
60 last_tx_time: Instant,
61 next_packet_id: u16,
62}
63
64impl<'a, T, const MAX_TOPICS: usize, const BUF_SIZE: usize>
65MqttClient<'a, T, MAX_TOPICS, BUF_SIZE>
66where
67 T: MqttTransport,
68{
69 pub fn new(transport: T, options: MqttOptions<'a>) -> Self {
70 Self { transport, options, tx_buffer: [0; BUF_SIZE], rx_buffer: [0; BUF_SIZE], state: ConnectionState::Disconnected, last_tx_time: Instant::now(), next_packet_id: 1 }
71 }
72
73 pub async fn connect(&mut self) -> Result<(), MqttError<T::Error>>
75 where
76 T::Error: transport::TransportError,
77 {
78 self.state = ConnectionState::Connecting;
79 let connect_packet = Connect::new(
80 self.options.client_id,
81 self.options.keep_alive.as_secs() as u16,
82 true,
83 );
84 let len = connect_packet
85 .encode(&mut self.tx_buffer, self.options.version)
86 .map_err(MqttError::cast_transport_error)?;
87 self.transport.send(&self.tx_buffer[..len]).await?;
88 let n = self.transport.recv(&mut self.rx_buffer).await?;
89 let packet = packet::decode::<T::Error>(&self.rx_buffer[..n], self.options.version)?
90 .ok_or(MqttError::Protocol(ProtocolError::InvalidResponse))?;
91 if let MqttPacket::ConnAck(connack) = packet {
92 if connack.reason_code == 0 {
93 self.state = ConnectionState::Connected;
94 self.last_tx_time = Instant::now();
95 Ok(())
96 } else {
97 self.state = ConnectionState::Disconnected;
98 Err(MqttError::ConnectionRefused(connack.reason_code.into()))
99 }
100 } else {
101 self.state = ConnectionState::Disconnected;
102 Err(MqttError::Protocol(ProtocolError::InvalidResponse))
103 }
104 }
105
106 pub async fn publish<'p>(
108 &mut self,
109 _topic: &'p str,
110 _payload: &'p [u8],
111 _qos: QoS,
112 ) -> Result<(), MqttError<T::Error>>
113 where
114 T::Error: transport::TransportError,
115 {
116 Ok(())
117 }
118
119 async fn _send_packet<P>(&mut self, packet: P) -> Result<(), MqttError<T::Error>>
121 where
122 P: EncodePacket,
123 T::Error: transport::TransportError,
124 {
125 if self.state != ConnectionState::Connected {
126 return Err(MqttError::NotConnected);
127 }
128 let len = packet
129 .encode(&mut self.tx_buffer, self.options.version)
130 .map_err(MqttError::cast_transport_error)?;
131 self.transport.send(&self.tx_buffer[..len]).await?;
132 self.last_tx_time = Instant::now();
133 Ok(())
134 }
135
136 pub async fn poll<'p>(&'p mut self) -> Result<Option<MqttEvent<'p>>, MqttError<T::Error>>
141 where
142 T::Error: transport::TransportError,
143 {
144 if self.state != ConnectionState::Connected { return Err(MqttError::NotConnected); }
145
146 if self.last_tx_time.elapsed() >= self.options.keep_alive {
147 self._send_packet(PingReq).await?;
148 }
149
150 let n = self.transport.recv(&mut self.rx_buffer).await?;
151 if n > 0 {
152 if let Some(packet) = packet::decode::<T::Error>(&self.rx_buffer[..n], self.options.version)? {
153 if let MqttPacket::Publish(p) = packet {
154 return Ok(Some(MqttEvent::Publish(p)));
156 }
157 }
158 }
159 Ok(None)
160 }
161
162 fn get_next_packet_id(&mut self) -> u16 {
163 self.next_packet_id = self.next_packet_id.wrapping_add(1);
164 if self.next_packet_id == 0 { self.next_packet_id = 1; }
165 self.next_packet_id
166 }
167}
168
169#[derive(Debug)]
173pub enum MqttEvent<'p> {
174 Publish(Publish<'p>),
175}
176