embedded_mqttc/
io.rs

1use core::convert::Infallible;
2use core::{cell::RefCell, future::Future, pin::Pin};
3
4use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, BufferWriter, ReadWrite};
5use embassy_futures::select::{select, select3, Either3};
6use embassy_sync::{blocking_mutex::raw::RawMutex, channel::Channel, pubsub::PubSubChannel};
7use mqttrs2::{decode_slice_with_len, LastWill, Packet, QoS};
8use crate::network::mqtt::MqttPacketError;
9use crate::network::NetworkError;
10use crate::network::{ mqtt::WriteMqttPacketMut, NetwordSendReceive, NetworkConnection };
11use crate::{client::MqttClient, state::State, time, ClientConfig, MqttError, MqttEvent, MqttPublish, MqttRequest};
12
13use crate::time::Duration;
14
15pub trait AsyncSender<T> {
16    fn send(&self, item: T) -> impl Future<Output = ()>;
17    fn try_send(&self, item: T) -> Result<(), T>;
18}
19
20impl <M: RawMutex, T, const N: usize> AsyncSender<T> for Channel<M, T, N> {
21    fn send(&self, item: T) -> impl Future<Output = ()> {
22        self.send(item)
23    }
24
25    fn try_send(&self, item: T) -> Result<(), T> {
26        self.try_send(item)
27            .map_err(|err|{
28                match err {
29                    embassy_sync::channel::TrySendError::Full(item) => item,
30                }
31            })
32    }
33}
34
35impl <M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> AsyncSender<T> for PubSubChannel<M, T, CAP, SUBS, PUBS> {
36    async fn send(&self, item: T) {
37        self.publisher().unwrap().publish(item).await;
38    }
39
40    fn try_send(&self, item: T) -> Result<(), T> {
41        self.publisher().unwrap().try_publish(item)
42    }
43}
44
45pub trait AsyncReceiver<T> {
46    fn receive(&self) -> impl Future<Output = T>;
47}
48
49impl <M: RawMutex, T, const N: usize> AsyncReceiver<T> for Channel<M, T, N> {
50    fn receive(&self) -> impl Future<Output = T> {
51        self.receive()
52    }
53}
54
55/// The main eventloop of the MQTT client
56/// 
57/// This struct holds all state an buffers for the mqtt client
58pub struct MqttEventLoop<'l, M: RawMutex, const B: usize> {
59    recv_buffer: RefCell<Buffer<[u8; B]>>,
60    send_buffer: RefCell<Buffer<[u8; B]>>,
61
62    state: State<'l, M>,
63
64    control_sender: PubSubChannel<M, MqttEvent, 4, 16, 8>,
65    request_receiver: Channel<M, MqttRequest, 4>,
66    received_publishes: Channel<M, MqttPublish, 4>
67}
68
69impl <M: RawMutex, const B: usize> MqttEventLoop<'static, M, B> {
70
71    /// Create a new event loop
72    pub fn new(config: ClientConfig) -> Self {
73        
74        Self {
75            recv_buffer: RefCell::new(new_stack_buffer::<B>()),
76            send_buffer: RefCell::new(new_stack_buffer::<B>()),
77
78            state: State::new(config, None),
79
80            control_sender: PubSubChannel::new(),
81            request_receiver: Channel::new(),
82            received_publishes: Channel::new()
83        }
84    }
85}
86
87impl <'l, M: RawMutex, const B: usize> MqttEventLoop<'l, M, B> {
88
89    pub fn new_with_last_will(config: ClientConfig, last_will: LastWill<'l>) -> Self {
90        
91        Self {
92            recv_buffer: RefCell::new(new_stack_buffer::<B>()),
93            send_buffer: RefCell::new(new_stack_buffer::<B>()),
94
95            state: State::new(config, Some(last_will)),
96
97            control_sender: PubSubChannel::new(),
98            request_receiver: Channel::new(),
99            received_publishes: Channel::new()
100        }
101    }
102
103    /// Create a client for the event loop. The Client can be used to publish and receive messages, ...
104    /// 
105    /// There can be multiple [`MqttClient`] for one [`MqttEventLoop`].
106    /// But concurrent receives from multiple clients result in not all clients to receive all publishes
107    pub fn client<'a>(&'a self) -> MqttClient<'a, M> {
108        MqttClient{
109            control_reveiver: &self.control_sender,
110            request_sender: self.request_receiver.sender(),
111            received_publishes: self.received_publishes.receiver()
112        }
113    }
114
115    /// Receive / sent bytes from / to the network
116    /// Sends blocking if there are bytes to send
117    /// Receives blocking if there are no more bytes to send.
118    async fn network_send_receive<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
119        let mut send_buffer = self.send_buffer.borrow_mut();
120        // let mut connection = self.connection.borrow_mut();
121        
122        if send_buffer.has_remaining_len() {
123            let n = connection.send(&mut send_buffer).await
124                .map_err(|e| MqttError::ConnectionFailed(e))?;
125
126            trace!("sent {} bytes to network; send_buffer remaining: {}", n, send_buffer.remaining_len());
127        } else {
128            trace!("send buffer is empty, skipping send");
129        }
130
131        let mut recv_buffer = self.recv_buffer.borrow_mut();
132        // Do not block for receiving if there is still something to send
133        if send_buffer.has_remaining_len() {
134            let n = connection.try_receive(&mut recv_buffer).await
135                .map_err(|e| MqttError::ConnectionFailed(e))?;
136            trace!("try_receive() {} bytes from network", n);
137        } else {
138            let n = connection.receive(&mut recv_buffer).await
139                .map_err(|e| MqttError::ConnectionFailed(e))?;
140            trace!("receive() {} bytes from network", n);
141        }
142
143        Ok(())
144    }
145
146    /// Try to read a packet from recv buffer. 
147    async fn try_package_receive(&self, send_buffer: &mut impl BufferWriter, recv_buffer: impl BufferReader) -> Result<(), MqttError> {
148        if recv_buffer.is_empty() {
149            trace!("try_package_receive(): recv_buffer is empty, cannot read packet");
150            return Ok(())
151        }
152        
153        let packet_op = decode_slice_with_len(&recv_buffer[..])
154            .map_err(|e| {
155                error!("try_package_receive(): error decoding package: {}", e);
156                MqttError::CodecError
157            })?;
158        
159        if let Some((len, packet)) = packet_op {
160            debug!("try_package_receive(): decoded packet from recv_buffer: len = {}, kind = {}", len, packet.get_type());
161            recv_buffer.add_bytes_read(len);
162            let events = 
163                self.state.process_packet(&packet, send_buffer, &self.received_publishes).await?;
164            
165            if ! events.is_empty() {
166                for event in events {
167                    debug!("try_package_receive(): processing packet -> MqttEvent: {}", &event);
168                    self.control_sender.publisher().unwrap().publish(event).await;
169                }
170            } else {
171                trace!("try_package_receive(): packet processed, no MqttEvent");
172            }
173
174        } else {
175            trace!("try_package_receive(): no complete packet in recv_buffer");
176        }
177
178        Ok(())
179    }
180
181    /// Makes the receive / send of the network
182    /// First tries to write outgoing traffic to buffer
183    /// Then tries to read / write to / from the connection
184    /// Then read data from receive buffer
185    async fn work_network<N: NetworkConnection>(&self, connection: &mut N) -> Result<Infallible, MqttError> {
186        loop {
187            // Try to send packets first before blocking for network traffic
188            // Send packets (Ping, Connect, Publish)
189            {   
190                let mut send_buffer = self.send_buffer.borrow_mut();
191                let mut send_buffer_writer = send_buffer.create_writer();
192                self.state.send_packets(&mut send_buffer_writer, &self.control_sender)?;
193                drop(send_buffer_writer);
194                trace!("after network send: send_buffer {} / {}", send_buffer.remaining_len(), send_buffer.remaining_capacity());
195
196            }
197            
198            // Send / Receive Network traffic
199            // Interript this when ...
200            // - a new Request (e. g. Publish) is added to process it
201            // - sending a ping message is required
202            let network_future = self.network_send_receive(connection);
203            let on_request_signal_future = self.state.on_requst_added.wait();
204            let next_ping_future = self.state.on_ping_required();
205            match select3(network_future, on_request_signal_future, next_ping_future).await {
206                Either3::First(res) => {
207                    trace!("stopping pause: got something from network");
208                    res
209                },
210                Either3::Second(_) => {
211                    trace!("stopping pause: new request");
212                    Ok(())
213                },
214                Either3::Third(()) => {
215                    trace!("stopping pause: ping required");
216                    let mut send_buffer = self.send_buffer.borrow_mut();
217                    let mut send_buffer_writer = send_buffer.create_writer();
218                    self.state.send_ping(&mut send_buffer_writer)
219                },
220            }?;
221
222            let mut send_buffer = self.send_buffer.borrow_mut();
223            let mut recv_buffer = self.recv_buffer.borrow_mut();
224            let recv_reader = recv_buffer.create_reader();
225            let mut send_buffer_writer = send_buffer.create_writer();
226
227            // Try to read a package from the receive buffer and write answers (e. g. acknoledgements) 
228            // to the send buffer
229            self.try_package_receive(&mut send_buffer_writer, recv_reader).await?; 
230
231            trace!("after try packege_receive: recv_buffer: {} / {}", recv_buffer.remaining_len(), recv_buffer.capacity());
232        }
233    }
234
235
236    /// Receive requests from client.
237    /// Returns if the client sends a disconnect message
238    async fn work_request_receive(&self) -> Result<(), MqttError> {
239        loop {
240            let req = self.request_receiver.receive().await;
241            let pid = self.state.pid_source.next_pid();
242
243            match req {
244                MqttRequest::Publish(mqtt_publish, id) => {
245                                self.state.publishes.push_publish(mqtt_publish, id, pid).await;
246                                debug!("new publish request added to queue");
247                            },
248                MqttRequest::Subscribe(topic, unique_id) => {
249                                const SUBSCRIBE_QOS: QoS = QoS::AtMostOnce;
250                                self.state.subscribes.push_subscribe(topic, pid, unique_id, SUBSCRIBE_QOS).await;
251                                debug!("new subscribe request added to queue");
252                            },
253                MqttRequest::Unsubscribe(topic, unique_id) => {
254                                self.state.subscribes.push_unsubscribe(topic, pid, unique_id).await;
255                                debug!("new unsubscribe request added to queue");
256                            },
257                MqttRequest::Disconnect => {
258                    self.state.on_requst_added.signal(0);
259                    return Ok(());
260                },
261            }
262
263            // Signal that a new request is added
264            self.state.on_requst_added.signal(0);
265        }
266    }
267
268    async fn connect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
269        self.send_buffer.borrow_mut().reset();
270        self.recv_buffer.borrow_mut().reset();
271        
272        let mut tries: usize = 0;
273        loop {
274
275            let result = connection.connect().await;
276
277            match result {
278                Ok(()) => {
279                    info!("connect to broker success");
280                    return Ok(())
281                },
282                Err(e) => {
283                    tries += 1;
284                    if tries < 5 {
285                        warn!("{}. try to connecto to host failed", tries);
286                        time::sleep(Duration::from_secs(3)).await;
287                    } else {
288                        error!("{} tries to connect failed", tries);
289                        return Err(MqttError::ConnectionFailed(e));
290                    }
291                },
292            }
293        }
294    }
295
296    async fn work<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
297        // Reset state on new connection
298        self.state.reset();
299            
300        // Poll both futures
301        // Select should never befinished because both jobs are infinite
302        let network_future = self.work_network(connection);
303        let request_future = self.work_request_receive();
304
305        match select(network_future, request_future).await {
306            embassy_futures::select::Either::First(net_result) => {
307                let err = net_result.unwrap_err();
308                error!("network infinite job finished: {}", &err);
309                Err(err)
310            },
311            embassy_futures::select::Either::Second(req_result) => {
312                if let Err(err) = req_result {
313                    error!("infinite request receive job finished: {}", &err);
314                    Err(err)
315                } else {
316                    info!("disconnect request received: stopping jobs");
317                    Ok(())
318                }
319            },
320        }
321    }
322
323    async fn disconnect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
324        
325        let mut send_buffer = self.send_buffer.borrow_mut();
326        let mut send_buffer_writer = send_buffer.create_writer();
327
328        send_buffer_writer.write_mqtt_packet_sync(&Packet::Disconnect).map_err(|err| {
329            match err {
330                MqttPacketError::NotEnaughBufferSpace => {
331                    warn!("could not write Disconnect to send buffer: full");
332                    MqttError::BufferFull
333                },
334                MqttPacketError::CodecError => MqttError::CodecError,
335                MqttPacketError::IoError(_) => MqttError::ConnectionFailed(NetworkError::ConnectionFailed),
336                MqttPacketError::NetworkError(network_error) => MqttError::ConnectionFailed(network_error),
337            }
338        })?;
339        drop(send_buffer_writer);
340
341        let mut send_buffer_reader = send_buffer.create_reader();
342        connection.send_all(&mut send_buffer_reader).await
343            .map_err(|e| MqttError::ConnectionFailed(e))?;
344
345        // Reset send buffer to be ready to 
346        self.recv_buffer.borrow_mut().reset();
347
348        Ok(())
349
350    }
351
352    pub async fn run<N: NetworkConnection>(&self, connection: Pin<&mut N>) -> Result<(), MqttError> {
353
354        let connection = unsafe {
355            connection.get_unchecked_mut()
356        };
357
358        self.connect(connection).await?;
359
360        loop {
361            let result = self.work(connection).await;
362            match result {
363                Ok(()) => {
364                    break;
365                }
366                Err(MqttError::ConnectionFailed(e)) => {
367                    warn!("reconnecting, conection faild: {}", e);
368                    self.connect(connection).await?;
369                }
370                Err(err) => {
371                    return Err(err);
372                }
373            }
374        }
375
376        self.disconnect(connection).await?;
377
378        Ok(())
379
380    }
381}
382
383#[cfg(all(test, feature = "std"))]
384mod test {
385    use core::pin::Pin;
386
387    use heapless::Vec;
388
389    use crate::state::KEEP_ALIVE;
390    use crate::time::Duration;
391
392    use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
393    use heapless::String;
394    use mqttrs2::{Connack, ConnectReturnCode, Packet, PacketType, QoS, Suback, SubscribeReturnCodes};
395    use crate::time;
396
397    use crate::network::{fake::{self, ConnectionRessources, ReadAtomic}, mqtt::{ReadMqttPacket, WriteMqttPacket}};
398    use crate::ClientConfig;
399
400    use super::MqttEventLoop;
401
402    fn print_packet(p: &Packet<'_>) -> std::string::String {
403        match p {
404            Packet::Connect(_) => format!("Connect"),
405            Packet::Connack(_) => format!("Connack"),
406            Packet::Publish(_) => format!("Publish"),
407            Packet::Puback(_) => format!("Puback"),
408            Packet::Pubrec(_) => format!("Pubrec"),
409            Packet::Pubrel(_) => format!("Pubrel"),
410            Packet::Pubcomp(_) => format!("Pubcomp"),
411            Packet::Subscribe(_) => format!("Subscribe"),
412            Packet::Suback(_) => format!("Suback"),
413            Packet::Unsubscribe(_) => format!("Unsubscribe"),
414            Packet::Unsuback(_) => format!("Unsuback"),
415            Packet::Pingreq => format!("Pingreq"),
416            Packet::Pingresp => format!("Pingresp"),
417            Packet::Disconnect => format!("Disconnect"),
418        }
419    }
420
421    #[tokio::test]
422    async fn test_run() {
423        time::test_time::set_default();
424
425        let mut config = ClientConfig{
426            client_id: String::new(),
427            credentials: None,
428            auto_subscribes: Vec::new()
429        };
430
431        config.client_id.push_str("asjdkaljs").unwrap();
432
433        let connection_resources = ConnectionRessources::<1024>::new();
434
435        let (mut client, server) = fake::new_connection(&connection_resources);
436
437        let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
438        let mqtt_client = event_loop.client();
439
440        let runner_future = async {
441            let client = Pin::new(&mut client);
442            event_loop.run(client).await.unwrap();
443        };
444
445        let test_future = async move {
446            let client_future = async {
447                mqtt_client.subscribe("test").await.unwrap();
448            };
449            
450            let server_future = async {
451
452                let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
453                assert_eq!(connect, PacketType::Connect);
454
455                server.write_mqtt_packet(&Packet::Connack(Connack{
456                    session_present: false,
457                    code: ConnectReturnCode::Accepted
458                })).await.unwrap();
459
460                let subscribe = server.read_mqtt_packet(|s| {
461                    match s {
462                        Packet::Subscribe(sub) => sub.clone(),
463                        other => panic!("expected subscribe, got {}", print_packet(other))
464                    }
465                }).await.unwrap();
466
467                let mut return_codes = heapless::Vec::new();
468                return_codes.push(SubscribeReturnCodes::Success(QoS::AtLeastOnce)).unwrap();
469                
470                server.write_mqtt_packet(&Packet::Suback(Suback{
471                    pid: subscribe.pid,
472                    return_codes 
473                })).await.unwrap();
474            };
475
476            tokio::join!(client_future, server_future);
477        };
478
479        tokio::select! {
480            _ = runner_future => {},
481            _ = test_future => {}
482        }
483    }
484
485    #[tokio::test]
486    async fn test_idle_connection() {
487        let config = ClientConfig{
488            client_id: String::new(),
489            credentials: None,
490            auto_subscribes: Vec::new()
491        };
492
493        time::test_time::set_static_now();
494
495        let connection_resources = ConnectionRessources::<1024>::new();
496        let (mut client, server) = fake::new_connection(&connection_resources);
497
498        let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
499
500        let runner_future = async {
501            let client = Pin::new(&mut client);
502            event_loop.run(client).await.unwrap();
503        };
504            
505        let server_future = async {
506
507            let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
508            assert_eq!(connect, PacketType::Connect);
509
510            server.write_mqtt_packet(&Packet::Connack(Connack{
511                session_present: false,
512                code: ConnectReturnCode::Accepted
513            })).await.unwrap();
514
515            time::test_time::advance_time(Duration::from_secs(2));
516            tokio::time::sleep(core::time::Duration::from_millis(100)).await;
517
518            let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
519            assert_eq!(pingreq, None);
520
521            time::test_time::advance_time(Duration::from_secs(KEEP_ALIVE as u64) / 2);
522            tokio::time::sleep(core::time::Duration::from_millis(100)).await;
523
524            let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
525            assert_eq!(pingreq, Some(PacketType::Pingreq));
526
527            server.write_mqtt_packet(&Packet::Pingresp).await.unwrap();
528
529            time::test_time::advance_time(Duration::from_secs(2));
530            tokio::time::sleep(core::time::Duration::from_millis(100)).await;
531
532            let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
533            assert_eq!(pingreq, None);
534        };
535
536        tokio::select! {
537            _ = runner_future => {},
538            _ = server_future => {}
539        }
540    }
541
542    #[tokio::test]
543    #[ntest::timeout(1000)]
544    async fn test_disconnect() {
545        let config = ClientConfig{
546            client_id: String::new(),
547            credentials: None,
548            auto_subscribes: Vec::new()
549        };
550
551        let connection_resources = ConnectionRessources::<1024>::new();
552        let (mut client, server) = fake::new_connection(&connection_resources);
553
554        let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
555        let mqtt_client = event_loop.client();
556
557        let runner_future = async {
558            let client = Pin::new(&mut client);
559            event_loop.run(client).await.unwrap();
560        };
561
562        let client_future = async {
563            mqtt_client.disconnect().await;
564        };
565            
566        let server_future = async {
567
568            let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
569            assert_eq!(connect, PacketType::Connect);
570
571            server.write_mqtt_packet(&Packet::Connack(Connack{
572                session_present: false,
573                code: ConnectReturnCode::Accepted
574            })).await.unwrap();
575
576            let disconnect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
577            assert_eq!(disconnect, PacketType::Disconnect);
578        };
579
580        tokio::join! {
581            runner_future,
582            server_future,
583            client_future
584        };
585    }
586}
587