Skip to main content

ntex_mqtt/v3/
shared.rs

1#![allow(clippy::type_complexity)]
2use std::{cell::Cell, cell::RefCell, collections::VecDeque, fmt, num, rc::Rc};
3
4use ntex_bytes::{Bytes, BytesMut};
5use ntex_codec::{Decoder, Encoder};
6use ntex_io::IoRef;
7use ntex_util::{HashSet, channel::pool};
8
9use crate::error::{DecodeError, EncodeError, ProtocolError, SendPacketError};
10use crate::types::packet_type;
11use crate::v3::codec::{self, Encoded, Publish};
12
13#[derive(Debug)]
14pub(super) enum Ack {
15    Publish(num::NonZeroU16),
16    Receive(num::NonZeroU16),
17    Complete(num::NonZeroU16),
18    Subscribe { packet_id: num::NonZeroU16, status: Vec<codec::SubscribeReturnCode> },
19    Unsubscribe(num::NonZeroU16),
20}
21
22#[derive(Copy, Clone, Debug)]
23pub(super) enum AckType {
24    Publish,
25    Receive,
26    Complete,
27    Subscribe,
28    Unsubscribe,
29}
30
31pub(super) struct MqttSinkPool {
32    queue: pool::Pool<Ack>,
33    pub(super) waiters: pool::Pool<()>,
34}
35
36impl Default for MqttSinkPool {
37    fn default() -> Self {
38        Self { queue: pool::new(), waiters: pool::new() }
39    }
40}
41
42bitflags::bitflags! {
43    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
44    struct Flags: u8 {
45        const CLIENT         = 0b0000_0001;
46        const WRB_ENABLED    = 0b0000_0010; // write-backpressure
47        const ON_PUBLISH_ACK = 0b0000_0100; // on-publish-ack callback
48
49        const DISCONNECT     = 0b0100_0000; // Disconnect frame is sent
50        const STOPPED        = 0b1000_0000; // DispatchItem::Stop() is sent
51    }
52}
53
54pub struct MqttShared {
55    io: IoRef,
56    cap: Cell<usize>,
57    queues: RefCell<MqttSharedQueues>,
58    inflight_idx: Cell<u16>,
59    flags: Cell<Flags>,
60    encode_error: Cell<Option<EncodeError>>,
61    streaming_waiter: Cell<Option<pool::Sender<()>>>,
62    streaming_remaining: Cell<Option<num::NonZeroU32>>,
63    on_publish_ack: Cell<Option<Box<dyn Fn(num::NonZeroU16, bool)>>>,
64    pub(super) codec: codec::Codec,
65    pub(super) pool: Rc<MqttSinkPool>,
66}
67
68impl fmt::Debug for MqttShared {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("MqttShared").finish()
71    }
72}
73
74#[derive(Debug)]
75struct MqttSharedQueues {
76    inflight: VecDeque<(num::NonZeroU16, Option<pool::Sender<Ack>>, AckType)>,
77    inflight_ids: HashSet<num::NonZeroU16>,
78    waiters: VecDeque<pool::Sender<()>>,
79    rx: Option<pool::Receiver<Ack>>,
80}
81
82impl MqttShared {
83    pub(super) fn new(
84        io: IoRef,
85        codec: codec::Codec,
86        client: bool,
87        pool: Rc<MqttSinkPool>,
88    ) -> Self {
89        Self {
90            io,
91            codec,
92            pool,
93            cap: Cell::new(0),
94            flags: Cell::new(if client { Flags::CLIENT } else { Flags::empty() }),
95            queues: RefCell::new(MqttSharedQueues {
96                inflight: VecDeque::with_capacity(8),
97                inflight_ids: HashSet::default(),
98                waiters: VecDeque::new(),
99                rx: None,
100            }),
101            inflight_idx: Cell::new(0),
102            encode_error: Cell::new(None),
103            streaming_waiter: Cell::new(None),
104            streaming_remaining: Cell::new(None),
105            on_publish_ack: Cell::new(None),
106        }
107    }
108
109    pub(super) fn tag(&self) -> &'static str {
110        self.io.tag()
111    }
112
113    pub(super) fn close(&self) {
114        if self.flags.get().contains(Flags::CLIENT) && !self.is_disconnect_sent() {
115            let _ = self.encode_packet(codec::Packet::Disconnect);
116        }
117        self.io.close();
118        self.clear_queues();
119    }
120
121    pub(super) fn force_close(&self) {
122        self.io.force_close();
123        self.clear_queues();
124    }
125
126    pub(super) fn streaming_dropped(&self) {
127        self.force_close();
128        self.encode_error.set(Some(EncodeError::PublishIncomplete));
129    }
130
131    pub(super) fn is_streaming(&self) -> bool {
132        self.streaming_remaining.get().is_some()
133    }
134
135    pub(super) fn is_closed(&self) -> bool {
136        self.io.is_closed()
137    }
138
139    pub(super) fn is_ready(&self) -> bool {
140        self.credit() > 0 && !self.flags.get().contains(Flags::WRB_ENABLED)
141    }
142
143    pub(super) fn is_disconnect_sent(&self) -> bool {
144        let mut flags = self.flags.get();
145        let sent = flags.contains(Flags::DISCONNECT);
146        if !sent {
147            flags.insert(Flags::DISCONNECT);
148            self.flags.set(flags);
149        }
150        sent
151    }
152
153    pub(super) fn is_dispatcher_stopped(&self) -> bool {
154        let mut flags = self.flags.get();
155        let stopped = flags.contains(Flags::STOPPED);
156        if !stopped {
157            flags.insert(Flags::STOPPED);
158            self.flags.set(flags);
159        }
160        stopped
161    }
162
163    pub(super) fn credit(&self) -> usize {
164        self.cap.get().saturating_sub(self.queues.borrow().inflight.len())
165    }
166
167    pub(super) fn next_id(&self) -> num::NonZeroU16 {
168        let idx = self.inflight_idx.get() + 1;
169        let idx = if idx == u16::MAX {
170            self.inflight_idx.set(0);
171            u16::MAX
172        } else {
173            self.inflight_idx.set(idx);
174            idx
175        };
176        num::NonZeroU16::new(idx).unwrap()
177    }
178
179    /// publish packet id
180    pub(super) fn set_publish_id(&self, pkt: &mut Publish) -> num::NonZeroU16 {
181        if let Some(idx) = pkt.packet_id {
182            idx
183        } else {
184            let idx = self.next_id();
185            pkt.packet_id = Some(idx);
186            idx
187        }
188    }
189
190    pub(super) fn set_cap(&self, cap: usize) {
191        let mut queues = self.queues.borrow_mut();
192
193        // wake up queued request (receive max limit)
194        'outer: for _ in 0..cap {
195            while let Some(tx) = queues.waiters.pop_front() {
196                if tx.send(()).is_ok() {
197                    continue 'outer;
198                }
199            }
200            break;
201        }
202        self.cap.set(cap);
203    }
204
205    pub(super) fn set_publish_ack(&self, f: Box<dyn Fn(num::NonZeroU16, bool)>) {
206        let mut flags = self.flags.get();
207        flags.insert(Flags::ON_PUBLISH_ACK);
208        self.flags.set(flags);
209        self.on_publish_ack.set(Some(f));
210    }
211
212    pub(super) fn encode_packet(&self, pkt: codec::Packet) -> Result<(), EncodeError> {
213        self.check_streaming()?;
214        self.io.encode(pkt.into(), &self.codec)
215    }
216
217    pub(super) fn encode_publish(
218        &self,
219        pkt: Publish,
220        payload: Option<Bytes>,
221    ) -> Result<(), EncodeError> {
222        self.check_streaming()?;
223        self.enable_streaming(&pkt, payload.as_ref());
224        self.io.encode(Encoded::Publish(pkt, payload), &self.codec)
225    }
226
227    pub(super) fn encode_publish_payload(&self, payload: Bytes) -> Result<bool, EncodeError> {
228        if let Some(remaining) = self.streaming_remaining.get() {
229            let len = payload.len() as u32;
230            if len > remaining.get() {
231                self.force_close();
232                Err(EncodeError::OverPublishSize)
233            } else {
234                self.io.encode(Encoded::PayloadChunk(payload), &self.codec)?;
235                self.streaming_remaining.set(num::NonZeroU32::new(remaining.get() - len));
236                Ok(self.streaming_remaining.get().is_some())
237            }
238        } else {
239            Err(EncodeError::UnexpectedPayload)
240        }
241    }
242
243    fn clear_queues(&self) {
244        let mut queues = self.queues.borrow_mut();
245        queues.waiters.clear();
246
247        if let Some(cb) = self.on_publish_ack.take() {
248            for (idx, tx, _) in queues.inflight.drain(..) {
249                if tx.is_none() {
250                    (*cb)(idx, true);
251                }
252            }
253        } else {
254            queues.inflight.clear();
255        }
256    }
257
258    pub(super) fn enable_wr_backpressure(&self) {
259        let mut flags = self.flags.get();
260        flags.insert(Flags::WRB_ENABLED);
261        self.flags.set(flags);
262    }
263
264    pub(super) fn disable_wr_backpressure(&self) {
265        let mut flags = self.flags.get();
266        flags.remove(Flags::WRB_ENABLED);
267        self.flags.set(flags);
268
269        // streaming waiter
270        if let Some(tx) = self.streaming_waiter.take()
271            && tx.send(()).is_ok()
272        {
273            return;
274        }
275
276        // check if there are waiters
277        let mut queues = self.queues.borrow_mut();
278        if queues.inflight.len() < self.cap.get() {
279            let mut num = self.cap.get() - queues.inflight.len();
280            while num > 0 {
281                if let Some(tx) = queues.waiters.pop_front() {
282                    if tx.send(()).is_ok() {
283                        num -= 1;
284                    }
285                } else {
286                    break;
287                }
288            }
289        }
290    }
291
292    pub(super) async fn want_payload_stream(&self) -> Result<(), SendPacketError> {
293        if self.is_closed() {
294            Err(SendPacketError::Disconnected)
295        } else if self.flags.get().contains(Flags::WRB_ENABLED) {
296            let (tx, rx) = self.pool.waiters.channel();
297            self.streaming_waiter.set(Some(tx));
298            if rx.await.is_ok() {
299                Ok(())
300            } else {
301                Err(SendPacketError::Disconnected)
302            }
303        } else {
304            Ok(())
305        }
306    }
307
308    fn check_streaming(&self) -> Result<(), EncodeError> {
309        if self.streaming_remaining.get().is_some() {
310            Err(EncodeError::ExpectPayload)
311        } else {
312            Ok(())
313        }
314    }
315
316    fn enable_streaming(&self, pkt: &Publish, payload: Option<&Bytes>) {
317        let len = payload.map_or(0, Bytes::len);
318        self.streaming_remaining.set(num::NonZeroU32::new(pkt.payload_size - len as u32));
319    }
320
321    pub(super) fn pkt_ack(&self, ack: Ack) -> Result<(), ProtocolError> {
322        self.pkt_ack_inner(ack).inspect_err(|_| {
323            self.close();
324        })
325    }
326
327    fn pkt_ack_inner(&self, pkt: Ack) -> Result<(), ProtocolError> {
328        let mut queues = self.queues.borrow_mut();
329
330        // check ack order
331        if let Some((idx, tx, tp)) = queues.inflight.pop_front() {
332            if idx != pkt.packet_id() {
333                log::trace!(
334                    "MQTT protocol error: packet id order does not match; expected {}, got: {}",
335                    idx,
336                    pkt.packet_id()
337                );
338                Err(ProtocolError::packet_id_mismatch())
339            } else if matches!(pkt, Ack::Receive(_)) {
340                // get publish ack channel
341                log::trace!("Ack packet with id: {}", pkt.packet_id());
342
343                if let Some(tx) = tx {
344                    let _ = tx.send(pkt);
345                }
346                let (tx, rx) = self.pool.queue.channel();
347                queues.rx = Some(rx);
348                queues.inflight.push_back((idx, Some(tx), AckType::Complete));
349                Ok(())
350            } else if matches!(pkt, Ack::Complete(_)) {
351                // get publish ack channel
352                log::trace!("Ack packet with id: {}", pkt.packet_id());
353                queues.inflight_ids.remove(&pkt.packet_id());
354                queues.rx.take();
355
356                if let Some(tx) = tx {
357                    let _ = tx.send(pkt);
358                }
359
360                // wake up queued request (receive max limit)
361                while let Some(tx) = queues.waiters.pop_front() {
362                    if tx.send(()).is_ok() {
363                        break;
364                    }
365                }
366                Ok(())
367            } else {
368                // get publish ack channel
369                log::trace!("Ack packet with id: {}", pkt.packet_id());
370                queues.inflight_ids.remove(&pkt.packet_id());
371
372                if pkt.is_match(tp) {
373                    if let Some(tx) = tx {
374                        let _ = tx.send(pkt);
375                    } else {
376                        let cb = self.on_publish_ack.take().unwrap();
377                        (*cb)(pkt.packet_id(), false);
378                        self.on_publish_ack.set(Some(cb));
379                    }
380
381                    // wake up queued request (receive max limit)
382                    while let Some(tx) = queues.waiters.pop_front() {
383                        if tx.send(()).is_ok() {
384                            break;
385                        }
386                    }
387                    Ok(())
388                } else {
389                    log::trace!("MQTT protocol error, unexpected packet");
390                    Err(ProtocolError::unexpected_packet(pkt.packet_type(), tp.expected_str()))
391                }
392            }
393        } else {
394            log::trace!("Unexpected PUBACK packet: {:?}", pkt.packet_id());
395            Err(ProtocolError::generic_violation(
396                "Received PUBACK packet while there are no unacknowledged PUBLISH packets",
397            ))
398        }
399    }
400
401    /// Register ack in response channel
402    pub(super) fn wait_response(
403        &self,
404        id: num::NonZeroU16,
405        ack: AckType,
406    ) -> Result<pool::Receiver<Ack>, SendPacketError> {
407        let mut queues = self.queues.borrow_mut();
408        if queues.inflight_ids.contains(&id) {
409            Err(SendPacketError::PacketIdInUse(id))
410        } else {
411            let (tx, rx) = self.pool.queue.channel();
412            queues.inflight.push_back((id, Some(tx), ack));
413            queues.inflight_ids.insert(id);
414            Ok(rx)
415        }
416    }
417
418    /// Register ack in response channel
419    pub(super) fn wait_publish_response(
420        &self,
421        id: num::NonZeroU16,
422        ack: AckType,
423        pkt: Publish,
424        payload: Option<Bytes>,
425    ) -> Result<pool::Receiver<Ack>, SendPacketError> {
426        self.check_streaming()?;
427        self.enable_streaming(&pkt, payload.as_ref());
428
429        let mut queues = self.queues.borrow_mut();
430        if queues.inflight_ids.contains(&id) {
431            Err(SendPacketError::PacketIdInUse(id))
432        } else {
433            match self.io.encode(Encoded::Publish(pkt, payload), &self.codec) {
434                Ok(()) => {
435                    let (tx, rx) = self.pool.queue.channel();
436                    queues.inflight.push_back((id, Some(tx), ack));
437                    queues.inflight_ids.insert(id);
438                    Ok(rx)
439                }
440                Err(e) => Err(SendPacketError::Encode(e)),
441            }
442        }
443    }
444
445    /// Register ack in response channel
446    pub(super) fn wait_publish_response_no_block(
447        &self,
448        id: num::NonZeroU16,
449        ack: AckType,
450        pkt: Publish,
451        payload: Option<Bytes>,
452    ) -> Result<(), SendPacketError> {
453        self.check_streaming()?;
454        self.enable_streaming(&pkt, payload.as_ref());
455
456        let mut queues = self.queues.borrow_mut();
457        if queues.inflight_ids.contains(&id) {
458            Err(SendPacketError::PacketIdInUse(id))
459        } else {
460            match self.io.encode(Encoded::Publish(pkt, payload), &self.codec) {
461                Ok(()) => {
462                    assert!(
463                        self.flags.get().contains(Flags::ON_PUBLISH_ACK),
464                        "Publish ack callback is not set"
465                    );
466                    queues.inflight.push_back((id, None, ack));
467                    queues.inflight_ids.insert(id);
468                    Ok(())
469                }
470                Err(e) => Err(SendPacketError::Encode(e)),
471            }
472        }
473    }
474
475    pub(super) fn wait_readiness(&self) -> Option<pool::Receiver<()>> {
476        let mut queues = self.queues.borrow_mut();
477
478        if queues.inflight.len() >= self.cap.get()
479            || self.flags.get().contains(Flags::WRB_ENABLED)
480        {
481            let (tx, rx) = self.pool.waiters.channel();
482            queues.waiters.push_back(tx);
483            Some(rx)
484        } else {
485            None
486        }
487    }
488
489    /// Register ack in response channel
490    pub(super) fn release_publish(
491        &self,
492        id: num::NonZeroU16,
493    ) -> Result<pool::Receiver<Ack>, SendPacketError> {
494        let Some(rx) = self.queues.borrow_mut().rx.take() else {
495            return Err(SendPacketError::UnexpectedRelease);
496        };
497        match self.io.encode(
498            Encoded::Packet(codec::Packet::PublishRelease { packet_id: id }),
499            &self.codec,
500        ) {
501            Ok(()) => Ok(rx),
502            Err(e) => Err(SendPacketError::Encode(e)),
503        }
504    }
505}
506
507impl Encoder for MqttShared {
508    type Item = Encoded;
509    type Error = EncodeError;
510
511    #[inline]
512    fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
513        self.codec.encode(item, dst)
514    }
515}
516
517impl Decoder for MqttShared {
518    type Item = codec::Decoded;
519    type Error = DecodeError;
520
521    #[inline]
522    fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
523        self.codec.decode(src)
524    }
525}
526
527impl Ack {
528    pub(super) fn packet_type(&self) -> u8 {
529        match self {
530            Ack::Publish(_) => packet_type::PUBACK,
531            Ack::Receive(_) => packet_type::PUBREC,
532            Ack::Complete(_) => packet_type::PUBCOMP,
533            Ack::Subscribe { .. } => packet_type::SUBACK,
534            Ack::Unsubscribe(_) => packet_type::UNSUBACK,
535        }
536    }
537
538    pub(super) fn packet_id(&self) -> num::NonZeroU16 {
539        match self {
540            Ack::Subscribe { packet_id, .. } => *packet_id,
541            Ack::Publish(id) | Ack::Receive(id) | Ack::Complete(id) | Ack::Unsubscribe(id) => {
542                *id
543            }
544        }
545    }
546
547    pub(super) fn subscribe(self) -> Vec<codec::SubscribeReturnCode> {
548        if let Ack::Subscribe { status, .. } = self {
549            status
550        } else {
551            panic!()
552        }
553    }
554
555    pub(super) fn is_match(&self, tp: AckType) -> bool {
556        match (self, tp) {
557            (Ack::Publish(_), AckType::Publish)
558            | (Ack::Receive(_), AckType::Receive)
559            | (Ack::Subscribe { .. }, AckType::Subscribe)
560            | (Ack::Unsubscribe(_), AckType::Unsubscribe) => true,
561            (_, _) => false,
562        }
563    }
564}
565
566impl AckType {
567    pub(super) fn expected_str(self) -> &'static str {
568        match self {
569            AckType::Publish => "Expected PUBACK packet",
570            AckType::Receive => "Expected PUBREC packet",
571            AckType::Complete => "Expected PUBCOMP packet",
572            AckType::Subscribe => "Expected SUBACK packet",
573            AckType::Unsubscribe => "Expected UNSUBACK packet",
574        }
575    }
576}