embedded_mqttc/state/
mod.rs

1use core::future::Future;
2use core::ops::Add;
3
4use embassy_futures::select::select3;
5use embassy_sync::pubsub::{DynSubscriber, PubSubChannel};
6use embedded_nal_async::{Dns, TcpConnect};
7
8use crate::client::MqttClient;
9use crate::state::connection::{ConnectionState, TcpConnectionState};
10use crate::state::pid::next_pid;
11use crate::state::publish2::Publishes;
12use crate::state::receives2::{ReceivedPublish, Receives};
13use crate::state::request::{RequestNotification, RequestState};
14use crate::state::sub2::Subs;
15use crate::time::Duration;
16
17use embassy_sync::blocking_mutex::raw::RawMutex;
18use mqttrs2::{LastWill, Packet, Publish, QoS, QosPid};
19use ping::PingState;
20
21use crate::{ClientConfig, MqttError, MqttEvent, UniqueID, time};
22
23pub(crate) const KEEP_ALIVE: usize = 60;
24
25pub(crate) mod ping;
26
27pub(crate) mod receives2;
28
29pub mod connection;
30
31/// outgoing publishes
32pub(crate) mod publish2;
33
34pub(crate) mod sub2;
35pub(crate) mod pid;
36
37pub(crate) mod request;
38
39const RECONNECT_DURATION: Duration = Duration::from_secs(5);
40
41/// Result returnes from methods that send packets to the network
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum SendResult {
44    PartiallySent,
45
46    /// Sent all pending packets to the network or nothing to send
47    SentAll,
48}
49
50impl SendResult {
51
52    pub async fn next<F, Fut>(self, f: F) -> Result<Self, MqttError> 
53    where F: FnOnce() -> Fut, Fut: Future<Output = Result<Self, MqttError>> {
54        if self != Self::PartiallySent {
55            let next = self + f().await?;
56            Ok(next)
57        } else {
58            Ok(self)
59        }
60    }
61
62    pub fn next_sync<F>(self, f: F) -> Result<Self, MqttError> 
63    where F: FnOnce() -> Result<Self, MqttError> {
64        if self != Self::PartiallySent {
65            let next = self + f()?;
66            Ok(next)
67        } else {
68            Ok(self)
69        }
70    }
71
72}
73
74impl Add<Self> for SendResult {
75    type Output = Self;
76
77    fn add(self, rhs: Self) -> Self::Output {
78        match (self, rhs) {
79            (Self::SentAll, Self::SentAll) => Self::SentAll,
80            _ => Self::PartiallySent
81        }
82    }
83}
84
85pub struct State<'n, 'l, M: RawMutex, NET, DNS, const BUFFER: usize, const TOPIC: usize, const QUEUE: usize> 
86where NET: TcpConnect, DNS: Dns {
87
88    pub(crate) connection_state: TcpConnectionState<'n, 'l, M, NET, DNS, BUFFER>, 
89
90    ping: PingState<M>,
91
92    publishes: Publishes<M, QUEUE, BUFFER, TOPIC, 4>,
93    received_publishes: Receives<M, BUFFER, TOPIC, QUEUE>,
94    subscribes: Subs<M, TOPIC, QUEUE>,
95
96    // Signal is sent, when a request is added
97    on_requst_added: RequestState<M>,
98
99    events: PubSubChannel<M, MqttEvent, 8, 16, 2>
100
101}
102
103impl <'n, 'l, M: RawMutex, NET, DNS, const BUFFER: usize, const TOPIC: usize, const QUEUE: usize> State<'n, 'l, M, NET, DNS, BUFFER, TOPIC, QUEUE> 
104where NET: TcpConnect, DNS: Dns{
105
106    pub fn new(config: ClientConfig<'l>, last_will: Option<LastWill<'l>>, network: &'n NET, dns: DNS) -> Self {
107        Self {
108            connection_state: TcpConnectionState::new(network, dns, last_will, config),
109
110            ping: PingState::new(),
111
112            publishes: Publishes::new(),
113            received_publishes: Receives::new(),
114            subscribes: Subs::new(),
115
116            on_requst_added: RequestState::new(),
117            events: PubSubChannel::new(),
118        }
119    }
120
121    pub fn new_client(&self) -> MqttClient<'_, 'n, 'l, M, NET, DNS, BUFFER, TOPIC, QUEUE> {
122        MqttClient::new(self)
123    }
124
125    pub(crate) async fn publish(&self, topic: &str, payload: &[u8], qos: QoS, retain: bool, unique_id: UniqueID) -> Result<(), MqttError> {
126
127        let qospid = match qos {
128            QoS::AtMostOnce => QosPid::AtMostOnce,
129            QoS::AtLeastOnce => QosPid::AtLeastOnce(next_pid()),
130            QoS::ExactlyOnce => QosPid::ExactlyOnce(next_pid()),
131        };
132
133        debug!("state: adding publish request with qos {}", &qospid);
134
135        let publish = Publish {
136            topic_name: topic,
137            dup: false,
138            qospid,
139            retain,
140            payload
141        };
142
143        self.publishes.publish(publish, unique_id).await?;
144        self.on_requst_added.notify_new_request();
145
146        Ok(())
147    }
148
149    pub(crate) async fn subscribe(&self, topics: &[&str], qos: QoS, unique_id: UniqueID) {
150        self.subscribes.add_subscribe_request(topics, qos, unique_id).await;
151        self.on_requst_added.notify_new_request();
152    }
153
154    pub(crate) async fn unsubscribe(&self, topics: &[&str], unique_id: UniqueID) {
155        self.subscribes.add_unsubscribe_request(topics, unique_id).await;
156        self.on_requst_added.notify_new_request();
157    }
158
159    pub(crate) fn disconnect(&self) {
160        self.on_requst_added.notify_disconnect();
161    }
162
163    /// Returns true until the client disconnects
164    async fn run_once(&self) -> Result<bool, MqttError> {
165        if self.connection_state.get_state() != Some(connection::ConnectionStateValue::Connected) {
166            info!("not connected yed: starting connection");
167            self.connection_state.connect().await?;
168        }
169
170        let publisher = self.events.dyn_publisher().unwrap();
171
172        debug!("event loop: start sending packets");
173
174        let send_packet_result = self.publishes.send_packets(&self.connection_state, publisher).await?
175            .next(|| self.received_publishes.send_packets(&self.connection_state)).await?
176            .next(|| self.subscribes.send(&self.connection_state)).await?
177            .next_sync(|| self.ping.send(&self.connection_state))?;
178
179        if send_packet_result == SendResult::PartiallySent {
180            debug!("partially sent packets, run io nonblocking");
181            if let Some(packet) = self.connection_state.run_io_nonblocking().await? {
182                self.process_packet(&packet).await?;
183            }
184            // Return early to rerun the loop faster
185            return Ok(true);
186        }
187
188        debug!("sent all packets, run io blocking");
189
190        // TODO make ping and resend future
191        let ping_future = self.ping.ping_pause();
192        let io_future = self.connection_state.run_io();
193        let request_added_future = self.on_requst_added.next_notification();
194
195        match select3(request_added_future, ping_future, io_future).await {
196            embassy_futures::select::Either3::First(request) if request == RequestNotification::Disconnect => {
197                debug!("run_once: disconnect request received");
198                Ok(false)
199            },
200            embassy_futures::select::Either3::Third(packet) => {
201                debug!("run_once: received packet");
202                let packet = packet?;
203                self.process_packet(&packet).await?;
204                Ok(true)
205            },
206            _ => {
207                debug!("run_once: stop io, new event arrived");
208                Ok(true)
209            },
210        }
211    }
212
213    pub async fn run(&self) -> Result<(), MqttError> {
214        loop {
215            match self.run_once().await {
216                Ok(true) => {},
217                Ok(false) => {
218                    // Disconnect
219                    self.connection_state.disconnect().await?;
220                    info!("disconnect: exit run loop");
221                    return Ok(())
222                },
223                Err(err) => {
224                    error!("connection error: {}", &err);
225                    match err {
226                        MqttError::ConnectionFailed2(_) |
227                        MqttError::ConnackError |
228                        MqttError::CodecError(_) |
229                        MqttError::ReceivedMessageTooLong |
230                        MqttError::QueueFull(_) |
231                        MqttError::UnexpectedAck(_)  => {
232                            self.connection_state.set_error();
233                            time::sleep(RECONNECT_DURATION).await;
234                        },
235
236                        err => {
237                            error!("not recoverable error: stop loop");
238                            return Err(err)
239                        },
240                    }
241                },
242            }
243        }
244    }
245
246    /// Processes incoming packets
247    async fn process_packet(&self, p: &Packet<'_>) -> Result<(), MqttError> {
248
249        let publisher = self.events.dyn_publisher().unwrap();
250
251        match p {
252            
253            Packet::Connack(_connack) => {
254                panic!("received connack: this must be handled by the connection module");
255            },
256            
257            Packet::Publish(publish) => {
258                self.received_publishes.on_publish(publish).await?;
259                Ok(())
260            },
261
262            Packet::Puback(_) | Packet::Pubrec(_) | Packet::Pubcomp(_) => {
263                self.publishes.process_incoming_packet(p, publisher).await?;
264                Ok(())
265            },
266
267            Packet::Pubrel(pid) => {
268                self.received_publishes.on_pubrel(*pid).await?;
269                Ok(())
270            },
271
272            Packet::Suback(suback) => {
273                self.subscribes.on_suback(suback, publisher).await?;
274                Ok(())
275            },
276            
277            Packet::Unsuback(pid) => {
278                self.subscribes.on_unsuback(*pid, publisher).await;
279                Ok(())
280            },
281            
282            Packet::Pingresp => {
283                self.ping.on_ping_response();
284                Ok(())
285            },
286
287            // # These Packages cannot be send Server -> Client
288            // # And are treated as unexpected
289            // Packet::Connect(connect) => todo!(),
290            // Packet::Disconnect => todo!(),
291            // Packet::Pingreq => todo!(),
292            // Packet::Unsubscribe(unsubscribe) => todo!(),
293            // Packet::Subscribe(subscribe) => todo!(),
294
295            unexpected => {
296                error!("unexpected packet {} received from broker", unexpected.get_type());
297                Ok(())
298            }
299        }
300    }
301
302    pub(crate) fn subscribe_events(&self) -> Result<DynSubscriber<'_, MqttEvent>, MqttError> {
303        self.events.dyn_subscriber().map_err(|e| e.into())
304    }
305
306    /// Subscribe to received publishes
307    pub fn subscribe_received_publishes(&self) -> Result<DynSubscriber<'_, ReceivedPublish<BUFFER, TOPIC>>, MqttError> {
308        self.received_publishes.subscribe_publishes()
309    }
310
311}
312
313// #[cfg(all(test, feature = "std"))]
314// mod tests {
315//     use core::time::Duration;
316//     use std::time::Instant;
317
318//     use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, ReadWrite};
319//     use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, channel::Channel};
320//     use heapless::{String, Vec};
321//     use mqttrs2::{decode_slice_with_len, Connack, ConnectReturnCode, LastWill, Packet, PacketType, QoS};
322
323//     use crate::{state::{ConnectionState, State, KEEP_ALIVE}, time, ClientConfig, MqttError, MqttEvent};
324
325//     use super::ping::PingState;
326
327//     struct Test<'t> {
328//         state: State<'t, CriticalSectionRawMutex>,
329//         send_buffer: Buffer<[u8; 1024]>,
330//         control_ch: Channel<CriticalSectionRawMutex, MqttEvent, 16>
331//     }
332
333//     impl <'t> Test<'t> {
334//         fn new (config: ClientConfig) -> Self {
335//             Self {
336//                 state: State::new(config, None),
337//                 send_buffer: new_stack_buffer(),
338//                 control_ch: Channel::new()
339//             }
340//         }
341
342//         fn new_with_last_will(config: ClientConfig, last_will: LastWill<'t>) -> Self {
343//             Self {
344//                 state: State::new(config, Some(last_will)),
345//                 send_buffer: new_stack_buffer(),
346//                 control_ch: Channel::new()
347//             }
348//         }
349
350//         fn expect_no_packet(&mut self) {
351//             let reader = self.send_buffer.create_reader();
352//             let op = decode_slice_with_len(&reader).unwrap();
353//             assert_eq!(op, None);
354//         }
355
356//         fn expect_packet<R, F: FnOnce(&Packet<'_>) -> R>(&mut self, operator: F) -> R {
357//             let reader = self.send_buffer.create_reader();
358//             let (n, packet) = decode_slice_with_len(&reader).unwrap().expect("there must be a packet");
359//             reader.add_bytes_read(n);
360
361//             operator(&packet)
362//         }
363
364//         async fn process_packet(&mut self, packet: &Packet<'_>) -> Result<Vec<MqttEvent, 16>, MqttError>{
365//             self.state.process_packet(
366//                 packet, 
367//                 &mut self.send_buffer.create_writer()
368//             ).await
369//         }
370//     }
371
372//     #[tokio::test]
373//     async fn test_on_ping_required() {
374//         time::test_time::set_static_now();
375
376//         let mut config = ClientConfig{
377//             client_id: String::new(),
378//             credentials: None,
379//             auto_subscribes: Vec::new()
380//         };
381
382//         config.client_id.push_str("1234567890").unwrap();
383
384//         let mut test = Test::new(config);
385//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
386//         assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
387
388//         let ping_required = test.state.on_ping_required();
389//         tokio::pin!(ping_required);
390
391//         let wait = tokio::time::sleep(core::time::Duration::from_millis(50));
392//         tokio::pin!(wait);
393
394//         tokio::select! {
395//             _ = &mut ping_required => {
396//                 panic!("ping is not required yet!");
397//             },
398//             _ = wait => {}
399//         }
400
401//         let wait = tokio::time::sleep(core::time::Duration::from_millis(50));
402//         tokio::pin!(wait);
403
404//         time::test_time::advance_time(Duration::from_secs(KEEP_ALIVE as u64) / 2 + Duration::from_secs(1));
405
406//         tokio::select! {
407//             _ = &mut ping_required => {},
408//             _ = wait => {
409//                 panic!("ping must be now required")
410//             }
411//         }
412//     }
413
414
415//     #[tokio::test]
416//     async fn test_connect_and_connack() {
417//         time::test_time::set_default();
418
419//         let mut config = ClientConfig{
420//             client_id: String::new(),
421//             credentials: None,
422//             auto_subscribes: Vec::new()
423//         };
424
425//         config.client_id.push_str("1234567890").unwrap();
426
427//         let mut test = Test::new(config);
428
429//         assert_eq!(test.state.get_connection_state(), ConnectionState::InitialState);
430
431//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
432
433//         assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
434
435//         test.expect_packet(|p| {
436//             if let Packet::Connect(c) = p {
437//                 assert_eq!(c.client_id, "1234567890");
438//                 assert_eq!(c.password, None);
439//                 assert_eq!(c.username, None);
440//             } else {
441//                 panic!("expected connect packet");
442//             }
443//         });
444
445//         assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
446
447//         let event = test.process_packet(&Packet::Connack(Connack{
448//             session_present: false,
449//             code: ConnectReturnCode::Accepted
450//         })).await.unwrap().into_iter().next().expect("expected connected event");
451
452//         assert_eq!(MqttEvent::Connected, event);
453
454//         assert_eq!(test.state.get_connection_state(), ConnectionState::Connected);
455//     }
456
457//     #[tokio::test]
458//     async fn test_connect_and_connack_with_last_will() {
459//         time::test_time::set_default();
460
461//         let mut config = ClientConfig{
462//             client_id: String::new(),
463//             credentials: None,
464//             auto_subscribes: Vec::new()
465//         };
466
467//         config.client_id.push_str("1234567890").unwrap();
468
469//         const LAST_WILL_TOPIC: &str = "some/topic";
470//         const LAST_WILL_MESSAGE: &str = "i-am-dead";
471//         let last_will = LastWill {
472//             topic: LAST_WILL_TOPIC,
473//             message: LAST_WILL_MESSAGE.as_bytes(),
474//             qos: QoS::ExactlyOnce,
475//             retain: true
476//         };
477
478//         let mut test = Test::new_with_last_will(config, last_will);
479
480//         assert_eq!(test.state.get_connection_state(), ConnectionState::InitialState);
481
482//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
483
484//         assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
485
486//         test.expect_packet(|p| {
487//             if let Packet::Connect(c) = p {
488//                 assert_eq!(c.client_id, "1234567890");
489//                 assert_eq!(c.password, None);
490//                 assert_eq!(c.username, None);
491                
492//                 let received_last_will = c.last_will.as_ref().unwrap();
493//                 assert_eq!(received_last_will.message, LAST_WILL_MESSAGE.as_bytes());
494//                 assert_eq!(received_last_will.topic, LAST_WILL_TOPIC);
495//                 assert_eq!(received_last_will.qos, QoS::ExactlyOnce);
496//                 assert_eq!(received_last_will.retain, true);
497//             } else {
498//                 panic!("expected connect packet");
499//             }
500//         });
501
502//         assert_eq!(test.state.get_connection_state(), ConnectionState::ConnectSent);
503
504//         let event = test.process_packet(&Packet::Connack(Connack{
505//             session_present: false,
506//             code: ConnectReturnCode::Accepted
507//         })).await.unwrap().into_iter().next().expect("expected connected event");
508
509//         assert_eq!(MqttEvent::Connected, event);
510
511//         assert_eq!(test.state.get_connection_state(), ConnectionState::Connected);
512//     }
513
514//     #[tokio::test]
515//     async fn test_ping() {
516//         let start_time = Instant::now();
517//         time::test_time::set_time(start_time);
518
519//         let config = ClientConfig{
520//             client_id: String::new(),
521//             credentials: None,
522//             auto_subscribes: Vec::new()
523//         };
524
525//         let mut test = Test::new(config);
526//         test.state.set_connection_state(ConnectionState::Connected);
527
528//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
529//         test.state.send_ping(&mut test.send_buffer.create_writer()).unwrap();
530//         test.expect_no_packet();
531
532//         time::test_time::advance_time(Duration::from_secs(40));
533
534//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
535//         test.state.send_ping(&mut test.send_buffer.create_writer()).unwrap();
536//         test.expect_packet(|p| {
537//             if Packet::Pingreq != *p {
538//                 panic!("expected Packet::Pingreq");
539//             }
540//         });
541
542//         test.state.ping.lock(|inner|{
543//             let inner = inner.borrow();
544
545//             if let PingState::AwaitingResponse { last_success, ping_request_sent } = *inner {
546//                 assert_eq!(last_success, start_time);
547//                 assert_eq!(ping_request_sent, start_time + Duration::from_secs(40));
548//             } else {
549//                 panic!("expected PingState::AwaitingResponse");
550//             }
551//         });
552//         time::test_time::advance_time(Duration::from_secs(2));
553
554//         test.process_packet(&Packet::Pingresp).await.unwrap();
555
556//         test.state.ping.lock(|inner|{
557//             let inner = inner.borrow();
558
559//             if let PingState::PingSuccess(last_ping) = *inner {
560//                 assert_eq!(last_ping, start_time + Duration::from_secs(42));
561//             } else {
562//                 panic!("expected PingState::PingSuccess");
563//             }
564//         });
565
566//     }
567
568//     #[tokio::test]
569//     async fn test_auto_subscribe() {
570
571//         let config: ClientConfig = ClientConfig::new_with_auto_subscribes(
572//             "asghfdasdhasdh", 
573//             None, 
574//             [ "test1", "test2" ].into_iter(), 
575//             QoS::AtLeastOnce
576//         );
577
578//         let mut test = Test::new(config);
579
580//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
581//         test.expect_packet(|p|{
582//             assert_eq!(p.get_type(), PacketType::Connect, "expected connect packet");
583//         });
584
585//         test.process_packet(&Packet::Connack(Connack { 
586//             session_present: false, 
587//             code: ConnectReturnCode::Accepted 
588//         })).await.unwrap();
589
590//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
591//         test.expect_packet(|p|{
592//             if let Packet::Subscribe(s) = p {
593//                 assert_eq!(1, s.topics.len());
594//                 let topic = s.topics.first().unwrap();
595//                 assert_eq!(&topic.topic_path, "test1");
596//                 assert_eq!(topic.qos, QoS::AtLeastOnce);
597//             } else {
598//                 panic!("expected subscribe packet but got {:?}", p.get_type());
599//             }
600//         });
601
602//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
603//         test.expect_packet(|p|{
604//             if let Packet::Subscribe(s) = p {
605//                 assert_eq!(1, s.topics.len());
606//                 let topic = s.topics.first().unwrap();
607//                 assert_eq!(&topic.topic_path, "test2");
608//                 assert_eq!(topic.qos, QoS::AtLeastOnce);
609//             } else {
610//                 panic!("expected subscribe packet but got {:?}", p.get_type());
611//             }
612//         });
613
614//         test.state.send_packets(&mut test.send_buffer.create_writer(), &test.control_ch).unwrap();
615//         test.expect_no_packet();
616//     }
617
618// }