embedded_mqttc/
io.rs

1use core::convert::Infallible;
2use core::{cell::RefCell, future::Future, pin::Pin};
3
4use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, BufferWriter, ReadWrite};
5use embassy_futures::select::{select, select3, Either3};
6use embassy_sync::{blocking_mutex::raw::RawMutex, channel::Channel, pubsub::PubSubChannel};
7use mqttrs2::{decode_slice_with_len, Packet, QoS};
8use crate::network::mqtt::MqttPacketError;
9use crate::network::NetworkError;
10use crate::network::{ mqtt::WriteMqttPacketMut, NetwordSendReceive, NetworkConnection };
11use crate::{client::MqttClient, state::State, time, ClientConfig, MqttError, MqttEvent, MqttPublish, MqttRequest};
12
13use crate::time::Duration;
14
15pub trait AsyncSender<T> {
16    fn send(&self, item: T) -> impl Future<Output = ()>;
17    fn try_send(&self, item: T) -> Result<(), T>;
18}
19
20impl <M: RawMutex, T, const N: usize> AsyncSender<T> for Channel<M, T, N> {
21    fn send(&self, item: T) -> impl Future<Output = ()> {
22        self.send(item)
23    }
24
25    fn try_send(&self, item: T) -> Result<(), T> {
26        self.try_send(item)
27            .map_err(|err|{
28                match err {
29                    embassy_sync::channel::TrySendError::Full(item) => item,
30                }
31            })
32    }
33}
34
35impl <M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> AsyncSender<T> for PubSubChannel<M, T, CAP, SUBS, PUBS> {
36    async fn send(&self, item: T) {
37        self.publisher().unwrap().publish(item).await;
38    }
39
40    fn try_send(&self, item: T) -> Result<(), T> {
41        self.publisher().unwrap().try_publish(item)
42    }
43}
44
45pub trait AsyncReceiver<T> {
46    fn receive(&self) -> impl Future<Output = T>;
47}
48
49impl <M: RawMutex, T, const N: usize> AsyncReceiver<T> for Channel<M, T, N> {
50    fn receive(&self) -> impl Future<Output = T> {
51        self.receive()
52    }
53}
54
55pub struct MqttEventLoop<M: RawMutex, const B: usize> {
56    recv_buffer: RefCell<Buffer<[u8; B]>>,
57    send_buffer: RefCell<Buffer<[u8; B]>>,
58
59    state: State<M>,
60
61    control_sender: PubSubChannel<M, MqttEvent, 4, 16, 8>,
62    request_receiver: Channel<M, MqttRequest, 4>,
63    received_publishes: Channel<M, MqttPublish, 4>
64}
65
66impl <M: RawMutex, const B: usize> MqttEventLoop<M, B> {
67
68    pub fn new(config: ClientConfig) -> Self {
69        
70        Self {
71            recv_buffer: RefCell::new(new_stack_buffer::<B>()),
72            send_buffer: RefCell::new(new_stack_buffer::<B>()),
73
74            state: State::new(config),
75
76            control_sender: PubSubChannel::new(),
77            request_receiver: Channel::new(),
78            received_publishes: Channel::new()
79        }
80    }
81
82    pub fn client<'a>(&'a self) -> MqttClient<'a, M> {
83        MqttClient{
84            control_reveiver: &self.control_sender,
85            request_sender: self.request_receiver.sender(),
86            received_publishes: self.received_publishes.receiver()
87        }
88    }
89
90    /// Receive / sent bytes from / to the network
91    /// Sends blocking if there are bytes to send
92    /// Receives blocking if there are no more bytes to send.
93    async fn network_send_receive<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
94        let mut send_buffer = self.send_buffer.borrow_mut();
95        // let mut connection = self.connection.borrow_mut();
96        
97        if send_buffer.has_remaining_len() {
98            let n = connection.send(&mut send_buffer).await
99                .map_err(|e| MqttError::ConnectionFailed(e))?;
100
101            trace!("sent {} bytes to network; send_buffer remaining: {}", n, send_buffer.remaining_len());
102        } else {
103            trace!("send buffer is empty, skipping send");
104        }
105
106        let mut recv_buffer = self.recv_buffer.borrow_mut();
107        // Do not block for receiving if there is still something to send
108        if send_buffer.has_remaining_len() {
109            let n = connection.try_receive(&mut recv_buffer).await
110                .map_err(|e| MqttError::ConnectionFailed(e))?;
111            trace!("try_receive() {} bytes from network", n);
112        } else {
113            let n = connection.receive(&mut recv_buffer).await
114                .map_err(|e| MqttError::ConnectionFailed(e))?;
115            trace!("receive() {} bytes from network", n);
116        }
117
118        Ok(())
119    }
120
121    /// Try to read a packet from recv buffer. 
122    async fn try_package_receive(&self, send_buffer: &mut impl BufferWriter, recv_buffer: impl BufferReader) -> Result<(), MqttError> {
123        if recv_buffer.is_empty() {
124            trace!("try_package_receive(): recv_buffer is empty, cannot read packet");
125            return Ok(())
126        }
127        
128        let packet_op = decode_slice_with_len(&recv_buffer[..])
129            .map_err(|e| {
130                error!("try_package_receive(): error decoding package: {}", e);
131                MqttError::CodecError
132            })?;
133        
134        if let Some((len, packet)) = packet_op {
135            debug!("try_package_receive(): decoded packet from recv_buffer: len = {}, kind = {}", len, packet.get_type());
136            recv_buffer.add_bytes_read(len);
137            let events = 
138                self.state.process_packet(&packet, send_buffer, &self.received_publishes).await?;
139            
140            if ! events.is_empty() {
141                for event in events {
142                    debug!("try_package_receive(): processing packet -> MqttEvent: {}", &event);
143                    self.control_sender.publisher().unwrap().publish(event).await;
144                }
145            } else {
146                trace!("try_package_receive(): packet processed, no MqttEvent");
147            }
148
149        } else {
150            trace!("try_package_receive(): no complete packet in recv_buffer");
151        }
152
153        Ok(())
154    }
155
156    /// Makes the receive / send of the network
157    /// First tries to write outgoing traffic to buffer
158    /// Then tries to read / write to / from the connection
159    /// Then read data from receive buffer
160    async fn work_network<N: NetworkConnection>(&self, connection: &mut N) -> Result<Infallible, MqttError> {
161        loop {
162            // Try to send packets first before blocking for network traffic
163            // Send packets (Ping, Connect, Publish)
164            {   
165                let mut send_buffer = self.send_buffer.borrow_mut();
166                let mut send_buffer_writer = send_buffer.create_writer();
167                self.state.send_packets(&mut send_buffer_writer, &self.control_sender)?;
168                drop(send_buffer_writer);
169                trace!("after network send: send_buffer {} / {}", send_buffer.remaining_len(), send_buffer.remaining_capacity());
170
171            }
172            
173            // Send / Receive Network traffic
174            // Interript this when ...
175            // - a new Request (e. g. Publish) is added to process it
176            // - sending a ping message is required
177            let network_future = self.network_send_receive(connection);
178            let on_request_signal_future = self.state.on_requst_added.wait();
179            let next_ping_future = self.state.on_ping_required();
180            match select3(network_future, on_request_signal_future, next_ping_future).await {
181                Either3::First(res) => {
182                    trace!("stopping pause: got something from network");
183                    res
184                },
185                Either3::Second(_) => {
186                    trace!("stopping pause: new request");
187                    Ok(())
188                },
189                Either3::Third(()) => {
190                    trace!("stopping pause: ping required");
191                    let mut send_buffer = self.send_buffer.borrow_mut();
192                    let mut send_buffer_writer = send_buffer.create_writer();
193                    self.state.send_ping(&mut send_buffer_writer)
194                },
195            }?;
196
197            let mut send_buffer = self.send_buffer.borrow_mut();
198            let mut recv_buffer = self.recv_buffer.borrow_mut();
199            let recv_reader = recv_buffer.create_reader();
200            let mut send_buffer_writer = send_buffer.create_writer();
201
202            // Try to read a package from the receive buffer and write answers (e. g. acknoledgements) 
203            // to the send buffer
204            self.try_package_receive(&mut send_buffer_writer, recv_reader).await?; 
205
206            trace!("after try packege_receive: recv_buffer: {} / {}", recv_buffer.remaining_len(), recv_buffer.capacity());
207        }
208    }
209
210
211    /// Receive requests from cleint.
212    /// Returns if the client sends a disconnect message
213    async fn work_request_receive(&self) -> Result<(), MqttError> {
214        loop {
215            let req = self.request_receiver.receive().await;
216            let pid = self.state.pid_source.next_pid();
217
218            match req {
219                MqttRequest::Publish(mqtt_publish, id) => {
220                                self.state.publishes.push_publish(mqtt_publish, id, pid).await;
221                                debug!("new publish request added to queue");
222                            },
223                MqttRequest::Subscribe(topic, unique_id) => {
224                                const SUBSCRIBE_QOS: QoS = QoS::AtMostOnce;
225                                self.state.subscribes.push_subscribe(topic, pid, unique_id, SUBSCRIBE_QOS).await;
226                                debug!("new subscribe request added to queue");
227                            },
228                MqttRequest::Unsubscribe(topic, unique_id) => {
229                                self.state.subscribes.push_unsubscribe(topic, pid, unique_id).await;
230                                debug!("new unsubscribe request added to queue");
231                            },
232                MqttRequest::Disconnect => {
233                    self.state.on_requst_added.signal(0);
234                    return Ok(());
235                },
236            }
237
238            // Signal that a new request is added
239            self.state.on_requst_added.signal(0);
240        }
241    }
242
243    async fn connect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
244        self.send_buffer.borrow_mut().reset();
245        self.recv_buffer.borrow_mut().reset();
246        
247        let mut tries = 0;
248        loop {
249
250            let result = connection.connect().await;
251
252            match result {
253                Ok(()) => {
254                    info!("connect to broker success");
255                    return Ok(())
256                },
257                Err(e) => {
258                    tries += 1;
259                    if tries < 5 {
260                        warn!("{}. try to connecto to host failed", tries);
261                        time::sleep(Duration::from_secs(3)).await;
262                    } else {
263                        error!("{} tries to connect failed", tries);
264                        return Err(MqttError::ConnectionFailed(e));
265                    }
266                },
267            }
268        }
269    }
270
271    async fn work<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
272        // Reset state on new connection
273        self.state.reset();
274            
275        // Poll both futures
276        // Select should never befinished because both jobs are infinite
277        let network_future = self.work_network(connection);
278        let request_future = self.work_request_receive();
279        match select(network_future, request_future).await {
280            embassy_futures::select::Either::First(net_result) => {
281                let err = net_result.unwrap_err();
282                error!("network infinite job finished: {}", &err);
283                Err(err)
284            },
285            embassy_futures::select::Either::Second(req_result) => {
286                if let Err(err) = req_result {
287                    error!("infinite request receive job finished: {}", &err);
288                    Err(err)
289                } else {
290                    info!("disconnect request received: stopping jobs");
291                    Ok(())
292                }
293            },
294        }
295    }
296
297    async fn disconnect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
298        
299        let mut send_buffer = self.send_buffer.borrow_mut();
300        let mut send_buffer_writer = send_buffer.create_writer();
301
302        send_buffer_writer.write_mqtt_packet_sync(&Packet::Disconnect).map_err(|err| {
303            match err {
304                MqttPacketError::NotEnaughBufferSpace => {
305                    warn!("could not write Disconnect to send buffer: full");
306                    MqttError::BufferFull
307                },
308                MqttPacketError::CodecError => MqttError::CodecError,
309                MqttPacketError::IoError(_) => MqttError::ConnectionFailed(NetworkError::ConnectionFailed),
310                MqttPacketError::NetworkError(network_error) => MqttError::ConnectionFailed(network_error),
311            }
312        })?;
313        drop(send_buffer_writer);
314
315        let mut send_buffer_reader = send_buffer.create_reader();
316        connection.send_all(&mut send_buffer_reader).await
317            .map_err(|e| MqttError::ConnectionFailed(e))?;
318
319        // Reset send buffer to be ready to 
320        self.recv_buffer.borrow_mut().reset();
321
322        Ok(())
323
324    }
325
326    pub async fn run<N: NetworkConnection>(&self, connection: Pin<&mut N>) -> Result<(), MqttError> {
327
328        let connection = unsafe {
329            connection.get_unchecked_mut()
330        };
331
332        self.connect(connection).await?;
333
334        loop {
335            let result = self.work(connection).await;
336            match result {
337                Ok(()) => {
338                    break;
339                }
340                Err(MqttError::ConnectionFailed(e)) => {
341                    warn!("reconnecting, conection faild: {}", e);
342                    self.connect(connection).await?;
343                }
344                Err(err) => {
345                    return Err(err);
346                }
347            }
348        }
349
350        self.disconnect(connection).await?;
351
352        Ok(())
353
354    }
355}
356
357#[cfg(test)]
358mod test {
359    use core::pin::Pin;
360
361    use heapless::Vec;
362
363    // use crate::misc::MqttPacketReader;
364    use crate::state::KEEP_ALIVE;
365    use crate::time::Duration;
366
367    use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
368    use heapless::String;
369    use mqttrs2::{Connack, ConnectReturnCode, Packet, PacketType, QoS, Suback, SubscribeReturnCodes};
370    use crate::time;
371
372    use crate::network::{fake::{self, ConnectionRessources, ReadAtomic}, mqtt::{ReadMqttPacket, WriteMqttPacket}};
373    use crate::ClientConfig;
374
375    use super::MqttEventLoop;
376
377    fn print_packet(p: &Packet<'_>) -> std::string::String {
378        match p {
379            Packet::Connect(_) => format!("Connect"),
380            Packet::Connack(_) => format!("Connack"),
381            Packet::Publish(_) => format!("Publish"),
382            Packet::Puback(_) => format!("Puback"),
383            Packet::Pubrec(_) => format!("Pubrec"),
384            Packet::Pubrel(_) => format!("Pubrel"),
385            Packet::Pubcomp(_) => format!("Pubcomp"),
386            Packet::Subscribe(_) => format!("Subscribe"),
387            Packet::Suback(_) => format!("Suback"),
388            Packet::Unsubscribe(_) => format!("Unsubscribe"),
389            Packet::Unsuback(_) => format!("Unsuback"),
390            Packet::Pingreq => format!("Pingreq"),
391            Packet::Pingresp => format!("Pingresp"),
392            Packet::Disconnect => format!("Disconnect"),
393        }
394    }
395
396    #[tokio::test]
397    async fn test_run() {
398        time::test_time::set_default();
399
400        let mut config = ClientConfig{
401            client_id: String::new(),
402            credentials: None,
403            auto_subscribes: Vec::new()
404        };
405
406        config.client_id.push_str("asjdkaljs").unwrap();
407
408        let connection_resources = ConnectionRessources::<1024>::new();
409
410        let (mut client, server) = fake::new_connection(&connection_resources);
411
412        let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
413        let mqtt_client = event_loop.client();
414
415        let runner_future = async {
416            let client = Pin::new(&mut client);
417            event_loop.run(client).await.unwrap();
418        };
419
420        let test_future = async move {
421            let client_future = async {
422                mqtt_client.subscribe("test").await.unwrap();
423            };
424            
425            let server_future = async {
426
427                let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
428                assert_eq!(connect, PacketType::Connect);
429
430                server.write_mqtt_packet(&Packet::Connack(Connack{
431                    session_present: false,
432                    code: ConnectReturnCode::Accepted
433                })).await.unwrap();
434
435                let subscribe = server.read_mqtt_packet(|s| {
436                    match s {
437                        Packet::Subscribe(sub) => sub.clone(),
438                        other => panic!("expected subscribe, got {}", print_packet(other))
439                    }
440                }).await.unwrap();
441
442                let mut return_codes = heapless::Vec::new();
443                return_codes.push(SubscribeReturnCodes::Success(QoS::AtLeastOnce)).unwrap();
444                
445                server.write_mqtt_packet(&Packet::Suback(Suback{
446                    pid: subscribe.pid,
447                    return_codes 
448                })).await.unwrap();
449            };
450
451            tokio::join!(client_future, server_future);
452        };
453
454        tokio::select! {
455            _ = runner_future => {},
456            _ = test_future => {}
457        }
458    }
459
460    #[tokio::test]
461    async fn test_idle_connection() {
462        let config = ClientConfig{
463            client_id: String::new(),
464            credentials: None,
465            auto_subscribes: Vec::new()
466        };
467
468        time::test_time::set_static_now();
469
470        let connection_resources = ConnectionRessources::<1024>::new();
471        let (mut client, server) = fake::new_connection(&connection_resources);
472
473        let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
474
475        let runner_future = async {
476            let client = Pin::new(&mut client);
477            event_loop.run(client).await.unwrap();
478        };
479            
480        let server_future = async {
481
482            let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
483            assert_eq!(connect, PacketType::Connect);
484
485            server.write_mqtt_packet(&Packet::Connack(Connack{
486                session_present: false,
487                code: ConnectReturnCode::Accepted
488            })).await.unwrap();
489
490            time::test_time::advance_time(Duration::from_secs(2));
491            tokio::time::sleep(core::time::Duration::from_millis(100)).await;
492
493            let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
494            assert_eq!(pingreq, None);
495
496            time::test_time::advance_time(Duration::from_secs(KEEP_ALIVE as u64) / 2);
497            tokio::time::sleep(core::time::Duration::from_millis(100)).await;
498
499            let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
500            assert_eq!(pingreq, Some(PacketType::Pingreq));
501
502            server.write_mqtt_packet(&Packet::Pingresp).await.unwrap();
503
504            time::test_time::advance_time(Duration::from_secs(2));
505            tokio::time::sleep(core::time::Duration::from_millis(100)).await;
506
507            let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
508            assert_eq!(pingreq, None);
509        };
510
511        tokio::select! {
512            _ = runner_future => {},
513            _ = server_future => {}
514        }
515    }
516
517    #[tokio::test]
518    #[ntest::timeout(1000)]
519    async fn test_disconnect() {
520        let config = ClientConfig{
521            client_id: String::new(),
522            credentials: None,
523            auto_subscribes: Vec::new()
524        };
525
526        let connection_resources = ConnectionRessources::<1024>::new();
527        let (mut client, server) = fake::new_connection(&connection_resources);
528
529        let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
530        let mqtt_client = event_loop.client();
531
532        let runner_future = async {
533            let client = Pin::new(&mut client);
534            event_loop.run(client).await.unwrap();
535        };
536
537        let client_future = async {
538            mqtt_client.disconnect().await;
539        };
540            
541        let server_future = async {
542
543            let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
544            assert_eq!(connect, PacketType::Connect);
545
546            server.write_mqtt_packet(&Packet::Connack(Connack{
547                session_present: false,
548                code: ConnectReturnCode::Accepted
549            })).await.unwrap();
550
551            let disconnect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
552            assert_eq!(disconnect, PacketType::Disconnect);
553        };
554
555        tokio::join! {
556            runner_future,
557            server_future,
558            client_future
559        };
560    }
561}
562