1use std::num::{NonZeroU16, NonZeroU32};
2use std::{cell::Cell, fmt, future::ready, future::Future, rc::Rc};
3
4use ntex_bytes::{ByteString, Bytes};
5use ntex_util::{channel::pool, future::Either, future::Ready};
6
7use super::codec::{self, EncodeLtd};
8use super::shared::{AckType, MqttShared};
9use crate::{error::EncodeError, error::SendPacketError, types::QoS};
10
11pub struct MqttSink(Rc<MqttShared>);
12
13impl Clone for MqttSink {
14 fn clone(&self) -> Self {
15 MqttSink(self.0.clone())
16 }
17}
18
19impl MqttSink {
20 pub(super) fn new(state: Rc<MqttShared>) -> Self {
21 MqttSink(state)
22 }
23
24 pub(super) fn shared(&self) -> Rc<MqttShared> {
25 self.0.clone()
26 }
27
28 #[inline]
29 pub fn is_open(&self) -> bool {
31 !self.0.is_closed()
32 }
33
34 #[inline]
35 pub fn is_ready(&self) -> bool {
37 if self.0.is_closed() {
38 false
39 } else {
40 self.0.is_ready()
41 }
42 }
43
44 #[inline]
45 pub fn credit(&self) -> usize {
47 self.0.credit()
48 }
49
50 pub fn ready(&self) -> impl Future<Output = bool> {
54 if !self.0.is_closed() {
55 self.0
56 .wait_readiness()
57 .map(|rx| Either::Right(async move { rx.await.is_ok() }))
58 .unwrap_or_else(|| Either::Left(ready(true)))
59 } else {
60 Either::Left(ready(false))
61 }
62 }
63
64 #[inline]
65 pub fn force_close(&self) {
68 self.0.force_close();
69 }
70
71 #[inline]
72 pub fn close(&self) {
74 self.0.close(codec::Disconnect::default());
75 }
76
77 #[inline]
78 pub fn close_with_reason(&self, pkt: codec::Disconnect) {
80 self.0.close(pkt);
81 }
82
83 pub(super) fn ping(&self) -> bool {
85 self.0.encode_packet(codec::Packet::PingRequest).is_ok()
86 }
87
88 #[inline]
89 pub fn publish<U>(&self, topic: U, payload: Bytes) -> PublishBuilder
91 where
92 ByteString: From<U>,
93 {
94 self.publish_pkt(
95 codec::Publish {
96 dup: false,
97 retain: false,
98 topic: topic.into(),
99 qos: QoS::AtMostOnce,
100 packet_id: None,
101 payload_size: 0,
102 properties: codec::PublishProperties::default(),
103 },
104 payload,
105 )
106 }
107
108 #[inline]
109 pub fn publish_pkt(&self, packet: codec::Publish, payload: Bytes) -> PublishBuilder {
111 PublishBuilder::new(self.0.clone(), packet, payload)
112 }
113
114 pub fn publish_ack_cb<F>(&self, f: F)
119 where
120 F: Fn(codec::PublishAck, bool) + 'static,
121 {
122 self.0.set_publish_ack(Box::new(f));
123 }
124
125 #[inline]
126 pub fn subscribe(&self, id: Option<NonZeroU32>) -> SubscribeBuilder {
128 SubscribeBuilder {
129 id: None,
130 packet: codec::Subscribe {
131 id,
132 packet_id: NonZeroU16::new(1).unwrap(),
133 user_properties: Vec::new(),
134 topic_filters: Vec::new(),
135 },
136 shared: self.0.clone(),
137 }
138 }
139
140 #[inline]
141 pub fn unsubscribe(&self) -> UnsubscribeBuilder {
143 UnsubscribeBuilder {
144 id: None,
145 packet: codec::Unsubscribe {
146 packet_id: NonZeroU16::new(1).unwrap(),
147 user_properties: Vec::new(),
148 topic_filters: Vec::new(),
149 },
150 shared: self.0.clone(),
151 }
152 }
153}
154
155impl fmt::Debug for MqttSink {
156 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
157 fmt.debug_struct("MqttSink").finish()
158 }
159}
160
161pub struct PublishBuilder {
162 shared: Rc<MqttShared>,
163 packet: codec::Publish,
164 payload: Bytes,
165}
166
167impl PublishBuilder {
168 fn new(shared: Rc<MqttShared>, mut packet: codec::Publish, payload: Bytes) -> Self {
169 packet.payload_size = payload.len() as u32;
170 Self { shared, packet, payload }
171 }
172
173 #[inline]
174 pub fn packet_id(mut self, id: u16) -> Self {
182 let id = NonZeroU16::new(id).expect("id 0 is not allowed");
183 self.packet.packet_id = Some(id);
184 self
185 }
186
187 #[inline]
188 pub fn dup(mut self, val: bool) -> Self {
190 self.packet.dup = val;
191 self
192 }
193
194 #[inline]
195 pub fn retain(mut self, val: bool) -> Self {
197 self.packet.retain = val;
198 self
199 }
200
201 #[inline]
202 pub fn properties<F>(mut self, f: F) -> Self
204 where
205 F: FnOnce(&mut codec::PublishProperties),
206 {
207 f(&mut self.packet.properties);
208 self
209 }
210
211 #[inline]
212 pub fn set_properties<F>(&mut self, f: F)
214 where
215 F: FnOnce(&mut codec::PublishProperties),
216 {
217 f(&mut self.packet.properties);
218 }
219
220 #[inline]
221 pub fn size(&self) -> u32 {
223 self.packet.encoded_size(u32::MAX) as u32
224 }
225
226 pub fn streaming(mut self, size: u32) -> (StreamingPublishBuilder, StreamingPayload) {
228 self.packet.payload_size = size;
229 let payload = if self.payload.is_empty() { None } else { Some(self.payload) };
230
231 let (tx, rx) = self.shared.pool.waiters.channel();
232 (
233 StreamingPublishBuilder {
234 size,
235 payload,
236 tx: Some(tx),
237 shared: self.shared.clone(),
238 packet: self.packet,
239 },
240 StreamingPayload {
241 rx: Cell::new(Some(rx)),
242 shared: self.shared.clone(),
243 inprocess: Cell::new(false),
244 },
245 )
246 }
247
248 #[inline]
249 pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
251 if !self.shared.is_closed() {
252 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
253 self.packet.qos = QoS::AtMostOnce;
254 self.shared
255 .encode_publish(self.packet, Some(self.payload))
256 .map_err(SendPacketError::Encode)
257 .map(|_| ())
258 } else {
259 log::error!("Mqtt sink is disconnected");
260 Err(SendPacketError::Disconnected)
261 }
262 }
263
264 pub fn send_at_least_once(
266 mut self,
267 ) -> impl Future<Output = Result<codec::PublishAck, SendPacketError>> {
268 if !self.shared.is_closed() {
269 self.packet.qos = QoS::AtLeastOnce;
270
271 if let Some(rx) = self.shared.wait_readiness() {
273 Either::Left(Either::Left(async move {
274 if rx.await.is_err() {
275 return Err(SendPacketError::Disconnected);
276 }
277 self.send_at_least_once_inner().await
278 }))
279 } else {
280 Either::Left(Either::Right(self.send_at_least_once_inner()))
281 }
282 } else {
283 Either::Right(Ready::Err(SendPacketError::Disconnected))
284 }
285 }
286
287 pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
291 if !self.shared.is_closed() {
292 if !self.shared.is_ready() {
294 panic!("Mqtt sink is not ready");
295 }
296 self.packet.qos = codec::QoS::AtLeastOnce;
297 let idx = self.shared.set_publish_id(&mut self.packet);
298
299 log::trace!("Publish (QoS1) to {:#?}", self.packet);
300 self.shared.wait_publish_response_no_block(
301 idx,
302 AckType::Publish,
303 self.packet,
304 Some(self.payload),
305 )
306 } else {
307 Err(SendPacketError::Disconnected)
308 }
309 }
310
311 async fn send_at_least_once_inner(mut self) -> Result<codec::PublishAck, SendPacketError> {
312 let idx = self.shared.set_publish_id(&mut self.packet);
314
315 log::trace!("Publish (QoS1) to {:#?}", self.packet);
317 self.shared
318 .wait_publish_response(idx, AckType::Publish, self.packet, Some(self.payload))?
319 .await
320 .map(|pkt| pkt.publish())
321 .map_err(|_| SendPacketError::Disconnected)
322 }
323
324 pub fn send_exactly_once(
326 mut self,
327 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
328 if !self.shared.is_closed() {
329 self.packet.qos = codec::QoS::ExactlyOnce;
330
331 if let Some(rx) = self.shared.wait_readiness() {
333 Either::Left(Either::Left(async move {
334 if rx.await.is_err() {
335 return Err(SendPacketError::Disconnected);
336 }
337 self.send_exactly_once_inner().await
338 }))
339 } else {
340 Either::Left(Either::Right(self.send_exactly_once_inner()))
341 }
342 } else {
343 Either::Right(Ready::Err(SendPacketError::Disconnected))
344 }
345 }
346
347 fn send_exactly_once_inner(
348 mut self,
349 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
350 let shared = self.shared.clone();
351 let idx = shared.set_publish_id(&mut self.packet);
352 log::trace!("Publish (QoS2) to {:#?}", self.packet);
353
354 let rx = shared.wait_publish_response(
355 idx,
356 AckType::Receive,
357 self.packet,
358 Some(self.payload),
359 );
360 async move {
361 rx?.await
362 .map(move |ack| PublishReceived::new(ack.receive(), shared))
363 .map_err(|_| SendPacketError::Disconnected)
364 }
365 }
366}
367
368pub struct PublishReceived {
370 ack: codec::PublishAck,
371 result: Option<codec::PublishAck2>,
372 shared: Rc<MqttShared>,
373}
374
375impl PublishReceived {
376 fn new(ack: codec::PublishAck, shared: Rc<MqttShared>) -> Self {
377 let packet_id = ack.packet_id;
378 Self {
379 ack,
380 shared,
381 result: Some(codec::PublishAck2 {
382 packet_id,
383 reason_code: codec::PublishAck2Reason::Success,
384 properties: codec::UserProperties::default(),
385 reason_string: None,
386 }),
387 }
388 }
389
390 pub fn packet(&self) -> &codec::PublishAck {
392 &self.ack
393 }
394
395 #[inline]
397 pub fn properties<F>(mut self, f: F) -> Self
398 where
399 F: FnOnce(&mut codec::UserProperties),
400 {
401 f(&mut self.result.as_mut().unwrap().properties);
402 self
403 }
404
405 #[inline]
407 pub fn reason(mut self, reason: ByteString) -> Self {
408 self.result.as_mut().unwrap().reason_string = Some(reason);
409 self
410 }
411
412 pub async fn release(mut self) -> Result<(), SendPacketError> {
414 let rx = self.shared.release_publish(self.result.take().unwrap())?;
415
416 rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
417 }
418}
419
420impl Drop for PublishReceived {
421 fn drop(&mut self) {
422 if let Some(ack) = self.result.take() {
423 self.shared.release_publish(ack);
424 }
425 }
426}
427
428pub struct SubscribeBuilder {
430 id: Option<NonZeroU16>,
431 packet: codec::Subscribe,
432 shared: Rc<MqttShared>,
433}
434
435impl SubscribeBuilder {
436 #[inline]
437 pub fn packet_id(mut self, id: u16) -> Self {
441 if let Some(id) = NonZeroU16::new(id) {
442 self.id = Some(id);
443 self
444 } else {
445 panic!("id 0 is not allowed");
446 }
447 }
448
449 #[inline]
450 pub fn topic_filter(
452 mut self,
453 filter: ByteString,
454 opts: codec::SubscriptionOptions,
455 ) -> Self {
456 self.packet.topic_filters.push((filter, opts));
457 self
458 }
459
460 #[inline]
461 pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
463 self.packet.user_properties.push((key, value));
464 self
465 }
466
467 #[inline]
468 pub fn size(&self) -> u32 {
470 self.packet.encoded_size(u32::MAX) as u32
471 }
472
473 pub async fn send(self) -> Result<codec::SubscribeAck, SendPacketError> {
475 let shared = self.shared;
476 let mut packet = self.packet;
477
478 if !shared.is_closed() {
479 if let Some(rx) = shared.wait_readiness() {
481 if rx.await.is_err() {
482 return Err(SendPacketError::Disconnected);
483 }
484 }
485
486 packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
488
489 log::trace!("Sending subscribe packet {:#?}", packet);
491
492 let rx = shared.wait_response(packet.packet_id, AckType::Subscribe)?;
493 match shared.encode_packet(codec::Packet::Subscribe(packet)) {
494 Ok(_) => {
495 rx.await
497 .map_err(|_| SendPacketError::Disconnected)
498 .map(|pkt| pkt.subscribe())
499 }
500 Err(err) => Err(SendPacketError::Encode(err)),
501 }
502 } else {
503 Err(SendPacketError::Disconnected)
504 }
505 }
506}
507
508pub struct UnsubscribeBuilder {
510 id: Option<NonZeroU16>,
511 packet: codec::Unsubscribe,
512 shared: Rc<MqttShared>,
513}
514
515impl UnsubscribeBuilder {
516 #[inline]
517 pub fn packet_id(mut self, id: u16) -> Self {
521 if let Some(id) = NonZeroU16::new(id) {
522 self.id = Some(id);
523 self
524 } else {
525 panic!("id 0 is not allowed");
526 }
527 }
528
529 #[inline]
530 pub fn topic_filter(mut self, filter: ByteString) -> Self {
532 self.packet.topic_filters.push(filter);
533 self
534 }
535
536 #[inline]
537 pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
539 self.packet.user_properties.push((key, value));
540 self
541 }
542
543 #[inline]
544 pub fn size(&self) -> u32 {
546 self.packet.encoded_size(u32::MAX) as u32
547 }
548
549 pub async fn send(self) -> Result<codec::UnsubscribeAck, SendPacketError> {
551 let shared = self.shared;
552 let mut packet = self.packet;
553
554 if !shared.is_closed() {
555 if let Some(rx) = shared.wait_readiness() {
557 if rx.await.is_err() {
558 return Err(SendPacketError::Disconnected);
559 }
560 }
561 packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
563
564 log::trace!("Sending unsubscribe packet {:#?}", packet);
566
567 let rx = shared.wait_response(packet.packet_id, AckType::Unsubscribe)?;
568 match shared.encode_packet(codec::Packet::Unsubscribe(packet)) {
569 Ok(_) => {
570 rx.await
572 .map_err(|_| SendPacketError::Disconnected)
573 .map(|pkt| pkt.unsubscribe())
574 }
575 Err(err) => Err(SendPacketError::Encode(err)),
576 }
577 } else {
578 Err(SendPacketError::Disconnected)
579 }
580 }
581}
582
583pub struct StreamingPublishBuilder {
584 shared: Rc<MqttShared>,
585 packet: codec::Publish,
586 payload: Option<Bytes>,
587 size: u32,
588 tx: Option<pool::Sender<()>>,
589}
590
591impl StreamingPublishBuilder {
592 fn notify_payload_streamer(&mut self) -> Result<(), SendPacketError> {
593 if let Some(tx) = self.tx.take() {
594 tx.send(()).map_err(|_| SendPacketError::StreamingCancelled)
595 } else {
596 Ok(())
597 }
598 }
599
600 pub fn send_at_most_once(mut self) -> Result<(), SendPacketError> {
602 if !self.shared.is_closed() {
603 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
604 self.notify_payload_streamer()?;
605
606 self.packet.qos = QoS::AtMostOnce;
607 self.shared
608 .encode_publish(self.packet, self.payload)
609 .map_err(SendPacketError::Encode)
610 .map(|_| ())
611 } else {
612 log::error!("Mqtt sink is disconnected");
613 Err(SendPacketError::Disconnected)
614 }
615 }
616
617 pub fn send_at_least_once(
619 mut self,
620 ) -> impl Future<Output = Result<codec::PublishAck, SendPacketError>> {
621 if !self.shared.is_closed() {
622 self.packet.qos = QoS::AtLeastOnce;
623
624 if let Some(rx) = self.shared.wait_readiness() {
626 Either::Left(Either::Left(async move {
627 if rx.await.is_err() {
628 return Err(SendPacketError::Disconnected);
629 }
630 self.send_at_least_once_inner().await
631 }))
632 } else {
633 Either::Left(Either::Right(self.send_at_least_once_inner()))
634 }
635 } else {
636 Either::Right(Ready::Err(SendPacketError::Disconnected))
637 }
638 }
639
640 pub fn send_at_least_once_no_block(mut self) -> Result<(), SendPacketError> {
644 if !self.shared.is_closed() {
645 if !self.shared.is_ready() {
647 panic!("Mqtt sink is not ready");
648 }
649 self.packet.qos = codec::QoS::AtLeastOnce;
650 let tx = self.tx.take().unwrap();
651 let idx = self.shared.set_publish_id(&mut self.packet);
652
653 if tx.is_canceled() {
654 Err(SendPacketError::StreamingCancelled)
655 } else {
656 log::trace!("Publish (QoS1) to {:#?}", self.packet);
657 let _ = tx.send(());
658 self.shared.wait_publish_response_no_block(
659 idx,
660 AckType::Publish,
661 self.packet,
662 self.payload,
663 )
664 }
665 } else {
666 Err(SendPacketError::Disconnected)
667 }
668 }
669
670 async fn send_at_least_once_inner(mut self) -> Result<codec::PublishAck, SendPacketError> {
671 let idx = self.shared.set_publish_id(&mut self.packet);
673
674 log::trace!("Publish (QoS1) to {:#?}", self.packet);
676
677 let tx = self.tx.take().unwrap();
678 if tx.is_canceled() {
679 Err(SendPacketError::StreamingCancelled)
680 } else {
681 let rx = self.shared.wait_publish_response(
682 idx,
683 AckType::Publish,
684 self.packet,
685 self.payload,
686 );
687 let _ = tx.send(());
688
689 rx?.await.map(|pkt| pkt.publish()).map_err(|_| SendPacketError::Disconnected)
690 }
691 }
692}
693
694pub struct StreamingPayload {
695 shared: Rc<MqttShared>,
696 rx: Cell<Option<pool::Receiver<()>>>,
697 inprocess: Cell<bool>,
698}
699
700impl StreamingPayload {
701 fn drop(&mut self) {
702 if self.inprocess.get() {
703 if self.shared.is_streaming() {
704 self.shared.streaming_dropped();
705 }
706 }
707 }
708}
709
710impl StreamingPayload {
711 pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
713 if let Some(rx) = self.rx.take() {
714 if rx.await.is_err() {
715 return Err(SendPacketError::StreamingCancelled);
716 }
717 log::trace!("Publish is encoded, ready to process payload");
718 self.inprocess.set(true);
719 }
720
721 if !self.inprocess.get() {
722 Err(EncodeError::UnexpectedPayload.into())
723 } else {
724 log::trace!("Sending payload chunk: {:?}", chunk.len());
725 self.shared.want_payload_stream().await?;
726
727 if !self.shared.encode_publish_payload(chunk)? {
728 self.inprocess.set(false);
729 }
730 Ok(())
731 }
732 }
733}