ntex_mqtt/v5/
sink.rs

1use std::num::{NonZeroU16, NonZeroU32};
2use std::{cell::Cell, fmt, future::ready, future::Future, rc::Rc};
3
4use ntex_bytes::{ByteString, Bytes};
5use ntex_util::{channel::pool, future::Either, future::Ready};
6
7use super::codec::{self, EncodeLtd};
8use super::shared::{AckType, MqttShared};
9use crate::{error::EncodeError, error::SendPacketError, types::QoS};
10
11pub struct MqttSink(Rc<MqttShared>);
12
13impl Clone for MqttSink {
14    fn clone(&self) -> Self {
15        MqttSink(self.0.clone())
16    }
17}
18
19impl MqttSink {
20    pub(super) fn new(state: Rc<MqttShared>) -> Self {
21        MqttSink(state)
22    }
23
24    pub(super) fn shared(&self) -> Rc<MqttShared> {
25        self.0.clone()
26    }
27
28    #[inline]
29    /// Check if io stream is open
30    pub fn is_open(&self) -> bool {
31        !self.0.is_closed()
32    }
33
34    #[inline]
35    /// Check if sink is ready
36    pub fn is_ready(&self) -> bool {
37        if self.0.is_closed() {
38            false
39        } else {
40            self.0.is_ready()
41        }
42    }
43
44    #[inline]
45    /// Get client's receive credit
46    pub fn credit(&self) -> usize {
47        self.0.credit()
48    }
49
50    /// Get notification when packet could be send to the peer.
51    ///
52    /// Result indicates if connection is alive
53    pub fn ready(&self) -> impl Future<Output = bool> {
54        if !self.0.is_closed() {
55            self.0
56                .wait_readiness()
57                .map(|rx| Either::Right(async move { rx.await.is_ok() }))
58                .unwrap_or_else(|| Either::Left(ready(true)))
59        } else {
60            Either::Left(ready(false))
61        }
62    }
63
64    #[inline]
65    /// Force close MQTT connection. Dispatcher does not wait for uncompleted
66    /// responses (ending them with error), but it flushes buffers.
67    pub fn force_close(&self) {
68        self.0.force_close();
69    }
70
71    #[inline]
72    /// Close mqtt connection with default Disconnect message
73    pub fn close(&self) {
74        self.0.close(codec::Disconnect::default());
75    }
76
77    #[inline]
78    /// Close mqtt connection
79    pub fn close_with_reason(&self, pkt: codec::Disconnect) {
80        self.0.close(pkt);
81    }
82
83    /// Send ping
84    pub(super) fn ping(&self) -> bool {
85        self.0.encode_packet(codec::Packet::PingRequest).is_ok()
86    }
87
88    #[inline]
89    /// Create publish packet builder
90    pub fn publish<U>(&self, topic: U, payload: Bytes) -> PublishBuilder
91    where
92        ByteString: From<U>,
93    {
94        self.publish_pkt(
95            codec::Publish {
96                dup: false,
97                retain: false,
98                topic: topic.into(),
99                qos: QoS::AtMostOnce,
100                packet_id: None,
101                payload_size: 0,
102                properties: codec::PublishProperties::default(),
103            },
104            payload,
105        )
106    }
107
108    #[inline]
109    /// Create publish builder with publish packet
110    pub fn publish_pkt(&self, packet: codec::Publish, payload: Bytes) -> PublishBuilder {
111        PublishBuilder::new(self.0.clone(), packet, payload)
112    }
113
114    /// Set publish ack callback
115    ///
116    /// Use non-blocking send, PublishBuilder::send_at_least_once_no_block()
117    /// First argument is packet id, second argument is "disconnected" state
118    pub fn publish_ack_cb<F>(&self, f: F)
119    where
120        F: Fn(codec::PublishAck, bool) + 'static,
121    {
122        self.0.set_publish_ack(Box::new(f));
123    }
124
125    #[inline]
126    /// Create subscribe packet builder
127    pub fn subscribe(&self, id: Option<NonZeroU32>) -> SubscribeBuilder {
128        SubscribeBuilder {
129            id: None,
130            packet: codec::Subscribe {
131                id,
132                packet_id: NonZeroU16::new(1).unwrap(),
133                user_properties: Vec::new(),
134                topic_filters: Vec::new(),
135            },
136            shared: self.0.clone(),
137        }
138    }
139
140    #[inline]
141    /// Create unsubscribe packet builder
142    pub fn unsubscribe(&self) -> UnsubscribeBuilder {
143        UnsubscribeBuilder {
144            id: None,
145            packet: codec::Unsubscribe {
146                packet_id: NonZeroU16::new(1).unwrap(),
147                user_properties: Vec::new(),
148                topic_filters: Vec::new(),
149            },
150            shared: self.0.clone(),
151        }
152    }
153}
154
155impl fmt::Debug for MqttSink {
156    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
157        fmt.debug_struct("MqttSink").finish()
158    }
159}
160
161pub struct PublishBuilder {
162    shared: Rc<MqttShared>,
163    packet: codec::Publish,
164    payload: Bytes,
165}
166
167impl PublishBuilder {
168    fn new(shared: Rc<MqttShared>, mut packet: codec::Publish, payload: Bytes) -> Self {
169        packet.payload_size = payload.len() as u32;
170        Self { shared, packet, payload }
171    }
172
173    #[inline]
174    /// Set packet id.
175    ///
176    /// Note: if packet id is not set, it gets generated automatically.
177    /// Packet id management should not be mixed, it should be auto-generated
178    /// or set by user. Otherwise collisions could occure.
179    ///
180    /// panics if id is 0
181    pub fn packet_id(mut self, id: u16) -> Self {
182        let id = NonZeroU16::new(id).expect("id 0 is not allowed");
183        self.packet.packet_id = Some(id);
184        self
185    }
186
187    #[inline]
188    /// This might be re-delivery of an earlier attempt to send the Packet.
189    pub fn dup(mut self, val: bool) -> Self {
190        self.packet.dup = val;
191        self
192    }
193
194    #[inline]
195    /// Set retain flag
196    pub fn retain(mut self, val: bool) -> Self {
197        self.packet.retain = val;
198        self
199    }
200
201    #[inline]
202    /// Set publish packet properties
203    pub fn properties<F>(mut self, f: F) -> Self
204    where
205        F: FnOnce(&mut codec::PublishProperties),
206    {
207        f(&mut self.packet.properties);
208        self
209    }
210
211    #[inline]
212    /// Set publish packet properties
213    pub fn set_properties<F>(&mut self, f: F)
214    where
215        F: FnOnce(&mut codec::PublishProperties),
216    {
217        f(&mut self.packet.properties);
218    }
219
220    #[inline]
221    /// Get size of the publish packet
222    pub fn size(&self) -> u32 {
223        self.packet.encoded_size(u32::MAX) as u32
224    }
225
226    /// Create streamimng publish builder
227    pub fn streaming(mut self, size: u32) -> (StreamingPublishBuilder, StreamingPayload) {
228        self.packet.payload_size = size;
229        let payload = if self.payload.is_empty() { None } else { Some(self.payload) };
230
231        let (tx, rx) = self.shared.pool.waiters.channel();
232        (
233            StreamingPublishBuilder {
234                size,
235                payload,
236                tx: Some(tx),
237                shared: self.shared.clone(),
238                packet: self.packet,
239            },
240            StreamingPayload {
241                rx: Cell::new(Some(rx)),
242                shared: self.shared.clone(),
243                inprocess: Cell::new(false),
244            },
245        )
246    }
247
248    #[inline]
249    /// Send publish packet with QoS 0
250    pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
251        if !self.shared.is_closed() {
252            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
253            self.packet.qos = QoS::AtMostOnce;
254            self.shared
255                .encode_publish(self.packet, Some(self.payload))
256                .map_err(SendPacketError::Encode)
257                .map(|_| ())
258        } else {
259            log::error!("Mqtt sink is disconnected");
260            Err(SendPacketError::Disconnected)
261        }
262    }
263
264    /// Send publish packet with QoS 1
265    pub fn send_at_least_once(
266        mut self,
267    ) -> impl Future<Output = Result<codec::PublishAck, SendPacketError>> {
268        if !self.shared.is_closed() {
269            self.packet.qos = QoS::AtLeastOnce;
270
271            // handle client receive maximum
272            if let Some(rx) = self.shared.wait_readiness() {
273                Either::Left(Either::Left(async move {
274                    if rx.await.is_err() {
275                        return Err(SendPacketError::Disconnected);
276                    }
277                    self.send_at_least_once_inner().await
278                }))
279            } else {
280                Either::Left(Either::Right(self.send_at_least_once_inner()))
281            }
282        } else {
283            Either::Right(Ready::Err(SendPacketError::Disconnected))
284        }
285    }
286
287    /// Non-blocking send publish packet with QoS 1
288    ///
289    /// Panics if sink is not ready or publish ack callback is not set
290    pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
291        if !self.shared.is_closed() {
292            // check readiness
293            if !self.shared.is_ready() {
294                panic!("Mqtt sink is not ready");
295            }
296            self.packet.qos = codec::QoS::AtLeastOnce;
297            let idx = self.shared.set_publish_id(&mut self.packet);
298
299            log::trace!("Publish (QoS1) to {:#?}", self.packet);
300            self.shared.wait_publish_response_no_block(
301                idx,
302                AckType::Publish,
303                self.packet,
304                Some(self.payload),
305            )
306        } else {
307            Err(SendPacketError::Disconnected)
308        }
309    }
310
311    async fn send_at_least_once_inner(mut self) -> Result<codec::PublishAck, SendPacketError> {
312        // packet id
313        let idx = self.shared.set_publish_id(&mut self.packet);
314
315        // send publish to client
316        log::trace!("Publish (QoS1) to {:#?}", self.packet);
317        self.shared
318            .wait_publish_response(idx, AckType::Publish, self.packet, Some(self.payload))?
319            .await
320            .map(|pkt| pkt.publish())
321            .map_err(|_| SendPacketError::Disconnected)
322    }
323
324    /// Send publish packet with QoS 2
325    pub fn send_exactly_once(
326        mut self,
327    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
328        if !self.shared.is_closed() {
329            self.packet.qos = codec::QoS::ExactlyOnce;
330
331            // handle client receive maximum
332            if let Some(rx) = self.shared.wait_readiness() {
333                Either::Left(Either::Left(async move {
334                    if rx.await.is_err() {
335                        return Err(SendPacketError::Disconnected);
336                    }
337                    self.send_exactly_once_inner().await
338                }))
339            } else {
340                Either::Left(Either::Right(self.send_exactly_once_inner()))
341            }
342        } else {
343            Either::Right(Ready::Err(SendPacketError::Disconnected))
344        }
345    }
346
347    fn send_exactly_once_inner(
348        mut self,
349    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
350        let shared = self.shared.clone();
351        let idx = shared.set_publish_id(&mut self.packet);
352        log::trace!("Publish (QoS2) to {:#?}", self.packet);
353
354        let rx = shared.wait_publish_response(
355            idx,
356            AckType::Receive,
357            self.packet,
358            Some(self.payload),
359        );
360        async move {
361            rx?.await
362                .map(move |ack| PublishReceived::new(ack.receive(), shared))
363                .map_err(|_| SendPacketError::Disconnected)
364        }
365    }
366}
367
368/// Publish released for QoS2
369pub struct PublishReceived {
370    ack: codec::PublishAck,
371    result: Option<codec::PublishAck2>,
372    shared: Rc<MqttShared>,
373}
374
375impl PublishReceived {
376    fn new(ack: codec::PublishAck, shared: Rc<MqttShared>) -> Self {
377        let packet_id = ack.packet_id;
378        Self {
379            ack,
380            shared,
381            result: Some(codec::PublishAck2 {
382                packet_id,
383                reason_code: codec::PublishAck2Reason::Success,
384                properties: codec::UserProperties::default(),
385                reason_string: None,
386            }),
387        }
388    }
389
390    /// Returns reference to auth packet
391    pub fn packet(&self) -> &codec::PublishAck {
392        &self.ack
393    }
394
395    /// Update user properties
396    #[inline]
397    pub fn properties<F>(mut self, f: F) -> Self
398    where
399        F: FnOnce(&mut codec::UserProperties),
400    {
401        f(&mut self.result.as_mut().unwrap().properties);
402        self
403    }
404
405    /// Set ack reason string
406    #[inline]
407    pub fn reason(mut self, reason: ByteString) -> Self {
408        self.result.as_mut().unwrap().reason_string = Some(reason);
409        self
410    }
411
412    /// Release publish
413    pub async fn release(mut self) -> Result<(), SendPacketError> {
414        let rx = self.shared.release_publish(self.result.take().unwrap())?;
415
416        rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
417    }
418}
419
420impl Drop for PublishReceived {
421    fn drop(&mut self) {
422        if let Some(ack) = self.result.take() {
423            self.shared.release_publish(ack);
424        }
425    }
426}
427
428/// Subscribe packet builder
429pub struct SubscribeBuilder {
430    id: Option<NonZeroU16>,
431    packet: codec::Subscribe,
432    shared: Rc<MqttShared>,
433}
434
435impl SubscribeBuilder {
436    #[inline]
437    /// Set packet id.
438    ///
439    /// panics if id is 0
440    pub fn packet_id(mut self, id: u16) -> Self {
441        if let Some(id) = NonZeroU16::new(id) {
442            self.id = Some(id);
443            self
444        } else {
445            panic!("id 0 is not allowed");
446        }
447    }
448
449    #[inline]
450    /// Add topic filter
451    pub fn topic_filter(
452        mut self,
453        filter: ByteString,
454        opts: codec::SubscriptionOptions,
455    ) -> Self {
456        self.packet.topic_filters.push((filter, opts));
457        self
458    }
459
460    #[inline]
461    /// Add user property
462    pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
463        self.packet.user_properties.push((key, value));
464        self
465    }
466
467    #[inline]
468    /// Get size of the subscribe packet
469    pub fn size(&self) -> u32 {
470        self.packet.encoded_size(u32::MAX) as u32
471    }
472
473    /// Send subscribe packet
474    pub async fn send(self) -> Result<codec::SubscribeAck, SendPacketError> {
475        let shared = self.shared;
476        let mut packet = self.packet;
477
478        if !shared.is_closed() {
479            // handle client receive maximum
480            if let Some(rx) = shared.wait_readiness() {
481                if rx.await.is_err() {
482                    return Err(SendPacketError::Disconnected);
483                }
484            }
485
486            // allocate packet id
487            packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
488
489            // send subscribe to client
490            log::trace!("Sending subscribe packet {:#?}", packet);
491
492            let rx = shared.wait_response(packet.packet_id, AckType::Subscribe)?;
493            match shared.encode_packet(codec::Packet::Subscribe(packet)) {
494                Ok(_) => {
495                    // wait ack from peer
496                    rx.await
497                        .map_err(|_| SendPacketError::Disconnected)
498                        .map(|pkt| pkt.subscribe())
499                }
500                Err(err) => Err(SendPacketError::Encode(err)),
501            }
502        } else {
503            Err(SendPacketError::Disconnected)
504        }
505    }
506}
507
508/// Unsubscribe packet builder
509pub struct UnsubscribeBuilder {
510    id: Option<NonZeroU16>,
511    packet: codec::Unsubscribe,
512    shared: Rc<MqttShared>,
513}
514
515impl UnsubscribeBuilder {
516    #[inline]
517    /// Set packet id.
518    ///
519    /// panics if id is 0
520    pub fn packet_id(mut self, id: u16) -> Self {
521        if let Some(id) = NonZeroU16::new(id) {
522            self.id = Some(id);
523            self
524        } else {
525            panic!("id 0 is not allowed");
526        }
527    }
528
529    #[inline]
530    /// Add topic filter
531    pub fn topic_filter(mut self, filter: ByteString) -> Self {
532        self.packet.topic_filters.push(filter);
533        self
534    }
535
536    #[inline]
537    /// Add user property
538    pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
539        self.packet.user_properties.push((key, value));
540        self
541    }
542
543    #[inline]
544    /// Get size of the unsubscribe packet
545    pub fn size(&self) -> u32 {
546        self.packet.encoded_size(u32::MAX) as u32
547    }
548
549    /// Send unsubscribe packet
550    pub async fn send(self) -> Result<codec::UnsubscribeAck, SendPacketError> {
551        let shared = self.shared;
552        let mut packet = self.packet;
553
554        if !shared.is_closed() {
555            // handle client receive maximum
556            if let Some(rx) = shared.wait_readiness() {
557                if rx.await.is_err() {
558                    return Err(SendPacketError::Disconnected);
559                }
560            }
561            // allocate packet id
562            packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
563
564            // send unsubscribe to client
565            log::trace!("Sending unsubscribe packet {:#?}", packet);
566
567            let rx = shared.wait_response(packet.packet_id, AckType::Unsubscribe)?;
568            match shared.encode_packet(codec::Packet::Unsubscribe(packet)) {
569                Ok(_) => {
570                    // wait ack from peer
571                    rx.await
572                        .map_err(|_| SendPacketError::Disconnected)
573                        .map(|pkt| pkt.unsubscribe())
574                }
575                Err(err) => Err(SendPacketError::Encode(err)),
576            }
577        } else {
578            Err(SendPacketError::Disconnected)
579        }
580    }
581}
582
583pub struct StreamingPublishBuilder {
584    shared: Rc<MqttShared>,
585    packet: codec::Publish,
586    payload: Option<Bytes>,
587    size: u32,
588    tx: Option<pool::Sender<()>>,
589}
590
591impl StreamingPublishBuilder {
592    fn notify_payload_streamer(&mut self) -> Result<(), SendPacketError> {
593        if let Some(tx) = self.tx.take() {
594            tx.send(()).map_err(|_| SendPacketError::StreamingCancelled)
595        } else {
596            Ok(())
597        }
598    }
599
600    /// Send publish packet with QoS 0
601    pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
602        if !self.shared.is_closed() {
603            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
604            self.notify_payload_streamer()?;
605
606            self.packet.qos = QoS::AtMostOnce;
607            self.shared
608                .encode_publish(self.packet, self.payload)
609                .map_err(SendPacketError::Encode)
610                .map(|_| ())
611        } else {
612            log::error!("Mqtt sink is disconnected");
613            Err(SendPacketError::Disconnected)
614        }
615    }
616
617    /// Send publish packet with QoS 1
618    pub fn send_at_least_once(
619        mut self,
620    ) -> impl Future<Output = Result<codec::PublishAck, SendPacketError>> {
621        if !self.shared.is_closed() {
622            self.packet.qos = QoS::AtLeastOnce;
623
624            // handle client receive maximum
625            if let Some(rx) = self.shared.wait_readiness() {
626                Either::Left(Either::Left(async move {
627                    if rx.await.is_err() {
628                        return Err(SendPacketError::Disconnected);
629                    }
630                    self.send_at_least_once_inner().await
631                }))
632            } else {
633                Either::Left(Either::Right(self.send_at_least_once_inner()))
634            }
635        } else {
636            Either::Right(Ready::Err(SendPacketError::Disconnected))
637        }
638    }
639
640    /// Non-blocking send publish packet with QoS 1
641    ///
642    /// Panics if sink is not ready or publish ack callback is not set
643    pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
644        if !self.shared.is_closed() {
645            // check readiness
646            if !self.shared.is_ready() {
647                panic!("Mqtt sink is not ready");
648            }
649            self.packet.qos = codec::QoS::AtLeastOnce;
650            let tx = self.tx.take().unwrap();
651            let idx = self.shared.set_publish_id(&mut self.packet);
652
653            if tx.is_canceled() {
654                Err(SendPacketError::StreamingCancelled)
655            } else {
656                log::trace!("Publish (QoS1) to {:#?}", self.packet);
657                let _ = tx.send(());
658                self.shared.wait_publish_response_no_block(
659                    idx,
660                    AckType::Publish,
661                    self.packet,
662                    self.payload,
663                )
664            }
665        } else {
666            Err(SendPacketError::Disconnected)
667        }
668    }
669
670    async fn send_at_least_once_inner(mut self) -> Result<codec::PublishAck, SendPacketError> {
671        // packet id
672        let idx = self.shared.set_publish_id(&mut self.packet);
673
674        // send publish to client
675        log::trace!("Publish (QoS1) to {:#?}", self.packet);
676
677        let tx = self.tx.take().unwrap();
678        if tx.is_canceled() {
679            Err(SendPacketError::StreamingCancelled)
680        } else {
681            let rx = self.shared.wait_publish_response(
682                idx,
683                AckType::Publish,
684                self.packet,
685                self.payload,
686            );
687            let _ = tx.send(());
688
689            rx?.await.map(|pkt| pkt.publish()).map_err(|_| SendPacketError::Disconnected)
690        }
691    }
692}
693
694pub struct StreamingPayload {
695    shared: Rc<MqttShared>,
696    rx: Cell<Option<pool::Receiver<()>>>,
697    inprocess: Cell<bool>,
698}
699
700impl StreamingPayload {
701    fn drop(&mut self) {
702        if self.inprocess.get() {
703            if self.shared.is_streaming() {
704                self.shared.streaming_dropped();
705            }
706        }
707    }
708}
709
710impl StreamingPayload {
711    /// Send payload chunk
712    pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
713        if let Some(rx) = self.rx.take() {
714            if rx.await.is_err() {
715                return Err(SendPacketError::StreamingCancelled);
716            }
717            log::trace!("Publish is encoded, ready to process payload");
718            self.inprocess.set(true);
719        }
720
721        if !self.inprocess.get() {
722            Err(EncodeError::UnexpectedPayload.into())
723        } else {
724            log::trace!("Sending payload chunk: {:?}", chunk.len());
725            self.shared.want_payload_stream().await?;
726
727            if !self.shared.encode_publish_payload(chunk)? {
728                self.inprocess.set(false);
729            }
730            Ok(())
731        }
732    }
733}