embedded_mqttc/state/mod.rs
1use core::future::Future;
2use core::ops::Add;
3
4use embassy_futures::select::select3;
5use embassy_sync::pubsub::{DynSubscriber, PubSubChannel};
6use embedded_nal_async::{Dns, TcpConnect};
7
8use crate::client::MqttClient;
9use crate::state::connection::{ConnectionState, TcpConnectionState};
10use crate::state::pid::next_pid;
11use crate::state::publish2::Publishes;
12use crate::state::receives2::{ReceivedPublish, Receives};
13use crate::state::request::{RequestNotification, RequestState};
14use crate::state::sub2::Subs;
15use crate::time::Duration;
16
17use embassy_sync::blocking_mutex::raw::RawMutex;
18use mqttrs2::{LastWill, Packet, Publish, QoS, QosPid};
19use ping::PingState;
20
21use crate::{ClientConfig, MqttError, MqttEvent, UniqueID, time};
22
23pub(crate) const KEEP_ALIVE: usize = 60;
24
25pub(crate) mod ping;
26
27pub(crate) mod receives2;
28
29pub mod connection;
30
31/// outgoing publishes
32pub(crate) mod publish2;
33
34pub(crate) mod sub2;
35pub(crate) mod pid;
36
37pub(crate) mod request;
38
39const RECONNECT_DURATION: Duration = Duration::from_secs(5);
40
41/// Result returnes from methods that send packets to the network
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum SendResult {
44 PartiallySent,
45
46 /// Sent all pending packets to the network or nothing to send
47 SentAll,
48}
49
50impl SendResult {
51
52 pub async fn next<F, Fut>(self, f: F) -> Result<Self, MqttError>
53 where F: FnOnce() -> Fut, Fut: Future<Output = Result<Self, MqttError>> {
54 if self != Self::PartiallySent {
55 let next = self + f().await?;
56 Ok(next)
57 } else {
58 Ok(self)
59 }
60 }
61
62 pub fn next_sync<F>(self, f: F) -> Result<Self, MqttError>
63 where F: FnOnce() -> Result<Self, MqttError> {
64 if self != Self::PartiallySent {
65 let next = self + f()?;
66 Ok(next)
67 } else {
68 Ok(self)
69 }
70 }
71
72}
73
74impl Add<Self> for SendResult {
75 type Output = Self;
76
77 fn add(self, rhs: Self) -> Self::Output {
78 match (self, rhs) {
79 (Self::SentAll, Self::SentAll) => Self::SentAll,
80 _ => Self::PartiallySent
81 }
82 }
83}
84
85pub struct State<'n, 'l, M: RawMutex, NET, DNS, const BUFFER: usize, const TOPIC: usize, const QUEUE: usize>
86where NET: TcpConnect, DNS: Dns {
87
88 pub(crate) connection_state: TcpConnectionState<'n, 'l, M, NET, DNS, BUFFER>,
89
90 ping: PingState<M>,
91
92 publishes: Publishes<M, QUEUE, BUFFER, TOPIC, 4>,
93 received_publishes: Receives<M, BUFFER, TOPIC, QUEUE>,
94 subscribes: Subs<M, TOPIC, QUEUE>,
95
96 // Signal is sent, when a request is added
97 on_requst_added: RequestState<M>,
98
99 events: PubSubChannel<M, MqttEvent, 8, 16, 2>
100
101}
102
103impl <'n, 'l, M: RawMutex, NET, DNS, const BUFFER: usize, const TOPIC: usize, const QUEUE: usize> State<'n, 'l, M, NET, DNS, BUFFER, TOPIC, QUEUE>
104where NET: TcpConnect, DNS: Dns{
105
106 pub fn new(config: ClientConfig<'l>, last_will: Option<LastWill<'l>>, network: &'n NET, dns: DNS) -> Self {
107 Self {
108 connection_state: TcpConnectionState::new(network, dns, last_will, config),
109
110 ping: PingState::new(),
111
112 publishes: Publishes::new(),
113 received_publishes: Receives::new(),
114 subscribes: Subs::new(),
115
116 on_requst_added: RequestState::new(),
117 events: PubSubChannel::new(),
118 }
119 }
120
121 pub fn new_client(&self) -> MqttClient<'_, 'n, 'l, M, NET, DNS, BUFFER, TOPIC, QUEUE> {
122 MqttClient::new(self)
123 }
124
125 pub(crate) async fn publish(&self, topic: &str, payload: &[u8], qos: QoS, retain: bool, unique_id: UniqueID) -> Result<(), MqttError> {
126
127 let qospid = match qos {
128 QoS::AtMostOnce => QosPid::AtMostOnce,
129 QoS::AtLeastOnce => QosPid::AtLeastOnce(next_pid()),
130 QoS::ExactlyOnce => QosPid::ExactlyOnce(next_pid()),
131 };
132
133 debug!("state: adding publish request with qos {}", &qospid);
134
135 let publish = Publish {
136 topic_name: topic,
137 dup: false,
138 qospid,
139 retain,
140 payload
141 };
142
143 self.publishes.publish(publish, unique_id).await?;
144 self.on_requst_added.notify_new_request();
145
146 Ok(())
147 }
148
149 pub(crate) async fn subscribe(&self, topics: &[&str], qos: QoS, unique_id: UniqueID) {
150 self.subscribes.add_subscribe_request(topics, qos, unique_id).await;
151 self.on_requst_added.notify_new_request();
152 }
153
154 pub(crate) async fn unsubscribe(&self, topics: &[&str], unique_id: UniqueID) {
155 self.subscribes.add_unsubscribe_request(topics, unique_id).await;
156 self.on_requst_added.notify_new_request();
157 }
158
159 pub(crate) fn disconnect(&self) {
160 self.on_requst_added.notify_disconnect();
161 }
162
163 /// Returns true until the client disconnects
164 async fn run_once(&self) -> Result<bool, MqttError> {
165 if self.connection_state.get_state() != Some(connection::ConnectionStateValue::Connected) {
166 info!("not connected yed: starting connection");
167 self.connection_state.connect().await?;
168 }
169
170 let publisher = self.events.dyn_publisher().unwrap();
171
172 debug!("event loop: start sending packets");
173
174 let send_packet_result = self.publishes.send_packets(&self.connection_state, publisher).await?
175 .next(|| self.received_publishes.send_packets(&self.connection_state)).await?
176 .next(|| self.subscribes.send(&self.connection_state)).await?
177 .next_sync(|| self.ping.send(&self.connection_state))?;
178
179 if send_packet_result == SendResult::PartiallySent {
180 debug!("partially sent packets, run io nonblocking");
181 if let Some(packet) = self.connection_state.run_io_nonblocking().await? {
182 self.process_packet(&packet).await?;
183 }
184 // Return early to rerun the loop faster
185 return Ok(true);
186 }
187
188 debug!("sent all packets, run io blocking");
189
190 // TODO make ping and resend future
191 let ping_future = self.ping.ping_pause();
192 let io_future = self.connection_state.run_io();
193 let request_added_future = self.on_requst_added.next_notification();
194
195 match select3(request_added_future, ping_future, io_future).await {
196 embassy_futures::select::Either3::First(request) if request == RequestNotification::Disconnect => {
197 debug!("run_once: disconnect request received");
198 Ok(false)
199 },
200 embassy_futures::select::Either3::Third(packet) => {
201 debug!("run_once: received packet");
202 let packet = packet?;
203 self.process_packet(&packet).await?;
204 Ok(true)
205 },
206 _ => {
207 debug!("run_once: stop io, new event arrived");
208 Ok(true)
209 },
210 }
211 }
212
213 pub async fn run(&self) -> Result<(), MqttError> {
214 loop {
215 match self.run_once().await {
216 Ok(true) => {},
217 Ok(false) => {
218 // Disconnect
219 self.connection_state.disconnect().await?;
220 info!("disconnect: exit run loop");
221 return Ok(())
222 },
223 Err(err) => {
224 error!("connection error: {}", &err);
225 match err {
226 MqttError::ConnectionFailed2(_) |
227 MqttError::ConnackError |
228 MqttError::CodecError(_) |
229 MqttError::ReceivedMessageTooLong |
230 MqttError::QueueFull(_) |
231 MqttError::UnexpectedAck(_) => {
232 self.connection_state.set_error();
233 time::sleep(RECONNECT_DURATION).await;
234 },
235
236 err => {
237 error!("not recoverable error: stop loop");
238 return Err(err)
239 },
240 }
241 },
242 }
243 }
244 }
245
246 /// Processes incoming packets
247 async fn process_packet(&self, p: &Packet<'_>) -> Result<(), MqttError> {
248
249 let publisher = self.events.dyn_publisher().unwrap();
250
251 match p {
252
253 Packet::Connack(_connack) => {
254 panic!("received connack: this must be handled by the connection module");
255 },
256
257 Packet::Publish(publish) => {
258 self.received_publishes.on_publish(publish).await?;
259 Ok(())
260 },
261
262 Packet::Puback(_) | Packet::Pubrec(_) | Packet::Pubcomp(_) => {
263 self.publishes.process_incoming_packet(p, publisher).await?;
264 Ok(())
265 },
266
267 Packet::Pubrel(pid) => {
268 self.received_publishes.on_pubrel(*pid).await?;
269 Ok(())
270 },
271
272 Packet::Suback(suback) => {
273 self.subscribes.on_suback(suback, publisher).await?;
274 Ok(())
275 },
276
277 Packet::Unsuback(pid) => {
278 self.subscribes.on_unsuback(*pid, publisher).await;
279 Ok(())
280 },
281
282 Packet::Pingresp => {
283 self.ping.on_ping_response();
284 Ok(())
285 },
286
287 // # These Packages cannot be send Server -> Client
288 // # And are treated as unexpected
289 // Packet::Connect(connect) => todo!(),
290 // Packet::Disconnect => todo!(),
291 // Packet::Pingreq => todo!(),
292 // Packet::Unsubscribe(unsubscribe) => todo!(),
293 // Packet::Subscribe(subscribe) => todo!(),
294
295 unexpected => {
296 error!("unexpected packet {} received from broker", unexpected.get_type());
297 Ok(())
298 }
299 }
300 }
301
302 pub(crate) fn subscribe_events(&self) -> Result<DynSubscriber<'_, MqttEvent>, MqttError> {
303 self.events.dyn_subscriber().map_err(|e| e.into())
304 }
305
306 /// Subscribe to received publishes
307 pub fn subscribe_received_publishes(&self) -> Result<DynSubscriber<'_, ReceivedPublish<BUFFER, TOPIC>>, MqttError> {
308 self.received_publishes.subscribe_publishes()
309 }
310
311}
312
313// #[cfg(all(test, feature = "std"))]
314// mod tests {
315// use core::time::Duration;
316// use std::time::Instant;
317
318// use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, ReadWrite};
319// use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, channel::Channel};
320// use heapless::{String, Vec};
321// use mqttrs2::{decode_slice_with_len, Connack, ConnectReturnCode, LastWill, Packet, PacketType, QoS};
322
323// use crate::{state::{ConnectionState, State, KEEP_ALIVE}, time, ClientConfig, MqttError, MqttEvent};
324
325// use super::ping::PingState;
326
327// struct Test<'t> {
328// state: State<'t, CriticalSectionRawMutex>,
329// send_buffer: Buffer<[u8; 1024]>,
330// control_ch: Channel<CriticalSectionRawMutex, MqttEvent, 16>
331// }
332
333// impl <'t> Test<'t> {
334// fn new (config: ClientConfig) -> Self {
335// Self {
336// state: State::new(config, None),
337// send_buffer: new_stack_buffer(),
338// control_ch: Channel::new()
339// }
340// }
341
342// fn new_with_last_will(config: ClientConfig, last_will: LastWill<'t>) -> Self {
343// Self {
344// state: State::new(config, Some(last_will)),
345// send_buffer: new_stack_buffer(),
346// control_ch: Channel::new()
347// }
348// }
349
350// fn expect_no_packet(&mut self) {
351// let reader = self.send_buffer.create_reader();
352// let op = decode_slice_with_len(&reader).unwrap();
353// assert_eq!(op, None);
354// }
355
356// fn expect_packet<R, F: FnOnce(&Packet<'_>) -> R>(&mut self, operator: F) -> R {
357// let reader = self.send_buffer.create_reader();
358// let (n, packet) = decode_slice_with_len(&reader).unwrap().expect("there must be a packet");
359// reader.add_bytes_read(n);
360
361// operator(&packet)
362// }
363
364// async fn process_packet(&mut self, packet: &Packet<'_>) -> Result<Vec<MqttEvent, 16>, MqttError>{
365// self.state.process_packet(
366// packet,
367// &mut self.send_buffer.create_writer()
368// ).await
369// }
370// }
371
372// #[tokio::test]
373// async fn test_on_ping_required() {
374// time::test_time::set_static_now();
375
376// let mut config = ClientConfig{
377// client_id: String::new(),
378// credentials: None,
379// auto_subscribes: Vec::new()
380// };
381
382// config.client_id.push_str("1234567890").unwrap();
383
384// let mut test = Test::new(config);
385// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
386// assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
387
388// let ping_required = test.state.on_ping_required();
389// tokio::pin!(ping_required);
390
391// let wait = tokio::time::sleep(core::time::Duration::from_millis(50));
392// tokio::pin!(wait);
393
394// tokio::select! {
395// _ = &mut ping_required => {
396// panic!("ping is not required yet!");
397// },
398// _ = wait => {}
399// }
400
401// let wait = tokio::time::sleep(core::time::Duration::from_millis(50));
402// tokio::pin!(wait);
403
404// time::test_time::advance_time(Duration::from_secs(KEEP_ALIVE as u64) / 2 + Duration::from_secs(1));
405
406// tokio::select! {
407// _ = &mut ping_required => {},
408// _ = wait => {
409// panic!("ping must be now required")
410// }
411// }
412// }
413
414
415// #[tokio::test]
416// async fn test_connect_and_connack() {
417// time::test_time::set_default();
418
419// let mut config = ClientConfig{
420// client_id: String::new(),
421// credentials: None,
422// auto_subscribes: Vec::new()
423// };
424
425// config.client_id.push_str("1234567890").unwrap();
426
427// let mut test = Test::new(config);
428
429// assert_eq!(test.state.get_connection_state(), ConnectionState::InitialState);
430
431// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
432
433// assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
434
435// test.expect_packet(|p| {
436// if let Packet::Connect(c) = p {
437// assert_eq!(c.client_id, "1234567890");
438// assert_eq!(c.password, None);
439// assert_eq!(c.username, None);
440// } else {
441// panic!("expected connect packet");
442// }
443// });
444
445// assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
446
447// let event = test.process_packet(&Packet::Connack(Connack{
448// session_present: false,
449// code: ConnectReturnCode::Accepted
450// })).await.unwrap().into_iter().next().expect("expected connected event");
451
452// assert_eq!(MqttEvent::Connected, event);
453
454// assert_eq!(test.state.get_connection_state(), ConnectionState::Connected);
455// }
456
457// #[tokio::test]
458// async fn test_connect_and_connack_with_last_will() {
459// time::test_time::set_default();
460
461// let mut config = ClientConfig{
462// client_id: String::new(),
463// credentials: None,
464// auto_subscribes: Vec::new()
465// };
466
467// config.client_id.push_str("1234567890").unwrap();
468
469// const LAST_WILL_TOPIC: &str = "some/topic";
470// const LAST_WILL_MESSAGE: &str = "i-am-dead";
471// let last_will = LastWill {
472// topic: LAST_WILL_TOPIC,
473// message: LAST_WILL_MESSAGE.as_bytes(),
474// qos: QoS::ExactlyOnce,
475// retain: true
476// };
477
478// let mut test = Test::new_with_last_will(config, last_will);
479
480// assert_eq!(test.state.get_connection_state(), ConnectionState::InitialState);
481
482// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
483
484// assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
485
486// test.expect_packet(|p| {
487// if let Packet::Connect(c) = p {
488// assert_eq!(c.client_id, "1234567890");
489// assert_eq!(c.password, None);
490// assert_eq!(c.username, None);
491
492// let received_last_will = c.last_will.as_ref().unwrap();
493// assert_eq!(received_last_will.message, LAST_WILL_MESSAGE.as_bytes());
494// assert_eq!(received_last_will.topic, LAST_WILL_TOPIC);
495// assert_eq!(received_last_will.qos, QoS::ExactlyOnce);
496// assert_eq!(received_last_will.retain, true);
497// } else {
498// panic!("expected connect packet");
499// }
500// });
501
502// assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
503
504// let event = test.process_packet(&Packet::Connack(Connack{
505// session_present: false,
506// code: ConnectReturnCode::Accepted
507// })).await.unwrap().into_iter().next().expect("expected connected event");
508
509// assert_eq!(MqttEvent::Connected, event);
510
511// assert_eq!(test.state.get_connection_state(), ConnectionState::Connected);
512// }
513
514// #[tokio::test]
515// async fn test_ping() {
516// let start_time = Instant::now();
517// time::test_time::set_time(start_time);
518
519// let config = ClientConfig{
520// client_id: String::new(),
521// credentials: None,
522// auto_subscribes: Vec::new()
523// };
524
525// let mut test = Test::new(config);
526// test.state.set_connection_state(ConnectionState::Connected);
527
528// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
529// test.state.send_ping(&mut test.send_buffer.create_writer()).unwrap();
530// test.expect_no_packet();
531
532// time::test_time::advance_time(Duration::from_secs(40));
533
534// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
535// test.state.send_ping(&mut test.send_buffer.create_writer()).unwrap();
536// test.expect_packet(|p| {
537// if Packet::Pingreq != *p {
538// panic!("expected Packet::Pingreq");
539// }
540// });
541
542// test.state.ping.lock(|inner|{
543// let inner = inner.borrow();
544
545// if let PingState::AwaitingResponse { last_success, ping_request_sent } = *inner {
546// assert_eq!(last_success, start_time);
547// assert_eq!(ping_request_sent, start_time + Duration::from_secs(40));
548// } else {
549// panic!("expected PingState::AwaitingResponse");
550// }
551// });
552// time::test_time::advance_time(Duration::from_secs(2));
553
554// test.process_packet(&Packet::Pingresp).await.unwrap();
555
556// test.state.ping.lock(|inner|{
557// let inner = inner.borrow();
558
559// if let PingState::PingSuccess(last_ping) = *inner {
560// assert_eq!(last_ping, start_time + Duration::from_secs(42));
561// } else {
562// panic!("expected PingState::PingSuccess");
563// }
564// });
565
566// }
567
568// #[tokio::test]
569// async fn test_auto_subscribe() {
570
571// let config: ClientConfig = ClientConfig::new_with_auto_subscribes(
572// "asghfdasdhasdh",
573// None,
574// [ "test1", "test2" ].into_iter(),
575// QoS::AtLeastOnce
576// );
577
578// let mut test = Test::new(config);
579
580// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
581// test.expect_packet(|p|{
582// assert_eq!(p.get_type(), PacketType::Connect, "expected connect packet");
583// });
584
585// test.process_packet(&Packet::Connack(Connack {
586// session_present: false,
587// code: ConnectReturnCode::Accepted
588// })).await.unwrap();
589
590// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
591// test.expect_packet(|p|{
592// if let Packet::Subscribe(s) = p {
593// assert_eq!(1, s.topics.len());
594// let topic = s.topics.first().unwrap();
595// assert_eq!(&topic.topic_path, "test1");
596// assert_eq!(topic.qos, QoS::AtLeastOnce);
597// } else {
598// panic!("expected subscribe packet but got {:?}", p.get_type());
599// }
600// });
601
602// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
603// test.expect_packet(|p|{
604// if let Packet::Subscribe(s) = p {
605// assert_eq!(1, s.topics.len());
606// let topic = s.topics.first().unwrap();
607// assert_eq!(&topic.topic_path, "test2");
608// assert_eq!(topic.qos, QoS::AtLeastOnce);
609// } else {
610// panic!("expected subscribe packet but got {:?}", p.get_type());
611// }
612// });
613
614// test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
615// test.expect_no_packet();
616// }
617
618// }