1use crate::state::{MqttState, StateError};
2use crate::MqttOptions;
3use crate::{network, Notification, Request};
4
5use async_stream::stream;
6use futures_util::sink::{Sink, SinkExt};
7use futures_util::stream::{Stream, StreamExt};
8use rumq_core::mqtt4::codec::MqttCodec;
9use rumq_core::mqtt4::{Connect, Packet, PacketIdentifier, Publish};
10use tokio::select;
11use tokio::stream::iter;
12use tokio::time::{self, Delay, Elapsed, Instant};
13use tokio_util::codec::Framed;
14
15use std::collections::VecDeque;
16use std::io;
17use std::mem;
18use std::time::Duration;
19
20pub struct MqttEventLoop {
22 pub options: MqttOptions,
26 pub state: MqttState,
28 pub requests: Box<dyn Requests>,
30 pending_pub: VecDeque<Publish>,
31 pending_rel: VecDeque<PacketIdentifier>,
32}
33
34#[derive(Debug, thiserror::Error)]
36pub enum EventLoopError {
37 #[error("Mqtt state")]
38 MqttState(#[from] StateError),
39 #[error("Timeout")]
40 Timeout(#[from] Elapsed),
41 #[error("Rumq")]
42 Rumq(#[from] rumq_core::Error),
43 #[error("Network")]
44 Network(#[from] network::Error),
45 #[error("I/O")]
46 Io(#[from] io::Error),
47 #[error("Stream done")]
48 StreamDone,
49 #[error("Requests done")]
50 RequestsDone,
51}
52
53pub fn eventloop(options: MqttOptions, requests: impl Requests + 'static) -> MqttEventLoop {
74 MqttEventLoop {
75 options: options,
76 state: MqttState::new(),
77 requests: Box::new(requests),
78 pending_pub: VecDeque::new(),
79 pending_rel: VecDeque::new(),
80 }
81}
82
83impl MqttEventLoop {
84 pub async fn connect<'eventloop>(&'eventloop mut self) -> Result<impl Stream<Item = Notification> + 'eventloop, EventLoopError> {
89 self.state.await_pingresp = false;
90
91 let mut network = self.network_connect().await?;
93 self.mqtt_connect(&mut network).await?;
94
95 self.populate_pending();
97
98 let stream = stream! {
100 let pending_rel = iter(self.pending_rel.drain(..)).map(Packet::Pubrec);
101 let mut pending = iter(self.pending_pub.drain(..)).map(Packet::Publish).chain(pending_rel);
102 let mut pending = time::throttle(self.options.throttle, pending);
103 let mut requests = time::throttle(self.options.throttle, &mut self.requests);
104
105 let mut timeout = time::delay_for(self.options.keep_alive);
106 let mut inout_marker = 0;
107 let mut pending_done = false;
108
109 loop {
110 let inflight_full = self.state.outgoing_pub.len() >= self.options.inflight;
111 let o = select(
112 &mut network,
113 &mut pending,
114 &mut requests,
115 &mut self.state,
116 self.options.keep_alive,
117 &mut inout_marker,
118 inflight_full,
119 &mut pending_done,
120 &mut timeout
121 ).await;
122
123 let (notification, outpacket) = match o {
124 Ok((n, p)) => (n, p),
125 Err(e) => {
126 yield Notification::Abort(e.into());
127 break
128 }
129 };
130
131 if let Some(p) = outpacket {
134 if let Err(e) = network.send(p).await {
135 yield Notification::Abort(e.into());
136 break
137 }
138 }
139
140 if let Some(n) = notification { yield n }
142 }
143 };
144
145 Ok(Box::pin(stream))
146 }
147
148 fn populate_pending(&mut self) {
149 let mut pending_pub = mem::replace(&mut self.state.outgoing_pub, VecDeque::new());
150 self.pending_pub.append(&mut pending_pub);
151
152 let mut pending_rel = mem::replace(&mut self.state.outgoing_rel, VecDeque::new());
153 self.pending_rel.append(&mut pending_rel);
154 }
155}
156
157async fn select<R: Requests, P: Packets>(
158 network: &mut Framed<Box<dyn N>, MqttCodec>,
159 mut pending: P,
160 mut requests: R,
161 state: &mut MqttState,
162 keepalive: Duration,
163 inout_marker: &mut u8,
164 inflight_full: bool,
165 pending_done: &mut bool,
166 mut timeout: &mut Delay,
167) -> Result<(Option<Notification>, Option<Packet>), EventLoopError> {
168 let ticker = &mut timeout;
170 let o = select! {
171 o = network.next() => match o {
172 Some(packet) => state.handle_incoming_packet(packet?)?,
173 None => return Err(EventLoopError::StreamDone)
174 },
175 o = requests.next(), if !inflight_full && *pending_done => match o {
176 Some(request) => state.handle_outgoing_packet(request.into())?,
177 None => return Err(EventLoopError::RequestsDone),
178 },
179 o = pending.next(), if !*pending_done => match o {
180 Some(packet) => state.handle_outgoing_packet(packet)?,
181 None => {
182 *pending_done = true;
183 (None, None)
184 }
185 },
186 _ = ticker => {
187 timeout.reset(Instant::now() + keepalive);
188 *inout_marker = 0;
189 let notification = None;
190 let packet = Packet::Pingreq;
191 state.handle_outgoing_packet(packet)?;
192 let packet = Some(Packet::Pingreq);
193 return Ok((notification, packet))
194 }
195 };
196
197 let (notification, packet) = (o.0.is_some(), o.1.is_some());
198 match (notification, packet) {
199 (true, true) => *inout_marker |= 3,
200 (true, false) => *inout_marker |= 1,
201 (false, true) => *inout_marker |= 2,
202 (false, false) => (),
203 }
204
205 if *inout_marker == 3 {
207 timeout.reset(Instant::now() + keepalive);
208 *inout_marker = 0;
209 }
210
211 Ok(o)
212}
213
214impl MqttEventLoop {
215 async fn network_connect(&self) -> Result<Framed<Box<dyn N>, MqttCodec>, EventLoopError> {
216 let network = time::timeout(Duration::from_secs(5), async {
217 let network = if self.options.ca.is_some() {
218 let o = network::tls_connect(&self.options).await?;
219 let o = Box::new(o) as Box<dyn N>;
220 Framed::new(o, MqttCodec::new(10 * 1024))
221 } else {
222 let o = network::tcp_connect(&self.options).await?;
223 let o = Box::new(o) as Box<dyn N>;
224 Framed::new(o, MqttCodec::new(10 * 1024))
225 };
226
227 Ok::<Framed<Box<dyn N>, MqttCodec>, EventLoopError>(network)
228 })
229 .await??;
230
231 Ok(network)
232 }
233
234 async fn mqtt_connect(&mut self, mut network: impl Network) -> Result<(), EventLoopError> {
235 let id = self.options.client_id();
236 let keep_alive = self.options.keep_alive().as_secs() as u16;
237 let clean_session = self.options.clean_session();
238 let last_will = self.options.last_will();
239
240 let mut connect = Connect::new(id);
241 connect.keep_alive = keep_alive;
242 connect.clean_session = clean_session;
243 connect.last_will = last_will;
244
245 if let Some((username, password)) = self.options.credentials() {
246 connect.set_username(username).set_password(password);
247 }
248
249 time::timeout(Duration::from_secs(5), async {
251 network.send(Packet::Connect(connect)).await?;
252 self.state.handle_outgoing_connect()?;
253 Ok::<_, EventLoopError>(())
254 })
255 .await??;
256
257 time::timeout(Duration::from_secs(5), async {
259 let packet = match network.next().await {
260 Some(o) => o?,
261 None => return Err(EventLoopError::StreamDone),
262 };
263 self.state.handle_incoming_connack(packet)?;
264 Ok::<_, EventLoopError>(())
265 })
266 .await??;
267
268 Ok(())
269 }
270}
271
272impl From<Request> for Packet {
273 fn from(item: Request) -> Self {
274 match item {
275 Request::Publish(publish) => Packet::Publish(publish),
276 Request::Disconnect => Packet::Disconnect,
277 Request::Subscribe(subscribe) => Packet::Subscribe(subscribe),
278 Request::Unsubscribe(unsubscribe) => Packet::Unsubscribe(unsubscribe),
279 _ => unimplemented!(),
280 }
281 }
282}
283
284use tokio::io::{AsyncRead, AsyncWrite};
285
286pub trait N: AsyncRead + AsyncWrite + Send + Unpin {}
287impl<T> N for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
288
289pub trait Network: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
290impl<T> Network for T where T: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
291
292pub trait Requests: Stream<Item = Request> + Unpin + Send + Sync {}
293impl<T> Requests for T where T: Stream<Item = Request> + Unpin + Send + Sync {}
294
295pub trait Packets: Stream<Item = Packet> + Unpin + Send + Sync {}
296impl<T> Packets for T where T: Stream<Item = Packet> + Unpin + Send + Sync {}
297
298#[cfg(test)]
299mod test {
300 use super::broker::*;
301 use crate::state::StateError;
302 use crate::{EventLoopError, MqttOptions, Notification, Request};
303 use futures_util::stream::StreamExt;
304 use rumq_core::mqtt4::*;
305 use std::time::{Duration, Instant};
306 use tokio::sync::mpsc::{channel, Sender};
307 use tokio::{task, time};
308
309 async fn start_requests(count: u8, qos: QoS, delay: u64, mut requests_tx: Sender<Request>) {
310 for i in 0..count {
311 let topic = "hello/world".to_owned();
312 let payload = vec![i, 1, 2, 3];
313
314 let publish = Publish::new(topic, qos, payload);
315 let request = Request::Publish(publish);
316 let _ = requests_tx.send(request).await;
317 time::delay_for(Duration::from_secs(delay)).await;
318 }
319 }
320
321 #[tokio::test]
322 async fn connection_should_timeout_on_time() {
323 let (_requests_tx, requests_rx) = channel(5);
324
325 task::spawn(async move {
326 let _broker = Broker::new(1880, false).await;
327 time::delay_for(Duration::from_secs(10)).await;
328 });
329
330 time::delay_for(Duration::from_secs(1)).await;
331 let options = MqttOptions::new("dummy", "127.0.0.1", 1880);
332 let mut eventloop = super::eventloop(options, requests_rx);
333
334 let start = Instant::now();
335 let o = eventloop.connect().await;
336 let elapsed = start.elapsed();
337
338 match o {
339 Ok(_) => assert!(false),
340 Err(super::EventLoopError::Timeout(_)) => assert!(true),
341 Err(_) => assert!(false),
342 }
343
344 assert_eq!(elapsed.as_secs(), 5);
345 }
346
347 #[tokio::test]
350 async fn throttled_requests_works_with_correct_delays_between_requests() {
351 let mut options = MqttOptions::new("dummy", "127.0.0.1", 1881);
352 options.set_throttle(Duration::from_secs(1));
353 let options2 = options.clone();
354
355 let (requests_tx, requests_rx) = channel(5);
357 task::spawn(async move {
358 start_requests(10, QoS::AtLeastOnce, 0, requests_tx).await;
359 });
360
361 task::spawn(async move {
363 time::delay_for(Duration::from_secs(1)).await;
364 let mut eventloop = super::eventloop(options, requests_rx);
365 let mut stream = eventloop.connect().await.unwrap();
366
367 while let Some(_) = stream.next().await {}
368 });
369
370 let mut broker = Broker::new(1881, true).await;
371
372 for i in 0..10 {
374 let start = Instant::now();
375 let _ = broker.read_packet().await;
376 let elapsed = start.elapsed();
377
378 if i > 0 {
379 dbg!(elapsed.as_millis());
380 assert_eq!(elapsed.as_secs(), options2.throttle.as_secs())
381 }
382 }
383 }
384
385 #[tokio::test]
386 async fn idle_connection_triggers_pings_on_time() {
387 let mut options = MqttOptions::new("dummy", "127.0.0.1", 1885);
388 options.set_keep_alive(5);
389 let keep_alive = options.keep_alive();
390
391 let (_requests_tx, requests_rx) = channel(5);
393 task::spawn(async move {
395 time::delay_for(Duration::from_secs(1)).await;
396 let mut eventloop = super::eventloop(options, requests_rx);
397 let mut stream = eventloop.connect().await.unwrap();
398
399 while let Some(_) = stream.next().await {}
400 });
401
402 let mut broker = Broker::new(1885, true).await;
403
404 let start = Instant::now();
406 let mut ping_received = false;
407
408 for _ in 0..10 {
409 let packet = broker.read_packet().await;
410 let elapsed = start.elapsed();
411 if packet == Packet::Pingreq {
412 ping_received = true;
413 assert_eq!(elapsed.as_secs(), keep_alive.as_secs());
414 break;
415 }
416 }
417
418 assert!(ping_received);
419 }
420
421 #[tokio::test]
422 async fn some_outgoing_and_no_incoming_packets_should_trigger_pings_on_time() {
423 let mut options = MqttOptions::new("dummy", "127.0.0.1", 1886);
424 options.set_keep_alive(5);
425 let keep_alive = options.keep_alive();
426
427 let (requests_tx, requests_rx) = channel(5);
430 task::spawn(async move {
431 start_requests(10, QoS::AtMostOnce,1, requests_tx).await;
432 });
433
434 task::spawn(async move {
436 time::delay_for(Duration::from_secs(1)).await;
437 let mut eventloop = super::eventloop(options, requests_rx);
438 let mut stream = eventloop.connect().await.unwrap();
439
440 while let Some(_) = stream.next().await {}
441 });
442
443 let mut broker = Broker::new(1886, true).await;
444
445 let start = Instant::now();
446 let mut ping_received = false;
447
448 for _ in 0..10 {
449 let packet = broker.read_packet_and_respond().await;
450 let elapsed = start.elapsed();
451 if packet == Packet::Pingreq {
452 ping_received = true;
453 assert_eq!(elapsed.as_secs(), keep_alive.as_secs());
454 break;
455 }
456 }
457
458 assert!(ping_received);
459 }
460
461 #[tokio::test]
462 async fn some_incoming_and_no_outgoing_packets_should_trigger_pings_on_time() {
463 let mut options = MqttOptions::new("dummy", "127.0.0.1", 2000);
464 options.set_keep_alive(5);
465
466 task::spawn(async move {
467 time::delay_for(Duration::from_secs(1)).await;
468 let (_requests_tx, requests_rx) = channel(5);
469 let mut eventloop = super::eventloop(options, requests_rx);
470 let mut stream = eventloop.connect().await.unwrap();
471 while let Some(_) = stream.next().await {}
472 });
473
474 let mut broker = Broker::new(2000, true).await;
475 broker.start_publishes(5, QoS::AtMostOnce, Duration::from_secs(1)).await;
476 let packet = broker.read_packet().await;
477 assert_eq!(packet, Packet::Pingreq);
478 }
479
480 #[tokio::test]
481 async fn detects_halfopen_connections_in_the_second_ping_request() {
482 let mut options = MqttOptions::new("dummy", "127.0.0.1", 2001);
483 options.set_keep_alive(5);
484
485 task::spawn(async move {
487 let mut broker = Broker::new(2001, true).await;
488 broker.blackhole().await;
489 });
490
491 time::delay_for(Duration::from_secs(1)).await;
492 let (_requests_tx, requests_rx) = channel(5);
493 let mut eventloop = super::eventloop(options, requests_rx);
494 let mut stream = eventloop.connect().await.unwrap();
495
496 let start = Instant::now();
497 match stream.next().await.unwrap() {
498 Notification::Abort(EventLoopError::MqttState(StateError::AwaitPingResp)) => assert_eq!(start.elapsed().as_secs(), 10),
499 _ => panic!("Expecting await pingresp error"),
500 }
501 }
502
503 #[tokio::test]
504 async fn requests_are_blocked_after_max_inflight_queue_size() {
505 let mut options = MqttOptions::new("dummy", "127.0.0.1", 1887);
506 options.set_inflight(5);
507 let inflight = options.inflight();
508
509 let (requests_tx, requests_rx) = channel(5);
512 task::spawn(async move {
513 start_requests(10, QoS::AtLeastOnce, 1, requests_tx).await;
514 });
515
516 task::spawn(async move {
518 time::delay_for(Duration::from_secs(1)).await;
519 let mut eventloop = super::eventloop(options, requests_rx);
520 let mut stream = eventloop.connect().await.unwrap();
521
522 while let Some(_) = stream.next().await {}
523 });
524
525 let mut broker = Broker::new(1887, true).await;
526 for i in 1..=10 {
527 let packet = broker.read_publish().await;
528
529 if i > inflight {
530 assert!(packet.is_none());
531 }
532 }
533 }
534
535 #[tokio::test]
536 async fn requests_are_recovered_after_inflight_queue_size_falls_below_max() {
537 let mut options = MqttOptions::new("dummy", "127.0.0.1", 1888);
538 options.set_inflight(3);
539
540 let (requests_tx, requests_rx) = channel(5);
541 task::spawn(async move {
542 start_requests(5, QoS::AtLeastOnce, 1, requests_tx).await;
543 time::delay_for(Duration::from_secs(60)).await;
544 });
545
546 task::spawn(async move {
548 time::delay_for(Duration::from_secs(1)).await;
549 let mut eventloop = super::eventloop(options, requests_rx);
550 let mut stream = eventloop.connect().await.unwrap();
551 while let Some(_p) = stream.next().await {}
552 });
553
554 let mut broker = Broker::new(1888, true).await;
555
556 let packet = broker.read_publish().await;
558 assert!(packet.is_some());
559 let packet = broker.read_publish().await;
561 assert!(packet.is_some());
562 let packet = broker.read_publish().await;
564 assert!(packet.is_some());
565 let packet = broker.read_publish().await;
567 assert!(packet.is_none());
568 broker.ack(PacketIdentifier(1)).await;
570 let packet = broker.read_publish().await;
571 assert!(packet.is_some());
572 let packet = broker.read_publish().await;
574 assert!(packet.is_none());
575 broker.ack(PacketIdentifier(2)).await;
577 let packet = broker.read_publish().await;
578 assert!(packet.is_some());
579 }
580
581 #[tokio::test]
582 async fn reconnection_resumes_from_the_previous_state() {
583 let options = MqttOptions::new("dummy", "127.0.0.1", 1889);
584
585 let (requests_tx, requests_rx) = channel(5);
588 task::spawn(async move {
589 start_requests(10, QoS::AtLeastOnce, 1, requests_tx).await;
590 time::delay_for(Duration::from_secs(10)).await;
591 });
592
593 task::spawn(async move {
595 time::delay_for(Duration::from_secs(1)).await;
596 let mut eventloop = super::eventloop(options, requests_rx);
597
598 loop {
599 let mut stream = eventloop.connect().await.unwrap();
600 while let Some(_) = stream.next().await {}
601 }
602 });
603
604 {
606 let mut broker = Broker::new(1889, true).await;
607 for i in 1..=2 {
608 let packet = broker.read_publish().await;
609 assert_eq!(PacketIdentifier(i), packet.unwrap());
610 broker.ack(packet.unwrap()).await;
611 }
612 }
613
614 {
616 let mut broker = Broker::new(1889, true).await;
617 for i in 3..=4 {
618 let packet = broker.read_publish().await;
619 assert_eq!(PacketIdentifier(i), packet.unwrap());
620 broker.ack(packet.unwrap()).await;
621 }
622 }
623 }
624
625 #[tokio::test]
626 async fn reconnection_resends_unacked_packets_from_the_previous_connection_before_sending_current_connection_requests() {
627 let options = MqttOptions::new("dummy", "127.0.0.1", 1890);
628
629 let (requests_tx, requests_rx) = channel(5);
632 task::spawn(async move {
633 start_requests(10, QoS::AtLeastOnce, 1, requests_tx).await;
634 time::delay_for(Duration::from_secs(10)).await;
635 });
636
637 task::spawn(async move {
639 time::delay_for(Duration::from_secs(1)).await;
640 let mut eventloop = super::eventloop(options, requests_rx);
641
642 loop {
643 let mut stream = eventloop.connect().await.unwrap();
644 while let Some(_) = stream.next().await {}
645 }
646 });
647
648 {
650 let mut broker = Broker::new(1890, true).await;
651 for i in 1..=2 {
652 let packet = broker.read_publish().await;
653 assert_eq!(PacketIdentifier(i), packet.unwrap());
654 }
655 }
656
657 {
659 let mut broker = Broker::new(1890, true).await;
660 for i in 1..=6 {
661 let packet = broker.read_publish().await;
662 assert_eq!(PacketIdentifier(i), packet.unwrap());
663 }
664 }
665 }
666}
667
668#[cfg(test)]
669mod broker {
670 use futures_util::sink::SinkExt;
671 use rumq_core::mqtt4::*;
672 use std::time::Duration;
673 use tokio::net::{TcpListener, TcpStream};
674 use tokio::select;
675 use tokio::stream::StreamExt;
676 use tokio::time;
677 use tokio_util::codec::Framed;
678
679 pub struct Broker {
680 framed: Framed<TcpStream, codec::MqttCodec>,
681 }
682
683 impl Broker {
684 pub async fn new(port: u16, send_connack: bool) -> Broker {
686 let addr = format!("127.0.0.1:{}", port);
687 let mut listener = TcpListener::bind(&addr).await.unwrap();
688 let (stream, _) = listener.accept().await.unwrap();
689 let mut framed = Framed::new(stream, codec::MqttCodec::new(1024 * 1024));
690
691 let packet = framed.next().await.unwrap().unwrap();
692 if let Packet::Connect(_) = packet {
693 if send_connack {
694 let connack = Connack::new(ConnectReturnCode::Accepted, false);
695 let packet = Packet::Connack(connack);
696 framed.send(packet).await.unwrap();
697 }
698 } else {
699 panic!("Expecting connect packet");
700 }
701
702 Broker { framed }
703 }
704
705 pub async fn read_publish(&mut self) -> Option<PacketIdentifier> {
707 let packet = time::timeout(Duration::from_secs(2), async { self.framed.next().await.unwrap() });
708 match packet.await {
709 Ok(Ok(Packet::Publish(publish))) => publish.pkid,
710 Ok(Ok(packet)) => panic!("Expecting a publish. Received = {:?}", packet),
711 Ok(Err(e)) => panic!("Error = {:?}", e),
712 Err(_) => None,
714 }
715 }
716
717 pub async fn read_packet(&mut self) -> Packet {
719 let packet = time::timeout(Duration::from_secs(30), async { self.framed.next().await.unwrap() });
720 packet.await.unwrap().unwrap()
721 }
722
723
724 pub async fn read_packet_and_respond(&mut self) -> Packet {
725 let packet = time::timeout(Duration::from_secs(30), async { self.framed.next().await.unwrap() });
726 let packet = packet.await.unwrap().unwrap();
727
728 match packet.clone() {
729 Packet::Publish(publish) => if let Some(pkid) = publish.pkid {
730 self.framed.send(Packet::Puback(pkid)).await.unwrap();
731 }
732 _ => (),
733 }
734
735 packet
736 }
737
738 pub async fn blackhole(&mut self) -> Packet {
740 loop {
741 let _packet = self.framed.next().await.unwrap().unwrap();
742 }
743 }
744
745 pub async fn ack(&mut self, pkid: PacketIdentifier) {
747 let packet = Packet::Puback(pkid);
748 self.framed.send(packet).await.unwrap();
749 self.framed.flush().await.unwrap();
750 }
751
752 pub async fn start_publishes(&mut self, count: u8, qos: QoS, delay: Duration) {
754 let mut interval = time::interval(delay);
755 for i in 0..count {
756 select! {
757 _ = interval.next() => {
758 let topic = "hello/world".to_owned();
759 let payload = vec![1, 2, 3, i];
760 let publish = Publish::new(topic, qos, payload);
761 let packet = Packet::Publish(publish);
762 self.framed.send(packet).await.unwrap();
763 }
764 packet = self.framed.next() => match packet.unwrap().unwrap() {
765 Packet::Pingreq => self.framed.send(Packet::Pingresp).await.unwrap(),
766 _ => ()
767 }
768 }
769 }
770 }
771 }
772}