tinymqtt/
lib.rs

1#![no_std]
2
3use num_derive::{FromPrimitive, ToPrimitive};
4use num_traits::FromPrimitive;
5
6use self::flags::Flags;
7use self::reader_writer::{MqttMessageReader, MqttMessageWriter};
8
9pub mod flags;
10pub mod reader_writer;
11
12pub type MqttResult<T> = Result<T, MqttError>;
13
14pub struct MqttClient<const N: usize> {
15    state: MqttState,
16    packet_counter: PacketIdCounter,
17
18    construct_buffer: [u8; N],
19    message_buffer: [u8; N],
20}
21
22impl<const N: usize> MqttClient<N> {
23    pub fn new() -> Self {
24        Self {
25            state: MqttState::Disconnected,
26            packet_counter: PacketIdCounter::new(),
27            construct_buffer: [0; N],
28            message_buffer: [0; N],
29        }
30    }
31
32    pub fn connect(
33        &mut self,
34        client_id: &str,
35        username_password: Option<(&str, &str)>,
36    ) -> MqttResult<&[u8]> {
37        let mut writer = MqttMessageWriter::new(&mut self.construct_buffer);
38
39        // Connect flags
40        let mut flags = Flags::zero();
41        flags.set(1); // clean start
42        if username_password.is_some() {
43            flags.set(6).set(7); // user name, password
44        }
45
46        // variable header
47        writer.write_string("MQTT");
48        writer.write_u8(0x05); // Protocol version
49        writer.write_flags(flags); // Connect flags
50        writer.write_u16(0); // Keep alive turned off
51        writer.write_u8(0); // No properties
52
53        // payload
54        writer.write_string(client_id);
55
56        if let Some((username, password)) = username_password {
57            writer.write_string(username);
58            writer.write_string(password);
59        }
60
61        self.state = MqttState::Connecting;
62
63        let len = writer.len();
64        self.write_packet(ControlPacketType::CONNECT, Flags::zero(), len)
65    }
66
67    pub fn publish(&mut self, topic: &str, payload: &[u8]) -> Result<&[u8], MqttError> {
68        self.assert_state(MqttState::Connected)?; // Check if the client is connected
69
70        let mut writer = MqttMessageWriter::new(&mut self.construct_buffer);
71
72        // variable header
73        writer.write_string(topic);
74        writer.write_u16(0); // Packet identifier
75        writer.write_u8(0); // No properties
76
77        // payload
78        writer.write_bytes_raw(payload);
79
80        let len = writer.len();
81        self.write_packet(ControlPacketType::PUBLISH, Flags::zero(), len)
82    }
83
84    pub fn subscribe(&mut self, topic_filter: &str) -> Result<&[u8], MqttError> {
85        self.assert_state(MqttState::Connected)?; // Check if the client is connected
86
87        let mut writer = MqttMessageWriter::new(&mut self.construct_buffer);
88
89        // variable header
90        writer.write_u16(self.packet_counter.next()); // Packet identifier
91        writer.write_u8(0); // No properties
92
93        // payload
94        writer.write_string(topic_filter);
95        writer.write_flags(Flags::zero()); // Subscription Options (with maximum QoS 0)
96
97        let len = writer.len();
98        self.write_packet(ControlPacketType::SUBSCRIBE, Flags::new(0b0010), len)
99    }
100
101    pub fn unsubscribe(&mut self, topic_filter: &str) -> Result<&[u8], MqttError> {
102        self.assert_state(MqttState::Connected)?; // Check if the client is connected
103
104        let mut writer = MqttMessageWriter::new(&mut self.construct_buffer);
105
106        // variable header
107        writer.write_u16(self.packet_counter.next()); // Packet identifier
108        writer.write_u8(0); // No properties
109
110        // payload
111        writer.write_string(topic_filter);
112
113        let len = writer.len();
114        self.write_packet(ControlPacketType::UNSUBSCRIBE, Flags::new(0b0010), len)
115    }
116
117    pub fn receive_packet(
118        &mut self,
119        packet: &[u8],
120        mut on_publish_rec: impl FnMut(&mut Self, &str, &[u8]) -> (),
121    ) -> Result<MqttState, MqttError> {
122        let mut reader = MqttMessageReader::new(packet);
123
124        while reader.remaining() > 0 {
125            // Parse fixed header
126            let fixed_header = reader.read_u8();
127            let ty =
128                ControlPacketType::from_u8(fixed_header >> 4).ok_or(MqttError::InvalidPacket)?;
129            let _fixed_header_flags = Flags::new(fixed_header & 0x0F);
130            let remaining_length = reader.read_variable_int() as usize;
131            reader.mark(); // Remember start of packet content so we can skip it later
132
133            match ty {
134                ControlPacketType::CONNACK => {
135                    let _connect_ack = reader.read_u8();
136                    let reason_code = reader.read_u8();
137                    if reason_code != 0 {
138                        self.state = MqttState::Disconnected;
139                        return Err(MqttError::ConnectionRefused);
140                    }
141                    self.state = MqttState::Connected;
142                }
143                ControlPacketType::SUBACK => {
144                    // Nothing to do here
145                }
146                ControlPacketType::UNSUBACK => {
147                    // Nothing to do here
148                }
149                ControlPacketType::PUBLISH => {
150                    let topic = reader.read_string();
151                    let property_length = reader.read_variable_int() as usize;
152                    reader.skip(property_length); // We don't care about properties
153                    let payload_length = remaining_length - reader.distance_from_mark();
154                    let payload = reader.read_bytes_raw(payload_length);
155                    on_publish_rec(self, topic, payload);
156                }
157                ControlPacketType::DISCONNECT => {
158                    self.state = MqttState::Disconnected;
159                    return Err(MqttError::Disconnected);
160                }
161                _ => {
162                    return Err(MqttError::InvalidPacket);
163                }
164            }
165            reader.skip_to(remaining_length);
166        }
167
168        Ok(self.state)
169    }
170
171    pub fn is_connected(&self) -> bool {
172        self.state == MqttState::Connected
173    }
174
175    #[inline(always)]
176    fn assert_state(&self, state: MqttState) -> MqttResult<()> {
177        if self.state != state {
178            Err(MqttError::InvalidState)
179        } else {
180            Ok(())
181        }
182    }
183
184    #[inline(always)]
185    fn write_packet(
186        &mut self,
187        ty: ControlPacketType,
188        flags: Flags,
189        payload_len: usize,
190    ) -> MqttResult<&[u8]> {
191        let mut writer = MqttMessageWriter::new(&mut self.message_buffer);
192        writer.write_u8((ty as u8) << 4 | flags.value);
193        writer.write_variable_int(payload_len as u32);
194        writer.write_bytes_raw(&self.construct_buffer[..payload_len]);
195        let len = writer.len();
196        Ok(&self.message_buffer[..len])
197    }
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum MqttState {
202    Disconnected,
203    Connecting,
204    Connected,
205}
206
207#[derive(Debug)]
208pub enum MqttError {
209    InvalidState,
210    InvalidPacket,
211    ConnectionRefused,
212    Disconnected,
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
216pub enum MqttQoS {
217    AtMostOnce,
218    AtLeastOnce,
219    ExactlyOnce,
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)]
223pub enum MqttProperty {
224    PayloadFormatIndicator = 0x01,
225    MessageExpiryInterval = 0x02,
226    ContentType = 0x03,
227    ResponseTopic = 0x08,
228    CorrelationData = 0x09,
229    SubscriptionIdentifier = 0x0B,
230    ReasonString = 0x1F,
231}
232
233struct PacketIdCounter {
234    counter: u16,
235}
236
237impl PacketIdCounter {
238    pub fn new() -> Self {
239        Self { counter: 1 }
240    }
241
242    pub fn next(&mut self) -> u16 {
243        let id = self.counter;
244        self.counter = self.counter.wrapping_add(1).max(1);
245        id
246    }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive)]
250pub enum ControlPacketType {
251    CONNECT = 1,
252    CONNACK = 2,
253    PUBLISH = 3,
254    PUBACK = 4,
255    PUBREC = 5,
256    PUBREL = 6,
257    PUBCOMP = 7,
258    SUBSCRIBE = 8,
259    SUBACK = 9,
260    UNSUBSCRIBE = 10,
261    UNSUBACK = 11,
262    PINGREQ = 12,
263    PINGRESP = 13,
264    DISCONNECT = 14,
265    AUTH = 15,
266}