rmqtt 0.20.0

MQTT Server for v3.1, v3.1.1 and v5.0 protocols
Documentation
//! MQTT Inflight Message Management System
//!
//! Provides reliable message delivery tracking for QoS 1 and 2 with:
//! - Outbound message retransmission
//! - Inbound message deduplication
//! - Configurable window sizes
//! - Automatic expiry handling
//!
//! ## Core Functionality
//! 1. ​**​Outbound Tracking (OutInflight)​**​:
//!    - Manages unacknowledged publishes (QoS 1/2)
//!    - Handles retransmission timeouts
//!    - Maintains packet ID sequencing
//!    - Supports configurable capacity limits
//!
//! 2. ​**​Inbound Tracking (InInflight)​**​:
//!    - Detects duplicate messages (QoS 2)
//!    - Enforces maximum window size
//!    - Provides packet ID lifecycle management
//!
//! ## Key Features
//! - Dual interval timing (retry/expiry)
//! - Event hooks for push/pop operations
//! - Atomic packet ID generation
//! - Time-based message expiry
//! - Statistics integration
//!
//! ## Implementation Details
//! - DequeMap for O(1) front access
//! - BTreeSet for efficient deduplication
//! - Atomic counters for thread safety
//! - Zero-cost status tracking
//!
//! Configuration Parameters:
//! - `cap`: Maximum concurrent outbound messages
//! - `retry_interval`: Retransmission delay (ms)
//! - `expiry_interval`: Message expiry timeout (ms)
//! - `max_inflight`: Maximum inbound window size
//!
//! Usage Patterns:
//! 1. Assign packet IDs via `next_id()`
//! 2. Track outbound messages with `push_back()`
//! 3. Process acknowledgements with `remove()`
//! 4. Handle timeouts via `pop_front_timeout()`
//! 5. Manage inbound flow with `add()`/`remove()`
//!
//! Note: Implements MQTT spec requirements for:
//! - Packet ID uniqueness (2.2.1)
//! - QoS flow control (4.6)
//! - Message expiry (3.3.2.3.2)

use std::collections::BTreeSet;
use std::num::NonZeroU16;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::Arc;
use std::time::Duration;

use anyhow::anyhow;
use itertools::Itertools;
use rust_box::dequemap::DequeBTreeMap as DequeMap;
use serde::{Deserialize, Serialize};

use crate::context::ServerContext;
use crate::net::MqttError;
use crate::queue::OnEventFn;
use crate::types::{From, PacketId, Publish, TimestampMillis};
use crate::types::{QoS, Reason};
use crate::utils::timestamp_millis;
use crate::Result;

type OutQueues = DequeMap<PacketId, OutInflightMessage>;

#[derive(Debug, Eq, PartialEq, Clone, Copy, Serialize, Deserialize)]
pub enum MomentStatus {
    UnAck,
    UnReceived,
    UnComplete,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutInflightMessage {
    pub publish: Publish,
    pub from: From,
    pub status: MomentStatus,
    pub update_time: TimestampMillis,
}

impl OutInflightMessage {
    #[inline]
    pub fn new(status: MomentStatus, from: From, publish: Publish) -> Self {
        Self { publish, from, status, update_time: timestamp_millis() }
    }

    #[inline]
    fn update_status(&mut self, status: MomentStatus) {
        self.update_time = timestamp_millis();
        self.status = status;
    }

    #[inline]
    pub fn timeout(&self, interval_millis: TimestampMillis) -> bool {
        log::debug!("interval_millis:{} {}", interval_millis, timestamp_millis() - self.update_time);
        interval_millis > 0 && ((timestamp_millis() - self.update_time) >= interval_millis)
    }
}

#[derive(Clone)]
pub struct OutInflight {
    cap: usize,
    interval: TimestampMillis,
    next: Arc<AtomicU16>,
    queues: OutQueues,
    on_push_fn: Option<Arc<dyn OnEventFn>>,
    on_pop_fn: Option<Arc<dyn OnEventFn>>,
}

impl OutInflight {
    #[inline]
    pub fn new(cap: usize, retry_interval: TimestampMillis, expiry_interval: TimestampMillis) -> Self {
        let interval = Self::interval(retry_interval, expiry_interval);
        Self {
            cap,
            interval,
            next: Arc::new(AtomicU16::new(1)),
            queues: OutQueues::default(),
            on_push_fn: None,
            on_pop_fn: None,
        }
    }

    #[inline]
    pub fn on_push<F>(mut self, f: F) -> Self
    where
        F: OnEventFn,
    {
        self.on_push_fn = Some(Arc::new(f));
        self
    }

    #[inline]
    pub fn on_pop<F>(mut self, f: F) -> Self
    where
        F: OnEventFn,
    {
        self.on_pop_fn = Some(Arc::new(f));
        self
    }

    #[inline]
    fn interval(retry_interval: TimestampMillis, expiry_interval: TimestampMillis) -> TimestampMillis {
        match (retry_interval, expiry_interval) {
            (0, 0) => 0,
            (0, expiry_interval) => expiry_interval,
            (retry_interval, 0) => retry_interval,
            (retry_interval, expiry_interval) => retry_interval.min(expiry_interval),
        }
    }

