1use std::{cell::Cell, fmt, future::Future, future::ready, num::NonZeroU16, rc::Rc};
2
3use ntex_bytes::{ByteString, Bytes};
4use ntex_util::{channel::pool, future::Either, future::Ready};
5
6use crate::v3::shared::{AckType, MqttShared};
7use crate::v3::{codec, error::SendPacketError};
8use crate::{error::EncodeError, types::QoS};
9
10pub struct MqttSink(Rc<MqttShared>);
11
12impl Clone for MqttSink {
13 fn clone(&self) -> Self {
14 MqttSink(self.0.clone())
15 }
16}
17
18impl MqttSink {
19 pub(crate) fn new(state: Rc<MqttShared>) -> Self {
20 MqttSink(state)
21 }
22
23 pub(super) fn shared(&self) -> Rc<MqttShared> {
24 self.0.clone()
25 }
26
27 #[inline]
28 pub fn is_open(&self) -> bool {
30 !self.0.is_closed()
31 }
32
33 #[inline]
34 pub fn is_ready(&self) -> bool {
36 if self.0.is_closed() {
37 false
38 } else {
39 self.0.is_ready()
40 }
41 }
42
43 #[inline]
44 pub fn credit(&self) -> usize {
46 self.0.credit()
47 }
48
49 pub fn ready(&self) -> impl Future<Output = bool> {
53 if !self.0.is_closed() {
54 self.0
55 .wait_readiness()
56 .map(|rx| Either::Right(async move { rx.await.is_ok() }))
57 .unwrap_or_else(|| Either::Left(ready(true)))
58 } else {
59 Either::Left(ready(false))
60 }
61 }
62
63 #[inline]
64 pub fn close(&self) {
66 self.0.close();
67 }
68
69 #[inline]
70 pub fn force_close(&self) {
73 self.0.force_close();
74 }
75
76 #[inline]
77 pub(super) fn ping(&self) -> bool {
79 self.0.encode_packet(codec::Packet::PingRequest).is_ok()
80 }
81
82 #[inline]
83 pub fn publish<U>(&self, topic: U) -> PublishBuilder
85 where
86 ByteString: From<U>,
87 {
88 self.publish_pkt(codec::Publish {
89 dup: false,
90 retain: false,
91 topic: topic.into(),
92 qos: codec::QoS::AtMostOnce,
93 packet_id: None,
94 payload_size: 0,
95 })
96 }
97
98 #[inline]
99 pub fn publish_pkt(&self, packet: codec::Publish) -> PublishBuilder {
101 PublishBuilder { packet, shared: self.0.clone() }
102 }
103
104 pub fn publish_ack_cb<F>(&self, f: F)
109 where
110 F: Fn(NonZeroU16, bool) + 'static,
111 {
112 self.0.set_publish_ack(Box::new(f));
113 }
114
115 #[inline]
116 pub fn subscribe(&self) -> SubscribeBuilder {
120 SubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
121 }
122
123 #[inline]
124 pub fn unsubscribe(&self) -> UnsubscribeBuilder {
126 UnsubscribeBuilder { id: None, topic_filters: Vec::new(), shared: self.0.clone() }
127 }
128}
129
130impl fmt::Debug for MqttSink {
131 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
132 fmt.debug_struct("MqttSink").finish()
133 }
134}
135
136pub struct PublishBuilder {
137 packet: codec::Publish,
138 shared: Rc<MqttShared>,
139}
140
141impl PublishBuilder {
142 #[inline]
143 pub fn packet_id(mut self, id: u16) -> Self {
151 let id = NonZeroU16::new(id).expect("id 0 is not allowed");
152 self.packet.packet_id = Some(id);
153 self
154 }
155
156 #[inline]
157 pub fn dup(mut self, val: bool) -> Self {
159 self.packet.dup = val;
160 self
161 }
162
163 #[inline]
164 pub fn retain(mut self) -> Self {
166 self.packet.retain = true;
167 self
168 }
169
170 #[inline]
171 pub fn size(&self, payload_size: usize) -> u32 {
173 (codec::encode::get_encoded_publish_size(&self.packet) + payload_size) as u32
174 }
175
176 #[inline]
177 pub fn send_at_most_once(mut self, payload: Bytes) -> Result<(), SendPacketError> {
179 if !self.shared.is_closed() {
180 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
181 self.packet.qos = codec::QoS::AtMostOnce;
182 self.packet.payload_size = payload.len() as u32;
183 self.shared
184 .encode_publish(self.packet, Some(payload))
185 .map_err(SendPacketError::Encode)
186 .map(|_| ())
187 } else {
188 log::error!("Mqtt sink is disconnected");
189 Err(SendPacketError::Disconnected)
190 }
191 }
192
193 pub fn stream_at_most_once(
195 mut self,
196 size: u32,
197 ) -> Result<StreamingPayload, SendPacketError> {
198 if !self.shared.is_closed() {
199 log::trace!("Publish (QoS-0) to {:?}", self.packet.topic);
200
201 let stream = StreamingPayload {
202 rx: Cell::new(None),
203 shared: self.shared.clone(),
204 inprocess: Cell::new(true),
205 };
206
207 self.packet.qos = QoS::AtMostOnce;
208 self.packet.payload_size = size;
209 self.shared
210 .encode_publish(self.packet, None)
211 .map_err(SendPacketError::Encode)
212 .map(|_| stream)
213 } else {
214 log::error!("Mqtt sink is disconnected");
215 Err(SendPacketError::Disconnected)
216 }
217 }
218
219 pub fn send_at_least_once(
221 mut self,
222 payload: Bytes,
223 ) -> impl Future<Output = Result<(), SendPacketError>> {
224 if !self.shared.is_closed() {
225 self.packet.qos = codec::QoS::AtLeastOnce;
226 self.packet.payload_size = payload.len() as u32;
227
228 if let Some(rx) = self.shared.wait_readiness() {
230 Either::Left(Either::Left(async move {
231 if rx.await.is_err() {
232 return Err(SendPacketError::Disconnected);
233 }
234 self.send_at_least_once_inner(payload).await
235 }))
236 } else {
237 Either::Left(Either::Right(self.send_at_least_once_inner(payload)))
238 }
239 } else {
240 Either::Right(Ready::Err(SendPacketError::Disconnected))
241 }
242 }
243
244 pub fn send_at_least_once_no_block(
248 mut self,
249 payload: Bytes,
250 ) -> Result<(), SendPacketError> {
251 if !self.shared.is_closed() {
252 if !self.shared.is_ready() {
254 panic!("Mqtt sink is not ready");
255 }
256 self.packet.qos = codec::QoS::AtLeastOnce;
257 self.packet.payload_size = payload.len() as u32;
258 let idx = self.shared.set_publish_id(&mut self.packet);
259
260 log::trace!("Publish (QoS1) to {:#?}", self.packet);
261
262 self.shared.wait_publish_response_no_block(
263 idx,
264 AckType::Publish,
265 self.packet,
266 Some(payload),
267 )
268 } else {
269 Err(SendPacketError::Disconnected)
270 }
271 }
272
273 fn send_at_least_once_inner(
274 mut self,
275 payload: Bytes,
276 ) -> impl Future<Output = Result<(), SendPacketError>> {
277 let idx = self.shared.set_publish_id(&mut self.packet);
278 log::trace!("Publish (QoS1) to {:#?}", self.packet);
279
280 let rx = self.shared.wait_publish_response(
281 idx,
282 AckType::Publish,
283 self.packet,
284 Some(payload),
285 );
286 async move { rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected) }
287 }
288
289 pub fn send_exactly_once(
291 mut self,
292 payload: Bytes,
293 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
294 if !self.shared.is_closed() {
295 self.packet.qos = codec::QoS::ExactlyOnce;
296 self.packet.payload_size = payload.len() as u32;
297
298 if let Some(rx) = self.shared.wait_readiness() {
300 Either::Left(Either::Left(async move {
301 if rx.await.is_err() {
302 return Err(SendPacketError::Disconnected);
303 }
304 self.send_exactly_once_inner(payload).await
305 }))
306 } else {
307 Either::Left(Either::Right(self.send_exactly_once_inner(payload)))
308 }
309 } else {
310 Either::Right(Ready::Err(SendPacketError::Disconnected))
311 }
312 }
313
314 fn send_exactly_once_inner(
315 mut self,
316 payload: Bytes,
317 ) -> impl Future<Output = Result<PublishReceived, SendPacketError>> {
318 let idx = self.shared.set_publish_id(&mut self.packet);
319 log::trace!("Publish (QoS2) to {:#?}", self.packet);
320
321 let rx = self.shared.wait_publish_response(
322 idx,
323 AckType::Receive,
324 self.packet,
325 Some(payload),
326 );
327 async move {
328 rx?.await
329 .map(move |_| PublishReceived { packet_id: Some(idx), shared: self.shared })
330 .map_err(|_| SendPacketError::Disconnected)
331 }
332 }
333
334 pub fn stream_at_least_once(
336 mut self,
337 size: u32,
338 ) -> (impl Future<Output = Result<(), SendPacketError>>, StreamingPayload) {
339 let (tx, rx) = self.shared.pool.waiters.channel();
340 let stream = StreamingPayload {
341 rx: Cell::new(Some(rx)),
342 shared: self.shared.clone(),
343 inprocess: Cell::new(false),
344 };
345
346 if !self.shared.is_closed() {
347 self.packet.qos = QoS::AtLeastOnce;
348 self.packet.payload_size = size;
349
350 let fut = if let Some(rx) = self.shared.wait_readiness() {
352 Either::Left(Either::Left(async move {
353 if rx.await.is_err() {
354 return Err(SendPacketError::Disconnected);
355 }
356 self.stream_at_least_once_inner(tx).await
357 }))
358 } else {
359 Either::Left(Either::Right(self.stream_at_least_once_inner(tx)))
360 };
361 (fut, stream)
362 } else {
363 (Either::Right(Ready::Err(SendPacketError::Disconnected)), stream)
364 }
365 }
366
367 async fn stream_at_least_once_inner(
368 mut self,
369 tx: pool::Sender<()>,
370 ) -> Result<(), SendPacketError> {
371 let idx = self.shared.set_publish_id(&mut self.packet);
373
374 log::trace!("Publish (QoS1) to {:#?}", self.packet);
376
377 if tx.is_canceled() {
378 Err(SendPacketError::StreamingCancelled)
379 } else {
380 let rx =
381 self.shared.wait_publish_response(idx, AckType::Publish, self.packet, None);
382 let _ = tx.send(());
383
384 rx?.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
385 }
386 }
387}
388
389pub struct PublishReceived {
391 packet_id: Option<NonZeroU16>,
392 shared: Rc<MqttShared>,
393}
394
395impl PublishReceived {
396 pub async fn release(mut self) -> Result<(), SendPacketError> {
398 let rx = self.shared.release_publish(self.packet_id.take().unwrap())?;
399
400 rx.await.map(|_| ()).map_err(|_| SendPacketError::Disconnected)
401 }
402}
403
404impl Drop for PublishReceived {
405 fn drop(&mut self) {
406 if let Some(id) = self.packet_id.take() {
407 let _ = self.shared.release_publish(id);
408 }
409 }
410}
411
412pub struct SubscribeBuilder {
414 id: Option<NonZeroU16>,
415 shared: Rc<MqttShared>,
416 topic_filters: Vec<(ByteString, codec::QoS)>,
417}
418
419impl SubscribeBuilder {
420 #[inline]
421 pub fn packet_id(mut self, id: u16) -> Self {
425 if let Some(id) = NonZeroU16::new(id) {
426 self.id = Some(id);
427 self
428 } else {
429 panic!("id 0 is not allowed");
430 }
431 }
432
433 #[inline]
434 pub fn topic_filter(mut self, filter: ByteString, qos: codec::QoS) -> Self {
436 self.topic_filters.push((filter, qos));
437 self
438 }
439
440 #[inline]
441 pub fn size(&self) -> u32 {
443 codec::encode::get_encoded_subscribe_size(&self.topic_filters) as u32
444 }
445
446 pub async fn send(self) -> Result<Vec<codec::SubscribeReturnCode>, SendPacketError> {
448 if !self.shared.is_closed() {
449 if let Some(rx) = self.shared.wait_readiness() {
451 if rx.await.is_err() {
452 return Err(SendPacketError::Disconnected);
453 }
454 }
455 let idx = self.id.unwrap_or_else(|| self.shared.next_id());
456 let rx = self.shared.wait_response(idx, AckType::Subscribe)?;
457
458 log::trace!(
460 "Sending subscribe packet id: {} filters:{:?}",
461 idx,
462 self.topic_filters
463 );
464
465 match self.shared.encode_packet(codec::Packet::Subscribe {
466 packet_id: idx,
467 topic_filters: self.topic_filters,
468 }) {
469 Ok(_) => {
470 rx.await
472 .map_err(|_| SendPacketError::Disconnected)
473 .map(|pkt| pkt.subscribe())
474 }
475 Err(err) => Err(SendPacketError::Encode(err)),
476 }
477 } else {
478 Err(SendPacketError::Disconnected)
479 }
480 }
481}
482
483pub struct UnsubscribeBuilder {
485 id: Option<NonZeroU16>,
486 shared: Rc<MqttShared>,
487 topic_filters: Vec<ByteString>,
488}
489
490impl UnsubscribeBuilder {
491 #[inline]
492 pub fn packet_id(mut self, id: u16) -> Self {
496 if let Some(id) = NonZeroU16::new(id) {
497 self.id = Some(id);
498 self
499 } else {
500 panic!("id 0 is not allowed");
501 }
502 }
503
504 #[inline]
505 pub fn topic_filter(mut self, filter: ByteString) -> Self {
507 self.topic_filters.push(filter);
508 self
509 }
510
511 #[inline]
512 pub fn size(&self) -> u32 {
514 codec::encode::get_encoded_unsubscribe_size(&self.topic_filters) as u32
515 }
516
517 pub async fn send(self) -> Result<(), SendPacketError> {
519 let shared = self.shared;
520 let filters = self.topic_filters;
521
522 if !shared.is_closed() {
523 if let Some(rx) = shared.wait_readiness() {
525 if rx.await.is_err() {
526 return Err(SendPacketError::Disconnected);
527 }
528 }
529 let idx = self.id.unwrap_or_else(|| shared.next_id());
531 let rx = shared.wait_response(idx, AckType::Unsubscribe)?;
532
533 log::trace!("Sending unsubscribe packet id: {} filters:{:?}", idx, filters);
535
536 match shared.encode_packet(codec::Packet::Unsubscribe {
537 packet_id: idx,
538 topic_filters: filters,
539 }) {
540 Ok(_) => {
541 rx.await.map_err(|_| SendPacketError::Disconnected).map(|_| ())
543 }
544 Err(err) => Err(SendPacketError::Encode(err)),
545 }
546 } else {
547 Err(SendPacketError::Disconnected)
548 }
549 }
550}
551
552pub struct StreamingPayload {
553 shared: Rc<MqttShared>,
554 rx: Cell<Option<pool::Receiver<()>>>,
555 inprocess: Cell<bool>,
556}
557
558impl Drop for StreamingPayload {
559 fn drop(&mut self) {
560 if self.inprocess.get() && self.shared.is_streaming() {
561 self.shared.streaming_dropped();
562 }
563 }
564}
565
566impl StreamingPayload {
567 pub async fn send(&self, chunk: Bytes) -> Result<(), SendPacketError> {
569 if let Some(rx) = self.rx.take() {
570 if rx.await.is_err() {
571 return Err(SendPacketError::StreamingCancelled);
572 }
573 log::trace!("Publish is encoded, ready to process payload");
574 self.inprocess.set(true);
575 }
576
577 if !self.inprocess.get() {
578 Err(EncodeError::UnexpectedPayload.into())
579 } else {
580 log::trace!("Sending payload chunk: {:?}", chunk.len());
581 self.shared.want_payload_stream().await?;
582
583 if !self.shared.encode_publish_payload(chunk)? {
584 self.inprocess.set(false);
585 }
586 Ok(())
587 }
588 }
589}