ntex_mqtt/v3/
dispatcher.rs

1use std::cell::{Cell, RefCell};
2use std::{marker::PhantomData, num::NonZeroU16, rc::Rc, task::Context};
3
4use ntex_io::DispatchItem;
5use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory, cfg::Cfg, cfg::SharedCfg};
6use ntex_util::services::buffer::{BufferService, BufferServiceError};
7use ntex_util::{HashSet, future::join, services::inflight::InFlightService};
8
9use crate::error::{DecodeError, HandshakeError, MqttError, PayloadError, ProtocolError};
10use crate::payload::{Payload, PayloadStatus, PlSender};
11use crate::{MqttServiceConfig, types::QoS, types::packet_type};
12
13use super::codec::{Decoded, Encoded, Packet};
14use super::control::{Control, ControlAck, ControlAckKind, Subscribe, Unsubscribe};
15use super::{Session, publish::Publish, shared::Ack, shared::MqttShared};
16
17/// mqtt3 protocol dispatcher
18pub(super) fn factory<St, T, C, E>(
19    publish: T,
20    control: C,
21) -> impl ServiceFactory<
22    DispatchItem<Rc<MqttShared>>,
23    (SharedCfg, Session<St>),
24    Response = Option<Encoded>,
25    Error = MqttError<E>,
26    InitError = MqttError<E>,
27>
28where
29    St: 'static,
30    T: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
31    C: ServiceFactory<Control<E>, Session<St>, Response = ControlAck> + 'static,
32    E: From<C::Error> + From<C::InitError> + From<T::Error> + From<T::InitError> + 'static,
33{
34    let factories = Rc::new((publish, control));
35
36    ntex_service::fn_factory_with_config(
37        async move |(cfg, session): (SharedCfg, Session<St>)| {
38            // create services
39            let sink = session.sink().shared();
40            let fut = join(factories.0.create(session.clone()), factories.1.create(session));
41            let (publish, control) = fut.await;
42
43            let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
44            let control = control.map_err(|e| MqttError::Service(e.into()))?;
45
46            let control = BufferService::new(
47                16,
48                // limit number of in-flight messages
49                InFlightService::new(1, control),
50            )
51            .map_err(|err| match err {
52                BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
53                BufferServiceError::RequestCanceled => {
54                    MqttError::Handshake(HandshakeError::Disconnected(None))
55                }
56            });
57
58            let cfg: Cfg<MqttServiceConfig> = cfg.get();
59            Ok(Dispatcher::<_, _, E>::new(sink, publish, control, cfg))
60        },
61    )
62}
63
64impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
65    fn size(&self) -> u32 {
66        match self {
67            DispatchItem::Item(Decoded::Packet(_, size))
68            | DispatchItem::Item(Decoded::Publish(_, _, size)) => *size,
69            _ => 0,
70        }
71    }
72
73    fn is_publish(&self) -> bool {
74        matches!(self, DispatchItem::Item(Decoded::Publish(..)))
75    }
76
77    fn is_chunk(&self) -> bool {
78        matches!(self, DispatchItem::Item(Decoded::PayloadChunk(..)))
79    }
80}
81
82/// Mqtt protocol dispatcher
83pub(crate) struct Dispatcher<T, C: Service<Control<E>>, E> {
84    publish: T,
85    inner: Rc<Inner<C>>,
86    cfg: Cfg<MqttServiceConfig>,
87    _t: PhantomData<(E,)>,
88}
89
90struct Inner<C> {
91    control: Pipeline<C>,
92    sink: Rc<MqttShared>,
93    payload: Cell<Option<PlSender>>,
94    inflight: RefCell<HashSet<NonZeroU16>>,
95}
96
97impl<T, C, E> Dispatcher<T, C, E>
98where
99    E: From<T::Error>,
100    T: Service<Publish, Response = ()>,
101    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
102{
103    pub(crate) fn new(
104        sink: Rc<MqttShared>,
105        publish: T,
106        control: C,
107        cfg: Cfg<MqttServiceConfig>,
108    ) -> Self {
109        Self {
110            cfg,
111            publish,
112            inner: Rc::new(Inner {
113                sink,
114                payload: Cell::new(None),
115                control: Pipeline::new(control),
116                inflight: RefCell::new(HashSet::default()),
117            }),
118            _t: PhantomData,
119        }
120    }
121
122    fn tag(&self) -> &'static str {
123        self.inner.sink.tag()
124    }
125}
126
127impl<C> Inner<C> {
128    fn drop_payload<PErr>(&self, err: &PErr)
129    where
130        PErr: Clone,
131        PayloadError: From<PErr>,
132    {
133        if let Some(pl) = self.payload.take() {
134            pl.set_error(err.clone().into());
135        }
136    }
137}
138
139impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
140where
141    E: From<T::Error> + 'static,
142    T: Service<Publish, Response = ()> + 'static,
143    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>> + 'static,
144{
145    type Response = Option<Encoded>;
146    type Error = MqttError<E>;
147
148    #[inline]
149    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
150        let (res1, res2) = join(ctx.ready(&self.publish), self.inner.control.ready()).await;
151        let result = if let Err(e) = res1 {
152            if res2.is_err() {
153                Err(MqttError::Service(e.into()))
154            } else {
155                match self.inner.control.call(Control::error(e.into())).await {
156                    Ok(_) => {
157                        self.inner.sink.close();
158                        Ok(())
159                    }
160                    Err(err) => Err(err),
161                }
162            }
163        } else {
164            res2
165        };
166
167        if result.is_ok() {
168            if let Some(pl) = self.inner.payload.take() {
169                self.inner.payload.set(Some(pl.clone()));
170                if pl.ready().await != PayloadStatus::Ready {
171                    self.inner.sink.close();
172                }
173            }
174        }
175        result
176    }
177
178    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
179        if let Err(e) = self.publish.poll(cx) {
180            let inner = self.inner.clone();
181            ntex_rt::spawn(async move {
182                if inner.control.call(Control::error(e.into())).await.is_ok() {
183                    inner.sink.close();
184                }
185            });
186        }
187        self.inner.control.poll(cx)
188    }
189
190    async fn shutdown(&self) {
191        self.inner.drop_payload(&PayloadError::Disconnected);
192        self.inner.sink.close();
193        let _ = self.inner.control.call(Control::closed()).await;
194
195        self.publish.shutdown().await;
196        self.inner.control.shutdown().await;
197    }
198
199    async fn call(
200        &self,
201        req: DispatchItem<Rc<MqttShared>>,
202        ctx: ServiceCtx<'_, Self>,
203    ) -> Result<Self::Response, Self::Error> {
204        log::trace!("{}; Dispatch v3 packet: {:#?}", self.tag(), req);
205
206        match req {
207            DispatchItem::Item(Decoded::Publish(publish, payload, size)) => {
208                if publish.topic.contains(['#', '+']) {
209                    return control(
210                        Control::proto_error(
211                            ProtocolError::generic_violation(
212                                "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]"
213                            )
214                        ),
215                        &self.inner,
216                    ).await;
217                }
218
219                let inner = self.inner.as_ref();
220                let packet_id = publish.packet_id;
221
222                // check for duplicated packet id
223                if let Some(pid) = packet_id {
224                    if !inner.inflight.borrow_mut().insert(pid) {
225                        log::trace!(
226                            "{}: Duplicated packet id for publish packet: {:?}",
227                            self.tag(),
228                            pid
229                        );
230                        return control(
231                            Control::proto_error(
232                                ProtocolError::generic_violation("PUBLISH received with packet id that is already in use [MQTT-2.2.1-3]")
233                            ),
234                            &self.inner,
235                        ).await;
236                    }
237                }
238
239                // check max allowed qos
240                if publish.qos > self.cfg.max_qos {
241                    log::trace!(
242                        "{}: Max allowed QoS is violated, max {:?} provided {:?}",
243                        self.tag(),
244                        self.cfg.max_qos,
245                        publish.qos
246                    );
247                    return control(
248                        Control::proto_error(ProtocolError::generic_violation(
249                            match publish.qos {
250                                QoS::AtLeastOnce => "PUBLISH with QoS 1 is not supported",
251                                QoS::ExactlyOnce => "PUBLISH with QoS 2 is not supported",
252                                QoS::AtMostOnce => unreachable!(), // max_qos cannot be lower than QoS 0
253                            },
254                        )),
255                        &self.inner,
256                    )
257                    .await;
258                }
259
260                if inner.sink.is_closed()
261                    && !self
262                        .cfg
263                        .handle_qos_after_disconnect
264                        .map(|max_qos| publish.qos <= max_qos)
265                        .unwrap_or_default()
266                {
267                    return Ok(None);
268                }
269
270                let payload = if publish.payload_size == payload.len() as u32 {
271                    Payload::from_bytes(payload)
272                } else {
273                    let (pl, sender) =
274                        Payload::from_stream(payload, self.cfg.max_payload_buffer_size);
275                    self.inner.payload.set(Some(sender));
276                    pl
277                };
278
279                publish_fn(
280                    &self.publish,
281                    Publish::new(publish, payload, size),
282                    packet_id,
283                    inner,
284                    ctx,
285                )
286                .await
287            }
288            DispatchItem::Item(Decoded::PayloadChunk(buf, eof)) => {
289                if let Some(pl) = self.inner.payload.take() {
290                    pl.feed_data(buf);
291                    if eof {
292                        pl.feed_eof();
293                    } else {
294                        self.inner.payload.set(Some(pl));
295                    }
296                    Ok(None)
297                } else {
298                    control(
299                        Control::proto_error(ProtocolError::Decode(
300                            DecodeError::UnexpectedPayload,
301                        )),
302                        &self.inner,
303                    )
304                    .await
305                }
306            }
307            DispatchItem::Item(Decoded::Packet(Packet::PublishAck { packet_id }, _)) => {
308                if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) {
309                    control(Control::proto_error(e), &self.inner).await
310                } else {
311                    Ok(None)
312                }
313            }
314            DispatchItem::Item(Decoded::Packet(Packet::PublishReceived { packet_id }, _)) => {
315                if let Err(e) = self.inner.sink.pkt_ack(Ack::Receive(packet_id)) {
316                    control(Control::proto_error(e), &self.inner).await
317                } else {
318                    Ok(None)
319                }
320            }
321            DispatchItem::Item(Decoded::Packet(Packet::PublishRelease { packet_id }, _)) => {
322                if self.inner.inflight.borrow().contains(&packet_id) {
323                    control(Control::pubrel(packet_id), &self.inner).await
324                } else {
325                    control(
326                        Control::proto_error(ProtocolError::unexpected_packet(
327                            packet_type::PUBREL,
328                            "Unknown packet-id in PublishRelease packet",
329                        )),
330                        &self.inner,
331                    )
332                    .await
333                }
334            }
335            DispatchItem::Item(Decoded::Packet(Packet::PublishComplete { packet_id }, _)) => {
336                if let Err(e) = self.inner.sink.pkt_ack(Ack::Complete(packet_id)) {
337                    control(Control::proto_error(e), &self.inner).await
338                } else {
339                    Ok(None)
340                }
341            }
342            DispatchItem::Item(Decoded::Packet(Packet::PingRequest, _)) => {
343                control(Control::ping(), &self.inner).await
344            }
345            DispatchItem::Item(Decoded::Packet(
346                Packet::Subscribe { packet_id, topic_filters },
347                size,
348            )) => {
349                if self.inner.sink.is_closed() {
350                    return Ok(None);
351                }
352
353                if topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
354                    return control(
355                        Control::proto_error(ProtocolError::generic_violation(
356                            "Topic filter is malformed [MQTT-4.7.1-*]",
357                        )),
358                        &self.inner,
359                    )
360                    .await;
361                }
362
363                if !self.inner.inflight.borrow_mut().insert(packet_id) {
364                    log::trace!(
365                        "{}: Duplicated packet id for subscribe packet: {:?}",
366                        self.tag(),
367                        packet_id
368                    );
369                    return control(
370                        Control::proto_error(ProtocolError::generic_violation(
371                            "SUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
372                        )),
373                        &self.inner,
374                    ).await;
375                }
376
377                control(
378                    Control::subscribe(Subscribe::new(packet_id, size, topic_filters)),
379                    &self.inner,
380                )
381                .await
382            }
383            DispatchItem::Item(Decoded::Packet(
384                Packet::Unsubscribe { packet_id, topic_filters },
385                size,
386            )) => {
387                if self.inner.sink.is_closed() {
388                    return Ok(None);
389                }
390
391                if topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
392                    return control(
393                        Control::proto_error(ProtocolError::generic_violation(
394                            "Topic filter is malformed [MQTT-4.7.1-*]",
395                        )),
396                        &self.inner,
397                    )
398                    .await;
399                }
400
401                if !self.inner.inflight.borrow_mut().insert(packet_id) {
402                    log::trace!(
403                        "{}: Duplicated packet id for unsubscribe packet: {:?}",
404                        self.tag(),
405                        packet_id
406                    );
407                    return control(
408                        Control::proto_error(ProtocolError::generic_violation(
409                            "UNSUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
410                        )),
411                        &self.inner,
412                    ).await;
413                }
414
415                control(
416                    Control::unsubscribe(Unsubscribe::new(packet_id, size, topic_filters)),
417                    &self.inner,
418                )
419                .await
420            }
421            DispatchItem::Item(Decoded::Packet(Packet::Disconnect, _)) => {
422                control(Control::remote_disconnect(), &self.inner).await
423            }
424            DispatchItem::Item(_) => Ok(None),
425            DispatchItem::EncoderError(err) => {
426                let err = ProtocolError::Encode(err);
427                self.inner.drop_payload(&err);
428                control(Control::proto_error(err), &self.inner).await
429            }
430            DispatchItem::KeepAliveTimeout => {
431                self.inner.drop_payload(&ProtocolError::KeepAliveTimeout);
432                control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner)
433                    .await
434            }
435            DispatchItem::ReadTimeout => {
436                self.inner.drop_payload(&ProtocolError::ReadTimeout);
437                control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner).await
438            }
439            DispatchItem::DecoderError(err) => {
440                let err = ProtocolError::Decode(err);
441                self.inner.drop_payload(&err);
442                control(Control::proto_error(err), &self.inner).await
443            }
444            DispatchItem::Disconnect(err) => {
445                self.inner.drop_payload(&PayloadError::Disconnected);
446                control(Control::peer_gone(err), &self.inner).await
447            }
448            DispatchItem::WBackPressureEnabled => {
449                self.inner.sink.enable_wr_backpressure();
450                control(Control::wr_backpressure(true), &self.inner).await
451            }
452            DispatchItem::WBackPressureDisabled => {
453                self.inner.sink.disable_wr_backpressure();
454                control(Control::wr_backpressure(false), &self.inner).await
455            }
456        }
457    }
458}
459
460/// Publish service response future
461async fn publish_fn<'f, T, C, E>(
462    svc: &'f T,
463    pkt: Publish,
464    packet_id: Option<NonZeroU16>,
465    inner: &'f Inner<C>,
466    ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
467) -> Result<Option<Encoded>, MqttError<E>>
468where
469    E: From<T::Error>,
470    T: Service<Publish, Response = ()>,
471    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
472{
473    let qos2 = pkt.qos() == QoS::ExactlyOnce;
474    match ctx.call(svc, pkt).await {
475        Ok(_) => {
476            log::trace!(
477                "{}: Publish result for packet {:?} is ready",
478                inner.sink.tag(),
479                packet_id
480            );
481
482            if let Some(packet_id) = packet_id {
483                if qos2 {
484                    Ok(Some(Encoded::Packet(Packet::PublishReceived { packet_id })))
485                } else {
486                    inner.inflight.borrow_mut().remove(&packet_id);
487                    Ok(Some(Encoded::Packet(Packet::PublishAck { packet_id })))
488                }
489            } else {
490                Ok(None)
491            }
492        }
493        Err(e) => control(Control::error(e.into()), inner).await,
494    }
495}
496
497async fn control<C, E>(
498    mut pkt: Control<E>,
499    inner: &Inner<C>,
500) -> Result<Option<Encoded>, MqttError<E>>
501where
502    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
503{
504    let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_));
505
506    loop {
507        match inner.control.call(pkt).await {
508            Ok(item) => {
509                let packet = match item.result {
510                    ControlAckKind::Ping => Some(Encoded::Packet(Packet::PingResponse)),
511                    ControlAckKind::Subscribe(res) => {
512                        inner.inflight.borrow_mut().remove(&res.packet_id);
513                        Some(Encoded::Packet(Packet::SubscribeAck {
514                            status: res.codes,
515                            packet_id: res.packet_id,
516                        }))
517                    }
518                    ControlAckKind::Unsubscribe(res) => {
519                        inner.inflight.borrow_mut().remove(&res.packet_id);
520                        Some(Encoded::Packet(Packet::UnsubscribeAck {
521                            packet_id: res.packet_id,
522                        }))
523                    }
524                    ControlAckKind::Disconnect => {
525                        inner.drop_payload(&PayloadError::Service);
526                        inner.sink.close();
527                        None
528                    }
529                    ControlAckKind::Closed | ControlAckKind::Nothing => None,
530                    ControlAckKind::PublishRelease(packet_id) => {
531                        inner.inflight.borrow_mut().remove(&packet_id);
532                        Some(Encoded::Packet(Packet::PublishComplete { packet_id }))
533                    }
534                    ControlAckKind::PublishAck(_) => unreachable!(),
535                };
536                return Ok(packet);
537            }
538            Err(err) => {
539                inner.drop_payload(&PayloadError::Service);
540
541                // do not handle nested error
542                return if error {
543                    Err(err)
544                } else {
545                    // handle error from control service
546                    match err {
547                        MqttError::Service(err) => {
548                            error = true;
549                            pkt = Control::error(err);
550                            continue;
551                        }
552                        _ => Err(err),
553                    }
554                };
555            }
556        }
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use std::{future::Future, pin::Pin};
563
564    use ntex_bytes::{ByteString, Bytes};
565    use ntex_io::{Io, testing::IoTest};
566    use ntex_service::{cfg::SharedCfg, fn_service};
567    use ntex_util::future::{Ready, lazy};
568    use ntex_util::time::{Seconds, sleep};
569
570    use super::*;
571    use crate::v3::{MqttSink, codec};
572
573    #[ntex::test]
574    async fn test_dup_packet_id() {
575        let cfg: SharedCfg = SharedCfg::new("DBG")
576            .add(MqttServiceConfig::new().set_max_qos(QoS::AtLeastOnce))
577            .into();
578
579        let io = Io::new(IoTest::create().0, cfg);
580        let codec = codec::Codec::default();
581        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
582        let err = Rc::new(RefCell::new(false));
583        let err2 = err.clone();
584
585        let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
586            shared.clone(),
587            fn_service(|_| async {
588                sleep(Seconds(10)).await;
589                Ok(())
590            }),
591            fn_service(move |ctrl| {
592                if let Control::ProtocolError(_) = ctrl {
593                    *err2.borrow_mut() = true;
594                }
595                Ready::Ok(ControlAck { result: ControlAckKind::Nothing })
596            }),
597            cfg.get(),
598        ));
599
600        let mut f: Pin<Box<dyn Future<Output = Result<_, _>>>> =
601            Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
602                codec::Publish {
603                    dup: false,
604                    retain: false,
605                    qos: QoS::AtLeastOnce,
606                    topic: ByteString::new(),
607                    packet_id: NonZeroU16::new(1),
608                    payload_size: 0,
609                },
610                Bytes::new(),
611                999,
612            ))));
613        let _ = lazy(|cx| Pin::new(&mut f).poll(cx)).await;
614
615        let f = Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
616            codec::Publish {
617                dup: false,
618                retain: false,
619                qos: QoS::AtLeastOnce,
620                topic: ByteString::new(),
621                packet_id: NonZeroU16::new(1),
622                payload_size: 0,
623            },
624            Bytes::new(),
625            999,
626        ))));
627        assert!(f.await.unwrap().is_none());
628        assert!(*err.borrow());
629    }
630
631    #[ntex::test]
632    async fn test_wr_backpressure() {
633        let cfg: SharedCfg = SharedCfg::new("DBG")
634            .add(MqttServiceConfig::new().set_max_qos(QoS::AtLeastOnce))
635            .into();
636
637        let io = Io::new(IoTest::create().0, cfg);
638        let codec = codec::Codec::default();
639        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
640
641        let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
642            shared.clone(),
643            fn_service(|_| Ready::Ok(())),
644            fn_service(|_| Ready::Ok(ControlAck { result: ControlAckKind::Nothing })),
645            cfg.get(),
646        ));
647
648        let sink = MqttSink::new(shared.clone());
649        assert!(!sink.is_ready());
650        shared.set_cap(1);
651        assert!(sink.is_ready());
652        assert!(shared.wait_readiness().is_none());
653
654        disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
655        assert!(!sink.is_ready());
656        let rx = shared.wait_readiness();
657        let rx2 = shared.wait_readiness().unwrap();
658        assert!(rx.is_some());
659
660        let rx = rx.unwrap();
661        disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
662        assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
663        assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
664    }
665}