ntex_mqtt/v3/
dispatcher.rs

1use std::cell::{Cell, RefCell};
2use std::{marker::PhantomData, num::NonZeroU16, rc::Rc, task::Context};
3
4use ntex_io::DispatchItem;
5use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory, cfg::Cfg, cfg::SharedCfg};
6use ntex_util::services::buffer::{BufferService, BufferServiceError};
7use ntex_util::services::inflight::InFlightService;
8use ntex_util::{HashSet, future::join};
9
10use crate::error::{DecodeError, HandshakeError, MqttError, PayloadError, ProtocolError};
11use crate::payload::{Payload, PayloadStatus, PlSender};
12use crate::{MqttServiceConfig, types::QoS, types::packet_type};
13
14use super::codec::{Decoded, Encoded, Packet};
15use super::control::{Control, ControlAck, ControlAckKind, Subscribe, Unsubscribe};
16use super::{Session, publish::Publish, shared::Ack, shared::MqttShared};
17
18/// mqtt3 protocol dispatcher
19pub(super) fn factory<St, T, C, E>(
20    publish: T,
21    control: C,
22) -> impl ServiceFactory<
23    DispatchItem<Rc<MqttShared>>,
24    (SharedCfg, Session<St>),
25    Response = Option<Encoded>,
26    Error = MqttError<E>,
27    InitError = MqttError<E>,
28>
29where
30    St: 'static,
31    T: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
32    C: ServiceFactory<Control<E>, Session<St>, Response = ControlAck> + 'static,
33    E: From<C::Error> + From<C::InitError> + From<T::Error> + From<T::InitError> + 'static,
34{
35    let factories = Rc::new((publish, control));
36
37    ntex_service::fn_factory_with_config(
38        async move |(cfg, session): (SharedCfg, Session<St>)| {
39            // create services
40            let sink = session.sink().shared();
41            let fut = join(factories.0.create(session.clone()), factories.1.create(session));
42            let (publish, control) = fut.await;
43
44            let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
45            let control = control.map_err(|e| MqttError::Service(e.into()))?;
46
47            let control = BufferService::new(
48                16,
49                // limit number of in-flight messages
50                InFlightService::new(1, control),
51            )
52            .map_err(|err| match err {
53                BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
54                BufferServiceError::RequestCanceled => {
55                    MqttError::Handshake(HandshakeError::Disconnected(None))
56                }
57            });
58
59            let cfg: Cfg<MqttServiceConfig> = cfg.get();
60            Ok(Dispatcher::<_, _, E>::new(
61                sink,
62                publish,
63                control,
64                cfg.max_qos,
65                cfg.handle_qos_after_disconnect,
66            ))
67        },
68    )
69}
70
71impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
72    fn size(&self) -> u32 {
73        match self {
74            DispatchItem::Item(Decoded::Packet(_, size))
75            | DispatchItem::Item(Decoded::Publish(_, _, size)) => *size,
76            _ => 0,
77        }
78    }
79
80    fn is_publish(&self) -> bool {
81        matches!(self, DispatchItem::Item(Decoded::Publish(..)))
82    }
83
84    fn is_chunk(&self) -> bool {
85        matches!(self, DispatchItem::Item(Decoded::PayloadChunk(..)))
86    }
87}
88
89/// Mqtt protocol dispatcher
90pub(crate) struct Dispatcher<T, C: Service<Control<E>>, E> {
91    publish: T,
92    max_qos: QoS,
93    handle_qos_after_disconnect: Option<QoS>,
94    inner: Rc<Inner<C>>,
95    _t: PhantomData<(E,)>,
96}
97
98struct Inner<C> {
99    control: Pipeline<C>,
100    sink: Rc<MqttShared>,
101    payload: Cell<Option<PlSender>>,
102    inflight: RefCell<HashSet<NonZeroU16>>,
103}
104
105impl<T, C, E> Dispatcher<T, C, E>
106where
107    E: From<T::Error>,
108    T: Service<Publish, Response = ()>,
109    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
110{
111    pub(crate) fn new(
112        sink: Rc<MqttShared>,
113        publish: T,
114        control: C,
115        max_qos: QoS,
116        handle_qos_after_disconnect: Option<QoS>,
117    ) -> Self {
118        Self {
119            publish,
120            max_qos,
121            handle_qos_after_disconnect,
122            inner: Rc::new(Inner {
123                sink,
124                payload: Cell::new(None),
125                control: Pipeline::new(control),
126                inflight: RefCell::new(HashSet::default()),
127            }),
128            _t: PhantomData,
129        }
130    }
131
132    fn tag(&self) -> &'static str {
133        self.inner.sink.tag()
134    }
135}
136
137impl<C> Inner<C> {
138    fn drop_payload<PErr>(&self, err: &PErr)
139    where
140        PErr: Clone,
141        PayloadError: From<PErr>,
142    {
143        if let Some(pl) = self.payload.take() {
144            pl.set_error(err.clone().into());
145        }
146    }
147}
148
149impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
150where
151    E: From<T::Error> + 'static,
152    T: Service<Publish, Response = ()> + 'static,
153    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>> + 'static,
154{
155    type Response = Option<Encoded>;
156    type Error = MqttError<E>;
157
158    #[inline]
159    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
160        let (res1, res2) =
161            join(ctx.ready(&self.publish), ctx.ready(self.inner.control.get_ref())).await;
162        let result = if let Err(e) = res1 {
163            if res2.is_err() {
164                Err(MqttError::Service(e.into()))
165            } else {
166                match ctx
167                    .call_nowait(self.inner.control.get_ref(), Control::error(e.into()))
168                    .await
169                {
170                    Ok(_) => {
171                        self.inner.sink.close();
172                        Ok(())
173                    }
174                    Err(err) => Err(err),
175                }
176            }
177        } else {
178            res2
179        };
180
181        if result.is_ok() {
182            if let Some(pl) = self.inner.payload.take() {
183                if pl.ready().await == PayloadStatus::Ready {
184                    self.inner.payload.set(Some(pl));
185                } else {
186                    self.inner.sink.close();
187                }
188            }
189        }
190        result
191    }
192
193    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
194        if let Err(e) = self.publish.poll(cx) {
195            let inner = self.inner.clone();
196            ntex_rt::spawn(async move {
197                if inner.control.call_nowait(Control::error(e.into())).await.is_ok() {
198                    inner.sink.close();
199                }
200            });
201        }
202        self.inner.control.poll(cx)
203    }
204
205    async fn shutdown(&self) {
206        self.inner.drop_payload(&PayloadError::Disconnected);
207        self.inner.sink.close();
208        let _ = self.inner.control.call(Control::closed()).await;
209
210        self.publish.shutdown().await;
211        self.inner.control.shutdown().await;
212    }
213
214    async fn call(
215        &self,
216        req: DispatchItem<Rc<MqttShared>>,
217        ctx: ServiceCtx<'_, Self>,
218    ) -> Result<Self::Response, Self::Error> {
219        log::trace!("{}; Dispatch v3 packet: {:#?}", self.tag(), req);
220
221        match req {
222            DispatchItem::Item(Decoded::Publish(publish, payload, size)) => {
223                if publish.topic.contains(['#', '+']) {
224                    return control(
225                        Control::proto_error(
226                            ProtocolError::generic_violation(
227                                "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]"
228                            )
229                        ),
230                        &self.inner,
231                        ctx,
232                    ).await;
233                }
234
235                let inner = self.inner.as_ref();
236                let packet_id = publish.packet_id;
237
238                // check for duplicated packet id
239                if let Some(pid) = packet_id {
240                    if !inner.inflight.borrow_mut().insert(pid) {
241                        log::trace!(
242                            "{}: Duplicated packet id for publish packet: {:?}",
243                            self.tag(),
244                            pid
245                        );
246                        return control(
247                            Control::proto_error(
248                                ProtocolError::generic_violation("PUBLISH received with packet id that is already in use [MQTT-2.2.1-3]")
249                            ),
250                            &self.inner,
251                            ctx,
252                        ).await;
253                    }
254                }
255
256                // check max allowed qos
257                if publish.qos > self.max_qos {
258                    log::trace!(
259                        "{}: Max allowed QoS is violated, max {:?} provided {:?}",
260                        self.tag(),
261                        self.max_qos,
262                        publish.qos
263                    );
264                    return control(
265                        Control::proto_error(ProtocolError::generic_violation(
266                            match publish.qos {
267                                QoS::AtLeastOnce => "PUBLISH with QoS 1 is not supported",
268                                QoS::ExactlyOnce => "PUBLISH with QoS 2 is not supported",
269                                QoS::AtMostOnce => unreachable!(), // max_qos cannot be lower than QoS 0
270                            },
271                        )),
272                        &self.inner,
273                        ctx,
274                    )
275                    .await;
276                }
277
278                if inner.sink.is_closed()
279                    && !self
280                        .handle_qos_after_disconnect
281                        .map(|max_qos| publish.qos <= max_qos)
282                        .unwrap_or_default()
283                {
284                    return Ok(None);
285                }
286
287                let payload = if publish.payload_size == payload.len() as u32 {
288                    Payload::from_bytes(payload)
289                } else {
290                    let (pl, sender) = Payload::from_stream(payload);
291                    self.inner.payload.set(Some(sender));
292                    pl
293                };
294
295                publish_fn(
296                    &self.publish,
297                    Publish::new(publish, payload, size),
298                    packet_id,
299                    inner,
300                    ctx,
301                )
302                .await
303            }
304            DispatchItem::Item(Decoded::PayloadChunk(buf, eof)) => {
305                if let Some(pl) = self.inner.payload.take() {
306                    pl.feed_data(buf);
307                    if eof {
308                        pl.feed_eof();
309                    } else {
310                        self.inner.payload.set(Some(pl));
311                    }
312                    Ok(None)
313                } else {
314                    control(
315                        Control::proto_error(ProtocolError::Decode(
316                            DecodeError::UnexpectedPayload,
317                        )),
318                        &self.inner,
319                        ctx,
320                    )
321                    .await
322                }
323            }
324            DispatchItem::Item(Decoded::Packet(Packet::PublishAck { packet_id }, _)) => {
325                if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) {
326                    control(Control::proto_error(e), &self.inner, ctx).await
327                } else {
328                    Ok(None)
329                }
330            }
331            DispatchItem::Item(Decoded::Packet(Packet::PublishReceived { packet_id }, _)) => {
332                if let Err(e) = self.inner.sink.pkt_ack(Ack::Receive(packet_id)) {
333                    control(Control::proto_error(e), &self.inner, ctx).await
334                } else {
335                    Ok(None)
336                }
337            }
338            DispatchItem::Item(Decoded::Packet(Packet::PublishRelease { packet_id }, _)) => {
339                if self.inner.inflight.borrow().contains(&packet_id) {
340                    control(Control::pubrel(packet_id), &self.inner, ctx).await
341                } else {
342                    control(
343                        Control::proto_error(ProtocolError::unexpected_packet(
344                            packet_type::PUBREL,
345                            "Unknown packet-id in PublishRelease packet",
346                        )),
347                        &self.inner,
348                        ctx,
349                    )
350                    .await
351                }
352            }
353            DispatchItem::Item(Decoded::Packet(Packet::PublishComplete { packet_id }, _)) => {
354                if let Err(e) = self.inner.sink.pkt_ack(Ack::Complete(packet_id)) {
355                    control(Control::proto_error(e), &self.inner, ctx).await
356                } else {
357                    Ok(None)
358                }
359            }
360            DispatchItem::Item(Decoded::Packet(Packet::PingRequest, _)) => {
361                control(Control::ping(), &self.inner, ctx).await
362            }
363            DispatchItem::Item(Decoded::Packet(
364                Packet::Subscribe { packet_id, topic_filters },
365                size,
366            )) => {
367                if self.inner.sink.is_closed() {
368                    return Ok(None);
369                }
370
371                if topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
372                    return control(
373                        Control::proto_error(ProtocolError::generic_violation(
374                            "Topic filter is malformed [MQTT-4.7.1-*]",
375                        )),
376                        &self.inner,
377                        ctx,
378                    )
379                    .await;
380                }
381
382                if !self.inner.inflight.borrow_mut().insert(packet_id) {
383                    log::trace!(
384                        "{}: Duplicated packet id for subscribe packet: {:?}",
385                        self.tag(),
386                        packet_id
387                    );
388                    return control(
389                        Control::proto_error(ProtocolError::generic_violation(
390                            "SUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
391                        )),
392                        &self.inner,
393                        ctx,
394                    ).await;
395                }
396
397                control(
398                    Control::subscribe(Subscribe::new(packet_id, size, topic_filters)),
399                    &self.inner,
400                    ctx,
401                )
402                .await
403            }
404            DispatchItem::Item(Decoded::Packet(
405                Packet::Unsubscribe { packet_id, topic_filters },
406                size,
407            )) => {
408                if self.inner.sink.is_closed() {
409                    return Ok(None);
410                }
411
412                if topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
413                    return control(
414                        Control::proto_error(ProtocolError::generic_violation(
415                            "Topic filter is malformed [MQTT-4.7.1-*]",
416                        )),
417                        &self.inner,
418                        ctx,
419                    )
420                    .await;
421                }
422
423                if !self.inner.inflight.borrow_mut().insert(packet_id) {
424                    log::trace!(
425                        "{}: Duplicated packet id for unsubscribe packet: {:?}",
426                        self.tag(),
427                        packet_id
428                    );
429                    return control(
430                        Control::proto_error(ProtocolError::generic_violation(
431                            "UNSUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
432                        )),
433                        &self.inner,
434                        ctx,
435                    ).await;
436                }
437
438                control(
439                    Control::unsubscribe(Unsubscribe::new(packet_id, size, topic_filters)),
440                    &self.inner,
441                    ctx,
442                )
443                .await
444            }
445            DispatchItem::Item(Decoded::Packet(Packet::Disconnect, _)) => {
446                control(Control::remote_disconnect(), &self.inner, ctx).await
447            }
448            DispatchItem::Item(_) => Ok(None),
449            DispatchItem::EncoderError(err) => {
450                let err = ProtocolError::Encode(err);
451                self.inner.drop_payload(&err);
452                control(Control::proto_error(err), &self.inner, ctx).await
453            }
454            DispatchItem::KeepAliveTimeout => {
455                self.inner.drop_payload(&ProtocolError::KeepAliveTimeout);
456                control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, ctx)
457                    .await
458            }
459            DispatchItem::ReadTimeout => {
460                self.inner.drop_payload(&ProtocolError::ReadTimeout);
461                control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, ctx)
462                    .await
463            }
464            DispatchItem::DecoderError(err) => {
465                let err = ProtocolError::Decode(err);
466                self.inner.drop_payload(&err);
467                control(Control::proto_error(err), &self.inner, ctx).await
468            }
469            DispatchItem::Disconnect(err) => {
470                self.inner.drop_payload(&PayloadError::Disconnected);
471                control(Control::peer_gone(err), &self.inner, ctx).await
472            }
473            DispatchItem::WBackPressureEnabled => {
474                self.inner.sink.enable_wr_backpressure();
475                control(Control::wr_backpressure(true), &self.inner, ctx).await
476            }
477            DispatchItem::WBackPressureDisabled => {
478                self.inner.sink.disable_wr_backpressure();
479                control(Control::wr_backpressure(false), &self.inner, ctx).await
480            }
481        }
482    }
483}
484
485/// Publish service response future
486async fn publish_fn<'f, T, C, E>(
487    svc: &'f T,
488    pkt: Publish,
489    packet_id: Option<NonZeroU16>,
490    inner: &'f Inner<C>,
491    ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
492) -> Result<Option<Encoded>, MqttError<E>>
493where
494    E: From<T::Error>,
495    T: Service<Publish, Response = ()>,
496    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
497{
498    let qos2 = pkt.qos() == QoS::ExactlyOnce;
499    match ctx.call(svc, pkt).await {
500        Ok(_) => {
501            log::trace!(
502                "{}: Publish result for packet {:?} is ready",
503                inner.sink.tag(),
504                packet_id
505            );
506
507            if let Some(packet_id) = packet_id {
508                if qos2 {
509                    Ok(Some(Encoded::Packet(Packet::PublishReceived { packet_id })))
510                } else {
511                    inner.inflight.borrow_mut().remove(&packet_id);
512                    Ok(Some(Encoded::Packet(Packet::PublishAck { packet_id })))
513                }
514            } else {
515                Ok(None)
516            }
517        }
518        Err(e) => control(Control::error(e.into()), inner, ctx).await,
519    }
520}
521
522async fn control<'f, T, C, E>(
523    mut pkt: Control<E>,
524    inner: &'f Inner<C>,
525    ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
526) -> Result<Option<Encoded>, MqttError<E>>
527where
528    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
529{
530    let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_));
531
532    loop {
533        match ctx.call(inner.control.get_ref(), pkt).await {
534            Ok(item) => {
535                let packet = match item.result {
536                    ControlAckKind::Ping => Some(Encoded::Packet(Packet::PingResponse)),
537                    ControlAckKind::Subscribe(res) => {
538                        inner.inflight.borrow_mut().remove(&res.packet_id);
539                        Some(Encoded::Packet(Packet::SubscribeAck {
540                            status: res.codes,
541                            packet_id: res.packet_id,
542                        }))
543                    }
544                    ControlAckKind::Unsubscribe(res) => {
545                        inner.inflight.borrow_mut().remove(&res.packet_id);
546                        Some(Encoded::Packet(Packet::UnsubscribeAck {
547                            packet_id: res.packet_id,
548                        }))
549                    }
550                    ControlAckKind::Disconnect => {
551                        inner.drop_payload(&PayloadError::Service);
552                        inner.sink.close();
553                        None
554                    }
555                    ControlAckKind::Closed | ControlAckKind::Nothing => None,
556                    ControlAckKind::PublishRelease(packet_id) => {
557                        inner.inflight.borrow_mut().remove(&packet_id);
558                        Some(Encoded::Packet(Packet::PublishComplete { packet_id }))
559                    }
560                    ControlAckKind::PublishAck(_) => unreachable!(),
561                };
562                return Ok(packet);
563            }
564            Err(err) => {
565                inner.drop_payload(&PayloadError::Service);
566
567                // do not handle nested error
568                return if error {
569                    Err(err)
570                } else {
571                    // handle error from control service
572                    match err {
573                        MqttError::Service(err) => {
574                            error = true;
575                            pkt = Control::error(err);
576                            continue;
577                        }
578                        _ => Err(err),
579                    }
580                };
581            }
582        }
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use std::{future::Future, pin::Pin};
589
590    use ntex_bytes::{ByteString, Bytes};
591    use ntex_io::{Io, testing::IoTest};
592    use ntex_service::{cfg::SharedCfg, fn_service};
593    use ntex_util::future::{Ready, lazy};
594    use ntex_util::time::{Seconds, sleep};
595
596    use super::*;
597    use crate::v3::{MqttSink, codec};
598
599    #[ntex::test]
600    async fn test_dup_packet_id() {
601        let io = Io::new(IoTest::create().0, SharedCfg::new("DBG"));
602        let codec = codec::Codec::default();
603        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
604        let err = Rc::new(RefCell::new(false));
605        let err2 = err.clone();
606
607        let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
608            shared.clone(),
609            fn_service(|_| async {
610                sleep(Seconds(10)).await;
611                Ok(())
612            }),
613            fn_service(move |ctrl| {
614                if let Control::ProtocolError(_) = ctrl {
615                    *err2.borrow_mut() = true;
616                }
617                Ready::Ok(ControlAck { result: ControlAckKind::Nothing })
618            }),
619            QoS::AtLeastOnce,
620            None,
621        ));
622
623        let mut f: Pin<Box<dyn Future<Output = Result<_, _>>>> =
624            Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
625                codec::Publish {
626                    dup: false,
627                    retain: false,
628                    qos: QoS::AtLeastOnce,
629                    topic: ByteString::new(),
630                    packet_id: NonZeroU16::new(1),
631                    payload_size: 0,
632                },
633                Bytes::new(),
634                999,
635            ))));
636        let _ = lazy(|cx| Pin::new(&mut f).poll(cx)).await;
637
638        let f = Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
639            codec::Publish {
640                dup: false,
641                retain: false,
642                qos: QoS::AtLeastOnce,
643                topic: ByteString::new(),
644                packet_id: NonZeroU16::new(1),
645                payload_size: 0,
646            },
647            Bytes::new(),
648            999,
649        ))));
650        assert!(f.await.unwrap().is_none());
651        assert!(*err.borrow());
652    }
653
654    #[ntex::test]
655    async fn test_wr_backpressure() {
656        let io = Io::new(IoTest::create().0, SharedCfg::new("DBG"));
657        let codec = codec::Codec::default();
658        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
659
660        let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
661            shared.clone(),
662            fn_service(|_| Ready::Ok(())),
663            fn_service(|_| Ready::Ok(ControlAck { result: ControlAckKind::Nothing })),
664            QoS::AtLeastOnce,
665            None,
666        ));
667
668        let sink = MqttSink::new(shared.clone());
669        assert!(!sink.is_ready());
670        shared.set_cap(1);
671        assert!(sink.is_ready());
672        assert!(shared.wait_readiness().is_none());
673
674        disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
675        assert!(!sink.is_ready());
676        let rx = shared.wait_readiness();
677        let rx2 = shared.wait_readiness().unwrap();
678        assert!(rx.is_some());
679
680        let rx = rx.unwrap();
681        disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
682        assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
683        assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
684    }
685}