1use std::cell::{Cell, RefCell};
2use std::{marker::PhantomData, num::NonZeroU16, rc::Rc, task::Context};
3
4use ntex_io::DispatchItem;
5use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory, cfg::Cfg, cfg::SharedCfg};
6use ntex_util::services::buffer::{BufferService, BufferServiceError};
7use ntex_util::{HashSet, future::join, services::inflight::InFlightService};
8
9use crate::error::{DecodeError, HandshakeError, MqttError, PayloadError, ProtocolError};
10use crate::payload::{Payload, PayloadStatus, PlSender};
11use crate::{MqttServiceConfig, types::QoS, types::packet_type};
12
13use super::codec::{Decoded, Encoded, Packet};
14use super::control::{Control, ControlAck, ControlAckKind, Subscribe, Unsubscribe};
15use super::{Session, publish::Publish, shared::Ack, shared::MqttShared};
16
17pub(super) fn factory<St, T, C, E>(
19 publish: T,
20 control: C,
21) -> impl ServiceFactory<
22 DispatchItem<Rc<MqttShared>>,
23 (SharedCfg, Session<St>),
24 Response = Option<Encoded>,
25 Error = MqttError<E>,
26 InitError = MqttError<E>,
27>
28where
29 St: 'static,
30 T: ServiceFactory<Publish, Session<St>, Response = ()> + 'static,
31 C: ServiceFactory<Control<E>, Session<St>, Response = ControlAck> + 'static,
32 E: From<C::Error> + From<C::InitError> + From<T::Error> + From<T::InitError> + 'static,
33{
34 let factories = Rc::new((publish, control));
35
36 ntex_service::fn_factory_with_config(
37 async move |(cfg, session): (SharedCfg, Session<St>)| {
38 let sink = session.sink().shared();
40 let fut = join(factories.0.create(session.clone()), factories.1.create(session));
41 let (publish, control) = fut.await;
42
43 let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
44 let control = control.map_err(|e| MqttError::Service(e.into()))?;
45
46 let control = BufferService::new(
47 16,
48 InFlightService::new(1, control),
50 )
51 .map_err(|err| match err {
52 BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
53 BufferServiceError::RequestCanceled => {
54 MqttError::Handshake(HandshakeError::Disconnected(None))
55 }
56 });
57
58 let cfg: Cfg<MqttServiceConfig> = cfg.get();
59 Ok(Dispatcher::<_, _, E>::new(sink, publish, control, cfg))
60 },
61 )
62}
63
64impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
65 fn size(&self) -> u32 {
66 match self {
67 DispatchItem::Item(Decoded::Packet(_, size))
68 | DispatchItem::Item(Decoded::Publish(_, _, size)) => *size,
69 _ => 0,
70 }
71 }
72
73 fn is_publish(&self) -> bool {
74 matches!(self, DispatchItem::Item(Decoded::Publish(..)))
75 }
76
77 fn is_chunk(&self) -> bool {
78 matches!(self, DispatchItem::Item(Decoded::PayloadChunk(..)))
79 }
80}
81
82pub(crate) struct Dispatcher<T, C: Service<Control<E>>, E> {
84 publish: T,
85 inner: Rc<Inner<C>>,
86 cfg: Cfg<MqttServiceConfig>,
87 _t: PhantomData<(E,)>,
88}
89
90struct Inner<C> {
91 control: Pipeline<C>,
92 sink: Rc<MqttShared>,
93 payload: Cell<Option<PlSender>>,
94 inflight: RefCell<HashSet<NonZeroU16>>,
95}
96
97impl<T, C, E> Dispatcher<T, C, E>
98where
99 E: From<T::Error>,
100 T: Service<Publish, Response = ()>,
101 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
102{
103 pub(crate) fn new(
104 sink: Rc<MqttShared>,
105 publish: T,
106 control: C,
107 cfg: Cfg<MqttServiceConfig>,
108 ) -> Self {
109 Self {
110 cfg,
111 publish,
112 inner: Rc::new(Inner {
113 sink,
114 payload: Cell::new(None),
115 control: Pipeline::new(control),
116 inflight: RefCell::new(HashSet::default()),
117 }),
118 _t: PhantomData,
119 }
120 }
121
122 fn tag(&self) -> &'static str {
123 self.inner.sink.tag()
124 }
125}
126
127impl<C> Inner<C> {
128 fn drop_payload<PErr>(&self, err: &PErr)
129 where
130 PErr: Clone,
131 PayloadError: From<PErr>,
132 {
133 if let Some(pl) = self.payload.take() {
134 pl.set_error(err.clone().into());
135 }
136 }
137}
138
139impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
140where
141 E: From<T::Error> + 'static,
142 T: Service<Publish, Response = ()> + 'static,
143 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>> + 'static,
144{
145 type Response = Option<Encoded>;
146 type Error = MqttError<E>;
147
148 #[inline]
149 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
150 let (res1, res2) = join(ctx.ready(&self.publish), self.inner.control.ready()).await;
151 let result = if let Err(e) = res1 {
152 if res2.is_err() {
153 Err(MqttError::Service(e.into()))
154 } else {
155 match self.inner.control.call(Control::error(e.into())).await {
156 Ok(_) => {
157 self.inner.sink.close();
158 Ok(())
159 }
160 Err(err) => Err(err),
161 }
162 }
163 } else {
164 res2
165 };
166
167 if result.is_ok() {
168 if let Some(pl) = self.inner.payload.take() {
169 self.inner.payload.set(Some(pl.clone()));
170 if pl.ready().await != PayloadStatus::Ready {
171 self.inner.sink.close();
172 }
173 }
174 }
175 result
176 }
177
178 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
179 if let Err(e) = self.publish.poll(cx) {
180 let inner = self.inner.clone();
181 ntex_rt::spawn(async move {
182 if inner.control.call(Control::error(e.into())).await.is_ok() {
183 inner.sink.close();
184 }
185 });
186 }
187 self.inner.control.poll(cx)
188 }
189
190 async fn shutdown(&self) {
191 self.inner.drop_payload(&PayloadError::Disconnected);
192 self.inner.sink.close();
193 let _ = self.inner.control.call(Control::closed()).await;
194
195 self.publish.shutdown().await;
196 self.inner.control.shutdown().await;
197 }
198
199 async fn call(
200 &self,
201 req: DispatchItem<Rc<MqttShared>>,
202 ctx: ServiceCtx<'_, Self>,
203 ) -> Result<Self::Response, Self::Error> {
204 log::trace!("{}; Dispatch v3 packet: {:#?}", self.tag(), req);
205
206 match req {
207 DispatchItem::Item(Decoded::Publish(publish, payload, size)) => {
208 if publish.topic.contains(['#', '+']) {
209 return control(
210 Control::proto_error(
211 ProtocolError::generic_violation(
212 "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]"
213 )
214 ),
215 &self.inner,
216 ).await;
217 }
218
219 let inner = self.inner.as_ref();
220 let packet_id = publish.packet_id;
221
222 if let Some(pid) = packet_id {
224 if !inner.inflight.borrow_mut().insert(pid) {
225 log::trace!(
226 "{}: Duplicated packet id for publish packet: {:?}",
227 self.tag(),
228 pid
229 );
230 return control(
231 Control::proto_error(
232 ProtocolError::generic_violation("PUBLISH received with packet id that is already in use [MQTT-2.2.1-3]")
233 ),
234 &self.inner,
235 ).await;
236 }
237 }
238
239 if publish.qos > self.cfg.max_qos {
241 log::trace!(
242 "{}: Max allowed QoS is violated, max {:?} provided {:?}",
243 self.tag(),
244 self.cfg.max_qos,
245 publish.qos
246 );
247 return control(
248 Control::proto_error(ProtocolError::generic_violation(
249 match publish.qos {
250 QoS::AtLeastOnce => "PUBLISH with QoS 1 is not supported",
251 QoS::ExactlyOnce => "PUBLISH with QoS 2 is not supported",
252 QoS::AtMostOnce => unreachable!(), },
254 )),
255 &self.inner,
256 )
257 .await;
258 }
259
260 if inner.sink.is_closed()
261 && !self
262 .cfg
263 .handle_qos_after_disconnect
264 .map(|max_qos| publish.qos <= max_qos)
265 .unwrap_or_default()
266 {
267 return Ok(None);
268 }
269
270 let payload = if publish.payload_size == payload.len() as u32 {
271 Payload::from_bytes(payload)
272 } else {
273 let (pl, sender) =
274 Payload::from_stream(payload, self.cfg.max_payload_buffer_size);
275 self.inner.payload.set(Some(sender));
276 pl
277 };
278
279 publish_fn(
280 &self.publish,
281 Publish::new(publish, payload, size),
282 packet_id,
283 inner,
284 ctx,
285 )
286 .await
287 }
288 DispatchItem::Item(Decoded::PayloadChunk(buf, eof)) => {
289 if let Some(pl) = self.inner.payload.take() {
290 pl.feed_data(buf);
291 if eof {
292 pl.feed_eof();
293 } else {
294 self.inner.payload.set(Some(pl));
295 }
296 Ok(None)
297 } else {
298 control(
299 Control::proto_error(ProtocolError::Decode(
300 DecodeError::UnexpectedPayload,
301 )),
302 &self.inner,
303 )
304 .await
305 }
306 }
307 DispatchItem::Item(Decoded::Packet(Packet::PublishAck { packet_id }, _)) => {
308 if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) {
309 control(Control::proto_error(e), &self.inner).await
310 } else {
311 Ok(None)
312 }
313 }
314 DispatchItem::Item(Decoded::Packet(Packet::PublishReceived { packet_id }, _)) => {
315 if let Err(e) = self.inner.sink.pkt_ack(Ack::Receive(packet_id)) {
316 control(Control::proto_error(e), &self.inner).await
317 } else {
318 Ok(None)
319 }
320 }
321 DispatchItem::Item(Decoded::Packet(Packet::PublishRelease { packet_id }, _)) => {
322 if self.inner.inflight.borrow().contains(&packet_id) {
323 control(Control::pubrel(packet_id), &self.inner).await
324 } else {
325 control(
326 Control::proto_error(ProtocolError::unexpected_packet(
327 packet_type::PUBREL,
328 "Unknown packet-id in PublishRelease packet",
329 )),
330 &self.inner,
331 )
332 .await
333 }
334 }
335 DispatchItem::Item(Decoded::Packet(Packet::PublishComplete { packet_id }, _)) => {
336 if let Err(e) = self.inner.sink.pkt_ack(Ack::Complete(packet_id)) {
337 control(Control::proto_error(e), &self.inner).await
338 } else {
339 Ok(None)
340 }
341 }
342 DispatchItem::Item(Decoded::Packet(Packet::PingRequest, _)) => {
343 control(Control::ping(), &self.inner).await
344 }
345 DispatchItem::Item(Decoded::Packet(
346 Packet::Subscribe { packet_id, topic_filters },
347 size,
348 )) => {
349 if self.inner.sink.is_closed() {
350 return Ok(None);
351 }
352
353 if topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
354 return control(
355 Control::proto_error(ProtocolError::generic_violation(
356 "Topic filter is malformed [MQTT-4.7.1-*]",
357 )),
358 &self.inner,
359 )
360 .await;
361 }
362
363 if !self.inner.inflight.borrow_mut().insert(packet_id) {
364 log::trace!(
365 "{}: Duplicated packet id for subscribe packet: {:?}",
366 self.tag(),
367 packet_id
368 );
369 return control(
370 Control::proto_error(ProtocolError::generic_violation(
371 "SUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
372 )),
373 &self.inner,
374 ).await;
375 }
376
377 control(
378 Control::subscribe(Subscribe::new(packet_id, size, topic_filters)),
379 &self.inner,
380 )
381 .await
382 }
383 DispatchItem::Item(Decoded::Packet(
384 Packet::Unsubscribe { packet_id, topic_filters },
385 size,
386 )) => {
387 if self.inner.sink.is_closed() {
388 return Ok(None);
389 }
390
391 if topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
392 return control(
393 Control::proto_error(ProtocolError::generic_violation(
394 "Topic filter is malformed [MQTT-4.7.1-*]",
395 )),
396 &self.inner,
397 )
398 .await;
399 }
400
401 if !self.inner.inflight.borrow_mut().insert(packet_id) {
402 log::trace!(
403 "{}: Duplicated packet id for unsubscribe packet: {:?}",
404 self.tag(),
405 packet_id
406 );
407 return control(
408 Control::proto_error(ProtocolError::generic_violation(
409 "UNSUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]"
410 )),
411 &self.inner,
412 ).await;
413 }
414
415 control(
416 Control::unsubscribe(Unsubscribe::new(packet_id, size, topic_filters)),
417 &self.inner,
418 )
419 .await
420 }
421 DispatchItem::Item(Decoded::Packet(Packet::Disconnect, _)) => {
422 control(Control::remote_disconnect(), &self.inner).await
423 }
424 DispatchItem::Item(_) => Ok(None),
425 DispatchItem::EncoderError(err) => {
426 let err = ProtocolError::Encode(err);
427 self.inner.drop_payload(&err);
428 control(Control::proto_error(err), &self.inner).await
429 }
430 DispatchItem::KeepAliveTimeout => {
431 self.inner.drop_payload(&ProtocolError::KeepAliveTimeout);
432 control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner)
433 .await
434 }
435 DispatchItem::ReadTimeout => {
436 self.inner.drop_payload(&ProtocolError::ReadTimeout);
437 control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner).await
438 }
439 DispatchItem::DecoderError(err) => {
440 let err = ProtocolError::Decode(err);
441 self.inner.drop_payload(&err);
442 control(Control::proto_error(err), &self.inner).await
443 }
444 DispatchItem::Disconnect(err) => {
445 self.inner.drop_payload(&PayloadError::Disconnected);
446 control(Control::peer_gone(err), &self.inner).await
447 }
448 DispatchItem::WBackPressureEnabled => {
449 self.inner.sink.enable_wr_backpressure();
450 control(Control::wr_backpressure(true), &self.inner).await
451 }
452 DispatchItem::WBackPressureDisabled => {
453 self.inner.sink.disable_wr_backpressure();
454 control(Control::wr_backpressure(false), &self.inner).await
455 }
456 }
457 }
458}
459
460async fn publish_fn<'f, T, C, E>(
462 svc: &'f T,
463 pkt: Publish,
464 packet_id: Option<NonZeroU16>,
465 inner: &'f Inner<C>,
466 ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
467) -> Result<Option<Encoded>, MqttError<E>>
468where
469 E: From<T::Error>,
470 T: Service<Publish, Response = ()>,
471 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
472{
473 let qos2 = pkt.qos() == QoS::ExactlyOnce;
474 match ctx.call(svc, pkt).await {
475 Ok(_) => {
476 log::trace!(
477 "{}: Publish result for packet {:?} is ready",
478 inner.sink.tag(),
479 packet_id
480 );
481
482 if let Some(packet_id) = packet_id {
483 if qos2 {
484 Ok(Some(Encoded::Packet(Packet::PublishReceived { packet_id })))
485 } else {
486 inner.inflight.borrow_mut().remove(&packet_id);
487 Ok(Some(Encoded::Packet(Packet::PublishAck { packet_id })))
488 }
489 } else {
490 Ok(None)
491 }
492 }
493 Err(e) => control(Control::error(e.into()), inner).await,
494 }
495}
496
497async fn control<C, E>(
498 mut pkt: Control<E>,
499 inner: &Inner<C>,
500) -> Result<Option<Encoded>, MqttError<E>>
501where
502 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
503{
504 let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_));
505
506 loop {
507 match inner.control.call(pkt).await {
508 Ok(item) => {
509 let packet = match item.result {
510 ControlAckKind::Ping => Some(Encoded::Packet(Packet::PingResponse)),
511 ControlAckKind::Subscribe(res) => {
512 inner.inflight.borrow_mut().remove(&res.packet_id);
513 Some(Encoded::Packet(Packet::SubscribeAck {
514 status: res.codes,
515 packet_id: res.packet_id,
516 }))
517 }
518 ControlAckKind::Unsubscribe(res) => {
519 inner.inflight.borrow_mut().remove(&res.packet_id);
520 Some(Encoded::Packet(Packet::UnsubscribeAck {
521 packet_id: res.packet_id,
522 }))
523 }
524 ControlAckKind::Disconnect => {
525 inner.drop_payload(&PayloadError::Service);
526 inner.sink.close();
527 None
528 }
529 ControlAckKind::Closed | ControlAckKind::Nothing => None,
530 ControlAckKind::PublishRelease(packet_id) => {
531 inner.inflight.borrow_mut().remove(&packet_id);
532 Some(Encoded::Packet(Packet::PublishComplete { packet_id }))
533 }
534 ControlAckKind::PublishAck(_) => unreachable!(),
535 };
536 return Ok(packet);
537 }
538 Err(err) => {
539 inner.drop_payload(&PayloadError::Service);
540
541 return if error {
543 Err(err)
544 } else {
545 match err {
547 MqttError::Service(err) => {
548 error = true;
549 pkt = Control::error(err);
550 continue;
551 }
552 _ => Err(err),
553 }
554 };
555 }
556 }
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use std::{future::Future, pin::Pin};
563
564 use ntex_bytes::{ByteString, Bytes};
565 use ntex_io::{Io, testing::IoTest};
566 use ntex_service::{cfg::SharedCfg, fn_service};
567 use ntex_util::future::{Ready, lazy};
568 use ntex_util::time::{Seconds, sleep};
569
570 use super::*;
571 use crate::v3::{MqttSink, codec};
572
573 #[ntex::test]
574 async fn test_dup_packet_id() {
575 let cfg: SharedCfg = SharedCfg::new("DBG")
576 .add(MqttServiceConfig::new().set_max_qos(QoS::AtLeastOnce))
577 .into();
578
579 let io = Io::new(IoTest::create().0, cfg);
580 let codec = codec::Codec::default();
581 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
582 let err = Rc::new(RefCell::new(false));
583 let err2 = err.clone();
584
585 let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
586 shared.clone(),
587 fn_service(|_| async {
588 sleep(Seconds(10)).await;
589 Ok(())
590 }),
591 fn_service(move |ctrl| {
592 if let Control::ProtocolError(_) = ctrl {
593 *err2.borrow_mut() = true;
594 }
595 Ready::Ok(ControlAck { result: ControlAckKind::Nothing })
596 }),
597 cfg.get(),
598 ));
599
600 let mut f: Pin<Box<dyn Future<Output = Result<_, _>>>> =
601 Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
602 codec::Publish {
603 dup: false,
604 retain: false,
605 qos: QoS::AtLeastOnce,
606 topic: ByteString::new(),
607 packet_id: NonZeroU16::new(1),
608 payload_size: 0,
609 },
610 Bytes::new(),
611 999,
612 ))));
613 let _ = lazy(|cx| Pin::new(&mut f).poll(cx)).await;
614
615 let f = Box::pin(disp.call(DispatchItem::Item(Decoded::Publish(
616 codec::Publish {
617 dup: false,
618 retain: false,
619 qos: QoS::AtLeastOnce,
620 topic: ByteString::new(),
621 packet_id: NonZeroU16::new(1),
622 payload_size: 0,
623 },
624 Bytes::new(),
625 999,
626 ))));
627 assert!(f.await.unwrap().is_none());
628 assert!(*err.borrow());
629 }
630
631 #[ntex::test]
632 async fn test_wr_backpressure() {
633 let cfg: SharedCfg = SharedCfg::new("DBG")
634 .add(MqttServiceConfig::new().set_max_qos(QoS::AtLeastOnce))
635 .into();
636
637 let io = Io::new(IoTest::create().0, cfg);
638 let codec = codec::Codec::default();
639 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
640
641 let disp = Pipeline::new(Dispatcher::<_, _, ()>::new(
642 shared.clone(),
643 fn_service(|_| Ready::Ok(())),
644 fn_service(|_| Ready::Ok(ControlAck { result: ControlAckKind::Nothing })),
645 cfg.get(),
646 ));
647
648 let sink = MqttSink::new(shared.clone());
649 assert!(!sink.is_ready());
650 shared.set_cap(1);
651 assert!(sink.is_ready());
652 assert!(shared.wait_readiness().is_none());
653
654 disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
655 assert!(!sink.is_ready());
656 let rx = shared.wait_readiness();
657 let rx2 = shared.wait_readiness().unwrap();
658 assert!(rx.is_some());
659
660 let rx = rx.unwrap();
661 disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
662 assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
663 assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
664 }
665}