ntex_mqtt/v5/
dispatcher.rs

1use std::{cell::Cell, cell::RefCell, marker, num, rc::Rc, task::Context};
2
3use ntex_bytes::ByteString;
4use ntex_io::DispatchItem;
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{self as service, Pipeline, Service, ServiceCtx, ServiceFactory};
7use ntex_util::services::inflight::InFlightService;
8use ntex_util::services::{buffer::BufferService, buffer::BufferServiceError};
9use ntex_util::{HashMap, HashSet, future::join};
10
11use crate::error::{DecodeError, HandshakeError, MqttError, PayloadError, ProtocolError};
12use crate::payload::{Payload, PayloadStatus, PlSender};
13use crate::{MqttServiceConfig, types::QoS};
14
15use super::Session;
16use super::codec::{self, Decoded, DisconnectReasonCode, Encoded, Packet};
17use super::control::{Control, ControlAck};
18use super::publish::{Publish, PublishAck};
19use super::shared::{Ack, MqttShared};
20
21/// MQTT 5 protocol dispatcher
22pub(super) fn factory<St, T, C, E>(
23    publish: T,
24    control: C,
25) -> impl ServiceFactory<
26    DispatchItem<Rc<MqttShared>>,
27    (SharedCfg, Session<St>),
28    Response = Option<Encoded>,
29    Error = MqttError<E>,
30    InitError = MqttError<E>,
31>
32where
33    St: 'static,
34    E: From<T::Error> + From<T::InitError> + From<C::Error> + From<C::InitError> + 'static,
35    T: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
36    C: ServiceFactory<Control<E>, Session<St>, Response = ControlAck> + 'static,
37    PublishAck: TryFrom<T::Error, Error = E>,
38{
39    let factories = Rc::new((publish, control));
40
41    service::fn_factory_with_config(async move |(cfg, ses): (SharedCfg, Session<St>)| {
42        let cfg: Cfg<MqttServiceConfig> = cfg.get();
43
44        // create services
45        let sink = ses.sink().shared();
46        let (publish, control) =
47            join(factories.0.create(ses.clone()), factories.1.create(ses)).await;
48
49        let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
50        let control = control.map_err(|e| MqttError::Service(e.into()))?;
51
52        let control = BufferService::new(
53            16,
54            // limit number of in-flight messages
55            InFlightService::new(1, control),
56        )
57        .map_err(|err| match err {
58            BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
59            BufferServiceError::RequestCanceled => {
60                MqttError::Handshake(HandshakeError::Disconnected(None))
61            }
62        });
63
64        Ok(Dispatcher::<_, _, E>::new(sink, publish, control, cfg.handle_qos_after_disconnect))
65    })
66}
67
68impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
69    fn size(&self) -> u32 {
70        match self {
71            DispatchItem::Item(Decoded::Packet(_, size))
72            | DispatchItem::Item(Decoded::Publish(_, _, size)) => *size,
73            _ => 0,
74        }
75    }
76
77    fn is_publish(&self) -> bool {
78        matches!(self, DispatchItem::Item(Decoded::Publish(..)))
79    }
80
81    fn is_chunk(&self) -> bool {
82        matches!(self, DispatchItem::Item(Decoded::PayloadChunk(..)))
83    }
84}
85
86/// Mqtt protocol dispatcher
87pub(crate) struct Dispatcher<T, C: Service<Control<E>>, E> {
88    publish: T,
89    handle_qos_after_disconnect: Option<QoS>,
90    inner: Rc<Inner<C>>,
91    _t: marker::PhantomData<E>,
92}
93
94struct Inner<C> {
95    control: Pipeline<C>,
96    sink: Rc<MqttShared>,
97    info: RefCell<PublishInfo>,
98    payload: Cell<Option<PlSender>>,
99}
100
101struct PublishInfo {
102    inflight: HashSet<num::NonZeroU16>,
103    aliases: HashMap<num::NonZeroU16, ByteString>,
104}
105
106impl<T, C, E> Dispatcher<T, C, E>
107where
108    E: From<T::Error>,
109    T: Service<Publish, Response = PublishAck>,
110    PublishAck: TryFrom<T::Error, Error = E>,
111    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
112{
113    fn new(
114        sink: Rc<MqttShared>,
115        publish: T,
116        control: C,
117        handle_qos_after_disconnect: Option<QoS>,
118    ) -> Self {
119        Self {
120            publish,
121            handle_qos_after_disconnect,
122            inner: Rc::new(Inner {
123                sink,
124                payload: Cell::new(None),
125                control: Pipeline::new(control),
126                info: RefCell::new(PublishInfo {
127                    aliases: HashMap::default(),
128                    inflight: HashSet::default(),
129                }),
130            }),
131            _t: marker::PhantomData,
132        }
133    }
134
135    fn tag(&self) -> &'static str {
136        self.inner.sink.tag()
137    }
138}
139
140impl<C> Inner<C> {
141    fn drop_payload<PErr>(&self, err: &PErr)
142    where
143        PErr: Clone,
144        PayloadError: From<PErr>,
145    {
146        if let Some(pl) = self.payload.take() {
147            pl.set_error(err.clone().into());
148        }
149    }
150}
151
152impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
153where
154    E: From<T::Error> + 'static,
155    T: Service<Publish, Response = PublishAck> + 'static,
156    PublishAck: TryFrom<T::Error, Error = E>,
157    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>> + 'static,
158{
159    type Response = Option<Encoded>;
160    type Error = MqttError<E>;
161
162    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
163        let (res1, res2) =
164            join(ctx.ready(&self.publish), ctx.ready(self.inner.control.get_ref())).await;
165        let result = if let Err(e) = res1 {
166            if res2.is_err() {
167                Err(MqttError::Service(e.into()))
168            } else {
169                match ctx
170                    .call_nowait(self.inner.control.get_ref(), Control::error(e.into()))
171                    .await
172                {
173                    Ok(res) => {
174                        if res.disconnect {
175                            self.inner.sink.drop_sink();
176                        }
177                        Ok(())
178                    }
179                    Err(err) => Err(err),
180                }
181            }
182        } else {
183            res2
184        };
185
186        if result.is_ok() {
187            if let Some(pl) = self.inner.payload.take() {
188                if pl.ready().await == PayloadStatus::Ready {
189                    self.inner.payload.set(Some(pl));
190                } else {
191                    self.inner.sink.force_close();
192                }
193            }
194        }
195        result
196    }
197
198    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
199        if let Err(e) = self.publish.poll(cx) {
200            let inner = self.inner.clone();
201            ntex_rt::spawn(async move {
202                if let Ok(res) = inner.control.call_nowait(Control::error(e.into())).await {
203                    if res.disconnect {
204                        inner.sink.drop_sink();
205                    }
206                }
207            });
208        }
209        self.inner.control.poll(cx)
210    }
211
212    async fn shutdown(&self) {
213        self.inner.drop_payload(&PayloadError::Disconnected);
214        self.inner.sink.drop_sink();
215        let _ = self.inner.control.call(Control::closed()).await;
216
217        self.publish.shutdown().await;
218        self.inner.control.shutdown().await;
219    }
220
221    #[allow(clippy::await_holding_refcell_ref)]
222    async fn call(
223        &self,
224        request: DispatchItem<Rc<MqttShared>>,
225        ctx: ServiceCtx<'_, Self>,
226    ) -> Result<Self::Response, Self::Error> {
227        log::trace!("{}: Dispatch v5 packet: {:#?}", self.tag(), request);
228
229        match request {
230            DispatchItem::Item(Decoded::Publish(mut publish, payload, size)) => {
231                let info = self.inner.as_ref();
232                let packet_id = publish.packet_id;
233
234                if publish.topic.contains(['#', '+']) {
235                    return control(
236                        Control::proto_error(
237                            ProtocolError::generic_violation(
238                                "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]"
239                            )
240                        ),
241                        &self.inner,
242                        ctx,
243                        0,
244                    ).await;
245                }
246
247                {
248                    let mut inner = info.info.borrow_mut();
249                    let state = &self.inner.sink;
250
251                    if let Some(pid) = packet_id {
252                        // check for receive maximum
253                        let receive_max = state.receive_max();
254                        if receive_max != 0 && inner.inflight.len() >= receive_max as usize {
255                            log::trace!(
256                                "{}: Receive maximum exceeded: max: {} in-flight: {}",
257                                self.tag(),
258                                receive_max,
259                                inner.inflight.len()
260                            );
261                            drop(inner);
262                            return control(
263                                Control::proto_error(
264                                    ProtocolError::violation(
265                                        DisconnectReasonCode::ReceiveMaximumExceeded,
266                                        "Number of in-flight messages exceeds set maximum [MQTT-3.3.4-7]"
267                                    )
268                                ),
269                                &self.inner,
270                                ctx,
271                                0,
272                            ).await;
273                        }
274
275                        // check max allowed qos
276                        if publish.qos > state.max_qos() {
277                            log::trace!(
278                                "{}: Max allowed QoS is violated, max {:?} provided {:?}",
279                                self.tag(),
280                                state.max_qos(),
281                                publish.qos
282                            );
283                            drop(inner);
284                            return control(
285                                Control::proto_error(ProtocolError::violation(
286                                    DisconnectReasonCode::QosNotSupported,
287                                    "PUBLISH QoS is higher than supported [MQTT-3.2.2-11]",
288                                )),
289                                &self.inner,
290                                ctx,
291                                0,
292                            )
293                            .await;
294                        }
295                        if publish.retain && !state.codec.retain_available() {
296                            log::trace!("{}: Retain is not available but is set", self.tag());
297                            drop(inner);
298                            return control(
299                                Control::proto_error(ProtocolError::violation(
300                                    DisconnectReasonCode::RetainNotSupported,
301                                    "RETAIN is not supported [MQTT-3.2.2-14]",
302                                )),
303                                &self.inner,
304                                ctx,
305                                0,
306                            )
307                            .await;
308                        }
309
310                        // check for duplicated packet id
311                        if !inner.inflight.insert(pid) {
312                            let _ = self.inner.sink.encode_packet(codec::Packet::PublishAck(
313                                codec::PublishAck {
314                                    packet_id: pid,
315                                    reason_code: codec::PublishAckReason::PacketIdentifierInUse,
316                                    ..Default::default()
317                                },
318                            ));
319                            return Ok(None);
320                        }
321                    }
322
323                    // handle topic aliases
324                    if let Some(alias) = publish.properties.topic_alias {
325                        if publish.topic.is_empty() {
326                            // lookup topic by provided alias
327                            match inner.aliases.get(&alias) {
328                                Some(aliased_topic) => publish.topic = aliased_topic.clone(),
329                                None => {
330                                    drop(inner);
331                                    return control(
332                                        Control::proto_error(ProtocolError::violation(
333                                            DisconnectReasonCode::TopicAliasInvalid,
334                                            "Unknown topic alias",
335                                        )),
336                                        &self.inner,
337                                        ctx,
338                                        0,
339                                    )
340                                    .await;
341                                }
342                            }
343                        } else {
344                            // record new alias
345                            match inner.aliases.entry(alias) {
346                                std::collections::hash_map::Entry::Occupied(mut entry) => {
347                                    if entry.get().as_str() != publish.topic.as_str() {
348                                        let mut topic = publish.topic.clone();
349                                        topic.trimdown();
350                                        entry.insert(topic);
351                                    }
352                                }
353                                std::collections::hash_map::Entry::Vacant(entry) => {
354                                    if alias.get() > state.topic_alias_max() {
355                                        drop(inner);
356                                        return control(
357                                                Control::proto_error(
358                                                    ProtocolError::generic_violation(
359                                                        "Topic alias is greater than max allowed [MQTT-3.2.2-17]",
360                                                    )
361                                                ),
362                                                &self.inner,
363                                            ctx,
364                                            0,
365                                            ).await;
366                                    }
367                                    let mut topic = publish.topic.clone();
368                                    topic.trimdown();
369                                    entry.insert(topic);
370                                }
371                            }
372                        }
373                    }
374
375                    if state.is_closed()
376                        && !self
377                            .handle_qos_after_disconnect
378                            .map(|max_qos| publish.qos <= max_qos)
379                            .unwrap_or_default()
380                    {
381                        return Ok(None);
382                    }
383                }
384
385                let payload = if publish.payload_size == payload.len() as u32 {
386                    Payload::from_bytes(payload)
387                } else {
388                    let (pl, sender) = Payload::from_stream(payload);
389                    self.inner.payload.set(Some(sender));
390                    pl
391                };
392
393                publish_fn(
394                    &self.publish,
395                    Publish::new(publish, payload, size),
396                    packet_id.map(|v| v.get()).unwrap_or(0),
397                    info,
398                    ctx,
399                )
400                .await
401            }
402            DispatchItem::Item(Decoded::PayloadChunk(buf, eof)) => {
403                if let Some(pl) = self.inner.payload.take() {
404                    pl.feed_data(buf);
405                    if eof {
406                        pl.feed_eof();
407                    } else {
408                        self.inner.payload.set(Some(pl));
409                    }
410                    Ok(None)
411                } else {
412                    control(
413                        Control::proto_error(ProtocolError::Decode(
414                            DecodeError::UnexpectedPayload,
415                        )),
416                        &self.inner,
417                        ctx,
418                        0,
419                    )
420                    .await
421                }
422            }
423            DispatchItem::Item(Decoded::Packet(Packet::PublishAck(packet), _)) => {
424                if let Err(err) = self.inner.sink.pkt_ack(Ack::Publish(packet)) {
425                    control(Control::proto_error(err), &self.inner, ctx, 0).await
426                } else {
427                    Ok(None)
428                }
429            }
430            DispatchItem::Item(Decoded::Packet(Packet::PublishReceived(pkt), _)) => {
431                if let Err(e) = self.inner.sink.pkt_ack(Ack::Receive(pkt)) {
432                    control(Control::proto_error(e), &self.inner, ctx, 0).await
433                } else {
434                    Ok(None)
435                }
436            }
437            DispatchItem::Item(Decoded::Packet(Packet::PublishRelease(ack), size)) => {
438                if self.inner.info.borrow().inflight.contains(&ack.packet_id) {
439                    control(Control::pubrel(ack, size), &self.inner, ctx, 0).await
440                } else {
441                    Ok(Some(Encoded::Packet(codec::Packet::PublishComplete(
442                        codec::PublishAck2 {
443                            packet_id: ack.packet_id,
444                            reason_code: codec::PublishAck2Reason::PacketIdNotFound,
445                            properties: codec::UserProperties::default(),
446                            reason_string: None,
447                        },
448                    ))))
449                }
450            }
451            DispatchItem::Item(Decoded::Packet(Packet::PublishComplete(pkt), _)) => {
452                if let Err(e) = self.inner.sink.pkt_ack(Ack::Complete(pkt)) {
453                    control(Control::proto_error(e), &self.inner, ctx, 0).await
454                } else {
455                    Ok(None)
456                }
457            }
458            DispatchItem::Item(Decoded::Packet(Packet::Auth(pkt), size)) => {
459                if self.inner.sink.is_closed() {
460                    return Ok(None);
461                }
462
463                control(Control::auth(pkt, size), &self.inner, ctx, 0).await
464            }
465            DispatchItem::Item(Decoded::Packet(Packet::PingRequest, _)) => {
466                control(Control::ping(), &self.inner, ctx, 0).await
467            }
468            DispatchItem::Item(Decoded::Packet(Packet::Disconnect(pkt), size)) => {
469                control(Control::remote_disconnect(pkt, size), &self.inner, ctx, 0).await
470            }
471            DispatchItem::Item(Decoded::Packet(Packet::Subscribe(pkt), size)) => {
472                if self.inner.sink.is_closed() {
473                    return Ok(None);
474                }
475
476                if pkt.topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
477                    return control(
478                        Control::proto_error(ProtocolError::generic_violation(
479                            "Topic filter is malformed [MQTT-4.7.1-*]",
480                        )),
481                        &self.inner,
482                        ctx,
483                        0,
484                    )
485                    .await;
486                }
487
488                if pkt.id.is_some() && !self.inner.sink.codec.sub_ids_available() {
489                    log::trace!(
490                        "{}: Subscription Identifiers are not supported but was set",
491                        self.tag()
492                    );
493                    return control(
494                        Control::proto_error(ProtocolError::violation(
495                            DisconnectReasonCode::SubscriptionIdentifiersNotSupported,
496                            "Subscription Identifiers are not supported",
497                        )),
498                        &self.inner,
499                        ctx,
500                        0,
501                    )
502                    .await;
503                }
504
505                // register inflight packet id
506                if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
507                    // duplicated packet id
508                    let _ = self.inner.sink.encode_packet(codec::Packet::SubscribeAck(
509                        codec::SubscribeAck {
510                            packet_id: pkt.packet_id,
511                            status: pkt
512                                .topic_filters
513                                .iter()
514                                .map(|_| codec::SubscribeAckReason::PacketIdentifierInUse)
515                                .collect(),
516                            properties: codec::UserProperties::new(),
517                            reason_string: None,
518                        },
519                    ));
520                    return Ok(None);
521                }
522                let id = pkt.packet_id;
523                control(Control::subscribe(pkt, size), &self.inner, ctx, id.get()).await
524            }
525            DispatchItem::Item(Decoded::Packet(Packet::Unsubscribe(pkt), size)) => {
526                if self.inner.sink.is_closed() {
527                    return Ok(None);
528                }
529
530                if pkt.topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
531                    return control(
532                        Control::proto_error(ProtocolError::generic_violation(
533                            "Topic filter is malformed [MQTT-4.7.1-*]",
534                        )),
535                        &self.inner,
536                        ctx,
537                        0,
538                    )
539                    .await;
540                }
541
542                // register inflight packet id
543                if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
544                    // duplicated packet id
545                    let _ = self.inner.sink.encode_packet(codec::Packet::UnsubscribeAck(
546                        codec::UnsubscribeAck {
547                            packet_id: pkt.packet_id,
548                            status: pkt
549                                .topic_filters
550                                .iter()
551                                .map(|_| codec::UnsubscribeAckReason::PacketIdentifierInUse)
552                                .collect(),
553                            properties: codec::UserProperties::new(),
554                            reason_string: None,
555                        },
556                    ));
557                    return Ok(None);
558                }
559                let id = pkt.packet_id;
560                control(Control::unsubscribe(pkt, size), &self.inner, ctx, id.get()).await
561            }
562            DispatchItem::Item(Decoded::Packet(_, _)) => Ok(None),
563            DispatchItem::EncoderError(err) => {
564                let err = ProtocolError::Encode(err);
565                self.inner.drop_payload(&err);
566                control(Control::proto_error(err), &self.inner, ctx, 0).await
567            }
568            DispatchItem::KeepAliveTimeout => {
569                self.inner.drop_payload(&ProtocolError::KeepAliveTimeout);
570                control(
571                    Control::proto_error(ProtocolError::KeepAliveTimeout),
572                    &self.inner,
573                    ctx,
574                    0,
575                )
576                .await
577            }
578            DispatchItem::ReadTimeout => {
579                self.inner.drop_payload(&ProtocolError::ReadTimeout);
580                control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, ctx, 0)
581                    .await
582            }
583            DispatchItem::DecoderError(err) => {
584                let err = ProtocolError::Decode(err);
585                self.inner.drop_payload(&err);
586                control(Control::proto_error(err), &self.inner, ctx, 0).await
587            }
588            DispatchItem::Disconnect(err) => {
589                self.inner.drop_payload(&PayloadError::Disconnected);
590                control(Control::peer_gone(err), &self.inner, ctx, 0).await
591            }
592            DispatchItem::WBackPressureEnabled => {
593                self.inner.sink.enable_wr_backpressure();
594                control(Control::wr_backpressure(true), &self.inner, ctx, 0).await
595            }
596            DispatchItem::WBackPressureDisabled => {
597                self.inner.sink.disable_wr_backpressure();
598                control(Control::wr_backpressure(false), &self.inner, ctx, 0).await
599            }
600        }
601    }
602}
603
604/// Publish service response future
605async fn publish_fn<'f, T, C, E>(
606    publish: &T,
607    pkt: Publish,
608    packet_id: u16,
609    inner: &'f Inner<C>,
610    ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
611) -> Result<Option<Encoded>, MqttError<E>>
612where
613    E: From<T::Error>,
614    T: Service<Publish, Response = PublishAck>,
615    PublishAck: TryFrom<T::Error, Error = E>,
616    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
617{
618    let qos2 = pkt.qos() == QoS::ExactlyOnce;
619    let ack = match ctx.call(publish, pkt).await {
620        Ok(ack) => ack,
621        Err(e) => {
622            if packet_id != 0 {
623                match PublishAck::try_from(e) {
624                    Ok(ack) => ack,
625                    Err(e) => return control(Control::error(e), inner, ctx, 0).await,
626                }
627            } else {
628                return control(Control::error(e.into()), inner, ctx, 0).await;
629            }
630        }
631    };
632    if let Some(id) = num::NonZeroU16::new(packet_id) {
633        let ack = if qos2 {
634            codec::Packet::PublishReceived(codec::PublishAck {
635                packet_id: id,
636                reason_code: ack.reason_code,
637                reason_string: ack.reason_string,
638                properties: ack.properties,
639            })
640        } else {
641            inner.info.borrow_mut().inflight.remove(&id);
642            codec::Packet::PublishAck(codec::PublishAck {
643                packet_id: id,
644                reason_code: ack.reason_code,
645                reason_string: ack.reason_string,
646                properties: ack.properties,
647            })
648        };
649        Ok(Some(Encoded::Packet(ack)))
650    } else {
651        Ok(None)
652    }
653}
654
655async fn control<'f, T, C, E>(
656    pkt: Control<E>,
657    inner: &'f Inner<C>,
658    ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
659    packet_id: u16,
660) -> Result<Option<Encoded>, MqttError<E>>
661where
662    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
663{
664    let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_));
665
666    let result = match ctx.call(inner.control.get_ref(), pkt).await {
667        Ok(result) => {
668            if let Some(id) = num::NonZeroU16::new(packet_id) {
669                inner.info.borrow_mut().inflight.remove(&id);
670            }
671            result
672        }
673        Err(err) => {
674            inner.drop_payload(&PayloadError::Service);
675
676            // do not handle nested error
677            if error {
678                inner.sink.drop_sink();
679                return Err(err);
680            } else {
681                // handle error from control service
682                match err {
683                    MqttError::Service(err) => {
684                        error = true;
685                        ctx.call(inner.control.get_ref(), Control::error(err)).await?
686                    }
687                    _ => return Err(err),
688                }
689            }
690        }
691    };
692
693    let response = if error {
694        if let Some(pkt) = result.packet {
695            let _ = inner.sink.encode_packet(pkt);
696        }
697        Ok(None)
698    } else {
699        Ok(result.packet.map(Encoded::Packet))
700    };
701
702    if result.disconnect {
703        inner.drop_payload(&PayloadError::Service);
704        inner.sink.drop_sink();
705    }
706    response
707}
708
709#[cfg(test)]
710mod tests {
711    use ntex_io::{Io, testing::IoTest};
712    use ntex_service::{cfg::SharedCfg, fn_service};
713    use ntex_util::future::{Ready, lazy};
714
715    use super::*;
716    use crate::v5::MqttSink;
717
718    #[derive(Debug)]
719    struct TestError;
720
721    impl TryFrom<TestError> for PublishAck {
722        type Error = TestError;
723
724        fn try_from(err: TestError) -> Result<Self, Self::Error> {
725            Err(err)
726        }
727    }
728
729    #[ntex::test]
730    async fn test_wr_backpressure() {
731        let io = Io::new(IoTest::create().0, SharedCfg::new("DBG"));
732        let codec = codec::Codec::default();
733        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, Default::default()));
734
735        let disp = Pipeline::new(Dispatcher::<_, _, _>::new(
736            shared.clone(),
737            fn_service(|p: Publish| Ready::Ok::<_, TestError>(p.ack())),
738            fn_service(|_| {
739                Ready::Ok::<_, MqttError<TestError>>(ControlAck {
740                    packet: None,
741                    disconnect: false,
742                })
743            }),
744            None,
745        ));
746
747        let sink = MqttSink::new(shared.clone());
748        assert!(!sink.is_ready());
749        shared.set_cap(1);
750        assert!(sink.is_ready());
751        assert!(shared.wait_readiness().is_none());
752
753        disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
754        assert!(!sink.is_ready());
755        let rx = shared.wait_readiness();
756        let rx2 = shared.wait_readiness().unwrap();
757        assert!(rx.is_some());
758
759        let rx = rx.unwrap();
760        disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
761        assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
762        assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
763    }
764}