ntex_mqtt/v3/
sink.rs

1use std::{cell::Cell, fmt, future::Future, future::ready, num::NonZeroU16, rc::Rc};
2
3use ntex_bytes::{ByteString, Bytes};
4use ntex_util::{channel::pool, future::Either, future::Ready};
5
6use crate::v3::shared::{AckType, MqttShared};
7use crate::v3::{codec, error::SendPacketError};
8use crate::{error::EncodeError, types::QoS};
9
10pub struct MqttSink(Rc<MqttShared>);
11
12impl Clone for MqttSink {
13    fn clone(&self) -> Self {
14        MqttSink(self.0.clone())
15    }
16}
17
18impl MqttSink {
19    pub(crate) fn new(state: Rc<MqttShared>) -> Self {
20        MqttSink(state)
21    }
22
23    pub(super) fn shared(&self) -> Rc<MqttShared> {
24        self.0.clone()
25    }
26
27    #[inline]
28    /// Check if io stream is open
29    pub fn is_open(&self) -> bool {
30        !self.0.is_closed()
31    }
32
33    #[inline]
34    /// Check if sink is ready
35    pub fn is_ready(&self) -> bool {
36        if self.0.is_closed() {
37            false
38        } else {
39            self.0.is_ready()
40        }
41    }
42
43    #[inline]
44    /// Get client receive credit
45    pub fn credit(&self) -> usize {
46        self.0.credit()
47    }
48
49    /// Get notification when packet could be send to the peer.
50    ///
51    /// Result indicates if connection is alive
52    pub fn ready(&self) -> impl Future<Output = bool> {
53        if !self.0.is_closed() {
54            self.0
55                .wait_readiness()
56                .map(|rx| Either::Right(async move { rx.await.is_ok() }))
57                .unwrap_or_else(|| Either::Left(ready(true)))
58        } else {
59            Either::Left(ready(false))
60        }
61    }
62
63    #[inline]
64    /// Close mqtt connection
65    pub fn close(&self) {
66        self.0.close();
67    }
68
69    #[inline]
70    /// Force close mqtt connection. mqtt dispatcher does not wait for uncompleted
71    /// responses, but it flushes buffers.
72    pub fn force_close(&self) {
73        self.0.force_close();
74    }
75
76    #[inline]
77    /// Send ping
78    pub(super) fn ping(&self) -> bool {
79        self.0.encode_packet(codec::Packet::PingRequest).is_ok()
80    }
81
82    #[inline]
83    /// Create publish message builder
84    pub fn publish<U>(&self, topic: U) -> PublishBuilder
85    where
86        ByteString: From<U>,
87    {
88        self.publish_pkt(codec::Publish {
89            dup: false,
90            retain: false,
91            topic: topic.into(),
92            qos: codec::QoS::AtMostOnce,
93            packet_id: None,
94            payload_size: 0,
95        })
96    }
97
98    #[inline]
99    /// Create publish builder with publish packet
100    pub fn publish_pkt(&self, packet: codec::Publish) -> PublishBuilder {
101        PublishBuilder { packet, shared: self.0.clone() }
102    }
103
104    /// Set publish ack callback
105    ///
106    /// Use non-blocking send, PublishBuilder::send_at_least_once_no_block()
107    /// First argument is packet id, second argument is "disconnected" state
108    pub fn publish_ack_cb<F>(&self, f: F)
109    where
110        F: Fn(NonZeroU16, bool) + 'static,
111    {
112        self.0.set_publish_ack(Box::new(f));
113    }
114
115    #[inline]
116    /// Create subscribe packet builder
117    ///
118    /// panics if id is 0
119    pub fn subscribe(&self) -> SubscribeBuilder {
120        SubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
121    }
122
123    #[inline]
124    /// Create unsubscribe packet builder
125    pub fn unsubscribe(&self) -> UnsubscribeBuilder {
126        UnsubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
127    }
128}
129
130impl fmt::Debug for MqttSink {
131    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
132        fmt.debug_struct("MqttSink").finish()
133    }
134}
135
136pub struct PublishBuilder {
137    packet: codec::Publish,
138    shared: Rc<MqttShared>,
139}
140
141impl PublishBuilder {
142    #[inline]
143    /// Set packet id.
144    ///
145    /// Note: if packet id is not set, it gets generated automatically.
146    /// Packet id management should not be mixed, it should be auto-generated
147    /// or set by user. Otherwise collisions could occure.
148    ///
149    /// panics if id is 0
150    pub fn packet_id(mut self, id: u16) -> Self {
151        let id = NonZeroU16::new(id).expect("id 0 is not allowed");
152        self.packet.packet_id = Some(id);
153        self
154    }
155
156    #[inline]
157    /// This might be re-delivery of an earlier attempt to send the Packet.
158    pub fn dup(mut self, val: bool) -> Self {
159        self.packet.dup = val;
160        self
161    }
162
163    #[inline]
164    /// Set retain flag
165    pub fn retain(mut self) -> Self {
166        self.packet.retain = true;
167        self
168    }
169
170    #[inline]
171    /// Get size of the publish packet
172    pub fn size(&self, payload_size: usize) -> u32 {
173        (codec::encode::get_encoded_publish_size(&self.packet) + payload_size) as u32
174    }
175
176    #[inline]
177    /// Send publish packet with QoS 0
178    pub fn send_at_most_once(mut self, payload: Bytes) -> Result<(), SendPacketError> {
179        if !self.shared.is_closed() {
180            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
181            self.packet.qos = codec::QoS::AtMostOnce;
182            self.packet.payload_size = payload.len() as u32;
183            self.shared
184                .encode_publish(self.packet, Some(payload))
185                .map_err(SendPacketError::Encode)
186                .map(|_| ())
187        } else {
188            log::error!("Mqtt sink is disconnected");
189            Err(SendPacketError::Disconnected)
190        }
191    }
192
193    /// Send publish packet with QoS 0
194    pub fn stream_at_most_once(
195        mut self,
196        size: u32,
197    ) -> Result<StreamingPayload, SendPacketError> {
198        if !self.shared.is_closed() {
199            log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
200
201            let stream = StreamingPayload {
202                rx: Cell::new(None),
203                shared: self.shared.clone(),
204                inprocess: Cell::new(true),
205            };
206
207            self.packet.qos = QoS::AtMostOnce;
208            self.packet.payload_size = size;
209            self.shared
210                .encode_publish(self.packet, None)
211                .map_err(SendPacketError::Encode)
212                .map(|_| stream)
213        } else {
214            log::error!("Mqtt sink is disconnected");
215            Err(SendPacketError::Disconnected)
216        }
217    }
218
219    /// Send publish packet with QoS 1
220    pub fn send_at_least_once(
221        mut self,
222        payload: Bytes,
223    ) -> impl Future<Output = Result<(), SendPacketError>> {
224        if !self.shared.is_closed() {
225            self.packet.qos = codec::QoS::AtLeastOnce;
226            self.packet.payload_size = payload.len() as u32;
227
228            // handle client receive maximum
229            if let Some(rx) = self.shared.wait_readiness() {
230                Either::Left(Either::Left(async move {
231                    if rx.await.is_err() {
232                        return Err(SendPacketError::Disconnected);
233                    }
234                    self.send_at_least_once_inner(payload).await
235                }))
236            } else {
237                Either::Left(Either::Right(self.send_at_least_once_inner(payload)))
238            }
239        } else {
240            Either::Right(Ready::Err(SendPacketError::Disconnected))
241        }
242    }
243
244    /// Non-blocking send publish packet with QoS 1
245    ///
246    /// Panics if sink is not ready or publish ack callback is not set
247    pub fn send_at_least_once_no_block(
248        mut self,
249        payload: Bytes,
250    ) -> Result<(), SendPacketError> {
251        if !self.shared.is_closed() {
252            // check readiness
253            if !self.shared.is_ready() {
254                panic!("Mqtt sink is not ready");
255            }
256            self.packet.qos = codec::QoS::AtLeastOnce;
257            self.packet.payload_size = payload.len() as u32;
258            let idx = self.shared.set_publish_id(&mut self.packet);
259
260            log::trace!("Publish (QoS1) to {:#?}", self.packet);
261
262            self.shared.wait_publish_response_no_block(
263                idx,
264                AckType::Publish,
265                self.packet,
266                Some(payload),
267            )
268        } else {
269            Err(SendPacketError::Disconnected)
270        }
271    }
272
273    fn send_at_least_once_inner(
274        mut self,
275        payload: Bytes,
276    ) -> impl Future<Output = Result<(), SendPacketError>> {
277        let idx = self.shared.set_publish_id(&mut self.packet);
278        log::trace!("Publish (QoS1) to {:#?}", self.packet);
279
280        let rx = self.shared.wait_publish_response(
281            idx,
282            AckType::Publish,
283            self.packet,
284            Some(payload),
285        );
286        async move { rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected) }
287    }
288
289    /// Send publish packet with QoS 2
290    pub fn send_exactly_once(
291        mut self,
292        payload: Bytes,
293    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
294        if !self.shared.is_closed() {
295            self.packet.qos = codec::QoS::ExactlyOnce;
296            self.packet.payload_size = payload.len() as u32;
297
298            // handle client receive maximum
299            if let Some(rx) = self.shared.wait_readiness() {
300                Either::Left(Either::Left(async move {
301                    if rx.await.is_err() {
302                        return Err(SendPacketError::Disconnected);
303                    }
304                    self.send_exactly_once_inner(payload).await
305                }))
306            } else {
307                Either::Left(Either::Right(self.send_exactly_once_inner(payload)))
308            }
309        } else {
310            Either::Right(Ready::Err(SendPacketError::Disconnected))
311        }
312    }
313
314    fn send_exactly_once_inner(
315        mut self,
316        payload: Bytes,
317    ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
318        let idx = self.shared.set_publish_id(&mut self.packet);
319        log::trace!("Publish (QoS2) to {:#?}", self.packet);
320
321        let rx = self.shared.wait_publish_response(
322            idx,
323            AckType::Receive,
324            self.packet,
325            Some(payload),
326        );
327        async move {
328            rx?.await
329                .map(move |_| PublishReceived { packet_id: Some(idx), shared: self.shared })
330                .map_err(|_| SendPacketError::Disconnected)
331        }
332    }
333
334    /// Send publish packet with QoS 1
335    pub fn stream_at_least_once(
336        mut self,
337        size: u32,
338    ) -> (impl Future<Output = Result<(), SendPacketError>>, StreamingPayload) {
339        let (tx, rx) = self.shared.pool.waiters.channel();
340        let stream = StreamingPayload {
341            rx: Cell::new(Some(rx)),
342            shared: self.shared.clone(),
343            inprocess: Cell::new(false),
344        };
345
346        if !self.shared.is_closed() {
347            self.packet.qos = QoS::AtLeastOnce;
348            self.packet.payload_size = size;
349
350            // handle client receive maximum
351            let fut = if let Some(rx) = self.shared.wait_readiness() {
352                Either::Left(Either::Left(async move {
353                    if rx.await.is_err() {
354                        return Err(SendPacketError::Disconnected);
355                    }
356                    self.stream_at_least_once_inner(tx).await
357                }))
358            } else {
359                Either::Left(Either::Right(self.stream_at_least_once_inner(tx)))
360            };
361            (fut, stream)
362        } else {
363            (Either::Right(Ready::Err(SendPacketError::Disconnected)), stream)
364        }
365    }
366
367    async fn stream_at_least_once_inner(
368        mut self,
369        tx: pool::Sender<()>,
370    ) -> Result<(), SendPacketError> {
371        // packet id
372        let idx = self.shared.set_publish_id(&mut self.packet);
373
374        // send publish to client
375        log::trace!("Publish (QoS1) to {:#?}", self.packet);
376
377        if tx.is_canceled() {
378            Err(SendPacketError::StreamingCancelled)
379        } else {
380            let rx =
381                self.shared.wait_publish_response(idx, AckType::Publish, self.packet, None);
382            let _ = tx.send(());
383
384            rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
385        }
386    }
387}
388
389/// Publish released for QoS2
390pub struct PublishReceived {
391    packet_id: Option<NonZeroU16>,
392    shared: Rc<MqttShared>,
393}
394
395impl PublishReceived {
396    /// Release publish
397    pub async fn release(mut self) -> Result<(), SendPacketError> {
398        let rx = self.shared.release_publish(self.packet_id.take().unwrap())?;
399
400        rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
401    }
402}
403
404impl Drop for PublishReceived {
405    fn drop(&mut self) {
406        if let Some(id) = self.packet_id.take() {
407            let _ = self.shared.release_publish(id);
408        }
409    }
410}
411
412/// Subscribe packet builder
413pub struct SubscribeBuilder {
414    id: Option<NonZeroU16>,
415    shared: Rc<MqttShared>,
416    topic_filters: Vec<(ByteString, codec::QoS)>,
417}
418
419impl SubscribeBuilder {
420    #[inline]
421    /// Set packet id.
422    ///
423    /// panics if id is 0
424    pub fn packet_id(mut self, id: u16) -> Self {
425        if let Some(id) = NonZeroU16::new(id) {
426            self.id = Some(id);
427            self
428        } else {
429            panic!("id 0 is not allowed");
430        }
431    }
432
433    #[inline]
434    /// Add topic filter
435    pub fn topic_filter(mut self, filter: ByteString, qos: codec::QoS) -> Self {
436        self.topic_filters.push((filter, qos));
437        self
438    }
439
440    #[inline]
441    /// Get size of the subscribe packet
442    pub fn size(&self) -> u32 {
443        codec::encode::get_encoded_subscribe_size(&self.topic_filters) as u32
444    }
445
446    /// Send subscribe packet
447    pub async fn send(self) -> Result<Vec<codec::SubscribeReturnCode>, SendPacketError> {
448        if !self.shared.is_closed() {
449            // handle client receive maximum
450            if let Some(rx) = self.shared.wait_readiness() {
451                if rx.await.is_err() {
452                    return Err(SendPacketError::Disconnected);
453                }
454            }
455            let idx = self.id.unwrap_or_else(|| self.shared.next_id());
456            let rx = self.shared.wait_response(idx, AckType::Subscribe)?;
457
458            // send subscribe to client
459            log::trace!(
460                "Sending subscribe packet id: {} filters:{:?}",
461                idx,
462                self.topic_filters
463            );
464
465            match self.shared.encode_packet(codec::Packet::Subscribe {
466                packet_id: idx,
467                topic_filters: self.topic_filters,
468            }) {
469                Ok(_) => {
470                    // wait ack from peer
471                    rx.await
472                        .map_err(|_| SendPacketError::Disconnected)
473                        .map(|pkt| pkt.subscribe())
474                }
475                Err(err) => Err(SendPacketError::Encode(err)),
476            }
477        } else {
478            Err(SendPacketError::Disconnected)
479        }
480    }
481}
482
483/// Unsubscribe packet builder
484pub struct UnsubscribeBuilder {
485    id: Option<NonZeroU16>,
486    shared: Rc<MqttShared>,
487    topic_filters: Vec<ByteString>,
488}
489
490impl UnsubscribeBuilder {
491    #[inline]
492    /// Set packet id.
493    ///
494    /// panics if id is 0
495    pub fn packet_id(mut self, id: u16) -> Self {
496        if let Some(id) = NonZeroU16::new(id) {
497            self.id = Some(id);
498            self
499        } else {
500            panic!("id 0 is not allowed");
501        }
502    }
503
504    #[inline]
505    /// Add topic filter
506    pub fn topic_filter(mut self, filter: ByteString) -> Self {
507        self.topic_filters.push(filter);
508        self
509    }
510
511    #[inline]
512    /// Get size of the unsubscribe packet
513    pub fn size(&self) -> u32 {
514        codec::encode::get_encoded_unsubscribe_size(&self.topic_filters) as u32
515    }
516
517    /// Send unsubscribe packet
518    pub async fn send(self) -> Result<(), SendPacketError> {
519        let shared = self.shared;
520        let filters = self.topic_filters;
521
522        if !shared.is_closed() {
523            // handle client receive maximum
524            if let Some(rx) = shared.wait_readiness() {
525                if rx.await.is_err() {
526                    return Err(SendPacketError::Disconnected);
527                }
528            }
529            // allocate packet id
530            let idx = self.id.unwrap_or_else(|| shared.next_id());
531            let rx = shared.wait_response(idx, AckType::Unsubscribe)?;
532
533            // send subscribe to client
534            log::trace!("Sending unsubscribe packet id: {} filters:{:?}", idx, filters);
535
536            match shared.encode_packet(codec::Packet::Unsubscribe {
537                packet_id: idx,
538                topic_filters: filters,
539            }) {
540                Ok(_) => {
541                    // wait ack from peer
542                    rx.await.map_err(|_| SendPacketError::Disconnected).map(|_| ())
543                }
544                Err(err) => Err(SendPacketError::Encode(err)),
545            }
546        } else {
547            Err(SendPacketError::Disconnected)
548        }
549    }
550}
551
552pub struct StreamingPayload {
553    shared: Rc<MqttShared>,
554    rx: Cell<Option<pool::Receiver<()>>>,
555    inprocess: Cell<bool>,
556}
557
558impl Drop for StreamingPayload {
559    fn drop(&mut self) {
560        if self.inprocess.get() && self.shared.is_streaming() {
561            self.shared.streaming_dropped();
562        }
563    }
564}
565
566impl StreamingPayload {
567    /// Send payload chunk
568    pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
569        if let Some(rx) = self.rx.take() {
570            if rx.await.is_err() {
571                return Err(SendPacketError::StreamingCancelled);
572            }
573            log::trace!("Publish is encoded, ready to process payload");
574            self.inprocess.set(true);
575        }
576
577        if !self.inprocess.get() {
578            Err(EncodeError::UnexpectedPayload.into())
579        } else {
580            log::trace!("Sending payload chunk: {:?}", chunk.len());
581            self.shared.want_payload_stream().await?;
582
583            if !self.shared.encode_publish_payload(chunk)? {
584                self.inprocess.set(false);
585            }
586            Ok(())
587        }
588    }
589}