1#![allow(clippy::type_complexity)]
2use std::{cell::Cell, cell::RefCell, collections::VecDeque, fmt, num, rc::Rc};
3
4use ntex_bytes::{Bytes, BytesMut};
5use ntex_codec::{Decoder, Encoder};
6use ntex_io::IoRef;
7use ntex_util::{HashSet, channel::pool};
8
9use crate::error::{DecodeError, EncodeError, ProtocolError, SendPacketError};
10use crate::types::packet_type;
11use crate::v3::codec::{self, Encoded, Publish};
12
13#[derive(Debug)]
14pub(super) enum Ack {
15 Publish(num::NonZeroU16),
16 Receive(num::NonZeroU16),
17 Complete(num::NonZeroU16),
18 Subscribe { packet_id: num::NonZeroU16, status: Vec<codec::SubscribeReturnCode> },
19 Unsubscribe(num::NonZeroU16),
20}
21
22#[derive(Copy, Clone, Debug)]
23pub(super) enum AckType {
24 Publish,
25 Receive,
26 Complete,
27 Subscribe,
28 Unsubscribe,
29}
30
31pub(super) struct MqttSinkPool {
32 queue: pool::Pool<Ack>,
33 pub(super) waiters: pool::Pool<()>,
34}
35
36impl Default for MqttSinkPool {
37 fn default() -> Self {
38 Self { queue: pool::new(), waiters: pool::new() }
39 }
40}
41
42bitflags::bitflags! {
43 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
44 struct Flags: u8 {
45 const CLIENT = 0b0000_0001;
46 const WRB_ENABLED = 0b0000_0010; const ON_PUBLISH_ACK = 0b0000_0100; const DISCONNECT = 0b0100_0000; const STOPPED = 0b1000_0000; }
52}
53
54pub struct MqttShared {
55 io: IoRef,
56 cap: Cell<usize>,
57 queues: RefCell<MqttSharedQueues>,
58 inflight_idx: Cell<u16>,
59 flags: Cell<Flags>,
60 encode_error: Cell<Option<EncodeError>>,
61 streaming_waiter: Cell<Option<pool::Sender<()>>>,
62 streaming_remaining: Cell<Option<num::NonZeroU32>>,
63 on_publish_ack: Cell<Option<Box<dyn Fn(num::NonZeroU16, bool)>>>,
64 pub(super) codec: codec::Codec,
65 pub(super) pool: Rc<MqttSinkPool>,
66}
67
68impl fmt::Debug for MqttShared {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_struct("MqttShared").finish()
71 }
72}
73
74#[derive(Debug)]
75struct MqttSharedQueues {
76 inflight: VecDeque<(num::NonZeroU16, Option<pool::Sender<Ack>>, AckType)>,
77 inflight_ids: HashSet<num::NonZeroU16>,
78 waiters: VecDeque<pool::Sender<()>>,
79 rx: Option<pool::Receiver<Ack>>,
80}
81
82impl MqttShared {
83 pub(super) fn new(
84 io: IoRef,
85 codec: codec::Codec,
86 client: bool,
87 pool: Rc<MqttSinkPool>,
88 ) -> Self {
89 Self {
90 io,
91 codec,
92 pool,
93 cap: Cell::new(0),
94 flags: Cell::new(if client { Flags::CLIENT } else { Flags::empty() }),
95 queues: RefCell::new(MqttSharedQueues {
96 inflight: VecDeque::with_capacity(8),
97 inflight_ids: HashSet::default(),
98 waiters: VecDeque::new(),
99 rx: None,
100 }),
101 inflight_idx: Cell::new(0),
102 encode_error: Cell::new(None),
103 streaming_waiter: Cell::new(None),
104 streaming_remaining: Cell::new(None),
105 on_publish_ack: Cell::new(None),
106 }
107 }
108
109 pub(super) fn tag(&self) -> &'static str {
110 self.io.tag()
111 }
112
113 pub(super) fn close(&self) {
114 if self.flags.get().contains(Flags::CLIENT) && !self.is_disconnect_sent() {
115 let _ = self.encode_packet(codec::Packet::Disconnect);
116 }
117 self.io.close();
118 self.clear_queues();
119 }
120
121 pub(super) fn force_close(&self) {
122 self.io.force_close();
123 self.clear_queues();
124 }
125
126 pub(super) fn streaming_dropped(&self) {
127 self.force_close();
128 self.encode_error.set(Some(EncodeError::PublishIncomplete));
129 }
130
131 pub(super) fn is_streaming(&self) -> bool {
132 self.streaming_remaining.get().is_some()
133 }
134
135 pub(super) fn is_closed(&self) -> bool {
136 self.io.is_closed()
137 }
138
139 pub(super) fn is_ready(&self) -> bool {
140 self.credit() > 0 && !self.flags.get().contains(Flags::WRB_ENABLED)
141 }
142
143 pub(super) fn is_disconnect_sent(&self) -> bool {
144 let mut flags = self.flags.get();
145 let sent = flags.contains(Flags::DISCONNECT);
146 if !sent {
147 flags.insert(Flags::DISCONNECT);
148 self.flags.set(flags);
149 }
150 sent
151 }
152
153 pub(super) fn is_dispatcher_stopped(&self) -> bool {
154 let mut flags = self.flags.get();
155 let stopped = flags.contains(Flags::STOPPED);
156 if !stopped {
157 flags.insert(Flags::STOPPED);
158 self.flags.set(flags);
159 }
160 stopped
161 }
162
163 pub(super) fn credit(&self) -> usize {
164 self.cap.get().saturating_sub(self.queues.borrow().inflight.len())
165 }
166
167 pub(super) fn next_id(&self) -> num::NonZeroU16 {
168 let idx = self.inflight_idx.get() + 1;
169 let idx = if idx == u16::MAX {
170 self.inflight_idx.set(0);
171 u16::MAX
172 } else {
173 self.inflight_idx.set(idx);
174 idx
175 };
176 num::NonZeroU16::new(idx).unwrap()
177 }
178
179 pub(super) fn set_publish_id(&self, pkt: &mut Publish) -> num::NonZeroU16 {
181 if let Some(idx) = pkt.packet_id {
182 idx
183 } else {
184 let idx = self.next_id();
185 pkt.packet_id = Some(idx);
186 idx
187 }
188 }
189
190 pub(super) fn set_cap(&self, cap: usize) {
191 let mut queues = self.queues.borrow_mut();
192
193 'outer: for _ in 0..cap {
195 while let Some(tx) = queues.waiters.pop_front() {
196 if tx.send(()).is_ok() {
197 continue 'outer;
198 }
199 }
200 break;
201 }
202 self.cap.set(cap);
203 }
204
205 pub(super) fn set_publish_ack(&self, f: Box<dyn Fn(num::NonZeroU16, bool)>) {
206 let mut flags = self.flags.get();
207 flags.insert(Flags::ON_PUBLISH_ACK);
208 self.flags.set(flags);
209 self.on_publish_ack.set(Some(f));
210 }
211
212 pub(super) fn encode_packet(&self, pkt: codec::Packet) -> Result<(), EncodeError> {
213 self.check_streaming()?;
214 self.io.encode(pkt.into(), &self.codec)
215 }
216
217 pub(super) fn encode_publish(
218 &self,
219 pkt: Publish,
220 payload: Option<Bytes>,
221 ) -> Result<(), EncodeError> {
222 self.check_streaming()?;
223 self.enable_streaming(&pkt, payload.as_ref());
224 self.io.encode(Encoded::Publish(pkt, payload), &self.codec)
225 }
226
227 pub(super) fn encode_publish_payload(&self, payload: Bytes) -> Result<bool, EncodeError> {
228 if let Some(remaining) = self.streaming_remaining.get() {
229 let len = payload.len() as u32;
230 if len > remaining.get() {
231 self.force_close();
232 Err(EncodeError::OverPublishSize)
233 } else {
234 self.io.encode(Encoded::PayloadChunk(payload), &self.codec)?;
235 self.streaming_remaining.set(num::NonZeroU32::new(remaining.get() - len));
236 Ok(self.streaming_remaining.get().is_some())
237 }
238 } else {
239 Err(EncodeError::UnexpectedPayload)
240 }
241 }
242
243 fn clear_queues(&self) {
244 let mut queues = self.queues.borrow_mut();
245 queues.waiters.clear();
246
247 if let Some(cb) = self.on_publish_ack.take() {
248 for (idx, tx, _) in queues.inflight.drain(..) {
249 if tx.is_none() {
250 (*cb)(idx, true);
251 }
252 }
253 } else {
254 queues.inflight.clear();
255 }
256 }
257
258 pub(super) fn enable_wr_backpressure(&self) {
259 let mut flags = self.flags.get();
260 flags.insert(Flags::WRB_ENABLED);
261 self.flags.set(flags);
262 }
263
264 pub(super) fn disable_wr_backpressure(&self) {
265 let mut flags = self.flags.get();
266 flags.remove(Flags::WRB_ENABLED);
267 self.flags.set(flags);
268
269 if let Some(tx) = self.streaming_waiter.take()
271 && tx.send(()).is_ok()
272 {
273 return;
274 }
275
276 let mut queues = self.queues.borrow_mut();
278 if queues.inflight.len() < self.cap.get() {
279 let mut num = self.cap.get() - queues.inflight.len();
280 while num > 0 {
281 if let Some(tx) = queues.waiters.pop_front() {
282 if tx.send(()).is_ok() {
283 num -= 1;
284 }
285 } else {
286 break;
287 }
288 }
289 }
290 }
291
292 pub(super) async fn want_payload_stream(&self) -> Result<(), SendPacketError> {
293 if self.is_closed() {
294 Err(SendPacketError::Disconnected)
295 } else if self.flags.get().contains(Flags::WRB_ENABLED) {
296 let (tx, rx) = self.pool.waiters.channel();
297 self.streaming_waiter.set(Some(tx));
298 if rx.await.is_ok() {
299 Ok(())
300 } else {
301 Err(SendPacketError::Disconnected)
302 }
303 } else {
304 Ok(())
305 }
306 }
307
308 fn check_streaming(&self) -> Result<(), EncodeError> {
309 if self.streaming_remaining.get().is_some() {
310 Err(EncodeError::ExpectPayload)
311 } else {
312 Ok(())
313 }
314 }
315
316 fn enable_streaming(&self, pkt: &Publish, payload: Option<&Bytes>) {
317 let len = payload.map_or(0, Bytes::len);
318 self.streaming_remaining.set(num::NonZeroU32::new(pkt.payload_size - len as u32));
319 }
320
321 pub(super) fn pkt_ack(&self, ack: Ack) -> Result<(), ProtocolError> {
322 self.pkt_ack_inner(ack).inspect_err(|_| {
323 self.close();
324 })
325 }
326
327 fn pkt_ack_inner(&self, pkt: Ack) -> Result<(), ProtocolError> {
328 let mut queues = self.queues.borrow_mut();
329
330 if let Some((idx, tx, tp)) = queues.inflight.pop_front() {
332 if idx != pkt.packet_id() {
333 log::trace!(
334 "MQTT protocol error: packet id order does not match; expected {}, got: {}",
335 idx,
336 pkt.packet_id()
337 );
338 Err(ProtocolError::packet_id_mismatch())
339 } else if matches!(pkt, Ack::Receive(_)) {
340 log::trace!("Ack packet with id: {}", pkt.packet_id());
342
343 if let Some(tx) = tx {
344 let _ = tx.send(pkt);
345 }
346 let (tx, rx) = self.pool.queue.channel();
347 queues.rx = Some(rx);
348 queues.inflight.push_back((idx, Some(tx), AckType::Complete));
349 Ok(())
350 } else if matches!(pkt, Ack::Complete(_)) {
351 log::trace!("Ack packet with id: {}", pkt.packet_id());
353 queues.inflight_ids.remove(&pkt.packet_id());
354 queues.rx.take();
355
356 if let Some(tx) = tx {
357 let _ = tx.send(pkt);
358 }
359
360 while let Some(tx) = queues.waiters.pop_front() {
362 if tx.send(()).is_ok() {
363 break;
364 }
365 }
366 Ok(())
367 } else {
368 log::trace!("Ack packet with id: {}", pkt.packet_id());
370 queues.inflight_ids.remove(&pkt.packet_id());
371
372 if pkt.is_match(tp) {
373 if let Some(tx) = tx {
374 let _ = tx.send(pkt);
375 } else {
376 let cb = self.on_publish_ack.take().unwrap();
377 (*cb)(pkt.packet_id(), false);
378 self.on_publish_ack.set(Some(cb));
379 }
380
381 while let Some(tx) = queues.waiters.pop_front() {
383 if tx.send(()).is_ok() {
384 break;
385 }
386 }
387 Ok(())
388 } else {
389 log::trace!("MQTT protocol error, unexpected packet");
390 Err(ProtocolError::unexpected_packet(pkt.packet_type(), tp.expected_str()))
391 }
392 }
393 } else {
394 log::trace!("Unexpected PUBACK packet: {:?}", pkt.packet_id());
395 Err(ProtocolError::generic_violation(
396 "Received PUBACK packet while there are no unacknowledged PUBLISH packets",
397 ))
398 }
399 }
400
401 pub(super) fn wait_response(
403 &self,
404 id: num::NonZeroU16,
405 ack: AckType,
406 ) -> Result<pool::Receiver<Ack>, SendPacketError> {
407 let mut queues = self.queues.borrow_mut();
408 if queues.inflight_ids.contains(&id) {
409 Err(SendPacketError::PacketIdInUse(id))
410 } else {
411 let (tx, rx) = self.pool.queue.channel();
412 queues.inflight.push_back((id, Some(tx), ack));
413 queues.inflight_ids.insert(id);
414 Ok(rx)
415 }
416 }
417
418 pub(super) fn wait_publish_response(
420 &self,
421 id: num::NonZeroU16,
422 ack: AckType,
423 pkt: Publish,
424 payload: Option<Bytes>,
425 ) -> Result<pool::Receiver<Ack>, SendPacketError> {
426 self.check_streaming()?;
427 self.enable_streaming(&pkt, payload.as_ref());
428
429 let mut queues = self.queues.borrow_mut();
430 if queues.inflight_ids.contains(&id) {
431 Err(SendPacketError::PacketIdInUse(id))
432 } else {
433 match self.io.encode(Encoded::Publish(pkt, payload), &self.codec) {
434 Ok(()) => {
435 let (tx, rx) = self.pool.queue.channel();
436 queues.inflight.push_back((id, Some(tx), ack));
437 queues.inflight_ids.insert(id);
438 Ok(rx)
439 }
440 Err(e) => Err(SendPacketError::Encode(e)),
441 }
442 }
443 }
444
445 pub(super) fn wait_publish_response_no_block(
447 &self,
448 id: num::NonZeroU16,
449 ack: AckType,
450 pkt: Publish,
451 payload: Option<Bytes>,
452 ) -> Result<(), SendPacketError> {
453 self.check_streaming()?;
454 self.enable_streaming(&pkt, payload.as_ref());
455
456 let mut queues = self.queues.borrow_mut();
457 if queues.inflight_ids.contains(&id) {
458 Err(SendPacketError::PacketIdInUse(id))
459 } else {
460 match self.io.encode(Encoded::Publish(pkt, payload), &self.codec) {
461 Ok(()) => {
462 assert!(
463 self.flags.get().contains(Flags::ON_PUBLISH_ACK),
464 "Publish ack callback is not set"
465 );
466 queues.inflight.push_back((id, None, ack));
467 queues.inflight_ids.insert(id);
468 Ok(())
469 }
470 Err(e) => Err(SendPacketError::Encode(e)),
471 }
472 }
473 }
474
475 pub(super) fn wait_readiness(&self) -> Option<pool::Receiver<()>> {
476 let mut queues = self.queues.borrow_mut();
477
478 if queues.inflight.len() >= self.cap.get()
479 || self.flags.get().contains(Flags::WRB_ENABLED)
480 {
481 let (tx, rx) = self.pool.waiters.channel();
482 queues.waiters.push_back(tx);
483 Some(rx)
484 } else {
485 None
486 }
487 }
488
489 pub(super) fn release_publish(
491 &self,
492 id: num::NonZeroU16,
493 ) -> Result<pool::Receiver<Ack>, SendPacketError> {
494 let Some(rx) = self.queues.borrow_mut().rx.take() else {
495 return Err(SendPacketError::UnexpectedRelease);
496 };
497 match self.io.encode(
498 Encoded::Packet(codec::Packet::PublishRelease { packet_id: id }),
499 &self.codec,
500 ) {
501 Ok(()) => Ok(rx),
502 Err(e) => Err(SendPacketError::Encode(e)),
503 }
504 }
505}
506
507impl Encoder for MqttShared {
508 type Item = Encoded;
509 type Error = EncodeError;
510
511 #[inline]
512 fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
513 self.codec.encode(item, dst)
514 }
515}
516
517impl Decoder for MqttShared {
518 type Item = codec::Decoded;
519 type Error = DecodeError;
520
521 #[inline]
522 fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
523 self.codec.decode(src)
524 }
525}
526
527impl Ack {
528 pub(super) fn packet_type(&self) -> u8 {
529 match self {
530 Ack::Publish(_) => packet_type::PUBACK,
531 Ack::Receive(_) => packet_type::PUBREC,
532 Ack::Complete(_) => packet_type::PUBCOMP,
533 Ack::Subscribe { .. } => packet_type::SUBACK,
534 Ack::Unsubscribe(_) => packet_type::UNSUBACK,
535 }
536 }
537
538 pub(super) fn packet_id(&self) -> num::NonZeroU16 {
539 match self {
540 Ack::Subscribe { packet_id, .. } => *packet_id,
541 Ack::Publish(id) | Ack::Receive(id) | Ack::Complete(id) | Ack::Unsubscribe(id) => {
542 *id
543 }
544 }
545 }
546
547 pub(super) fn subscribe(self) -> Vec<codec::SubscribeReturnCode> {
548 if let Ack::Subscribe { status, .. } = self {
549 status
550 } else {
551 panic!()
552 }
553 }
554
555 pub(super) fn is_match(&self, tp: AckType) -> bool {
556 match (self, tp) {
557 (Ack::Publish(_), AckType::Publish)
558 | (Ack::Receive(_), AckType::Receive)
559 | (Ack::Subscribe { .. }, AckType::Subscribe)
560 | (Ack::Unsubscribe(_), AckType::Unsubscribe) => true,
561 (_, _) => false,
562 }
563 }
564}
565
566impl AckType {
567 pub(super) fn expected_str(self) -> &'static str {
568 match self {
569 AckType::Publish => "Expected PUBACK packet",
570 AckType::Receive => "Expected PUBREC packet",
571 AckType::Complete => "Expected PUBCOMP packet",
572 AckType::Subscribe => "Expected SUBACK packet",
573 AckType::Unsubscribe => "Expected UNSUBACK packet",
574 }
575 }
576}