    #[inline]
    pub fn get_timeout(&self) -> Option<Duration> {
        if self.interval == 0 {
            return None;
        }
        if let Some((_, m)) = self.queues.front() {
            let mut t = self.interval - (timestamp_millis() - m.update_time);
            if t < 1 {
                t = 1;
            }
            log::debug!("get timeout t: {t}");
            return Some(Duration::from_millis(t as u64));
        }
        None
    }

    #[inline]
    fn front_timeout(&self) -> bool {
        if self.interval == 0 {
            return false;
        }
        if let Some((_, m)) = self.queues.front() {
            if m.timeout(self.interval) {
                return true;
            }
        }
        false
    }

    #[inline]
    pub fn get(&self, packet_id: PacketId) -> Option<&OutInflightMessage> {
        self.queues.get(&packet_id)
    }

    #[inline]
    pub fn front(&self) -> Option<(&PacketId, &OutInflightMessage)> {
        self.queues.front()
    }

    #[inline]
    pub fn pop_front(&mut self) -> Option<OutInflightMessage> {
        if let Some(msg) = self.queues.pop_front().map(|(_, m)| m) {
            if let Some(f) = self.on_pop_fn.as_ref() {
                f();
            }
            Some(msg)
        } else {
            None
        }
    }

    #[inline]
    pub fn pop_front_timeout(&mut self) -> Option<OutInflightMessage> {
        if self.front_timeout() {
            self.pop_front()
        } else {
            None
        }
    }

    #[inline]
    pub fn push_back(&mut self, m: OutInflightMessage) -> Option<NonZeroU16> {
        if let Some(packet_id) = m.publish.packet_id {
            if let Some(f) = self.on_push_fn.as_ref() {
                f();
            }
            let old = self.queues.insert(packet_id.get(), m);
            if old.is_some() {
                if let Some(f) = self.on_pop_fn.as_ref() {
                    f();
                }
            }
            old.and_then(|old| old.publish.packet_id)
        } else {
            log::warn!("packet_id is None, inflight message: {m:?}");
            None
        }
    }

    #[inline]
    pub fn remove(&mut self, packet_id: &PacketId) -> Option<OutInflightMessage> {
        if let Some(msg) = self.queues.remove(packet_id) {
            if let Some(f) = self.on_pop_fn.as_ref() {
                f();
            }
            Some(msg)
        } else {
            None
        }
    }

    #[inline]
    pub fn update_status(&mut self, packet_id: &PacketId, s: MomentStatus) {
        if let Some(m) = self.queues.get_mut(packet_id) {
            m.update_status(s);
        }
    }

    #[inline]
    pub fn len(&self) -> usize {
        self.queues.len()
    }

    #[inline]
    pub fn is_empty(&self) -> bool {
        self.queues.is_empty()
    }

    #[inline]
    pub fn exist(&self, packet_id: &PacketId) -> bool {
        self.queues.contains_key(packet_id)
    }

    #[inline]
    pub fn has_credit(&self) -> bool {
        (self.cap - self.queues.len()) > 0
    }

    #[inline]
    pub fn next_id(&self) -> Result<PacketId> {
        for _ in 0..u16::MAX {
            let packet_id = self.next.fetch_add(1, Ordering::SeqCst);
            if packet_id == 0 {
                continue;
            }
            if !self.queues.contains_key(&packet_id) {
                return Ok(packet_id);
            }
        }
        Err(anyhow!("no packet_id available, should unreachable!()"))
    }

    #[inline]
    pub fn to_inflight_messages(&mut self) -> Vec<OutInflightMessage> {
        let mut inflight_messages = Vec::new();
        while let Some(msg) = self.pop_front() {
            //@TODO ..., check message expired
            inflight_messages.push(msg);
        }
        inflight_messages
    }

    #[inline]
    pub fn clone_inflight_messages(&mut self) -> Vec<OutInflightMessage> {
        self.queues.iter().map(|(_, msg)| msg.clone()).collect_vec()
    }
}

//@TODO 大小限制,即同一个连接上接收消息的并发限制
pub struct InInflight {
    cached: BTreeSet<NonZeroU16>,
    #[allow(dead_code)]
    scx: ServerContext,
    max_inflight: u16,
}

impl Drop for InInflight {
    fn drop(&mut self) {
        #[cfg(feature = "stats")]
        self.scx.stats.in_inflights.decs(self.cached.len() as isize);
    }
}

impl InInflight {
    pub(crate) fn new(scx: ServerContext, max_inflight: u16) -> Self {
        Self { cached: BTreeSet::default(), scx, max_inflight }
    }

    #[inline]
    pub(crate) fn add(&mut self, pid: NonZeroU16, qos: QoS) -> std::result::Result<bool, Reason> {
        if self.cached.len() >= self.max_inflight as usize {
            return Err(Reason::InflightWindowFull);
        }
        if self.cached.insert(pid) {
            #[cfg(feature = "stats")]
            self.scx.stats.in_inflights.inc();
            Ok(true)
        } else if matches!(qos, QoS::ExactlyOnce) {
            Err(MqttError::PacketIdInUse(pid).into())
        } else {
            Ok(false)
        }
    }

    #[inline]
    pub(crate) fn remove(&mut self, pid: &NonZeroU16) -> bool {
        #[allow(clippy::needless_bool)]
        if self.cached.remove(pid) {
            #[cfg(feature = "stats")]
            self.scx.stats.in_inflights.dec();
            true
        } else {
            false
        }
    }
}