1use core::convert::Infallible;
2use core::{cell::RefCell, future::Future, pin::Pin};
3
4use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, BufferWriter, ReadWrite};
5use embassy_futures::select::{select, select3, Either3};
6use embassy_sync::{blocking_mutex::raw::RawMutex, channel::Channel, pubsub::PubSubChannel};
7use mqttrs2::{decode_slice_with_len, LastWill, Packet, QoS};
8use crate::network::mqtt::MqttPacketError;
9use crate::network::NetworkError;
10use crate::network::{ mqtt::WriteMqttPacketMut, NetwordSendReceive, NetworkConnection };
11use crate::{client::MqttClient, state::State, time, ClientConfig, MqttError, MqttEvent, MqttPublish, MqttRequest};
12
13use crate::time::Duration;
14
15pub trait AsyncSender<T> {
16 fn send(&self, item: T) -> impl Future<Output = ()>;
17 fn try_send(&self, item: T) -> Result<(), T>;
18}
19
20impl <M: RawMutex, T, const N: usize> AsyncSender<T> for Channel<M, T, N> {
21 fn send(&self, item: T) -> impl Future<Output = ()> {
22 self.send(item)
23 }
24
25 fn try_send(&self, item: T) -> Result<(), T> {
26 self.try_send(item)
27 .map_err(|err|{
28 match err {
29 embassy_sync::channel::TrySendError::Full(item) => item,
30 }
31 })
32 }
33}
34
35impl <M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> AsyncSender<T> for PubSubChannel<M, T, CAP, SUBS, PUBS> {
36 async fn send(&self, item: T) {
37 self.publisher().unwrap().publish(item).await;
38 }
39
40 fn try_send(&self, item: T) -> Result<(), T> {
41 self.publisher().unwrap().try_publish(item)
42 }
43}
44
45pub trait AsyncReceiver<T> {
46 fn receive(&self) -> impl Future<Output = T>;
47}
48
49impl <M: RawMutex, T, const N: usize> AsyncReceiver<T> for Channel<M, T, N> {
50 fn receive(&self) -> impl Future<Output = T> {
51 self.receive()
52 }
53}
54
55pub struct MqttEventLoop<'l, M: RawMutex, const B: usize> {
59 recv_buffer: RefCell<Buffer<[u8; B]>>,
60 send_buffer: RefCell<Buffer<[u8; B]>>,
61
62 state: State<'l, M>,
63
64 control_sender: PubSubChannel<M, MqttEvent, 4, 16, 8>,
65 request_receiver: Channel<M, MqttRequest, 4>,
66 received_publishes: Channel<M, MqttPublish, 4>
67}
68
69impl <M: RawMutex, const B: usize> MqttEventLoop<'static, M, B> {
70
71 pub fn new(config: ClientConfig) -> Self {
73
74 Self {
75 recv_buffer: RefCell::new(new_stack_buffer::<B>()),
76 send_buffer: RefCell::new(new_stack_buffer::<B>()),
77
78 state: State::new(config, None),
79
80 control_sender: PubSubChannel::new(),
81 request_receiver: Channel::new(),
82 received_publishes: Channel::new()
83 }
84 }
85}
86
87impl <'l, M: RawMutex, const B: usize> MqttEventLoop<'l, M, B> {
88
89 pub fn new_with_last_will(config: ClientConfig, last_will: LastWill<'l>) -> Self {
90
91 Self {
92 recv_buffer: RefCell::new(new_stack_buffer::<B>()),
93 send_buffer: RefCell::new(new_stack_buffer::<B>()),
94
95 state: State::new(config, Some(last_will)),
96
97 control_sender: PubSubChannel::new(),
98 request_receiver: Channel::new(),
99 received_publishes: Channel::new()
100 }
101 }
102
103 pub fn client<'a>(&'a self) -> MqttClient<'a, M> {
108 MqttClient{
109 control_reveiver: &self.control_sender,
110 request_sender: self.request_receiver.sender(),
111 received_publishes: self.received_publishes.receiver()
112 }
113 }
114
115 async fn network_send_receive<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
119 let mut send_buffer = self.send_buffer.borrow_mut();
120 if send_buffer.has_remaining_len() {
123 let n = connection.send(&mut send_buffer).await
124 .map_err(|e| MqttError::ConnectionFailed(e))?;
125
126 trace!("sent {} bytes to network; send_buffer remaining: {}", n, send_buffer.remaining_len());
127 } else {
128 trace!("send buffer is empty, skipping send");
129 }
130
131 let mut recv_buffer = self.recv_buffer.borrow_mut();
132 if send_buffer.has_remaining_len() {
134 let n = connection.try_receive(&mut recv_buffer).await
135 .map_err(|e| MqttError::ConnectionFailed(e))?;
136 trace!("try_receive() {} bytes from network", n);
137 } else {
138 let n = connection.receive(&mut recv_buffer).await
139 .map_err(|e| MqttError::ConnectionFailed(e))?;
140 trace!("receive() {} bytes from network", n);
141 }
142
143 Ok(())
144 }
145
146 async fn try_package_receive(&self, send_buffer: &mut impl BufferWriter, recv_buffer: impl BufferReader) -> Result<(), MqttError> {
148 if recv_buffer.is_empty() {
149 trace!("try_package_receive(): recv_buffer is empty, cannot read packet");
150 return Ok(())
151 }
152
153 let packet_op = decode_slice_with_len(&recv_buffer[..])
154 .map_err(|e| {
155 error!("try_package_receive(): error decoding package: {}", e);
156 MqttError::CodecError
157 })?;
158
159 if let Some((len, packet)) = packet_op {
160 debug!("try_package_receive(): decoded packet from recv_buffer: len = {}, kind = {}", len, packet.get_type());
161 recv_buffer.add_bytes_read(len);
162 let events =
163 self.state.process_packet(&packet, send_buffer, &self.received_publishes).await?;
164
165 if ! events.is_empty() {
166 for event in events {
167 debug!("try_package_receive(): processing packet -> MqttEvent: {}", &event);
168 self.control_sender.publisher().unwrap().publish(event).await;
169 }
170 } else {
171 trace!("try_package_receive(): packet processed, no MqttEvent");
172 }
173
174 } else {
175 trace!("try_package_receive(): no complete packet in recv_buffer");
176 }
177
178 Ok(())
179 }
180
181 async fn work_network<N: NetworkConnection>(&self, connection: &mut N) -> Result<Infallible, MqttError> {
186 loop {
187 {
190 let mut send_buffer = self.send_buffer.borrow_mut();
191 let mut send_buffer_writer = send_buffer.create_writer();
192 self.state.send_packets(&mut send_buffer_writer, &self.control_sender)?;
193 drop(send_buffer_writer);
194 trace!("after network send: send_buffer {} / {}", send_buffer.remaining_len(), send_buffer.remaining_capacity());
195
196 }
197
198 let network_future = self.network_send_receive(connection);
203 let on_request_signal_future = self.state.on_requst_added.wait();
204 let next_ping_future = self.state.on_ping_required();
205 match select3(network_future, on_request_signal_future, next_ping_future).await {
206 Either3::First(res) => {
207 trace!("stopping pause: got something from network");
208 res
209 },
210 Either3::Second(_) => {
211 trace!("stopping pause: new request");
212 Ok(())
213 },
214 Either3::Third(()) => {
215 trace!("stopping pause: ping required");
216 let mut send_buffer = self.send_buffer.borrow_mut();
217 let mut send_buffer_writer = send_buffer.create_writer();
218 self.state.send_ping(&mut send_buffer_writer)
219 },
220 }?;
221
222 let mut send_buffer = self.send_buffer.borrow_mut();
223 let mut recv_buffer = self.recv_buffer.borrow_mut();
224 let recv_reader = recv_buffer.create_reader();
225 let mut send_buffer_writer = send_buffer.create_writer();
226
227 self.try_package_receive(&mut send_buffer_writer, recv_reader).await?;
230
231 trace!("after try packege_receive: recv_buffer: {} / {}", recv_buffer.remaining_len(), recv_buffer.capacity());
232 }
233 }
234
235
236 async fn work_request_receive(&self) -> Result<(), MqttError> {
239 loop {
240 let req = self.request_receiver.receive().await;
241 let pid = self.state.pid_source.next_pid();
242
243 match req {
244 MqttRequest::Publish(mqtt_publish, id) => {
245 self.state.publishes.push_publish(mqtt_publish, id, pid).await;
246 debug!("new publish request added to queue");
247 },
248 MqttRequest::Subscribe(topic, unique_id) => {
249 const SUBSCRIBE_QOS: QoS = QoS::AtMostOnce;
250 self.state.subscribes.push_subscribe(topic, pid, unique_id, SUBSCRIBE_QOS).await;
251 debug!("new subscribe request added to queue");
252 },
253 MqttRequest::Unsubscribe(topic, unique_id) => {
254 self.state.subscribes.push_unsubscribe(topic, pid, unique_id).await;
255 debug!("new unsubscribe request added to queue");
256 },
257 MqttRequest::Disconnect => {
258 self.state.on_requst_added.signal(0);
259 return Ok(());
260 },
261 }
262
263 self.state.on_requst_added.signal(0);
265 }
266 }
267
268 async fn connect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
269 self.send_buffer.borrow_mut().reset();
270 self.recv_buffer.borrow_mut().reset();
271
272 let mut tries: usize = 0;
273 loop {
274
275 let result = connection.connect().await;
276
277 match result {
278 Ok(()) => {
279 info!("connect to broker success");
280 return Ok(())
281 },
282 Err(e) => {
283 tries += 1;
284 if tries < 5 {
285 warn!("{}. try to connecto to host failed", tries);
286 time::sleep(Duration::from_secs(3)).await;
287 } else {
288 error!("{} tries to connect failed", tries);
289 return Err(MqttError::ConnectionFailed(e));
290 }
291 },
292 }
293 }
294 }
295
296 async fn work<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
297 self.state.reset();
299
300 let network_future = self.work_network(connection);
303 let request_future = self.work_request_receive();
304
305 match select(network_future, request_future).await {
306 embassy_futures::select::Either::First(net_result) => {
307 let err = net_result.unwrap_err();
308 error!("network infinite job finished: {}", &err);
309 Err(err)
310 },
311 embassy_futures::select::Either::Second(req_result) => {
312 if let Err(err) = req_result {
313 error!("infinite request receive job finished: {}", &err);
314 Err(err)
315 } else {
316 info!("disconnect request received: stopping jobs");
317 Ok(())
318 }
319 },
320 }
321 }
322
323 async fn disconnect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
324
325 let mut send_buffer = self.send_buffer.borrow_mut();
326 let mut send_buffer_writer = send_buffer.create_writer();
327
328 send_buffer_writer.write_mqtt_packet_sync(&Packet::Disconnect).map_err(|err| {
329 match err {
330 MqttPacketError::NotEnaughBufferSpace => {
331 warn!("could not write Disconnect to send buffer: full");
332 MqttError::BufferFull
333 },
334 MqttPacketError::CodecError => MqttError::CodecError,
335 MqttPacketError::IoError(_) => MqttError::ConnectionFailed(NetworkError::ConnectionFailed),
336 MqttPacketError::NetworkError(network_error) => MqttError::ConnectionFailed(network_error),
337 }
338 })?;
339 drop(send_buffer_writer);
340
341 let mut send_buffer_reader = send_buffer.create_reader();
342 connection.send_all(&mut send_buffer_reader).await
343 .map_err(|e| MqttError::ConnectionFailed(e))?;
344
345 self.recv_buffer.borrow_mut().reset();
347
348 Ok(())
349
350 }
351
352 pub async fn run<N: NetworkConnection>(&self, connection: Pin<&mut N>) -> Result<(), MqttError> {
353
354 let connection = unsafe {
355 connection.get_unchecked_mut()
356 };
357
358 self.connect(connection).await?;
359
360 loop {
361 let result = self.work(connection).await;
362 match result {
363 Ok(()) => {
364 break;
365 }
366 Err(MqttError::ConnectionFailed(e)) => {
367 warn!("reconnecting, conection faild: {}", e);
368 self.connect(connection).await?;
369 }
370 Err(err) => {
371 return Err(err);
372 }
373 }
374 }
375
376 self.disconnect(connection).await?;
377
378 Ok(())
379
380 }
381}
382
383#[cfg(all(test, feature = "std"))]
384mod test {
385 use core::pin::Pin;
386
387 use heapless::Vec;
388
389 use crate::state::KEEP_ALIVE;
390 use crate::time::Duration;
391
392 use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
393 use heapless::String;
394 use mqttrs2::{Connack, ConnectReturnCode, Packet, PacketType, QoS, Suback, SubscribeReturnCodes};
395 use crate::time;
396
397 use crate::network::{fake::{self, ConnectionRessources, ReadAtomic}, mqtt::{ReadMqttPacket, WriteMqttPacket}};
398 use crate::ClientConfig;
399
400 use super::MqttEventLoop;
401
402 fn print_packet(p: &Packet<'_>) -> std::string::String {
403 match p {
404 Packet::Connect(_) => format!("Connect"),
405 Packet::Connack(_) => format!("Connack"),
406 Packet::Publish(_) => format!("Publish"),
407 Packet::Puback(_) => format!("Puback"),
408 Packet::Pubrec(_) => format!("Pubrec"),
409 Packet::Pubrel(_) => format!("Pubrel"),
410 Packet::Pubcomp(_) => format!("Pubcomp"),
411 Packet::Subscribe(_) => format!("Subscribe"),
412 Packet::Suback(_) => format!("Suback"),
413 Packet::Unsubscribe(_) => format!("Unsubscribe"),
414 Packet::Unsuback(_) => format!("Unsuback"),
415 Packet::Pingreq => format!("Pingreq"),
416 Packet::Pingresp => format!("Pingresp"),
417 Packet::Disconnect => format!("Disconnect"),
418 }
419 }
420
421 #[tokio::test]
422 async fn test_run() {
423 time::test_time::set_default();
424
425 let mut config = ClientConfig{
426 client_id: String::new(),
427 credentials: None,
428 auto_subscribes: Vec::new()
429 };
430
431 config.client_id.push_str("asjdkaljs").unwrap();
432
433 let connection_resources = ConnectionRessources::<1024>::new();
434
435 let (mut client, server) = fake::new_connection(&connection_resources);
436
437 let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
438 let mqtt_client = event_loop.client();
439
440 let runner_future = async {
441 let client = Pin::new(&mut client);
442 event_loop.run(client).await.unwrap();
443 };
444
445 let test_future = async move {
446 let client_future = async {
447 mqtt_client.subscribe("test").await.unwrap();
448 };
449
450 let server_future = async {
451
452 let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
453 assert_eq!(connect, PacketType::Connect);
454
455 server.write_mqtt_packet(&Packet::Connack(Connack{
456 session_present: false,
457 code: ConnectReturnCode::Accepted
458 })).await.unwrap();
459
460 let subscribe = server.read_mqtt_packet(|s| {
461 match s {
462 Packet::Subscribe(sub) => sub.clone(),
463 other => panic!("expected subscribe, got {}", print_packet(other))
464 }
465 }).await.unwrap();
466
467 let mut return_codes = heapless::Vec::new();
468 return_codes.push(SubscribeReturnCodes::Success(QoS::AtLeastOnce)).unwrap();
469
470 server.write_mqtt_packet(&Packet::Suback(Suback{
471 pid: subscribe.pid,
472 return_codes
473 })).await.unwrap();
474 };
475
476 tokio::join!(client_future, server_future);
477 };
478
479 tokio::select! {
480 _ = runner_future => {},
481 _ = test_future => {}
482 }
483 }
484
485 #[tokio::test]
486 async fn test_idle_connection() {
487 let config = ClientConfig{
488 client_id: String::new(),
489 credentials: None,
490 auto_subscribes: Vec::new()
491 };
492
493 time::test_time::set_static_now();
494
495 let connection_resources = ConnectionRessources::<1024>::new();
496 let (mut client, server) = fake::new_connection(&connection_resources);
497
498 let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
499
500 let runner_future = async {
501 let client = Pin::new(&mut client);
502 event_loop.run(client).await.unwrap();
503 };
504
505 let server_future = async {
506
507 let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
508 assert_eq!(connect, PacketType::Connect);
509
510 server.write_mqtt_packet(&Packet::Connack(Connack{
511 session_present: false,
512 code: ConnectReturnCode::Accepted
513 })).await.unwrap();
514
515 time::test_time::advance_time(Duration::from_secs(2));
516 tokio::time::sleep(core::time::Duration::from_millis(100)).await;
517
518 let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
519 assert_eq!(pingreq, None);
520
521 time::test_time::advance_time(Duration::from_secs(KEEP_ALIVE as u64) / 2);
522 tokio::time::sleep(core::time::Duration::from_millis(100)).await;
523
524 let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
525 assert_eq!(pingreq, Some(PacketType::Pingreq));
526
527 server.write_mqtt_packet(&Packet::Pingresp).await.unwrap();
528
529 time::test_time::advance_time(Duration::from_secs(2));
530 tokio::time::sleep(core::time::Duration::from_millis(100)).await;
531
532 let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
533 assert_eq!(pingreq, None);
534 };
535
536 tokio::select! {
537 _ = runner_future => {},
538 _ = server_future => {}
539 }
540 }
541
542 #[tokio::test]
543 #[ntest::timeout(1000)]
544 async fn test_disconnect() {
545 let config = ClientConfig{
546 client_id: String::new(),
547 credentials: None,
548 auto_subscribes: Vec::new()
549 };
550
551 let connection_resources = ConnectionRessources::<1024>::new();
552 let (mut client, server) = fake::new_connection(&connection_resources);
553
554 let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
555 let mqtt_client = event_loop.client();
556
557 let runner_future = async {
558 let client = Pin::new(&mut client);
559 event_loop.run(client).await.unwrap();
560 };
561
562 let client_future = async {
563 mqtt_client.disconnect().await;
564 };
565
566 let server_future = async {
567
568 let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
569 assert_eq!(connect, PacketType::Connect);
570
571 server.write_mqtt_packet(&Packet::Connack(Connack{
572 session_present: false,
573 code: ConnectReturnCode::Accepted
574 })).await.unwrap();
575
576 let disconnect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
577 assert_eq!(disconnect, PacketType::Disconnect);
578 };
579
580 tokio::join! {
581 runner_future,
582 server_future,
583 client_future
584 };
585 }
586}
587