mqtt5_protocol/
packet_id.rs

1//! Packet ID generation for MQTT
2
3use std::sync::atomic::{AtomicU16, Ordering};
4use std::sync::Arc;
5
6/// Generates unique packet IDs for MQTT messages
7#[derive(Debug, Clone)]
8pub struct PacketIdGenerator {
9    next_id: Arc<AtomicU16>,
10}
11
12impl PacketIdGenerator {
13    /// Creates a new packet ID generator
14    #[must_use]
15    pub fn new() -> Self {
16        Self {
17            next_id: Arc::new(AtomicU16::new(1)),
18        }
19    }
20
21    #[must_use]
22    /// Gets the next available packet ID
23    ///
24    /// Packet IDs are in the range 1..=65535 (0 is invalid)
25    pub fn next(&self) -> u16 {
26        loop {
27            let current = self.next_id.load(Ordering::SeqCst);
28            let next = if current == u16::MAX { 1 } else { current + 1 };
29
30            if self
31                .next_id
32                .compare_exchange(current, next, Ordering::SeqCst, Ordering::SeqCst)
33                .is_ok()
34            {
35                return current;
36            }
37        }
38    }
39}
40
41impl Default for PacketIdGenerator {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    #[test]
52    fn test_packet_id_generation() {
53        let gen = PacketIdGenerator::new();
54
55        assert_eq!(gen.next(), 1);
56        assert_eq!(gen.next(), 2);
57        assert_eq!(gen.next(), 3);
58    }
59
60    #[test]
61    fn test_packet_id_wraparound() {
62        let gen = PacketIdGenerator::new();
63        gen.next_id.store(u16::MAX, Ordering::SeqCst);
64
65        assert_eq!(gen.next(), u16::MAX);
66        assert_eq!(gen.next(), 1); // Wraps back to 1
67        assert_eq!(gen.next(), 2);
68    }
69
70    #[test]
71    fn test_concurrent_access() {
72        use std::sync::Arc;
73        use std::thread;
74
75        let gen = Arc::new(PacketIdGenerator::new());
76        let mut handles = vec![];
77
78        for _ in 0..10 {
79            let gen_clone = gen.clone();
80            let handle = thread::spawn(move || {
81                let mut ids = vec![];
82                for _ in 0..100 {
83                    ids.push(gen_clone.next());
84                }
85                ids
86            });
87            handles.push(handle);
88        }
89
90        let mut all_ids = vec![];
91        for handle in handles {
92            all_ids.extend(handle.join().unwrap());
93        }
94
95        // All IDs should be unique
96        all_ids.sort_unstable();
97        all_ids.dedup();
98        assert_eq!(all_ids.len(), 1000);
99    }
100}