ntex_mqtt/v5/
dispatcher.rs

1use std::{cell::Cell, cell::RefCell, marker, num, rc::Rc, task::Context};
2
3use ntex_bytes::ByteString;
4use ntex_io::DispatchItem;
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{self as service, Pipeline, Service, ServiceCtx, ServiceFactory};
7use ntex_util::services::inflight::InFlightService;
8use ntex_util::services::{buffer::BufferService, buffer::BufferServiceError};
9use ntex_util::{HashMap, HashSet, future::join};
10
11use crate::error::{DecodeError, HandshakeError, MqttError, PayloadError, ProtocolError};
12use crate::payload::{Payload, PayloadStatus, PlSender};
13use crate::{MqttServiceConfig, types::QoS};
14
15use super::Session;
16use super::codec::{self, Decoded, DisconnectReasonCode, Encoded, Packet};
17use super::control::{Control, ControlAck};
18use super::publish::{Publish, PublishAck};
19use super::shared::{Ack, MqttShared};
20
21/// MQTT 5 protocol dispatcher
22pub(super) fn factory<St, T, C, E>(
23    publish: T,
24    control: C,
25) -> impl ServiceFactory<
26    DispatchItem<Rc<MqttShared>>,
27    (SharedCfg, Session<St>),
28    Response = Option<Encoded>,
29    Error = MqttError<E>,
30    InitError = MqttError<E>,
31>
32where
33    St: 'static,
34    E: From<T::Error> + From<T::InitError> + From<C::Error> + From<C::InitError> + 'static,
35    T: ServiceFactory<Publish, Session<St>, Response = PublishAck> + 'static,
36    C: ServiceFactory<Control<E>, Session<St>, Response = ControlAck> + 'static,
37    PublishAck: TryFrom<T::Error, Error = E>,
38{
39    let factories = Rc::new((publish, control));
40
41    service::fn_factory_with_config(async move |(cfg, ses): (SharedCfg, Session<St>)| {
42        let cfg: Cfg<MqttServiceConfig> = cfg.get();
43
44        // create services
45        let sink = ses.sink().shared();
46        let (publish, control) =
47            join(factories.0.create(ses.clone()), factories.1.create(ses)).await;
48
49        let publish = publish.map_err(|e| MqttError::Service(e.into()))?;
50        let control = control.map_err(|e| MqttError::Service(e.into()))?;
51
52        let control = BufferService::new(
53            16,
54            // limit number of in-flight messages
55            InFlightService::new(1, control),
56        )
57        .map_err(|err| match err {
58            BufferServiceError::Service(e) => MqttError::Service(E::from(e)),
59            BufferServiceError::RequestCanceled => {
60                MqttError::Handshake(HandshakeError::Disconnected(None))
61            }
62        });
63
64        Ok(Dispatcher::<_, _, E>::new(sink, publish, control, cfg))
65    })
66}
67
68impl crate::inflight::SizedRequest for DispatchItem<Rc<MqttShared>> {
69    fn size(&self) -> u32 {
70        match self {
71            DispatchItem::Item(Decoded::Packet(_, size))
72            | DispatchItem::Item(Decoded::Publish(_, _, size)) => *size,
73            _ => 0,
74        }
75    }
76
77    fn is_publish(&self) -> bool {
78        matches!(self, DispatchItem::Item(Decoded::Publish(..)))
79    }
80
81    fn is_chunk(&self) -> bool {
82        matches!(self, DispatchItem::Item(Decoded::PayloadChunk(..)))
83    }
84}
85
86/// Mqtt protocol dispatcher
87pub(crate) struct Dispatcher<T, C: Service<Control<E>>, E> {
88    publish: T,
89    inner: Rc<Inner<C>>,
90    cfg: Cfg<MqttServiceConfig>,
91    _t: marker::PhantomData<E>,
92}
93
94struct Inner<C> {
95    control: Pipeline<C>,
96    sink: Rc<MqttShared>,
97    info: RefCell<PublishInfo>,
98    payload: Cell<Option<PlSender>>,
99}
100
101struct PublishInfo {
102    inflight: HashSet<num::NonZeroU16>,
103    aliases: HashMap<num::NonZeroU16, ByteString>,
104}
105
106impl<T, C, E> Dispatcher<T, C, E>
107where
108    E: From<T::Error>,
109    T: Service<Publish, Response = PublishAck>,
110    PublishAck: TryFrom<T::Error, Error = E>,
111    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
112{
113    fn new(sink: Rc<MqttShared>, publish: T, control: C, cfg: Cfg<MqttServiceConfig>) -> Self {
114        Self {
115            cfg,
116            publish,
117            inner: Rc::new(Inner {
118                sink,
119                payload: Cell::new(None),
120                control: Pipeline::new(control),
121                info: RefCell::new(PublishInfo {
122                    aliases: HashMap::default(),
123                    inflight: HashSet::default(),
124                }),
125            }),
126            _t: marker::PhantomData,
127        }
128    }
129
130    fn tag(&self) -> &'static str {
131        self.inner.sink.tag()
132    }
133}
134
135impl<C> Inner<C> {
136    fn drop_payload<PErr>(&self, err: &PErr)
137    where
138        PErr: Clone,
139        PayloadError: From<PErr>,
140    {
141        if let Some(pl) = self.payload.take() {
142            pl.set_error(err.clone().into());
143        }
144    }
145}
146
147impl<T, C, E> Service<DispatchItem<Rc<MqttShared>>> for Dispatcher<T, C, E>
148where
149    E: From<T::Error> + 'static,
150    T: Service<Publish, Response = PublishAck> + 'static,
151    PublishAck: TryFrom<T::Error, Error = E>,
152    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>> + 'static,
153{
154    type Response = Option<Encoded>;
155    type Error = MqttError<E>;
156
157    async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
158        let (res1, res2) = join(ctx.ready(&self.publish), self.inner.control.ready()).await;
159        let result = if let Err(e) = res1 {
160            if res2.is_err() {
161                Err(MqttError::Service(e.into()))
162            } else {
163                match self.inner.control.call(Control::error(e.into())).await {
164                    Ok(res) => {
165                        if res.disconnect {
166                            self.inner.sink.drop_sink();
167                        }
168                        Ok(())
169                    }
170                    Err(err) => Err(err),
171                }
172            }
173        } else {
174            res2
175        };
176
177        if result.is_ok() {
178            if let Some(pl) = self.inner.payload.take() {
179                self.inner.payload.set(Some(pl.clone()));
180                if pl.ready().await != PayloadStatus::Ready {
181                    self.inner.sink.force_close();
182                }
183            }
184        }
185        result
186    }
187
188    fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
189        if let Err(e) = self.publish.poll(cx) {
190            let inner = self.inner.clone();
191            ntex_rt::spawn(async move {
192                if let Ok(res) = inner.control.call(Control::error(e.into())).await {
193                    if res.disconnect {
194                        inner.sink.drop_sink();
195                    }
196                }
197            });
198        }
199        self.inner.control.poll(cx)
200    }
201
202    async fn shutdown(&self) {
203        log::trace!("{}: Shutdown v5 dispatcher", self.tag());
204        self.inner.drop_payload(&PayloadError::Disconnected);
205        self.inner.sink.drop_sink();
206        let _ = self.inner.control.call(Control::closed()).await;
207
208        self.publish.shutdown().await;
209        self.inner.control.shutdown().await;
210    }
211
212    #[allow(clippy::await_holding_refcell_ref)]
213    async fn call(
214        &self,
215        request: DispatchItem<Rc<MqttShared>>,
216        ctx: ServiceCtx<'_, Self>,
217    ) -> Result<Self::Response, Self::Error> {
218        log::trace!("{}: Dispatch v5 packet: {:#?}", self.tag(), request);
219
220        match request {
221            DispatchItem::Item(Decoded::Publish(mut publish, payload, size)) => {
222                let info = self.inner.as_ref();
223                let packet_id = publish.packet_id;
224
225                if publish.topic.contains(['#', '+']) {
226                    return control(
227                        Control::proto_error(
228                            ProtocolError::generic_violation(
229                                "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]"
230                            )
231                        ),
232                        &self.inner,
233                        0,
234                    ).await;
235                }
236
237                {
238                    let mut inner = info.info.borrow_mut();
239                    let state = &self.inner.sink;
240
241                    if let Some(pid) = packet_id {
242                        // check for receive maximum
243                        let receive_max = state.receive_max();
244                        if receive_max != 0 && inner.inflight.len() >= receive_max as usize {
245                            log::trace!(
246                                "{}: Receive maximum exceeded: max: {} in-flight: {}",
247                                self.tag(),
248                                receive_max,
249                                inner.inflight.len()
250                            );
251                            drop(inner);
252                            return control(
253                                Control::proto_error(
254                                    ProtocolError::violation(
255                                        DisconnectReasonCode::ReceiveMaximumExceeded,
256                                        "Number of in-flight messages exceeds set maximum [MQTT-3.3.4-7]"
257                                    )
258                                ),
259                                &self.inner,
260                                0,
261                            ).await;
262                        }
263
264                        // check max allowed qos
265                        if publish.qos > state.max_qos() {
266                            log::trace!(
267                                "{}: Max allowed QoS is violated, max {:?} provided {:?}",
268                                self.tag(),
269                                state.max_qos(),
270                                publish.qos
271                            );
272                            drop(inner);
273                            return control(
274                                Control::proto_error(ProtocolError::violation(
275                                    DisconnectReasonCode::QosNotSupported,
276                                    "PUBLISH QoS is higher than supported [MQTT-3.2.2-11]",
277                                )),
278                                &self.inner,
279                                0,
280                            )
281                            .await;
282                        }
283                        if publish.retain && !state.codec.retain_available() {
284                            log::trace!("{}: Retain is not available but is set", self.tag());
285                            drop(inner);
286                            return control(
287                                Control::proto_error(ProtocolError::violation(
288                                    DisconnectReasonCode::RetainNotSupported,
289                                    "RETAIN is not supported [MQTT-3.2.2-14]",
290                                )),
291                                &self.inner,
292                                0,
293                            )
294                            .await;
295                        }
296
297                        // check for duplicated packet id
298                        if !inner.inflight.insert(pid) {
299                            let _ = self.inner.sink.encode_packet(codec::Packet::PublishAck(
300                                codec::PublishAck {
301                                    packet_id: pid,
302                                    reason_code: codec::PublishAckReason::PacketIdentifierInUse,
303                                    ..Default::default()
304                                },
305                            ));
306                            return Ok(None);
307                        }
308                    }
309
310                    // handle topic aliases
311                    if let Some(alias) = publish.properties.topic_alias {
312                        if publish.topic.is_empty() {
313                            // lookup topic by provided alias
314                            match inner.aliases.get(&alias) {
315                                Some(aliased_topic) => publish.topic = aliased_topic.clone(),
316                                None => {
317                                    drop(inner);
318                                    return control(
319                                        Control::proto_error(ProtocolError::violation(
320                                            DisconnectReasonCode::TopicAliasInvalid,
321                                            "Unknown topic alias",
322                                        )),
323                                        &self.inner,
324                                        0,
325                                    )
326                                    .await;
327                                }
328                            }
329                        } else {
330                            // record new alias
331                            match inner.aliases.entry(alias) {
332                                std::collections::hash_map::Entry::Occupied(mut entry) => {
333                                    if entry.get().as_str() != publish.topic.as_str() {
334                                        let mut topic = publish.topic.clone();
335                                        topic.trimdown();
336                                        entry.insert(topic);
337                                    }
338                                }
339                                std::collections::hash_map::Entry::Vacant(entry) => {
340                                    if alias.get() > state.topic_alias_max() {
341                                        drop(inner);
342                                        return control(
343                                                Control::proto_error(
344                                                    ProtocolError::generic_violation(
345                                                        "Topic alias is greater than max allowed [MQTT-3.2.2-17]",
346                                                    )
347                                                ),
348                                                &self.inner,
349                                            0,
350                                            ).await;
351                                    }
352                                    let mut topic = publish.topic.clone();
353                                    topic.trimdown();
354                                    entry.insert(topic);
355                                }
356                            }
357                        }
358                    }
359
360                    if state.is_closed()
361                        && !self
362                            .cfg
363                            .handle_qos_after_disconnect
364                            .map(|max_qos| publish.qos <= max_qos)
365                            .unwrap_or_default()
366                    {
367                        return Ok(None);
368                    }
369                }
370
371                let payload = if publish.payload_size == payload.len() as u32 {
372                    Payload::from_bytes(payload)
373                } else {
374                    let (pl, sender) =
375                        Payload::from_stream(payload, self.cfg.max_payload_buffer_size);
376                    self.inner.payload.set(Some(sender));
377                    pl
378                };
379
380                publish_fn(
381                    &self.publish,
382                    Publish::new(publish, payload, size),
383                    packet_id.map(|v| v.get()).unwrap_or(0),
384                    info,
385                    ctx,
386                )
387                .await
388            }
389            DispatchItem::Item(Decoded::PayloadChunk(buf, eof)) => {
390                if let Some(pl) = self.inner.payload.take() {
391                    pl.feed_data(buf);
392                    if eof {
393                        pl.feed_eof();
394                    } else {
395                        self.inner.payload.set(Some(pl));
396                    }
397                    Ok(None)
398                } else {
399                    control(
400                        Control::proto_error(ProtocolError::Decode(
401                            DecodeError::UnexpectedPayload,
402                        )),
403                        &self.inner,
404                        0,
405                    )
406                    .await
407                }
408            }
409            DispatchItem::Item(Decoded::Packet(Packet::PublishAck(packet), _)) => {
410                if let Err(err) = self.inner.sink.pkt_ack(Ack::Publish(packet)) {
411                    control(Control::proto_error(err), &self.inner, 0).await
412                } else {
413                    Ok(None)
414                }
415            }
416            DispatchItem::Item(Decoded::Packet(Packet::PublishReceived(pkt), _)) => {
417                if let Err(e) = self.inner.sink.pkt_ack(Ack::Receive(pkt)) {
418                    control(Control::proto_error(e), &self.inner, 0).await
419                } else {
420                    Ok(None)
421                }
422            }
423            DispatchItem::Item(Decoded::Packet(Packet::PublishRelease(ack), size)) => {
424                if self.inner.info.borrow().inflight.contains(&ack.packet_id) {
425                    control(Control::pubrel(ack, size), &self.inner, 0).await
426                } else {
427                    Ok(Some(Encoded::Packet(codec::Packet::PublishComplete(
428                        codec::PublishAck2 {
429                            packet_id: ack.packet_id,
430                            reason_code: codec::PublishAck2Reason::PacketIdNotFound,
431                            properties: codec::UserProperties::default(),
432                            reason_string: None,
433                        },
434                    ))))
435                }
436            }
437            DispatchItem::Item(Decoded::Packet(Packet::PublishComplete(pkt), _)) => {
438                if let Err(e) = self.inner.sink.pkt_ack(Ack::Complete(pkt)) {
439                    control(Control::proto_error(e), &self.inner, 0).await
440                } else {
441                    Ok(None)
442                }
443            }
444            DispatchItem::Item(Decoded::Packet(Packet::Auth(pkt), size)) => {
445                if self.inner.sink.is_closed() {
446                    return Ok(None);
447                }
448
449                control(Control::auth(pkt, size), &self.inner, 0).await
450            }
451            DispatchItem::Item(Decoded::Packet(Packet::PingRequest, _)) => {
452                control(Control::ping(), &self.inner, 0).await
453            }
454            DispatchItem::Item(Decoded::Packet(Packet::Disconnect(pkt), size)) => {
455                control(Control::remote_disconnect(pkt, size), &self.inner, 0).await
456            }
457            DispatchItem::Item(Decoded::Packet(Packet::Subscribe(pkt), size)) => {
458                if self.inner.sink.is_closed() {
459                    return Ok(None);
460                }
461
462                if pkt.topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) {
463                    return control(
464                        Control::proto_error(ProtocolError::generic_violation(
465                            "Topic filter is malformed [MQTT-4.7.1-*]",
466                        )),
467                        &self.inner,
468                        0,
469                    )
470                    .await;
471                }
472
473                if pkt.id.is_some() && !self.inner.sink.codec.sub_ids_available() {
474                    log::trace!(
475                        "{}: Subscription Identifiers are not supported but was set",
476                        self.tag()
477                    );
478                    return control(
479                        Control::proto_error(ProtocolError::violation(
480                            DisconnectReasonCode::SubscriptionIdentifiersNotSupported,
481                            "Subscription Identifiers are not supported",
482                        )),
483                        &self.inner,
484                        0,
485                    )
486                    .await;
487                }
488
489                // register inflight packet id
490                if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
491                    // duplicated packet id
492                    let _ = self.inner.sink.encode_packet(codec::Packet::SubscribeAck(
493                        codec::SubscribeAck {
494                            packet_id: pkt.packet_id,
495                            status: pkt
496                                .topic_filters
497                                .iter()
498                                .map(|_| codec::SubscribeAckReason::PacketIdentifierInUse)
499                                .collect(),
500                            properties: codec::UserProperties::new(),
501                            reason_string: None,
502                        },
503                    ));
504                    return Ok(None);
505                }
506                let id = pkt.packet_id;
507                control(Control::subscribe(pkt, size), &self.inner, id.get()).await
508            }
509            DispatchItem::Item(Decoded::Packet(Packet::Unsubscribe(pkt), size)) => {
510                if self.inner.sink.is_closed() {
511                    return Ok(None);
512                }
513
514                if pkt.topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) {
515                    return control(
516                        Control::proto_error(ProtocolError::generic_violation(
517                            "Topic filter is malformed [MQTT-4.7.1-*]",
518                        )),
519                        &self.inner,
520                        0,
521                    )
522                    .await;
523                }
524
525                // register inflight packet id
526                if !self.inner.info.borrow_mut().inflight.insert(pkt.packet_id) {
527                    // duplicated packet id
528                    let _ = self.inner.sink.encode_packet(codec::Packet::UnsubscribeAck(
529                        codec::UnsubscribeAck {
530                            packet_id: pkt.packet_id,
531                            status: pkt
532                                .topic_filters
533                                .iter()
534                                .map(|_| codec::UnsubscribeAckReason::PacketIdentifierInUse)
535                                .collect(),
536                            properties: codec::UserProperties::new(),
537                            reason_string: None,
538                        },
539                    ));
540                    return Ok(None);
541                }
542                let id = pkt.packet_id;
543                control(Control::unsubscribe(pkt, size), &self.inner, id.get()).await
544            }
545            DispatchItem::Item(Decoded::Packet(_, _)) => Ok(None),
546            DispatchItem::EncoderError(err) => {
547                let err = ProtocolError::Encode(err);
548                self.inner.drop_payload(&err);
549                control(Control::proto_error(err), &self.inner, 0).await
550            }
551            DispatchItem::KeepAliveTimeout => {
552                self.inner.drop_payload(&ProtocolError::KeepAliveTimeout);
553                control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, 0)
554                    .await
555            }
556            DispatchItem::ReadTimeout => {
557                self.inner.drop_payload(&ProtocolError::ReadTimeout);
558                control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, 0).await
559            }
560            DispatchItem::DecoderError(err) => {
561                let err = ProtocolError::Decode(err);
562                self.inner.drop_payload(&err);
563                control(Control::proto_error(err), &self.inner, 0).await
564            }
565            DispatchItem::Disconnect(err) => {
566                self.inner.drop_payload(&PayloadError::Disconnected);
567                control(Control::peer_gone(err), &self.inner, 0).await
568            }
569            DispatchItem::WBackPressureEnabled => {
570                self.inner.sink.enable_wr_backpressure();
571                control(Control::wr_backpressure(true), &self.inner, 0).await
572            }
573            DispatchItem::WBackPressureDisabled => {
574                self.inner.sink.disable_wr_backpressure();
575                control(Control::wr_backpressure(false), &self.inner, 0).await
576            }
577        }
578    }
579}
580
581/// Publish service response future
582async fn publish_fn<'f, T, C, E>(
583    publish: &T,
584    pkt: Publish,
585    packet_id: u16,
586    inner: &'f Inner<C>,
587    ctx: ServiceCtx<'f, Dispatcher<T, C, E>>,
588) -> Result<Option<Encoded>, MqttError<E>>
589where
590    E: From<T::Error>,
591    T: Service<Publish, Response = PublishAck>,
592    PublishAck: TryFrom<T::Error, Error = E>,
593    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
594{
595    let qos2 = pkt.qos() == QoS::ExactlyOnce;
596    let ack = match ctx.call(publish, pkt).await {
597        Ok(ack) => ack,
598        Err(e) => {
599            if packet_id != 0 {
600                match PublishAck::try_from(e) {
601                    Ok(ack) => ack,
602                    Err(e) => return control(Control::error(e), inner, 0).await,
603                }
604            } else {
605                return control(Control::error(e.into()), inner, 0).await;
606            }
607        }
608    };
609    if let Some(id) = num::NonZeroU16::new(packet_id) {
610        let ack = if qos2 {
611            codec::Packet::PublishReceived(codec::PublishAck {
612                packet_id: id,
613                reason_code: ack.reason_code,
614                reason_string: ack.reason_string,
615                properties: ack.properties,
616            })
617        } else {
618            inner.info.borrow_mut().inflight.remove(&id);
619            codec::Packet::PublishAck(codec::PublishAck {
620                packet_id: id,
621                reason_code: ack.reason_code,
622                reason_string: ack.reason_string,
623                properties: ack.properties,
624            })
625        };
626        Ok(Some(Encoded::Packet(ack)))
627    } else {
628        Ok(None)
629    }
630}
631
632async fn control<C, E>(
633    pkt: Control<E>,
634    inner: &Inner<C>,
635    packet_id: u16,
636) -> Result<Option<Encoded>, MqttError<E>>
637where
638    C: Service<Control<E>, Response = ControlAck, Error = MqttError<E>>,
639{
640    let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_));
641
642    let result = match inner.control.call(pkt).await {
643        Ok(result) => {
644            if let Some(id) = num::NonZeroU16::new(packet_id) {
645                inner.info.borrow_mut().inflight.remove(&id);
646            }
647            result
648        }
649        Err(err) => {
650            inner.drop_payload(&PayloadError::Service);
651
652            // do not handle nested error
653            if error {
654                inner.sink.drop_sink();
655                return Err(err);
656            } else {
657                // handle error from control service
658                match err {
659                    MqttError::Service(err) => {
660                        error = true;
661                        inner.control.call(Control::error(err)).await?
662                    }
663                    _ => return Err(err),
664                }
665            }
666        }
667    };
668
669    let response = if error {
670        if let Some(pkt) = result.packet {
671            let _ = inner.sink.encode_packet(pkt);
672        }
673        Ok(None)
674    } else {
675        Ok(result.packet.map(Encoded::Packet))
676    };
677
678    if result.disconnect {
679        inner.drop_payload(&PayloadError::Service);
680        inner.sink.drop_sink();
681    }
682    response
683}
684
685#[cfg(test)]
686mod tests {
687    use ntex_io::{Io, testing::IoTest};
688    use ntex_service::{cfg::SharedCfg, fn_service};
689    use ntex_util::future::{Ready, lazy};
690
691    use super::*;
692    use crate::v5::MqttSink;
693
694    #[derive(Debug)]
695    struct TestError;
696
697    impl TryFrom<TestError> for PublishAck {
698        type Error = TestError;
699
700        fn try_from(err: TestError) -> Result<Self, Self::Error> {
701            Err(err)
702        }
703    }
704
705    #[ntex::test]
706    async fn test_wr_backpressure() {
707        let io = Io::new(IoTest::create().0, SharedCfg::new("DBG"));
708        let codec = codec::Codec::default();
709        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, Default::default()));
710
711        let disp = Pipeline::new(Dispatcher::<_, _, _>::new(
712            shared.clone(),
713            fn_service(|p: Publish| Ready::Ok::<_, TestError>(p.ack())),
714            fn_service(|_| {
715                Ready::Ok::<_, MqttError<TestError>>(ControlAck {
716                    packet: None,
717                    disconnect: false,
718                })
719            }),
720            Default::default(),
721        ));
722
723        let sink = MqttSink::new(shared.clone());
724        assert!(!sink.is_ready());
725        shared.set_cap(1);
726        assert!(sink.is_ready());
727        assert!(shared.wait_readiness().is_none());
728
729        disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
730        assert!(!sink.is_ready());
731        let rx = shared.wait_readiness();
732        let rx2 = shared.wait_readiness().unwrap();
733        assert!(rx.is_some());
734
735        let rx = rx.unwrap();
736        disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
737        assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
738        assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
739    }
740}