1use std::{cell::Cell, cell::RefCell, marker, num, rc::Rc, task::Context};
2
3use ntex_bytes::ByteString;
4use ntex_io::DispatchItem;
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{self as service, Pipeline, Service, ServiceCtx, ServiceFactory};
7use ntex_util::services::inflight::InFlightService;
8use ntex_util::services::{buffer::BufferService, buffer::BufferServiceError};
9use ntex_util::{HashMap, HashSet, future::join};
10
11use crate::error::{DecodeError, HandshakeError, MqttError, PayloadError, ProtocolError};
12use crate::payload::{Payload, PayloadStatus, PlSender};
13use crate::{MqttServiceConfig, types::QoS};
14
15use super::Session;
16use super::codec::{self, Decoded, DisconnectReasonCode, Encoded, Packet};
17use super::control::{Control, ControlAck};
18use super::publish::{Publish, PublishAck};
19use super::shared::{Ack, MqttShared};
20
21pub(super) fn factory<St, T, C, E>(
23 publish: T,
24 control: C,
25) -> impl ServiceFactory<
26 DispatchItem<Rc<MqttShared>>,
27 (SharedCfg, Session<St>),
28 Response = Option<Encoded>,
29 Error = MqttError<E>,
30 InitError = MqttError<E>,
31>
32where
33 St: 'static,
34 E: From<T::Error> + From<T::InitError> + From<C::Error> + From<C::InitError> + 'static,
35 T: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
36 C: ServiceFactory<Control<E>, Session<St>, Response = ControlAck> + 'static,
37 PublishAck: TryFrom<T::Error, Error = E>,
38{
39 let factories = Rc::new((publish, control));
40
41 service::fn_factory_with_config(async move |(cfg, ses): (SharedCfg, Session<St>)| {
42 let cfg: Cfg<MqttServiceConfig> = cfg.get();
43
44 let sink = ses.sink().shared();
46 let (publish, control) =
47 join(factories.0.create(ses.clone()), factories.1.create(ses)).await;
48
49 let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
50 let control = control.map_err(|e| MqttError::Service(e.into()))?;
51
52 let control = BufferService::new(
53 16,
54 InFlightService::new(1, control),
56 )
57 .map_err(|err| match err {
58 BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
59 BufferServiceError::RequestCanceled => {
60 MqttError::Handshake(HandshakeError::Disconnected(None))
61 }
62 });
63
64 Ok(Dispatcher::<_, _, E>::new(sink, publish, control, cfg))
65 })
66}
67
68impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
69 fn size(&self) -> u32 {
70 match self {
71 DispatchItem::Item(Decoded::Packet(_, size))
72 | DispatchItem::Item(Decoded::Publish(_, _, size)) => *size,
73 _ => 0,
74 }
75 }
76
77 fn is_publish(&self) -> bool {
78 matches!(self, DispatchItem::Item(Decoded::Publish(..)))
79 }
80
81 fn is_chunk(&self) -> bool {
82 matches!(self, DispatchItem::Item(Decoded::PayloadChunk(..)))
83 }
84}
85
86pub(crate) struct Dispatcher<T, C: Service<Control<E>>, E> {
88 publish: T,
89 inner: Rc<Inner<C>>,
90 cfg: Cfg<MqttServiceConfig>,
91 _t: marker::PhantomData<E>,
92}
93
94struct Inner<C> {
95 control: Pipeline<C>,
96 sink: Rc<MqttShared>,
97 info: RefCell<PublishInfo>,
98 payload: Cell<Option<PlSender>>,
99}
100
101struct PublishInfo {
102 inflight: HashSet<num::NonZeroU16>,
103 aliases: HashMap<num::NonZeroU16, ByteString>,
104}
105
106impl<T, C, E> Dispatcher<T, C, E>
107where
108 E: From<T::Error>,
109 T: Service<Publish, Response = PublishAck>,
110 PublishAck: TryFrom<T::Error, Error = E>,
111 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
112{
113 fn new(sink: Rc<MqttShared>, publish: T, control: C, cfg: Cfg<MqttServiceConfig>) -> Self {
114 Self {
115 cfg,
116 publish,
117 inner: Rc::new(Inner {
118 sink,
119 payload: Cell::new(None),
120 control: Pipeline::new(control),
121 info: RefCell::new(PublishInfo {
122 aliases: HashMap::default(),
123 inflight: HashSet::default(),
124 }),
125 }),
126 _t: marker::PhantomData,
127 }
128 }
129
130 fn tag(&self) -> &'static str {
131 self.inner.sink.tag()
132 }
133}
134
135impl<C> Inner<C> {
136 fn drop_payload<PErr>(&self, err: &PErr)
137 where
138 PErr: Clone,
139 PayloadError: From<PErr>,
140 {
141 if let Some(pl) = self.payload.take() {
142 pl.set_error(err.clone().into());
143 }
144 }
145}
146
147impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
148where
149 E: From<T::Error> + 'static,
150 T: Service<Publish, Response = PublishAck> + 'static,
151 PublishAck: TryFrom<T::Error, Error = E>,
152 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>> + 'static,
153{
154 type Response = Option<Encoded>;
155 type Error = MqttError<E>;
156
157 async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
158 let (res1, res2) = join(ctx.ready(&self.publish), self.inner.control.ready()).await;
159 let result = if let Err(e) = res1 {
160 if res2.is_err() {
161 Err(MqttError::Service(e.into()))
162 } else {
163 match self.inner.control.call(Control::error(e.into())).await {
164 Ok(res) => {
165 if res.disconnect {
166 self.inner.sink.drop_sink();
167 }
168 Ok(())
169 }
170 Err(err) => Err(err),
171 }
172 }
173 } else {
174 res2
175 };
176
177 if result.is_ok() {
178 if let Some(pl) = self.inner.payload.take() {
179 self.inner.payload.set(Some(pl.clone()));
180 if pl.ready().await != PayloadStatus::Ready {
181 self.inner.sink.force_close();
182 }
183 }
184 }
185 result
186 }
187
188 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
189 if let Err(e) = self.publish.poll(cx) {
190 let inner = self.inner.clone();
191 ntex_rt::spawn(async move {
192 if let Ok(res) = inner.control.call(Control::error(e.into())).await {
193 if res.disconnect {
194 inner.sink.drop_sink();
195 }
196 }
197 });
198 }
199 self.inner.control.poll(cx)
200 }
201
202 async fn shutdown(&self) {
203 log::trace!("{}: Shutdown v5 dispatcher", self.tag());
204 self.inner.drop_payload(&PayloadError::Disconnected);
205 self.inner.sink.drop_sink();
206 let _ = self.inner.control.call(Control::closed()).await;
207
208 self.publish.shutdown().await;
209 self.inner.control.shutdown().await;
210 }
211
212 #[allow(clippy::await_holding_refcell_ref)]
213 async fn call(
214 &self,
215 request: DispatchItem<Rc<MqttShared>>,
216 ctx: ServiceCtx<'_, Self>,
217 ) -> Result<Self::Response, Self::Error> {
218 log::trace!("{}: Dispatch v5 packet: {:#?}", self.tag(), request);
219
220 match request {
221 DispatchItem::Item(Decoded::Publish(mut publish, payload, size)) => {
222 let info = self.inner.as_ref();
223 let packet_id = publish.packet_id;
224
225 if publish.topic.contains(['#', '+']) {
226 return control(
227 Control::proto_error(
228 ProtocolError::generic_violation(
229 "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]"
230 )
231 ),
232 &self.inner,
233 0,
234 ).await;
235 }
236
237 {
238 let mut inner = info.info.borrow_mut();
239 let state = &self.inner.sink;
240
241 if let Some(pid) = packet_id {
242 let receive_max = state.receive_max();
244 if receive_max != 0 && inner.inflight.len() >= receive_max as usize {
245 log::trace!(
246 "{}: Receive maximum exceeded: max: {} in-flight: {}",
247 self.tag(),
248 receive_max,
249 inner.inflight.len()
250 );
251 drop(inner);
252 return control(
253 Control::proto_error(
254 ProtocolError::violation(
255 DisconnectReasonCode::ReceiveMaximumExceeded,
256 "Number of in-flight messages exceeds set maximum [MQTT-3.3.4-7]"
257 )
258 ),
259 &self.inner,
260 0,
261 ).await;
262 }
263
264 if publish.qos > state.max_qos() {
266 log::trace!(
267 "{}: Max allowed QoS is violated, max {:?} provided {:?}",
268 self.tag(),
269 state.max_qos(),
270 publish.qos
271 );
272 drop(inner);
273 return control(
274 Control::proto_error(ProtocolError::violation(
275 DisconnectReasonCode::QosNotSupported,
276 "PUBLISH QoS is higher than supported [MQTT-3.2.2-11]",
277 )),
278 &self.inner,
279 0,
280 )
281 .await;
282 }
283 if publish.retain && !state.codec.retain_available() {
284 log::trace!("{}: Retain is not available but is set", self.tag());
285 drop(inner);
286 return control(
287 Control::proto_error(ProtocolError::violation(
288 DisconnectReasonCode::RetainNotSupported,
289 "RETAIN is not supported [MQTT-3.2.2-14]",
290 )),
291 &self.inner,
292 0,
293 )
294 .await;
295 }
296
297 if !inner.inflight.insert(pid) {
299 let _ = self.inner.sink.encode_packet(codec::Packet::PublishAck(
300 codec::PublishAck {
301 packet_id: pid,
302 reason_code: codec::PublishAckReason::PacketIdentifierInUse,
303 ..Default::default()
304 },
305 ));
306 return Ok(None);
307 }
308 }
309
310 if let Some(alias) = publish.properties.topic_alias {
312 if publish.topic.is_empty() {
313 match inner.aliases.get(&alias) {
315 Some(aliased_topic) => publish.topic = aliased_topic.clone(),
316 None => {
317 drop(inner);
318 return control(
319 Control::proto_error(ProtocolError::violation(
320 DisconnectReasonCode::TopicAliasInvalid,
321 "Unknown topic alias",
322 )),
323 &self.inner,
324 0,
325 )
326 .await;
327 }
328 }
329 } else {
330 match inner.aliases.entry(alias) {
332 std::collections::hash_map::Entry::Occupied(mut entry) => {
333 if entry.get().as_str() != publish.topic.as_str() {
334 let mut topic = publish.topic.clone();
335 topic.trimdown();
336 entry.insert(topic);
337 }
338 }
339 std::collections::hash_map::Entry::Vacant(entry) => {
340 if alias.get() > state.topic_alias_max() {
341 drop(inner);
342 return control(
343 Control::proto_error(
344 ProtocolError::generic_violation(
345 "Topic alias is greater than max allowed [MQTT-3.2.2-17]",
346 )
347 ),
348 &self.inner,
349 0,
350 ).await;
351 }
352 let mut topic = publish.topic.clone();
353 topic.trimdown();
354 entry.insert(topic);
355 }
356 }
357 }
358 }
359
360 if state.is_closed()
361 && !self
362 .cfg
363 .handle_qos_after_disconnect
364 .map(|max_qos| publish.qos <= max_qos)
365 .unwrap_or_default()
366 {
367 return Ok(None);
368 }
369 }
370
371 let payload = if publish.payload_size == payload.len() as u32 {
372 Payload::from_bytes(payload)
373 } else {
374 let (pl, sender) =
375 Payload::from_stream(payload, self.cfg.max_payload_buffer_size);
376 self.inner.payload.set(Some(sender));
377 pl
378 };
379
380 publish_fn(
381 &self.publish,
382 Publish::new(publish, payload, size),
383 packet_id.map(|v| v.get()).unwrap_or(0),
384 info,
385 ctx,
386 )
387 .await
388 }
389 DispatchItem::Item(Decoded::PayloadChunk(buf, eof)) => {
390 if let Some(pl) = self.inner.payload.take() {
391 pl.feed_data(buf);
392 if eof {
393 pl.feed_eof();
394 } else {
395 self.inner.payload.set(Some(pl));
396 }
397 Ok(None)
398 } else {
399 control(
400 Control::proto_error(ProtocolError::Decode(
401 DecodeError::UnexpectedPayload,
402 )),
403 &self.inner,
404 0,
405 )
406 .await
407 }
408 }
409 DispatchItem::Item(Decoded::Packet(Packet::PublishAck(packet), _)) => {
410 if let Err(err) = self.inner.sink.pkt_ack(Ack::Publish(packet)) {
411 control(Control::proto_error(err), &self.inner, 0).await
412 } else {
413 Ok(None)
414 }
415 }
416 DispatchItem::Item(Decoded::Packet(Packet::PublishReceived(pkt), _)) => {
417 if let Err(e) = self.inner.sink.pkt_ack(Ack::Receive(pkt)) {
418 control(Control::proto_error(e), &self.inner, 0).await
419 } else {
420 Ok(None)
421 }
422 }
423 DispatchItem::Item(Decoded::Packet(Packet::PublishRelease(ack), size)) => {
424 if self.inner.info.borrow().inflight.contains(&ack.packet_id) {
425 control(Control::pubrel(ack, size), &self.inner, 0).await
426 } else {
427 Ok(Some(Encoded::Packet(codec::Packet::PublishComplete(
428 codec::PublishAck2 {
429 packet_id: ack.packet_id,
430 reason_code: codec::PublishAck2Reason::PacketIdNotFound,
431 properties: codec::UserProperties::default(),
432 reason_string: None,
433 },
434 ))))
435 }
436 }
437 DispatchItem::Item(Decoded::Packet(Packet::PublishComplete(pkt), _)) => {
438 if let Err(e) = self.inner.sink.pkt_ack(Ack::Complete(pkt)) {
439 control(Control::proto_error(e), &self.inner, 0).await
440 } else {
441 Ok(None)
442 }
443 }
444 DispatchItem::Item(Decoded::Packet(Packet::Auth(pkt), size)) => {
445 if self.inner.sink.is_closed() {
446 return Ok(None);
447 }
448
449 control(Control::auth(pkt, size), &self.inner, 0).await
450 }
451 DispatchItem::Item(Decoded::Packet(Packet::PingRequest, _)) => {
452 control(Control::ping(), &self.inner, 0).await
453 }
454 DispatchItem::Item(Decoded::Packet(Packet::Disconnect(pkt), size)) => {
455 control(Control::remote_disconnect(pkt, size), &self.inner, 0).await
456 }
457 DispatchItem::Item(Decoded::Packet(Packet::Subscribe(pkt), size)) => {
458 if self.inner.sink.is_closed() {
459 return Ok(None);
460 }
461
462 if pkt.topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
463 return control(
464 Control::proto_error(ProtocolError::generic_violation(
465 "Topic filter is malformed [MQTT-4.7.1-*]",
466 )),
467 &self.inner,
468 0,
469 )
470 .await;
471 }
472
473 if pkt.id.is_some() && !self.inner.sink.codec.sub_ids_available() {
474 log::trace!(
475 "{}: Subscription Identifiers are not supported but was set",
476 self.tag()
477 );
478 return control(
479 Control::proto_error(ProtocolError::violation(
480 DisconnectReasonCode::SubscriptionIdentifiersNotSupported,
481 "Subscription Identifiers are not supported",
482 )),
483 &self.inner,
484 0,
485 )
486 .await;
487 }
488
489 if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
491 let _ = self.inner.sink.encode_packet(codec::Packet::SubscribeAck(
493 codec::SubscribeAck {
494 packet_id: pkt.packet_id,
495 status: pkt
496 .topic_filters
497 .iter()
498 .map(|_| codec::SubscribeAckReason::PacketIdentifierInUse)
499 .collect(),
500 properties: codec::UserProperties::new(),
501 reason_string: None,
502 },
503 ));
504 return Ok(None);
505 }
506 let id = pkt.packet_id;
507 control(Control::subscribe(pkt, size), &self.inner, id.get()).await
508 }
509 DispatchItem::Item(Decoded::Packet(Packet::Unsubscribe(pkt), size)) => {
510 if self.inner.sink.is_closed() {
511 return Ok(None);
512 }
513
514 if pkt.topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
515 return control(
516 Control::proto_error(ProtocolError::generic_violation(
517 "Topic filter is malformed [MQTT-4.7.1-*]",
518 )),
519 &self.inner,
520 0,
521 )
522 .await;
523 }
524
525 if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
527 let _ = self.inner.sink.encode_packet(codec::Packet::UnsubscribeAck(
529 codec::UnsubscribeAck {
530 packet_id: pkt.packet_id,
531 status: pkt
532 .topic_filters
533 .iter()
534 .map(|_| codec::UnsubscribeAckReason::PacketIdentifierInUse)
535 .collect(),
536 properties: codec::UserProperties::new(),
537 reason_string: None,
538 },
539 ));
540 return Ok(None);
541 }
542 let id = pkt.packet_id;
543 control(Control::unsubscribe(pkt, size), &self.inner, id.get()).await
544 }
545 DispatchItem::Item(Decoded::Packet(_, _)) => Ok(None),
546 DispatchItem::EncoderError(err) => {
547 let err = ProtocolError::Encode(err);
548 self.inner.drop_payload(&err);
549 control(Control::proto_error(err), &self.inner, 0).await
550 }
551 DispatchItem::KeepAliveTimeout => {
552 self.inner.drop_payload(&ProtocolError::KeepAliveTimeout);
553 control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, 0)
554 .await
555 }
556 DispatchItem::ReadTimeout => {
557 self.inner.drop_payload(&ProtocolError::ReadTimeout);
558 control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, 0).await
559 }
560 DispatchItem::DecoderError(err) => {
561 let err = ProtocolError::Decode(err);
562 self.inner.drop_payload(&err);
563 control(Control::proto_error(err), &self.inner, 0).await
564 }
565 DispatchItem::Disconnect(err) => {
566 self.inner.drop_payload(&PayloadError::Disconnected);
567 control(Control::peer_gone(err), &self.inner, 0).await
568 }
569 DispatchItem::WBackPressureEnabled => {
570 self.inner.sink.enable_wr_backpressure();
571 control(Control::wr_backpressure(true), &self.inner, 0).await
572 }
573 DispatchItem::WBackPressureDisabled => {
574 self.inner.sink.disable_wr_backpressure();
575 control(Control::wr_backpressure(false), &self.inner, 0).await
576 }
577 }
578 }
579}
580
581async fn publish_fn<'f, T, C, E>(
583 publish: &T,
584 pkt: Publish,
585 packet_id: u16,
586 inner: &'f Inner<C>,
587 ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
588) -> Result<Option<Encoded>, MqttError<E>>
589where
590 E: From<T::Error>,
591 T: Service<Publish, Response = PublishAck>,
592 PublishAck: TryFrom<T::Error, Error = E>,
593 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
594{
595 let qos2 = pkt.qos() == QoS::ExactlyOnce;
596 let ack = match ctx.call(publish, pkt).await {
597 Ok(ack) => ack,
598 Err(e) => {
599 if packet_id != 0 {
600 match PublishAck::try_from(e) {
601 Ok(ack) => ack,
602 Err(e) => return control(Control::error(e), inner, 0).await,
603 }
604 } else {
605 return control(Control::error(e.into()), inner, 0).await;
606 }
607 }
608 };
609 if let Some(id) = num::NonZeroU16::new(packet_id) {
610 let ack = if qos2 {
611 codec::Packet::PublishReceived(codec::PublishAck {
612 packet_id: id,
613 reason_code: ack.reason_code,
614 reason_string: ack.reason_string,
615 properties: ack.properties,
616 })
617 } else {
618 inner.info.borrow_mut().inflight.remove(&id);
619 codec::Packet::PublishAck(codec::PublishAck {
620 packet_id: id,
621 reason_code: ack.reason_code,
622 reason_string: ack.reason_string,
623 properties: ack.properties,
624 })
625 };
626 Ok(Some(Encoded::Packet(ack)))
627 } else {
628 Ok(None)
629 }
630}
631
632async fn control<C, E>(
633 pkt: Control<E>,
634 inner: &Inner<C>,
635 packet_id: u16,
636) -> Result<Option<Encoded>, MqttError<E>>
637where
638 C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
639{
640 let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_));
641
642 let result = match inner.control.call(pkt).await {
643 Ok(result) => {
644 if let Some(id) = num::NonZeroU16::new(packet_id) {
645 inner.info.borrow_mut().inflight.remove(&id);
646 }
647 result
648 }
649 Err(err) => {
650 inner.drop_payload(&PayloadError::Service);
651
652 if error {
654 inner.sink.drop_sink();
655 return Err(err);
656 } else {
657 match err {
659 MqttError::Service(err) => {
660 error = true;
661 inner.control.call(Control::error(err)).await?
662 }
663 _ => return Err(err),
664 }
665 }
666 }
667 };
668
669 let response = if error {
670 if let Some(pkt) = result.packet {
671 let _ = inner.sink.encode_packet(pkt);
672 }
673 Ok(None)
674 } else {
675 Ok(result.packet.map(Encoded::Packet))
676 };
677
678 if result.disconnect {
679 inner.drop_payload(&PayloadError::Service);
680 inner.sink.drop_sink();
681 }
682 response
683}
684
685#[cfg(test)]
686mod tests {
687 use ntex_io::{Io, testing::IoTest};
688 use ntex_service::{cfg::SharedCfg, fn_service};
689 use ntex_util::future::{Ready, lazy};
690
691 use super::*;
692 use crate::v5::MqttSink;
693
694 #[derive(Debug)]
695 struct TestError;
696
697 impl TryFrom<TestError> for PublishAck {
698 type Error = TestError;
699
700 fn try_from(err: TestError) -> Result<Self, Self::Error> {
701 Err(err)
702 }
703 }
704
705 #[ntex::test]
706 async fn test_wr_backpressure() {
707 let io = Io::new(IoTest::create().0, SharedCfg::new("DBG"));
708 let codec = codec::Codec::default();
709 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, Default::default()));
710
711 let disp = Pipeline::new(Dispatcher::<_, _, _>::new(
712 shared.clone(),
713 fn_service(|p: Publish| Ready::Ok::<_, TestError>(p.ack())),
714 fn_service(|_| {
715 Ready::Ok::<_, MqttError<TestError>>(ControlAck {
716 packet: None,
717 disconnect: false,
718 })
719 }),
720 Default::default(),
721 ));
722
723 let sink = MqttSink::new(shared.clone());
724 assert!(!sink.is_ready());
725 shared.set_cap(1);
726 assert!(sink.is_ready());
727 assert!(shared.wait_readiness().is_none());
728
729 disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
730 assert!(!sink.is_ready());
731 let rx = shared.wait_readiness();
732 let rx2 = shared.wait_readiness().unwrap();
733 assert!(rx.is_some());
734
735 let rx = rx.unwrap();
736 disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
737 assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
738 assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
739 }
740}