mqtt5-protocol 0.12.0

MQTT v5.0 protocol implementation - packets, encoding, and validation
Documentation
//! Packet ID generation for MQTT

use portable_atomic::{AtomicU16, Ordering};

#[cfg(not(feature = "std"))]
use portable_atomic_util::Arc;
#[cfg(feature = "std")]
use std::sync::Arc;

/// Generates unique packet IDs for MQTT messages
#[derive(Debug, Clone)]
pub struct PacketIdGenerator {
    next_id: Arc<AtomicU16>,
}

impl PacketIdGenerator {
    /// Creates a new packet ID generator
    #[must_use]
    pub fn new() -> Self {
        Self {
            next_id: Arc::new(AtomicU16::new(1)),
        }
    }

    #[must_use]
    /// Gets the next available packet ID
    ///
    /// Packet IDs are in the range 1..=65535 (0 is invalid).
    /// Uses compare-and-swap with a retry limit to prevent infinite loops
    /// on embedded systems without preemption.
    pub fn next(&self) -> u16 {
        const MAX_RETRIES: u32 = 1000;

        for _ in 0..MAX_RETRIES {
            let current = self.next_id.load(Ordering::SeqCst);
            let next = if current == u16::MAX { 1 } else { current + 1 };

            if self
                .next_id
                .compare_exchange(current, next, Ordering::SeqCst, Ordering::SeqCst)
                .is_ok()
            {
                return current;
            }
        }

        let current = self.next_id.fetch_add(1, Ordering::SeqCst);
        if current == 0 {
            self.next_id.store(2, Ordering::SeqCst);
            1
        } else {
            current
        }
    }
}

impl Default for PacketIdGenerator {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_packet_id_generation() {
        let gen = PacketIdGenerator::new();

        assert_eq!(gen.next(), 1);
        assert_eq!(gen.next(), 2);
        assert_eq!(gen.next(), 3);
    }

    #[test]
    fn test_packet_id_wraparound() {
        let gen = PacketIdGenerator::new();
        gen.next_id.store(u16::MAX, Ordering::SeqCst);

        assert_eq!(gen.next(), u16::MAX);
        assert_eq!(gen.next(), 1);
        assert_eq!(gen.next(), 2);
    }

    #[cfg(feature = "std")]
    #[test]
    fn test_concurrent_access() {
        use std::sync::Arc;
        use std::thread;

        let gen = Arc::new(PacketIdGenerator::new());
        let mut handles = vec![];

        for _ in 0..10 {
            let gen_clone = gen.clone();
            let handle = thread::spawn(move || {
                let mut ids = vec![];
                for _ in 0..100 {
                    ids.push(gen_clone.next());
                }
                ids
            });
            handles.push(handle);
        }

        let mut all_ids = vec![];
        for handle in handles {
            all_ids.extend(handle.join().unwrap());
        }

        all_ids.sort_unstable();
        all_ids.dedup();
        assert_eq!(all_ids.len(), 1000);
    }
}