1use std::num::{NonZeroU16, NonZeroU32};
2use std::{cell::Cell, fmt, future::Future, future::ready, rc::Rc};
3
4use ntex_bytes::{ByteString, Bytes};
5use ntex_util::{channel::pool, future::Either, future::Ready};
6
7use super::codec::{self, EncodeLtd};
8use super::shared::{AckType, MqttShared};
9use crate::{error::EncodeError, error::SendPacketError, types::QoS};
10
11pub struct MqttSink(Rc<MqttShared>);
12
13impl Clone for MqttSink {
14 fn clone(&self) -> Self {
15 MqttSink(self.0.clone())
16 }
17}
18
19impl MqttSink {
20 pub(super) fn new(state: Rc<MqttShared>) -> Self {
21 MqttSink(state)
22 }
23
24 pub(super) fn shared(&self) -> Rc<MqttShared> {
25 self.0.clone()
26 }
27
28 #[inline]
29 pub fn is_open(&self) -> bool {
31 !self.0.is_closed()
32 }
33
34 #[inline]
35 pub fn is_ready(&self) -> bool {
37 if self.0.is_closed() {
38 false
39 } else {
40 self.0.is_ready()
41 }
42 }
43
44 #[inline]
45 pub fn credit(&self) -> usize {
47 self.0.credit()
48 }
49
50 pub fn ready(&self) -> impl Future<Output = bool> {
54 if !self.0.is_closed() {
55 self.0
56 .wait_readiness()
57 .map(|rx| Either::Right(async move { rx.await.is_ok() }))
58 .unwrap_or_else(|| Either::Left(ready(true)))
59 } else {
60 Either::Left(ready(false))
61 }
62 }
63
64 #[inline]
65 pub fn force_close(&self) {
68 self.0.force_close();
69 }
70
71 #[inline]
72 pub fn close(&self) {
74 self.0.close(codec::Disconnect::default());
75 }
76
77 #[inline]
78 pub fn close_with_reason(&self, pkt: codec::Disconnect) {
80 self.0.close(pkt);
81 }
82
83 pub(super) fn ping(&self) -> bool {
85 self.0.encode_packet(codec::Packet::PingRequest).is_ok()
86 }
87
88 #[inline]
89 pub fn publish<U>(&self, topic: U) -> PublishBuilder
91 where
92 ByteString: From<U>,
93 {
94 self.publish_pkt(codec::Publish {
95 dup: false,
96 retain: false,
97 topic: topic.into(),
98 qos: QoS::AtMostOnce,
99 packet_id: None,
100 payload_size: 0,
101 properties: codec::PublishProperties::default(),
102 })
103 }
104
105 #[inline]
106 pub fn publish_pkt(&self, packet: codec::Publish) -> PublishBuilder {
108 PublishBuilder::new(self.0.clone(), packet)
109 }
110
111 pub fn publish_ack_cb<F>(&self, f: F)
116 where
117 F: Fn(codec::PublishAck, bool) + 'static,
118 {
119 self.0.set_publish_ack(Box::new(f));
120 }
121
122 #[inline]
123 pub fn subscribe(&self, id: Option<NonZeroU32>) -> SubscribeBuilder {
125 SubscribeBuilder {
126 id: None,
127 packet: codec::Subscribe {
128 id,
129 packet_id: NonZeroU16::new(1).unwrap(),
130 user_properties: Vec::new(),
131 topic_filters: Vec::new(),
132 },
133 shared: self.0.clone(),
134 }
135 }
136
137 #[inline]
138 pub fn unsubscribe(&self) -> UnsubscribeBuilder {
140 UnsubscribeBuilder {
141 id: None,
142 packet: codec::Unsubscribe {
143 packet_id: NonZeroU16::new(1).unwrap(),
144 user_properties: Vec::new(),
145 topic_filters: Vec::new(),
146 },
147 shared: self.0.clone(),
148 }
149 }
150}
151
152impl fmt::Debug for MqttSink {
153 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
154 fmt.debug_struct("MqttSink").finish()
155 }
156}
157
158pub struct PublishBuilder {
159 shared: Rc<MqttShared>,
160 packet: codec::Publish,
161}
162
163impl PublishBuilder {
164 fn new(shared: Rc<MqttShared>, packet: codec::Publish) -> Self {
165 Self { shared, packet }
166 }
167
168 #[inline]
169 pub fn packet_id(mut self, id: u16) -> Self {
177 let id = NonZeroU16::new(id).expect("id 0 is not allowed");
178 self.packet.packet_id = Some(id);
179 self
180 }
181
182 #[inline]
183 pub fn dup(mut self, val: bool) -> Self {
185 self.packet.dup = val;
186 self
187 }
188
189 #[inline]
190 pub fn retain(mut self, val: bool) -> Self {
192 self.packet.retain = val;
193 self
194 }
195
196 #[inline]
197 pub fn properties<F>(mut self, f: F) -> Self
199 where
200 F: FnOnce(&mut codec::PublishProperties),
201 {
202 f(&mut self.packet.properties);
203 self
204 }
205
206 #[inline]
207 pub fn set_properties<F>(&mut self, f: F)
209 where
210 F: FnOnce(&mut codec::PublishProperties),
211 {
212 f(&mut self.packet.properties);
213 }
214
215 #[inline]
216 pub fn size(&self, payload_size: usize) -> u32 {
218 (self.packet.encoded_size(u32::MAX) + payload_size) as u32
219 }
220
221 #[inline]
222 pub fn send_at_most_once(mut self, payload: Bytes) -> Result<(), SendPacketError> {
224 if !self.shared.is_closed() {
225 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
226 self.packet.qos = QoS::AtMostOnce;
227 self.packet.payload_size = payload.len() as u32;
228 self.shared
229 .encode_publish(self.packet, Some(payload))
230 .map_err(SendPacketError::Encode)
231 .map(|_| ())
232 } else {
233 log::error!("Mqtt sink is disconnected");
234 Err(SendPacketError::Disconnected)
235 }
236 }
237
238 pub fn stream_at_most_once(
240 mut self,
241 size: u32,
242 ) -> Result<StreamingPayload, SendPacketError> {
243 if !self.shared.is_closed() {
244 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
245
246 let stream = StreamingPayload {
247 rx: Cell::new(None),
248 shared: self.shared.clone(),
249 inprocess: Cell::new(true),
250 };
251
252 self.packet.qos = QoS::AtMostOnce;
253 self.packet.payload_size = size;
254 self.shared
255 .encode_publish(self.packet, None)
256 .map_err(SendPacketError::Encode)
257 .map(|_| stream)
258 } else {
259 log::error!("Mqtt sink is disconnected");
260 Err(SendPacketError::Disconnected)
261 }
262 }
263
264 pub fn send_at_least_once(
266 mut self,
267 payload: Bytes,
268 ) -> impl Future<Output = Result<codec::PublishAck, SendPacketError>> {
269 if !self.shared.is_closed() {
270 self.packet.qos = QoS::AtLeastOnce;
271 self.packet.payload_size = payload.len() as u32;
272
273 if let Some(rx) = self.shared.wait_readiness() {
275 Either::Left(Either::Left(async move {
276 if rx.await.is_err() {
277 return Err(SendPacketError::Disconnected);
278 }
279 self.send_at_least_once_inner(payload).await
280 }))
281 } else {
282 Either::Left(Either::Right(self.send_at_least_once_inner(payload)))
283 }
284 } else {
285 Either::Right(Ready::Err(SendPacketError::Disconnected))
286 }
287 }
288
289 pub fn send_at_least_once_no_block(
293 mut self,
294 payload: Bytes,
295 ) -> Result<(), SendPacketError> {
296 if !self.shared.is_closed() {
297 if !self.shared.is_ready() {
299 panic!("Mqtt sink is not ready");
300 }
301 self.packet.qos = codec::QoS::AtLeastOnce;
302 self.packet.payload_size = payload.len() as u32;
303
304 let idx = self.shared.set_publish_id(&mut self.packet);
305
306 log::trace!("Publish (QoS1) to {:#?}", self.packet);
307 self.shared.wait_publish_response_no_block(
308 idx,
309 AckType::Publish,
310 self.packet,
311 Some(payload),
312 )
313 } else {
314 Err(SendPacketError::Disconnected)
315 }
316 }
317
318 pub fn stream_at_least_once(
320 mut self,
321 size: u32,
322 ) -> (impl Future<Output = Result<codec::PublishAck, SendPacketError>>, StreamingPayload)
323 {
324 let (tx, rx) = self.shared.pool.waiters.channel();
325 let stream = StreamingPayload {
326 rx: Cell::new(Some(rx)),
327 shared: self.shared.clone(),
328 inprocess: Cell::new(false),
329 };
330
331 if !self.shared.is_closed() {
332 self.packet.qos = QoS::AtLeastOnce;
333 self.packet.payload_size = size;
334
335 let fut = if let Some(rx) = self.shared.wait_readiness() {
337 Either::Left(Either::Left(async move {
338 if rx.await.is_err() {
339 return Err(SendPacketError::Disconnected);
340 }
341 self.stream_at_least_once_inner(tx, None).await
342 }))
343 } else {
344 Either::Left(Either::Right(self.stream_at_least_once_inner(tx, None)))
345 };
346 (fut, stream)
347 } else {
348 (Either::Right(Ready::Err(SendPacketError::Disconnected)), stream)
349 }
350 }
351
352 async fn send_at_least_once_inner(
353 mut self,
354 payload: Bytes,
355 ) -> Result<codec::PublishAck, SendPacketError> {
356 let idx = self.shared.set_publish_id(&mut self.packet);
358
359 log::trace!("Publish (QoS1) to {:#?}", self.packet);
361 self.shared
362 .wait_publish_response(idx, AckType::Publish, self.packet, Some(payload))?
363 .await
364 .map(|pkt| pkt.publish())
365 .map_err(|_| SendPacketError::Disconnected)
366 }
367
368 async fn stream_at_least_once_inner(
369 mut self,
370 tx: pool::Sender<()>,
371 chunk: Option<Bytes>,
372 ) -> Result<codec::PublishAck, SendPacketError> {
373 let idx = self.shared.set_publish_id(&mut self.packet);
375
376 log::trace!("Publish (QoS1) to {:#?}", self.packet);
378
379 if tx.is_canceled() {
380 Err(SendPacketError::StreamingCancelled)
381 } else {
382 let rx =
383 self.shared.wait_publish_response(idx, AckType::Publish, self.packet, chunk);
384 let _ = tx.send(());
385
386 rx?.await.map(|pkt| pkt.publish()).map_err(|_| SendPacketError::Disconnected)
387 }
388 }
389
390 pub fn send_exactly_once(
392 mut self,
393 payload: Bytes,
394 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
395 if !self.shared.is_closed() {
396 self.packet.qos = codec::QoS::ExactlyOnce;
397 self.packet.payload_size = payload.len() as u32;
398
399 if let Some(rx) = self.shared.wait_readiness() {
401 Either::Left(Either::Left(async move {
402 if rx.await.is_err() {
403 return Err(SendPacketError::Disconnected);
404 }
405 self.send_exactly_once_inner(payload).await
406 }))
407 } else {
408 Either::Left(Either::Right(self.send_exactly_once_inner(payload)))
409 }
410 } else {
411 Either::Right(Ready::Err(SendPacketError::Disconnected))
412 }
413 }
414
415 fn send_exactly_once_inner(
416 mut self,
417 payload: Bytes,
418 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
419 let shared = self.shared.clone();
420 let idx = shared.set_publish_id(&mut self.packet);
421 log::trace!("Publish (QoS2) to {:#?}", self.packet);
422
423 let rx =
424 shared.wait_publish_response(idx, AckType::Receive, self.packet, Some(payload));
425 async move {
426 rx?.await
427 .map(move |ack| PublishReceived::new(ack.receive(), shared))
428 .map_err(|_| SendPacketError::Disconnected)
429 }
430 }
431}
432
433pub struct PublishReceived {
435 ack: codec::PublishAck,
436 result: Option<codec::PublishAck2>,
437 shared: Rc<MqttShared>,
438}
439
440impl PublishReceived {
441 fn new(ack: codec::PublishAck, shared: Rc<MqttShared>) -> Self {
442 let packet_id = ack.packet_id;
443 Self {
444 ack,
445 shared,
446 result: Some(codec::PublishAck2 {
447 packet_id,
448 reason_code: codec::PublishAck2Reason::Success,
449 properties: codec::UserProperties::default(),
450 reason_string: None,
451 }),
452 }
453 }
454
455 pub fn packet(&self) -> &codec::PublishAck {
457 &self.ack
458 }
459
460 #[inline]
462 pub fn properties<F>(mut self, f: F) -> Self
463 where
464 F: FnOnce(&mut codec::UserProperties),
465 {
466 f(&mut self.result.as_mut().unwrap().properties);
467 self
468 }
469
470 #[inline]
472 pub fn reason(mut self, reason: ByteString) -> Self {
473 self.result.as_mut().unwrap().reason_string = Some(reason);
474 self
475 }
476
477 pub async fn release(mut self) -> Result<(), SendPacketError> {
479 let rx = self.shared.release_publish(self.result.take().unwrap())?;
480
481 rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
482 }
483}
484
485impl Drop for PublishReceived {
486 fn drop(&mut self) {
487 if let Some(ack) = self.result.take() {
488 let _ = self.shared.release_publish(ack);
489 }
490 }
491}
492
493pub struct SubscribeBuilder {
495 id: Option<NonZeroU16>,
496 packet: codec::Subscribe,
497 shared: Rc<MqttShared>,
498}
499
500impl SubscribeBuilder {
501 #[inline]
502 pub fn packet_id(mut self, id: u16) -> Self {
506 if let Some(id) = NonZeroU16::new(id) {
507 self.id = Some(id);
508 self
509 } else {
510 panic!("id 0 is not allowed");
511 }
512 }
513
514 #[inline]
515 pub fn topic_filter(
517 mut self,
518 filter: ByteString,
519 opts: codec::SubscriptionOptions,
520 ) -> Self {
521 self.packet.topic_filters.push((filter, opts));
522 self
523 }
524
525 #[inline]
526 pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
528 self.packet.user_properties.push((key, value));
529 self
530 }
531
532 #[inline]
533 pub fn size(&self) -> u32 {
535 self.packet.encoded_size(u32::MAX) as u32
536 }
537
538 pub async fn send(self) -> Result<codec::SubscribeAck, SendPacketError> {
540 let shared = self.shared;
541 let mut packet = self.packet;
542
543 if !shared.is_closed() {
544 if let Some(rx) = shared.wait_readiness() {
546 if rx.await.is_err() {
547 return Err(SendPacketError::Disconnected);
548 }
549 }
550
551 packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
553
554 log::trace!("Sending subscribe packet {:#?}", packet);
556
557 let rx = shared.wait_response(packet.packet_id, AckType::Subscribe)?;
558 match shared.encode_packet(codec::Packet::Subscribe(packet)) {
559 Ok(_) => {
560 rx.await
562 .map_err(|_| SendPacketError::Disconnected)
563 .map(|pkt| pkt.subscribe())
564 }
565 Err(err) => Err(SendPacketError::Encode(err)),
566 }
567 } else {
568 Err(SendPacketError::Disconnected)
569 }
570 }
571}
572
573pub struct UnsubscribeBuilder {
575 id: Option<NonZeroU16>,
576 packet: codec::Unsubscribe,
577 shared: Rc<MqttShared>,
578}
579
580impl UnsubscribeBuilder {
581 #[inline]
582 pub fn packet_id(mut self, id: u16) -> Self {
586 if let Some(id) = NonZeroU16::new(id) {
587 self.id = Some(id);
588 self
589 } else {
590 panic!("id 0 is not allowed");
591 }
592 }
593
594 #[inline]
595 pub fn topic_filter(mut self, filter: ByteString) -> Self {
597 self.packet.topic_filters.push(filter);
598 self
599 }
600
601 #[inline]
602 pub fn property(mut self, key: ByteString, value: ByteString) -> Self {
604 self.packet.user_properties.push((key, value));
605 self
606 }
607
608 #[inline]
609 pub fn size(&self) -> u32 {
611 self.packet.encoded_size(u32::MAX) as u32
612 }
613
614 pub async fn send(self) -> Result<codec::UnsubscribeAck, SendPacketError> {
616 let shared = self.shared;
617 let mut packet = self.packet;
618
619 if !shared.is_closed() {
620 if let Some(rx) = shared.wait_readiness() {
622 if rx.await.is_err() {
623 return Err(SendPacketError::Disconnected);
624 }
625 }
626 packet.packet_id = self.id.unwrap_or_else(|| shared.next_id());
628
629 log::trace!("Sending unsubscribe packet {:#?}", packet);
631
632 let rx = shared.wait_response(packet.packet_id, AckType::Unsubscribe)?;
633 match shared.encode_packet(codec::Packet::Unsubscribe(packet)) {
634 Ok(_) => {
635 rx.await
637 .map_err(|_| SendPacketError::Disconnected)
638 .map(|pkt| pkt.unsubscribe())
639 }
640 Err(err) => Err(SendPacketError::Encode(err)),
641 }
642 } else {
643 Err(SendPacketError::Disconnected)
644 }
645 }
646}
647
648pub struct StreamingPayload {
649 shared: Rc<MqttShared>,
650 rx: Cell<Option<pool::Receiver<()>>>,
651 inprocess: Cell<bool>,
652}
653
654impl Drop for StreamingPayload {
655 fn drop(&mut self) {
656 if self.inprocess.get() && self.shared.is_streaming() {
657 self.shared.streaming_dropped();
658 }
659 }
660}
661
662impl StreamingPayload {
663 pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
665 if let Some(rx) = self.rx.take() {
666 if rx.await.is_err() {
667 return Err(SendPacketError::StreamingCancelled);
668 }
669 log::trace!("Publish is encoded, ready to process payload");
670 self.inprocess.set(true);
671 }
672
673 if !self.inprocess.get() {
674 Err(EncodeError::UnexpectedPayload.into())
675 } else {
676 log::trace!("Sending payload chunk: {:?}", chunk.len());
677 self.shared.want_payload_stream().await?;
678
679 if !self.shared.encode_publish_payload(chunk)? {
680 self.inprocess.set(false);
681 }
682 Ok(())
683 }
684 }
685}