ntex_mqtt/v5/
sink.rs

1use std::num::{NonZeroU16, NonZeroU32};
2use std::{cell::Cell, fmt, future::Future, future::ready, rc::Rc};
3
4use ntex_bytes::{ByteString, Bytes};
5use ntex_util::{channel::pool, future::Either, future::Ready};
6
7use super::codec::{self, EncodeLtd};
8use super::shared::{AckType, MqttShared};
9use crate::{error::EncodeError, error::SendPacketError, types::QoS};
10
11pub struct MqttSink(Rc<MqttShared>);
12
13impl Clone for MqttSink {
14    fn clone(&self) -> Self {
15        MqttSink(self.0.clone())
16    }
17}
18
19impl MqttSink {
20    pub(super) fn new(state: Rc<MqttShared>) -> Self {
21        MqttSink(state)
22    }
23
24    pub(super) fn shared(&self) -> Rc<MqttShared> {
25        self.0.clone()
26    }
27
28    #[inline]
29    /// Check if io stream is open
30    pub fn is_open(&self) -> bool {
31        !self.0.is_closed()
32    }
33
34    #[inline]
35    /// Check if sink is ready
36    pub fn is_ready(&self) -> bool {
37        if self.0.is_closed() {
38            false
39        } else {
40            self.0.is_ready()
41        }
42    }
43
44    #[inline]
45    /// Get client's receive credit
46    pub fn credit(&self) -> usize {
47        self.0.credit()
48    }
49
50    /// Get notification when packet could be send to the peer.
51    ///
52    /// Result indicates if connection is alive
53    pub fn ready(&self) -> impl Future<Output = bool> {
54        if !self.0.is_closed() {
55            self.0
56                .wait_readiness()
57                .map(|rx| Either::Right(async move { rx.await.is_ok() }))
58                .unwrap_or_else(|| Either::Left(ready(true)))
59        } else {
60            Either::Left(ready(false))
61        }
62    }
63
64    #[inline]
65    /// Force close MQTT connection. Dispatcher does not wait for uncompleted
66    /// responses (ending them with error), but it flushes buffers.
67    pub fn force_close(&self) {
68        self.0.force_close();
69    }
70
71    #[inline]
72    /// Close mqtt connection with default Disconnect message
73    pub fn close(&self) {
74        self.0.close(codec::Disconnect::default());
75    }
76
77    #[inline]
78    /// Close mqtt connection
79    pub fn close_with_reason(&self, pkt: codec::Disconnect) {
80        self.0.close(pkt);
81    }
82
83    /// Send ping
84    pub(super) fn ping(&self) -> bool {
85        self.0.encode_packet(codec::Packet::PingRequest).is_ok()
86    }
87
88    #[inline]
89    /// Create publish packet builder
90    pub fn publish<U>(&self, topic: U) -> PublishBuilder
91    where
92        ByteString: From<U>,
93    {
94        self.publish_pkt(codec::Publish {
95            dup: false,
96            retain: false,
97            topic: topic.into(),
98            qos: QoS::AtMostOnce,
99            packet_id: None,
100            payload_size: 0,
101            properties: codec::PublishProperties::default(),
102        })
103    }
104
105    #[inline]
106    /// Create publish builder with publish packet
107    pub fn publish_pkt(&self, packet: codec::Publish) -> PublishBuilder {
108        PublishBuilder::new(self.0.clone(), packet)
109    }
110
111    /// Set publish ack callback
112    ///
113    /// Use non-blocking send, PublishBuilder::send_at_least_once_no_block()
114    /// First argument is packet id, second argument is "disconnected" state
115    pub fn publish_ack_cb<F>(&self, f: F)
116    where
117        F: Fn(codec::PublishAck, bool) + 'static,
118    {
119        self.0.set_publish_ack(Box::new(f));
120    }
121
122    #[inline]
123    /// Create subscribe packet builder
124    pub fn subscribe(&self, id: Option<NonZeroU32>) -> SubscribeBuilder {
125        SubscribeBuilder {
126            id: None,
127            packet: codec::Subscribe {
128                id,
129                packet_id: NonZeroU16::new(1).unwrap(),
130                user_properties: Vec::new(),
131                topic_filters: Vec::new(),
132            },
133            shared: self.0.clone(),
134        }
135    }
136
137    #[inline]
138    /// Create unsubscribe packet builder
139    pub fn unsubscribe(&self) -> UnsubscribeBuilder {
140        UnsubscribeBuilder {
141            id: None,
142            packet: codec::Unsubscribe {
143                packet_id: NonZeroU16::new(1).unwrap(),
144                user_properties: Vec::new(),
145                topic_filters: Vec::new(),
146            },
147            shared: self.0.clone(),
148        }
149    }
150}
151
152impl fmt::Debug for MqttSink {
153    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
154        fmt.debug_struct("MqttSink").finish()
155    }
156}
157
158pub struct PublishBuilder {
159    shared: Rc<MqttShared>,
160    packet: codec::Publish,
161}
162
163impl PublishBuilder {
164    fn new(shared: Rc<MqttShared>, packet: codec::Publish) -> Self {
165        Self { shared, packet }
166    }
167
168    #[inline]
169    /// Set packet id.
170    ///
171    /// Note: if packet id is not set, it gets generated automatically.
172    /// Packet id management should not be mixed, it should be auto-generated
173    /// or set by user. Otherwise collisions could occure.
174    ///
175    /// panics if id is 0
176    pub fn packet_id(mut self, id: u16) -> Self {
177        let id = NonZeroU16::new(id).expect("id 0 is not allowed");
178        self.packet.packet_id = Some(id);
179        self
180    }
181
182    #[inline]
183    /// This might be re-delivery of an earlier attempt to send the Packet.
184    pub fn dup(mut self, val: bool) -> Self {
185        self.packet.dup = val;
186        self
187    }
188
189    #[inline]
190    /// Set retain flag
191    pub fn retain(mut self, val: bool) -> Self {
192        self.packet.retain = val;
193        self
194    }
195
196    #[inline]
197    /// Set publish packet properties
198    pub fn properties<F>(mut self, f: F) -> Self
199    where
200        F: FnOnce(&mut codec::PublishProperties),
201    {
202        f(&mut self.packet.properties);
203        self
204    }
205
206    #[inline]
207    /// Set publish packet properties
208    pub fn set_properties<F>(&mut self, f: F)
209    where
210        F: FnOnce(&mut codec::PublishProperties),
211    {
212        f(&mut self.packet.properties);
213    }
214
215    #[inline]
216    /// Get size of the publish packet
217    pub fn size(&self, payload_size: usize) -> u32 {
218        (self.packet.encoded_size(u32::MAX) + payload_size) as u32
219    }
220
221    #[inline]
222    /// Send publish packet with QoS 0
223    pub fn send_at_most_once(mut self, payload: Bytes) -> Result<(), SendPacketError> {
224        if !self.shared.is_closed() {
225            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
226            self.packet.qos = QoS::AtMostOnce;
227            self.packet.payload_size = payload.len() as u32;
228            self.shared
229                .encode_publish(self.packet, Some(payload))
230                .map_err(SendPacketError::Encode)
231                .map(|_| ())
232        } else {
233            log::error!("Mqtt sink is disconnected");
234            Err(SendPacketError::Disconnected)
235        }
236    }
237
238    /// Send publish packet with QoS 0
239    pub fn stream_at_most_once(
240        mut self,
241        size: u32,
242    ) -> Result<StreamingPayload, SendPacketError> {
243        if !self.shared.is_closed() {
244            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
245
246            let stream = StreamingPayload {
247                rx: Cell::new(None),
248                shared: self.shared.clone(),
249                inprocess: Cell::new(true),
250            };
251
252            self.packet.qos = QoS::AtMostOnce;
253            self.packet.payload_size = size;
254            self.shared
255                .encode_publish(self.packet, None)
256                .map_err(SendPacketError::Encode)
257                .map(|_| stream)
258        } else {
259            log::error!("Mqtt sink is disconnected");
260            Err(SendPacketError::Disconnected)
261        }
262    }
263
264    /// Send publish packet with QoS 1
265    pub fn send_at_least_once(
266        mut self,
267        payload: Bytes,
268    ) -> impl Future<Output = Result<codec::PublishAck, SendPacketError>> {
269        if !self.shared.is_closed() {
270            self.packet.qos = QoS::AtLeastOnce;
271            self.packet.payload_size = payload.len() as u32;
272
273            // handle client receive maximum
274            if let Some(rx) = self.shared.wait_readiness() {
275                Either::Left(Either::Left(async move {
276                    if rx.await.is_err() {
277                        return Err(SendPacketError::Disconnected);
278                    }
279                    self.send_at_least_once_inner(payload).await
280                }))
281            } else {
282                Either::Left(Either::Right(self.send_at_least_once_inner(payload)))
283            }
284        } else {
285            Either::Right(Ready::Err(SendPacketError::Disconnected))
286        }
287    }
288
289    /// Non-blocking send publish packet with QoS 1
290    ///
291    /// Panics if sink is not ready or publish ack callback is not set
292    pub fn send_at_least_once_no_block(
293        mut self,
294        payload: Bytes,
295    ) -> Result<(), SendPacketError> {
296        if !self.shared.is_closed() {
297            // check readiness
298            if !self.shared.is_ready() {
299                panic!("Mqtt sink is not ready");
300            }
301            self.packet.qos = codec::QoS::AtLeastOnce;
302            self.packet.payload_size = payload.len() as u32;
303
304            let idx = self.shared.set_publish_id(&mut self.packet);
305
306            log::trace!("Publish (QoS1) to {:#?}", self.packet);
307            self.shared.wait_publish_response_no_block(
308                idx,
309                AckType::Publish,
310                self.packet,
311                Some(payload),
312            )
313        } else {
314            Err(SendPacketError::Disconnected)
315        }
316    }
317
318    /// Send publish packet with QoS 1
319    pub fn stream_at_least_once(
320        mut self,
321        size: u32,
322    ) -> (impl Future<Output = Result<codec::PublishAck, SendPacketError>>, StreamingPayload)
323    {
324        let (tx, rx) = self.shared.pool.waiters.channel();
325        let stream = StreamingPayload {
326            rx: Cell::new(Some(rx)),
327            shared: self.shared.clone(),
328            inprocess: Cell::new(false),
329        };
330
331        if !self.shared.is_closed() {
332            self.packet.qos = QoS::AtLeastOnce;
333            self.packet.payload_size = size;
334
335            // handle client receive maximum
336            let fut = if let Some(rx) = self.shared.wait_readiness() {
337                Either::Left(Either::Left(async move {
338                    if rx.await.is_err() {
339                        return Err(SendPacketError::Disconnected);
340                    }
341                    self.stream_at_least_once_inner(tx, None).await
342                }))
343            } else {
344                Either::Left(Either::Right(self.stream_at_least_once_inner(tx, None)))
345            };
346            (fut, stream)
347        } else {
348            (Either::Right(Ready::Err(SendPacketError::Disconnected)), stream)
349        }
350    }
351
352    async fn send_at_least_once_inner(
353        mut self,
354        payload: Bytes,
355    ) -> Result<codec::PublishAck, SendPacketError> {
356        // packet id
357        let idx = self.shared.set_publish_id(&mut self.packet);
358
359        // send publish to client
360        log::trace!("Publish (QoS1) to {:#?}", self.packet);
361        self.shared
362            .wait_publish_response(idx, AckType::Publish, self.packet, Some(payload))?
363            .await
364            .map(|pkt| pkt.publish())
365            .map_err(|_| SendPacketError::Disconnected)
366    }
367
368    async fn stream_at_least_once_inner(
369        mut self,
370        tx: pool::Sender<()>,
371        chunk: Option<Bytes>,
372    ) -> Result<codec::PublishAck, SendPacketError> {
373        // packet id
374        let idx = self.shared.set_publish_id(&mut self.packet);
375
376        // send publish to client
377        log::trace!("Publish (QoS1) to {:#?}", self.packet);
378
379        if tx.is_canceled() {
380            Err(SendPacketError::StreamingCancelled)
381        } else {
382            let rx =
383                self.shared.wait_publish_response(idx, AckType::Publish, self.packet, chunk);
384            let _ = tx.send(());
385
386            rx?.await.map(|pkt| pkt.publish()).map_err(|_| SendPacketError::Disconnected)
387        }
388    }
389
390    /// Send publish packet with QoS 2
391    pub fn send_exactly_once(
392        mut self,
393        payload: Bytes,
394    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
395        if !self.shared.is_closed() {
396            self.packet.qos = codec::QoS::ExactlyOnce;
397            self.packet.payload_size = payload.len() as u32;
398
399            // handle client receive maximum
400            if let Some(rx) = self.shared.wait_readiness() {
401                Either::Left(Either::Left(async move {
402                    if rx.await.is_err() {
403                        return Err(SendPacketError::Disconnected);
404                    }
405                    self.send_exactly_once_inner(payload).await
406                }))
407            } else {
408                Either::Left(Either::Right(self.send_exactly_once_inner(payload)))
409            }
410        } else {
411            Either::Right(Ready::Err(SendPacketError::Disconnected))
412        }
413    }
414
415    fn send_exactly_once_inner(
416        mut self,
417        payload: Bytes,
418    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
419        let shared = self.shared.clone();
420        let idx = shared.set_publish_id(&mut self.packet);
421        log::trace!("Publish (QoS2) to {:#?}", self.packet);
422
423        let rx =
424            shared.wait_publish_response(idx, AckType::Receive, self.packet, Some(payload));
425        async move {
426            rx?.await
427                .map(move |ack| PublishReceived::new(ack.receive(), shared))
428                .map_err(|_| SendPacketError::Disconnected)
429        }
430    }
431}
432
433/// Publish released for QoS2
434pub struct PublishReceived {
435    ack: codec::PublishAck,
436    result: Option<codec::PublishAck2>,
437    shared: Rc<MqttShared>,
438}
439
440impl PublishReceived {
441    fn new(ack: codec::PublishAck, shared: Rc<MqttShared>) -> Self {
442        let packet_id = ack.packet_id;
443        Self {
444            ack,
445            shared,
446            result: Some(codec::PublishAck2 {
447                packet_id,
448                reason_code: codec::PublishAck2Reason::Success,
449                properties: codec::UserProperties::default(),
450                reason_string: None,
451            }),
452        }
453    }
454
455    /// Returns reference to auth packet
456    pub fn packet(&self) -> &codec::PublishAck {
457        &self.ack
458    }
459
460    /// Update user properties
461    #[inline]
462    pub fn properties<F>(mut self, f: F) -> Self
463    where
464        F: FnOnce(&mut codec::UserProperties),
465    {
466        f(&mut self.result.as_mut().unwrap().properties);
467        self
468    }
469
470    /// Set ack reason string
471    #[inline]
472    pub fn reason(mut self, reason: ByteString) -> Self {
473        self.result.as_mut().unwrap().reason_string = Some(reason);
474        self
475    }
476
477    /// Release publish
478    pub async fn release(mut self) -> Result<(), SendPacketError> {
479        let rx = self.shared.release_publish(self.result.take().unwrap())?;
480
481        rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
482    }
483}
484
485impl Drop for PublishReceived {
486    fn drop(&mut self) {
487        if let Some(ack) = self.result.take() {
488            let _ = self.shared.release_publish(ack);
489        }
490    }
491}
492
493/// Subscribe packet builder
494pub struct SubscribeBuilder {
495    id: Option<NonZeroU16>,
496    packet: codec::Subscribe,
497    shared: Rc<MqttShared>,
498}
499
500impl SubscribeBuilder {
501    #[inline]
502    /// Set packet id.
503    ///
504    /// panics if id is 0
505    pub fn packet_id(mut self, id: u16) -> Self {
506        if let Some(id) = NonZeroU16::new(id) {
507            self.id = Some(id);
508            self
509        } else {
510            panic!("id 0 is not allowed");
511        }
512    }
513
514    #[inline]
515    /// Add topic filter
516    pub fn topic_filter(
517        mut self,
518        filter: ByteString,
519        opts: codec::SubscriptionOptions,
520    ) -> Self {
521        self.packet.topic_filters.push((filter, opts));
522        self
523    }
524
525    #[inline]
526    /// Add user property
527    pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
528        self.packet.user_properties.push((key, value));
529        self
530    }
531
532    #[inline]
533    /// Get size of the subscribe packet
534    pub fn size(&self) -> u32 {
535        self.packet.encoded_size(u32::MAX) as u32
536    }
537
538    /// Send subscribe packet
539    pub async fn send(self) -> Result<codec::SubscribeAck, SendPacketError> {
540        let shared = self.shared;
541        let mut packet = self.packet;
542
543        if !shared.is_closed() {
544            // handle client receive maximum
545            if let Some(rx) = shared.wait_readiness() {
546                if rx.await.is_err() {
547                    return Err(SendPacketError::Disconnected);
548                }
549            }
550
551            // allocate packet id
552            packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
553
554            // send subscribe to client
555            log::trace!("Sending subscribe packet {:#?}", packet);
556
557            let rx = shared.wait_response(packet.packet_id, AckType::Subscribe)?;
558            match shared.encode_packet(codec::Packet::Subscribe(packet)) {
559                Ok(_) => {
560                    // wait ack from peer
561                    rx.await
562                        .map_err(|_| SendPacketError::Disconnected)
563                        .map(|pkt| pkt.subscribe())
564                }
565                Err(err) => Err(SendPacketError::Encode(err)),
566            }
567        } else {
568            Err(SendPacketError::Disconnected)
569        }
570    }
571}
572
573/// Unsubscribe packet builder
574pub struct UnsubscribeBuilder {
575    id: Option<NonZeroU16>,
576    packet: codec::Unsubscribe,
577    shared: Rc<MqttShared>,
578}
579
580impl UnsubscribeBuilder {
581    #[inline]
582    /// Set packet id.
583    ///
584    /// panics if id is 0
585    pub fn packet_id(mut self, id: u16) -> Self {
586        if let Some(id) = NonZeroU16::new(id) {
587            self.id = Some(id);
588            self
589        } else {
590            panic!("id 0 is not allowed");
591        }
592    }
593
594    #[inline]
595    /// Add topic filter
596    pub fn topic_filter(mut self, filter: ByteString) -> Self {
597        self.packet.topic_filters.push(filter);
598        self
599    }
600
601    #[inline]
602    /// Add user property
603    pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
604        self.packet.user_properties.push((key, value));
605        self
606    }
607
608    #[inline]
609    /// Get size of the unsubscribe packet
610    pub fn size(&self) -> u32 {
611        self.packet.encoded_size(u32::MAX) as u32
612    }
613
614    /// Send unsubscribe packet
615    pub async fn send(self) -> Result<codec::UnsubscribeAck, SendPacketError> {
616        let shared = self.shared;
617        let mut packet = self.packet;
618
619        if !shared.is_closed() {
620            // handle client receive maximum
621            if let Some(rx) = shared.wait_readiness() {
622                if rx.await.is_err() {
623                    return Err(SendPacketError::Disconnected);
624                }
625            }
626            // allocate packet id
627            packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
628
629            // send unsubscribe to client
630            log::trace!("Sending unsubscribe packet {:#?}", packet);
631
632            let rx = shared.wait_response(packet.packet_id, AckType::Unsubscribe)?;
633            match shared.encode_packet(codec::Packet::Unsubscribe(packet)) {
634                Ok(_) => {
635                    // wait ack from peer
636                    rx.await
637                        .map_err(|_| SendPacketError::Disconnected)
638                        .map(|pkt| pkt.unsubscribe())
639                }
640                Err(err) => Err(SendPacketError::Encode(err)),
641            }
642        } else {
643            Err(SendPacketError::Disconnected)
644        }
645    }
646}
647
648pub struct StreamingPayload {
649    shared: Rc<MqttShared>,
650    rx: Cell<Option<pool::Receiver<()>>>,
651    inprocess: Cell<bool>,
652}
653
654impl Drop for StreamingPayload {
655    fn drop(&mut self) {
656        if self.inprocess.get() && self.shared.is_streaming() {
657            self.shared.streaming_dropped();
658        }
659    }
660}
661
662impl StreamingPayload {
663    /// Send payload chunk
664    pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
665        if let Some(rx) = self.rx.take() {
666            if rx.await.is_err() {
667                return Err(SendPacketError::StreamingCancelled);
668            }
669            log::trace!("Publish is encoded, ready to process payload");
670            self.inprocess.set(true);
671        }
672
673        if !self.inprocess.get() {
674            Err(EncodeError::UnexpectedPayload.into())
675        } else {
676            log::trace!("Sending payload chunk: {:?}", chunk.len());
677            self.shared.want_payload_stream().await?;
678
679            if !self.shared.encode_publish_payload(chunk)? {
680                self.inprocess.set(false);
681            }
682            Ok(())
683        }
684    }
685}