1use core::convert::Infallible;
2use core::{cell::RefCell, future::Future, pin::Pin};
3
4use embytes_buffer::{new_stack_buffer, Buffer, BufferReader, BufferWriter, ReadWrite};
5use embassy_futures::select::{select, select3, Either3};
6use embassy_sync::{blocking_mutex::raw::RawMutex, channel::Channel, pubsub::PubSubChannel};
7use mqttrs2::{decode_slice_with_len, Packet, QoS};
8use crate::network::mqtt::MqttPacketError;
9use crate::network::NetworkError;
10use crate::network::{ mqtt::WriteMqttPacketMut, NetwordSendReceive, NetworkConnection };
11use crate::{client::MqttClient, state::State, time, ClientConfig, MqttError, MqttEvent, MqttPublish, MqttRequest};
12
13use crate::time::Duration;
14
15pub trait AsyncSender<T> {
16 fn send(&self, item: T) -> impl Future<Output = ()>;
17 fn try_send(&self, item: T) -> Result<(), T>;
18}
19
20impl <M: RawMutex, T, const N: usize> AsyncSender<T> for Channel<M, T, N> {
21 fn send(&self, item: T) -> impl Future<Output = ()> {
22 self.send(item)
23 }
24
25 fn try_send(&self, item: T) -> Result<(), T> {
26 self.try_send(item)
27 .map_err(|err|{
28 match err {
29 embassy_sync::channel::TrySendError::Full(item) => item,
30 }
31 })
32 }
33}
34
35impl <M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> AsyncSender<T> for PubSubChannel<M, T, CAP, SUBS, PUBS> {
36 async fn send(&self, item: T) {
37 self.publisher().unwrap().publish(item).await;
38 }
39
40 fn try_send(&self, item: T) -> Result<(), T> {
41 self.publisher().unwrap().try_publish(item)
42 }
43}
44
45pub trait AsyncReceiver<T> {
46 fn receive(&self) -> impl Future<Output = T>;
47}
48
49impl <M: RawMutex, T, const N: usize> AsyncReceiver<T> for Channel<M, T, N> {
50 fn receive(&self) -> impl Future<Output = T> {
51 self.receive()
52 }
53}
54
55pub struct MqttEventLoop<M: RawMutex, const B: usize> {
56 recv_buffer: RefCell<Buffer<[u8; B]>>,
57 send_buffer: RefCell<Buffer<[u8; B]>>,
58
59 state: State<M>,
60
61 control_sender: PubSubChannel<M, MqttEvent, 4, 16, 8>,
62 request_receiver: Channel<M, MqttRequest, 4>,
63 received_publishes: Channel<M, MqttPublish, 4>
64}
65
66impl <M: RawMutex, const B: usize> MqttEventLoop<M, B> {
67
68 pub fn new(config: ClientConfig) -> Self {
69
70 Self {
71 recv_buffer: RefCell::new(new_stack_buffer::<B>()),
72 send_buffer: RefCell::new(new_stack_buffer::<B>()),
73
74 state: State::new(config),
75
76 control_sender: PubSubChannel::new(),
77 request_receiver: Channel::new(),
78 received_publishes: Channel::new()
79 }
80 }
81
82 pub fn client<'a>(&'a self) -> MqttClient<'a, M> {
83 MqttClient{
84 control_reveiver: &self.control_sender,
85 request_sender: self.request_receiver.sender(),
86 received_publishes: self.received_publishes.receiver()
87 }
88 }
89
90 async fn network_send_receive<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
94 let mut send_buffer = self.send_buffer.borrow_mut();
95 if send_buffer.has_remaining_len() {
98 let n = connection.send(&mut send_buffer).await
99 .map_err(|e| MqttError::ConnectionFailed(e))?;
100
101 trace!("sent {} bytes to network; send_buffer remaining: {}", n, send_buffer.remaining_len());
102 } else {
103 trace!("send buffer is empty, skipping send");
104 }
105
106 let mut recv_buffer = self.recv_buffer.borrow_mut();
107 if send_buffer.has_remaining_len() {
109 let n = connection.try_receive(&mut recv_buffer).await
110 .map_err(|e| MqttError::ConnectionFailed(e))?;
111 trace!("try_receive() {} bytes from network", n);
112 } else {
113 let n = connection.receive(&mut recv_buffer).await
114 .map_err(|e| MqttError::ConnectionFailed(e))?;
115 trace!("receive() {} bytes from network", n);
116 }
117
118 Ok(())
119 }
120
121 async fn try_package_receive(&self, send_buffer: &mut impl BufferWriter, recv_buffer: impl BufferReader) -> Result<(), MqttError> {
123 if recv_buffer.is_empty() {
124 trace!("try_package_receive(): recv_buffer is empty, cannot read packet");
125 return Ok(())
126 }
127
128 let packet_op = decode_slice_with_len(&recv_buffer[..])
129 .map_err(|e| {
130 error!("try_package_receive(): error decoding package: {}", e);
131 MqttError::CodecError
132 })?;
133
134 if let Some((len, packet)) = packet_op {
135 debug!("try_package_receive(): decoded packet from recv_buffer: len = {}, kind = {}", len, packet.get_type());
136 recv_buffer.add_bytes_read(len);
137 let events =
138 self.state.process_packet(&packet, send_buffer, &self.received_publishes).await?;
139
140 if ! events.is_empty() {
141 for event in events {
142 debug!("try_package_receive(): processing packet -> MqttEvent: {}", &event);
143 self.control_sender.publisher().unwrap().publish(event).await;
144 }
145 } else {
146 trace!("try_package_receive(): packet processed, no MqttEvent");
147 }
148
149 } else {
150 trace!("try_package_receive(): no complete packet in recv_buffer");
151 }
152
153 Ok(())
154 }
155
156 async fn work_network<N: NetworkConnection>(&self, connection: &mut N) -> Result<Infallible, MqttError> {
161 loop {
162 {
165 let mut send_buffer = self.send_buffer.borrow_mut();
166 let mut send_buffer_writer = send_buffer.create_writer();
167 self.state.send_packets(&mut send_buffer_writer, &self.control_sender)?;
168 drop(send_buffer_writer);
169 trace!("after network send: send_buffer {} / {}", send_buffer.remaining_len(), send_buffer.remaining_capacity());
170
171 }
172
173 let network_future = self.network_send_receive(connection);
178 let on_request_signal_future = self.state.on_requst_added.wait();
179 let next_ping_future = self.state.on_ping_required();
180 match select3(network_future, on_request_signal_future, next_ping_future).await {
181 Either3::First(res) => {
182 trace!("stopping pause: got something from network");
183 res
184 },
185 Either3::Second(_) => {
186 trace!("stopping pause: new request");
187 Ok(())
188 },
189 Either3::Third(()) => {
190 trace!("stopping pause: ping required");
191 let mut send_buffer = self.send_buffer.borrow_mut();
192 let mut send_buffer_writer = send_buffer.create_writer();
193 self.state.send_ping(&mut send_buffer_writer)
194 },
195 }?;
196
197 let mut send_buffer = self.send_buffer.borrow_mut();
198 let mut recv_buffer = self.recv_buffer.borrow_mut();
199 let recv_reader = recv_buffer.create_reader();
200 let mut send_buffer_writer = send_buffer.create_writer();
201
202 self.try_package_receive(&mut send_buffer_writer, recv_reader).await?;
205
206 trace!("after try packege_receive: recv_buffer: {} / {}", recv_buffer.remaining_len(), recv_buffer.capacity());
207 }
208 }
209
210
211 async fn work_request_receive(&self) -> Result<(), MqttError> {
214 loop {
215 let req = self.request_receiver.receive().await;
216 let pid = self.state.pid_source.next_pid();
217
218 match req {
219 MqttRequest::Publish(mqtt_publish, id) => {
220 self.state.publishes.push_publish(mqtt_publish, id, pid).await;
221 debug!("new publish request added to queue");
222 },
223 MqttRequest::Subscribe(topic, unique_id) => {
224 const SUBSCRIBE_QOS: QoS = QoS::AtMostOnce;
225 self.state.subscribes.push_subscribe(topic, pid, unique_id, SUBSCRIBE_QOS).await;
226 debug!("new subscribe request added to queue");
227 },
228 MqttRequest::Unsubscribe(topic, unique_id) => {
229 self.state.subscribes.push_unsubscribe(topic, pid, unique_id).await;
230 debug!("new unsubscribe request added to queue");
231 },
232 MqttRequest::Disconnect => {
233 self.state.on_requst_added.signal(0);
234 return Ok(());
235 },
236 }
237
238 self.state.on_requst_added.signal(0);
240 }
241 }
242
243 async fn connect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
244 self.send_buffer.borrow_mut().reset();
245 self.recv_buffer.borrow_mut().reset();
246
247 let mut tries = 0;
248 loop {
249
250 let result = connection.connect().await;
251
252 match result {
253 Ok(()) => {
254 info!("connect to broker success");
255 return Ok(())
256 },
257 Err(e) => {
258 tries += 1;
259 if tries < 5 {
260 warn!("{}. try to connecto to host failed", tries);
261 time::sleep(Duration::from_secs(3)).await;
262 } else {
263 error!("{} tries to connect failed", tries);
264 return Err(MqttError::ConnectionFailed(e));
265 }
266 },
267 }
268 }
269 }
270
271 async fn work<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
272 self.state.reset();
274
275 let network_future = self.work_network(connection);
278 let request_future = self.work_request_receive();
279 match select(network_future, request_future).await {
280 embassy_futures::select::Either::First(net_result) => {
281 let err = net_result.unwrap_err();
282 error!("network infinite job finished: {}", &err);
283 Err(err)
284 },
285 embassy_futures::select::Either::Second(req_result) => {
286 if let Err(err) = req_result {
287 error!("infinite request receive job finished: {}", &err);
288 Err(err)
289 } else {
290 info!("disconnect request received: stopping jobs");
291 Ok(())
292 }
293 },
294 }
295 }
296
297 async fn disconnect<N: NetworkConnection>(&self, connection: &mut N) -> Result<(), MqttError> {
298
299 let mut send_buffer = self.send_buffer.borrow_mut();
300 let mut send_buffer_writer = send_buffer.create_writer();
301
302 send_buffer_writer.write_mqtt_packet_sync(&Packet::Disconnect).map_err(|err| {
303 match err {
304 MqttPacketError::NotEnaughBufferSpace => {
305 warn!("could not write Disconnect to send buffer: full");
306 MqttError::BufferFull
307 },
308 MqttPacketError::CodecError => MqttError::CodecError,
309 MqttPacketError::IoError(_) => MqttError::ConnectionFailed(NetworkError::ConnectionFailed),
310 MqttPacketError::NetworkError(network_error) => MqttError::ConnectionFailed(network_error),
311 }
312 })?;
313 drop(send_buffer_writer);
314
315 let mut send_buffer_reader = send_buffer.create_reader();
316 connection.send_all(&mut send_buffer_reader).await
317 .map_err(|e| MqttError::ConnectionFailed(e))?;
318
319 self.recv_buffer.borrow_mut().reset();
321
322 Ok(())
323
324 }
325
326 pub async fn run<N: NetworkConnection>(&self, connection: Pin<&mut N>) -> Result<(), MqttError> {
327
328 let connection = unsafe {
329 connection.get_unchecked_mut()
330 };
331
332 self.connect(connection).await?;
333
334 loop {
335 let result = self.work(connection).await;
336 match result {
337 Ok(()) => {
338 break;
339 }
340 Err(MqttError::ConnectionFailed(e)) => {
341 warn!("reconnecting, conection faild: {}", e);
342 self.connect(connection).await?;
343 }
344 Err(err) => {
345 return Err(err);
346 }
347 }
348 }
349
350 self.disconnect(connection).await?;
351
352 Ok(())
353
354 }
355}
356
357#[cfg(test)]
358mod test {
359 use core::pin::Pin;
360
361 use heapless::Vec;
362
363 use crate::state::KEEP_ALIVE;
365 use crate::time::Duration;
366
367 use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
368 use heapless::String;
369 use mqttrs2::{Connack, ConnectReturnCode, Packet, PacketType, QoS, Suback, SubscribeReturnCodes};
370 use crate::time;
371
372 use crate::network::{fake::{self, ConnectionRessources, ReadAtomic}, mqtt::{ReadMqttPacket, WriteMqttPacket}};
373 use crate::ClientConfig;
374
375 use super::MqttEventLoop;
376
377 fn print_packet(p: &Packet<'_>) -> std::string::String {
378 match p {
379 Packet::Connect(_) => format!("Connect"),
380 Packet::Connack(_) => format!("Connack"),
381 Packet::Publish(_) => format!("Publish"),
382 Packet::Puback(_) => format!("Puback"),
383 Packet::Pubrec(_) => format!("Pubrec"),
384 Packet::Pubrel(_) => format!("Pubrel"),
385 Packet::Pubcomp(_) => format!("Pubcomp"),
386 Packet::Subscribe(_) => format!("Subscribe"),
387 Packet::Suback(_) => format!("Suback"),
388 Packet::Unsubscribe(_) => format!("Unsubscribe"),
389 Packet::Unsuback(_) => format!("Unsuback"),
390 Packet::Pingreq => format!("Pingreq"),
391 Packet::Pingresp => format!("Pingresp"),
392 Packet::Disconnect => format!("Disconnect"),
393 }
394 }
395
396 #[tokio::test]
397 async fn test_run() {
398 time::test_time::set_default();
399
400 let mut config = ClientConfig{
401 client_id: String::new(),
402 credentials: None,
403 auto_subscribes: Vec::new()
404 };
405
406 config.client_id.push_str("asjdkaljs").unwrap();
407
408 let connection_resources = ConnectionRessources::<1024>::new();
409
410 let (mut client, server) = fake::new_connection(&connection_resources);
411
412 let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
413 let mqtt_client = event_loop.client();
414
415 let runner_future = async {
416 let client = Pin::new(&mut client);
417 event_loop.run(client).await.unwrap();
418 };
419
420 let test_future = async move {
421 let client_future = async {
422 mqtt_client.subscribe("test").await.unwrap();
423 };
424
425 let server_future = async {
426
427 let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
428 assert_eq!(connect, PacketType::Connect);
429
430 server.write_mqtt_packet(&Packet::Connack(Connack{
431 session_present: false,
432 code: ConnectReturnCode::Accepted
433 })).await.unwrap();
434
435 let subscribe = server.read_mqtt_packet(|s| {
436 match s {
437 Packet::Subscribe(sub) => sub.clone(),
438 other => panic!("expected subscribe, got {}", print_packet(other))
439 }
440 }).await.unwrap();
441
442 let mut return_codes = heapless::Vec::new();
443 return_codes.push(SubscribeReturnCodes::Success(QoS::AtLeastOnce)).unwrap();
444
445 server.write_mqtt_packet(&Packet::Suback(Suback{
446 pid: subscribe.pid,
447 return_codes
448 })).await.unwrap();
449 };
450
451 tokio::join!(client_future, server_future);
452 };
453
454 tokio::select! {
455 _ = runner_future => {},
456 _ = test_future => {}
457 }
458 }
459
460 #[tokio::test]
461 async fn test_idle_connection() {
462 let config = ClientConfig{
463 client_id: String::new(),
464 credentials: None,
465 auto_subscribes: Vec::new()
466 };
467
468 time::test_time::set_static_now();
469
470 let connection_resources = ConnectionRessources::<1024>::new();
471 let (mut client, server) = fake::new_connection(&connection_resources);
472
473 let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
474
475 let runner_future = async {
476 let client = Pin::new(&mut client);
477 event_loop.run(client).await.unwrap();
478 };
479
480 let server_future = async {
481
482 let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
483 assert_eq!(connect, PacketType::Connect);
484
485 server.write_mqtt_packet(&Packet::Connack(Connack{
486 session_present: false,
487 code: ConnectReturnCode::Accepted
488 })).await.unwrap();
489
490 time::test_time::advance_time(Duration::from_secs(2));
491 tokio::time::sleep(core::time::Duration::from_millis(100)).await;
492
493 let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
494 assert_eq!(pingreq, None);
495
496 time::test_time::advance_time(Duration::from_secs(KEEP_ALIVE as u64) / 2);
497 tokio::time::sleep(core::time::Duration::from_millis(100)).await;
498
499 let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
500 assert_eq!(pingreq, Some(PacketType::Pingreq));
501
502 server.write_mqtt_packet(&Packet::Pingresp).await.unwrap();
503
504 time::test_time::advance_time(Duration::from_secs(2));
505 tokio::time::sleep(core::time::Duration::from_millis(100)).await;
506
507 let pingreq = server.with_reader(|reader| reader.read_packet().unwrap().map(|p| p.get_type()));
508 assert_eq!(pingreq, None);
509 };
510
511 tokio::select! {
512 _ = runner_future => {},
513 _ = server_future => {}
514 }
515 }
516
517 #[tokio::test]
518 #[ntest::timeout(1000)]
519 async fn test_disconnect() {
520 let config = ClientConfig{
521 client_id: String::new(),
522 credentials: None,
523 auto_subscribes: Vec::new()
524 };
525
526 let connection_resources = ConnectionRessources::<1024>::new();
527 let (mut client, server) = fake::new_connection(&connection_resources);
528
529 let event_loop = MqttEventLoop::<CriticalSectionRawMutex, 1024>::new(config);
530 let mqtt_client = event_loop.client();
531
532 let runner_future = async {
533 let client = Pin::new(&mut client);
534 event_loop.run(client).await.unwrap();
535 };
536
537 let client_future = async {
538 mqtt_client.disconnect().await;
539 };
540
541 let server_future = async {
542
543 let connect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
544 assert_eq!(connect, PacketType::Connect);
545
546 server.write_mqtt_packet(&Packet::Connack(Connack{
547 session_present: false,
548 code: ConnectReturnCode::Accepted
549 })).await.unwrap();
550
551 let disconnect = server.read_mqtt_packet(|p| p.get_type()).await.unwrap();
552 assert_eq!(disconnect, PacketType::Disconnect);
553 };
554
555 tokio::join! {
556 runner_future,
557 server_future,
558 client_future
559 };
560 }
561}
562