mqtt5_protocol/
packet_id.rs1use std::sync::atomic::{AtomicU16, Ordering};
4use std::sync::Arc;
5
6#[derive(Debug, Clone)]
8pub struct PacketIdGenerator {
9 next_id: Arc<AtomicU16>,
10}
11
12impl PacketIdGenerator {
13 #[must_use]
15 pub fn new() -> Self {
16 Self {
17 next_id: Arc::new(AtomicU16::new(1)),
18 }
19 }
20
21 #[must_use]
22 pub fn next(&self) -> u16 {
26 loop {
27 let current = self.next_id.load(Ordering::SeqCst);
28 let next = if current == u16::MAX { 1 } else { current + 1 };
29
30 if self
31 .next_id
32 .compare_exchange(current, next, Ordering::SeqCst, Ordering::SeqCst)
33 .is_ok()
34 {
35 return current;
36 }
37 }
38 }
39}
40
41impl Default for PacketIdGenerator {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50
51 #[test]
52 fn test_packet_id_generation() {
53 let gen = PacketIdGenerator::new();
54
55 assert_eq!(gen.next(), 1);
56 assert_eq!(gen.next(), 2);
57 assert_eq!(gen.next(), 3);
58 }
59
60 #[test]
61 fn test_packet_id_wraparound() {
62 let gen = PacketIdGenerator::new();
63 gen.next_id.store(u16::MAX, Ordering::SeqCst);
64
65 assert_eq!(gen.next(), u16::MAX);
66 assert_eq!(gen.next(), 1); assert_eq!(gen.next(), 2);
68 }
69
70 #[test]
71 fn test_concurrent_access() {
72 use std::sync::Arc;
73 use std::thread;
74
75 let gen = Arc::new(PacketIdGenerator::new());
76 let mut handles = vec![];
77
78 for _ in 0..10 {
79 let gen_clone = gen.clone();
80 let handle = thread::spawn(move || {
81 let mut ids = vec![];
82 for _ in 0..100 {
83 ids.push(gen_clone.next());
84 }
85 ids
86 });
87 handles.push(handle);
88 }
89
90 let mut all_ids = vec![];
91 for handle in handles {
92 all_ids.extend(handle.join().unwrap());
93 }
94
95 all_ids.sort_unstable();
97 all_ids.dedup();
98 assert_eq!(all_ids.len(), 1000);
99 }
100}