1use std::cell::{Cell, RefCell};
2use std::{marker::PhantomData, num::NonZeroU16, rc::Rc, task::Context};
3
4use ntex_io::DispatchItem;
5use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory, cfg::Cfg, cfg::SharedCfg};
6use ntex_util::services::buffer::{BufferService, BufferServiceError};
7use ntex_util::services::inflight::InFlightService;
8use ntex_util::{HashSet, future::join};
9
10use crate::error::{DecodeError, HandshakeError, MqttError, PayloadError, ProtocolError};
11use crate::payload::{Payload, PayloadStatus, PlSender};
12use crate::{MqttServiceConfig, types::QoS, types::packet_type};
13
14use super::codec::{Decoded, Encoded, Packet};
15use super::control::{Control, ControlAck, ControlAckKind, Subscribe, Unsubscribe};
16use super::{Session, publish::Publish, shared::Ack, shared::MqttShared};
17
18pub(super) fn factory<St, T, C, E>(
20 publish: T,
21 control: C,
22) -> impl ServiceFactory<
23 DispatchItem<Rc<MqttShared>>,
24 (SharedCfg, Session<St>),
25 Response = Option<Encoded>,
26 Error = MqttError<E>,
27 InitError = MqttError<E>,
28>
29where
30 St: 'static,
31 T: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
32 C: ServiceFactory<Control<E>, Session<St>, Response = ControlAck> + 'static,
33 E: From<C::Error> + From<C::InitError> + From<T::Error> + From<T::InitError> + 'static,
34{
35 let factories = Rc::new((publish, control));
36
37 ntex_service::fn_factory_with_config(
38 async move |(cfg, session): (SharedCfg, Session<St>)| {
39 let sink = session.sink().shared();
41 let fut = join(factories.0.create(session.clone()), factories.1.create(session));
42 let (publish, control) = fut.await;
43
44 let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
45 let control = control.map_err(|e| MqttError::Service(e.into()))?;
46
47 let control = BufferService::new(
48 16,
49 InFlightService::new(1, control),
51 )
52 .map_err(|err| match err {
53 BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
54 BufferServiceError::RequestCanceled => {
55 MqttError::Handshake(HandshakeError::Disconnected(None))
56 }
57 });
58
59 let cfg: Cfg<MqttServiceConfig> = cfg.get();
60 Ok(Dispatcher::<_, _, E>::new(
61 sink,
62 publish,
63 control,
64 cfg.max_qos,
65 cfg.handle_qos_after_disconnect,
66 ))
67 },
68 )
69}
70
71impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
72 fn size(&self) -> u32 {
73 match self {
74 DispatchItem::Item(Decoded::Packet(_, size))
75 | DispatchItem::Item(Decoded::Publish(_, _, size)) => *size,
76 _ => 0,
77 }
78 }
79
80 fn is_publish(&self) -> bool {
81 matches!(self, DispatchItem::Item(Decoded::Publish(..)))
82 }
83
84 fn is_chunk(&self) -> bool {
85 matches!(self, DispatchItem::Item(Decoded::PayloadChunk(..)))
86 }
87}
88
89pub(crate) struct Dispatcher<T, C: Service<Control<E>>, E> {
91 publish: T,
92 max_qos: QoS,
93 handle_qos_after_disconnect: Option<QoS>,
94 inner: Rc<Inner<C>>,
95 _t: PhantomData<(E,)>,
96}
97
98struct Inner<C> {
99 control: Pipeline<C>,
100 sink: Rc<MqttShared>,
101 payload: Cell<Option<PlSender>>,
102 inflight: RefCell<HashSet<NonZeroU16>>,
103}
104
105impl<T, C, E> Dispatcher<T, C, E>
106where
107 E: From<T::Error>,
108 T: Service<Publish, Response = ()>,
109 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
110{
111 pub(crate) fn new(
112 sink: Rc<MqttShared>,
113 publish: T,
114 control: C,
115 max_qos: QoS,
116 handle_qos_after_disconnect: Option<QoS>,
117 ) -> Self {
118 Self {
119 publish,
120 max_qos,
121 handle_qos_after_disconnect,
122 inner: Rc::new(Inner {
123 sink,
124 payload: Cell::new(None),
125 control: Pipeline::new(control),
126 inflight: RefCell::new(HashSet::default()),
127 }),
128 _t: PhantomData,
129 }
130 }
131
132 fn tag(&self) -> &'static str {
133 self.inner.sink.tag()
134 }
135}
136
137impl<C> Inner<C> {
138 fn drop_payload<PErr>(&self, err: &PErr)
139 where
140 PErr: Clone,
141 PayloadError: From<PErr>,
142 {
143 if let Some(pl) = self.payload.take() {
144 pl.set_error(err.clone().into());
145 }
146 }
147}
148
149impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
150where
151 E: From<T::Error> + 'static,
152 T: Service<Publish, Response = ()> + 'static,
153 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>> + 'static,
154{
155 type Response = Option<Encoded>;
156 type Error = MqttError<E>;
157
158 #[inline]
159 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
160 let (res1, res2) =
161 join(ctx.ready(&self.publish), ctx.ready(self.inner.control.get_ref())).await;
162 let result = if let Err(e) = res1 {
163 if res2.is_err() {
164 Err(MqttError::Service(e.into()))
165 } else {
166 match ctx
167 .call_nowait(self.inner.control.get_ref(), Control::error(e.into()))
168 .await
169 {
170 Ok(_) => {
171 self.inner.sink.close();
172 Ok(())
173 }
174 Err(err) => Err(err),
175 }
176 }
177 } else {
178 res2
179 };
180
181 if result.is_ok() {
182 if let Some(pl) = self.inner.payload.take() {
183 if pl.ready().await == PayloadStatus::Ready {
184 self.inner.payload.set(Some(pl));
185 } else {
186 self.inner.sink.close();
187 }
188 }
189 }
190 result
191 }
192
193 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
194 if let Err(e) = self.publish.poll(cx) {
195 let inner = self.inner.clone();
196 ntex_rt::spawn(async move {
197 if inner.control.call_nowait(Control::error(e.into())).await.is_ok() {
198 inner.sink.close();
199 }
200 });
201 }
202 self.inner.control.poll(cx)
203 }
204
205 async fn shutdown(&self) {
206 self.inner.drop_payload(&PayloadError::Disconnected);
207 self.inner.sink.close();
208 let _ = self.inner.control.call(Control::closed()).await;
209
210 self.publish.shutdown().await;
211 self.inner.control.shutdown().await;
212 }
213
214 async fn call(
215 &self,
216 req: DispatchItem<Rc<MqttShared>>,
217 ctx: ServiceCtx<'_, Self>,
218 ) -> Result<Self::Response, Self::Error> {
219 log::trace!("{}; Dispatch v3 packet: {:#?}", self.tag(), req);
220
221 match req {
222 DispatchItem::Item(Decoded::Publish(publish, payload, size)) => {
223 if publish.topic.contains(['#', '+']) {
224 return control(
225 Control::proto_error(
226 ProtocolError::generic_violation(
227 "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]"
228 )
229 ),
230 &self.inner,
231 ctx,
232 ).await;
233 }
234
235 let inner = self.inner.as_ref();
236 let packet_id = publish.packet_id;
237
238 if let Some(pid) = packet_id {
240 if !inner.inflight.borrow_mut().insert(pid) {
241 log::trace!(
242 "{}: Duplicated packet id for publish packet: {:?}",
243 self.tag(),
244 pid
245 );
246 return control(
247 Control::proto_error(
248 ProtocolError::generic_violation("PUBLISH received with packet id that is already in use [MQTT-2.2.1-3]")
249 ),
250 &self.inner,
251 ctx,
252 ).await;
253 }
254 }
255
256 if publish.qos > self.max_qos {
258 log::trace!(
259 "{}: Max allowed QoS is violated, max {:?} provided {:?}",
260 self.tag(),
261 self.max_qos,
262 publish.qos
263 );
264 return control(
265 Control::proto_error(ProtocolError::generic_violation(
266 match publish.qos {
267 QoS::AtLeastOnce => "PUBLISH with QoS 1 is not supported",
268 QoS::ExactlyOnce => "PUBLISH with QoS 2 is not supported",
269 QoS::AtMostOnce => unreachable!(), },
271 )),
272 &self.inner,
273 ctx,
274 )
275 .await;
276 }
277
278 if inner.sink.is_closed()
279 && !self
280 .handle_qos_after_disconnect
281 .map(|max_qos| publish.qos <= max_qos)
282 .unwrap_or_default()
283 {
284 return Ok(None);
285 }
286
287 let payload = if publish.payload_size == payload.len() as u32 {
288 Payload::from_bytes(payload)
289 } else {
290 let (pl, sender) = Payload::from_stream(payload);
291 self.inner.payload.set(Some(sender));
292 pl
293 };
294
295 publish_fn(
296 &self.publish,
297 Publish::new(publish, payload, size),
298 packet_id,
299 inner,
300 ctx,
301 )
302 .await
303 }
304 DispatchItem::Item(Decoded::PayloadChunk(buf, eof)) => {
305 if let Some(pl) = self.inner.payload.take() {
306 pl.feed_data(buf);
307 if eof {
308 pl.feed_eof();
309 } else {
310 self.inner.payload.set(Some(pl));
311 }
312 Ok(None)
313 } else {
314 control(
315 Control::proto_error(ProtocolError::Decode(
316 DecodeError::UnexpectedPayload,
317 )),
318 &self.inner,
319 ctx,
320 )
321 .await
322 }
323 }
324 DispatchItem::Item(Decoded::Packet(Packet::PublishAck { packet_id }, _)) => {
325 if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) {
326 control(Control::proto_error(e), &self.inner, ctx).await
327 } else {
328 Ok(None)
329 }
330 }
331 DispatchItem::Item(Decoded::Packet(Packet::PublishReceived { packet_id }, _)) => {
332 if let Err(e) = self.inner.sink.pkt_ack(Ack::Receive(packet_id)) {
333 control(Control::proto_error(e), &self.inner, ctx).await
334 } else {
335 Ok(None)
336 }
337 }
338 DispatchItem::Item(Decoded::Packet(Packet::PublishRelease { packet_id }, _)) => {
339 if self.inner.inflight.borrow().contains(&packet_id) {
340 control(Control::pubrel(packet_id), &self.inner, ctx).await
341 } else {
342 control(
343 Control::proto_error(ProtocolError::unexpected_packet(
344 packet_type::PUBREL,
345 "Unknown packet-id in PublishRelease packet",
346 )),
347 &self.inner,
348 ctx,
349 )
350 .await
351 }
352 }
353 DispatchItem::Item(Decoded::Packet(Packet::PublishComplete { packet_id }, _)) => {
354 if let Err(e) = self.inner.sink.pkt_ack(Ack::Complete(packet_id)) {
355 control(Control::proto_error(e), &self.inner, ctx).await
356 } else {
357 Ok(None)
358 }
359 }
360 DispatchItem::Item(Decoded::Packet(Packet::PingRequest, _)) => {
361 control(Control::ping(), &self.inner, ctx).await
362 }
363 DispatchItem::Item(Decoded::Packet(
364 Packet::Subscribe { packet_id, topic_filters },
365 size,
366 )) => {
367 if self.inner.sink.is_closed() {
368 return Ok(None);
369 }
370
371 if topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
372 return control(
373 Control::proto_error(ProtocolError::generic_violation(
374 "Topic filter is malformed [MQTT-4.7.1-*]",
375 )),
376 &self.inner,
377 ctx,
378 )
379 .await;
380 }
381
382 if !self.inner.inflight.borrow_mut().insert(packet_id) {
383 log::trace!(
384 "{}: Duplicated packet id for subscribe packet: {:?}",
385 self.tag(),
386 packet_id
387 );
388 return control(
389 Control::proto_error(ProtocolError::generic_violation(
390 "SUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
391 )),
392 &self.inner,
393 ctx,
394 ).await;
395 }
396
397 control(
398 Control::subscribe(Subscribe::new(packet_id, size, topic_filters)),
399 &self.inner,
400 ctx,
401 )
402 .await
403 }
404 DispatchItem::Item(Decoded::Packet(
405 Packet::Unsubscribe { packet_id, topic_filters },
406 size,
407 )) => {
408 if self.inner.sink.is_closed() {
409 return Ok(None);
410 }
411
412 if topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
413 return control(
414 Control::proto_error(ProtocolError::generic_violation(
415 "Topic filter is malformed [MQTT-4.7.1-*]",
416 )),
417 &self.inner,
418 ctx,
419 )
420 .await;
421 }
422
423 if !self.inner.inflight.borrow_mut().insert(packet_id) {
424 log::trace!(
425 "{}: Duplicated packet id for unsubscribe packet: {:?}",
426 self.tag(),
427 packet_id
428 );
429 return control(
430 Control::proto_error(ProtocolError::generic_violation(
431 "UNSUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
432 )),
433 &self.inner,
434 ctx,
435 ).await;
436 }
437
438 control(
439 Control::unsubscribe(Unsubscribe::new(packet_id, size, topic_filters)),
440 &self.inner,
441 ctx,
442 )
443 .await
444 }
445 DispatchItem::Item(Decoded::Packet(Packet::Disconnect, _)) => {
446 control(Control::remote_disconnect(), &self.inner, ctx).await
447 }
448 DispatchItem::Item(_) => Ok(None),
449 DispatchItem::EncoderError(err) => {
450 let err = ProtocolError::Encode(err);
451 self.inner.drop_payload(&err);
452 control(Control::proto_error(err), &self.inner, ctx).await
453 }
454 DispatchItem::KeepAliveTimeout => {
455 self.inner.drop_payload(&ProtocolError::KeepAliveTimeout);
456 control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, ctx)
457 .await
458 }
459 DispatchItem::ReadTimeout => {
460 self.inner.drop_payload(&ProtocolError::ReadTimeout);
461 control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, ctx)
462 .await
463 }
464 DispatchItem::DecoderError(err) => {
465 let err = ProtocolError::Decode(err);
466 self.inner.drop_payload(&err);
467 control(Control::proto_error(err), &self.inner, ctx).await
468 }
469 DispatchItem::Disconnect(err) => {
470 self.inner.drop_payload(&PayloadError::Disconnected);
471 control(Control::peer_gone(err), &self.inner, ctx).await
472 }
473 DispatchItem::WBackPressureEnabled => {
474 self.inner.sink.enable_wr_backpressure();
475 control(Control::wr_backpressure(true), &self.inner, ctx).await
476 }
477 DispatchItem::WBackPressureDisabled => {
478 self.inner.sink.disable_wr_backpressure();
479 control(Control::wr_backpressure(false), &self.inner, ctx).await
480 }
481 }
482 }
483}
484
485async fn publish_fn<'f, T, C, E>(
487 svc: &'f T,
488 pkt: Publish,
489 packet_id: Option<NonZeroU16>,
490 inner: &'f Inner<C>,
491 ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
492) -> Result<Option<Encoded>, MqttError<E>>
493where
494 E: From<T::Error>,
495 T: Service<Publish, Response = ()>,
496 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
497{
498 let qos2 = pkt.qos() == QoS::ExactlyOnce;
499 match ctx.call(svc, pkt).await {
500 Ok(_) => {
501 log::trace!(
502 "{}: Publish result for packet {:?} is ready",
503 inner.sink.tag(),
504 packet_id
505 );
506
507 if let Some(packet_id) = packet_id {
508 if qos2 {
509 Ok(Some(Encoded::Packet(Packet::PublishReceived { packet_id })))
510 } else {
511 inner.inflight.borrow_mut().remove(&packet_id);
512 Ok(Some(Encoded::Packet(Packet::PublishAck { packet_id })))
513 }
514 } else {
515 Ok(None)
516 }
517 }
518 Err(e) => control(Control::error(e.into()), inner, ctx).await,
519 }
520}
521
522async fn control<'f, T, C, E>(
523 mut pkt: Control<E>,
524 inner: &'f Inner<C>,
525 ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
526) -> Result<Option<Encoded>, MqttError<E>>
527where
528 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
529{
530 let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_));
531
532 loop {
533 match ctx.call(inner.control.get_ref(), pkt).await {
534 Ok(item) => {
535 let packet = match item.result {
536 ControlAckKind::Ping => Some(Encoded::Packet(Packet::PingResponse)),
537 ControlAckKind::Subscribe(res) => {
538 inner.inflight.borrow_mut().remove(&res.packet_id);
539 Some(Encoded::Packet(Packet::SubscribeAck {
540 status: res.codes,
541 packet_id: res.packet_id,
542 }))
543 }
544 ControlAckKind::Unsubscribe(res) => {
545 inner.inflight.borrow_mut().remove(&res.packet_id);
546 Some(Encoded::Packet(Packet::UnsubscribeAck {
547 packet_id: res.packet_id,
548 }))
549 }
550 ControlAckKind::Disconnect => {
551 inner.drop_payload(&PayloadError::Service);
552 inner.sink.close();
553 None
554 }
555 ControlAckKind::Closed | ControlAckKind::Nothing => None,
556 ControlAckKind::PublishRelease(packet_id) => {
557 inner.inflight.borrow_mut().remove(&packet_id);
558 Some(Encoded::Packet(Packet::PublishComplete { packet_id }))
559 }
560 ControlAckKind::PublishAck(_) => unreachable!(),
561 };
562 return Ok(packet);
563 }
564 Err(err) => {
565 inner.drop_payload(&PayloadError::Service);
566
567 return if error {
569 Err(err)
570 } else {
571 match err {
573 MqttError::Service(err) => {
574 error = true;
575 pkt = Control::error(err);
576 continue;
577 }
578 _ => Err(err),
579 }
580 };
581 }
582 }
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use std::{future::Future, pin::Pin};
589
590 use ntex_bytes::{ByteString, Bytes};
591 use ntex_io::{Io, testing::IoTest};
592 use ntex_service::{cfg::SharedCfg, fn_service};
593 use ntex_util::future::{Ready, lazy};
594 use ntex_util::time::{Seconds, sleep};
595
596 use super::*;
597 use crate::v3::{MqttSink, codec};
598
599 #[ntex::test]
600 async fn test_dup_packet_id() {
601 let io = Io::new(IoTest::create().0, SharedCfg::new("DBG"));
602 let codec = codec::Codec::default();
603 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
604 let err = Rc::new(RefCell::new(false));
605 let err2 = err.clone();
606
607 let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
608 shared.clone(),
609 fn_service(|_| async {
610 sleep(Seconds(10)).await;
611 Ok(())
612 }),
613 fn_service(move |ctrl| {
614 if let Control::ProtocolError(_) = ctrl {
615 *err2.borrow_mut() = true;
616 }
617 Ready::Ok(ControlAck { result: ControlAckKind::Nothing })
618 }),
619 QoS::AtLeastOnce,
620 None,
621 ));
622
623 let mut f: Pin<Box<dyn Future<Output = Result<_, _>>>> =
624 Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
625 codec::Publish {
626 dup: false,
627 retain: false,
628 qos: QoS::AtLeastOnce,
629 topic: ByteString::new(),
630 packet_id: NonZeroU16::new(1),
631 payload_size: 0,
632 },
633 Bytes::new(),
634 999,
635 ))));
636 let _ = lazy(|cx| Pin::new(&mut f).poll(cx)).await;
637
638 let f = Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
639 codec::Publish {
640 dup: false,
641 retain: false,
642 qos: QoS::AtLeastOnce,
643 topic: ByteString::new(),
644 packet_id: NonZeroU16::new(1),
645 payload_size: 0,
646 },
647 Bytes::new(),
648 999,
649 ))));
650 assert!(f.await.unwrap().is_none());
651 assert!(*err.borrow());
652 }
653
654 #[ntex::test]
655 async fn test_wr_backpressure() {
656 let io = Io::new(IoTest::create().0, SharedCfg::new("DBG"));
657 let codec = codec::Codec::default();
658 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
659
660 let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
661 shared.clone(),
662 fn_service(|_| Ready::Ok(())),
663 fn_service(|_| Ready::Ok(ControlAck { result: ControlAckKind::Nothing })),
664 QoS::AtLeastOnce,
665 None,
666 ));
667
668 let sink = MqttSink::new(shared.clone());
669 assert!(!sink.is_ready());
670 shared.set_cap(1);
671 assert!(sink.is_ready());
672 assert!(shared.wait_readiness().is_none());
673
674 disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
675 assert!(!sink.is_ready());
676 let rx = shared.wait_readiness();
677 let rx2 = shared.wait_readiness().unwrap();
678 assert!(rx.is_some());
679
680 let rx = rx.unwrap();
681 disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
682 assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
683 assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
684 }
685}