Skip to main content

rns_core/transport/
dedup.rs

1use alloc::collections::BTreeSet;
2
3/// Double-buffered packet hash deduplication.
4///
5/// Uses two BTreeSets: `current` and `previous`. When the current set
6/// grows past `max_size / 2`, it rotates: current becomes previous,
7/// and a new empty set becomes current. This means the oldest hashes
8/// are forgotten after two rotations.
9pub struct PacketHashlist {
10    current: BTreeSet<[u8; 32]>,
11    previous: BTreeSet<[u8; 32]>,
12    max_size: usize,
13}
14
15impl PacketHashlist {
16    pub fn new(max_size: usize) -> Self {
17        PacketHashlist {
18            current: BTreeSet::new(),
19            previous: BTreeSet::new(),
20            max_size,
21        }
22    }
23
24    /// Check if a hash is a duplicate (exists in current or previous set).
25    pub fn is_duplicate(&self, hash: &[u8; 32]) -> bool {
26        self.current.contains(hash) || self.previous.contains(hash)
27    }
28
29    /// Add a hash to the current set.
30    pub fn add(&mut self, hash: [u8; 32]) {
31        self.current.insert(hash);
32    }
33
34    /// Rotate if current set exceeds max_size / 2.
35    /// Returns true if rotation occurred.
36    pub fn maybe_rotate(&mut self) -> bool {
37        if self.current.len() > self.max_size / 2 {
38            let old_current = core::mem::take(&mut self.current);
39            self.previous = old_current;
40            true
41        } else {
42            false
43        }
44    }
45
46    /// Total number of tracked hashes (current + previous).
47    pub fn len(&self) -> usize {
48        self.current.len() + self.previous.len()
49    }
50
51    /// Number of hashes in the current set only.
52    pub fn current_len(&self) -> usize {
53        self.current.len()
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60
61    fn make_hash(seed: u8) -> [u8; 32] {
62        let mut h = [0u8; 32];
63        h[0] = seed;
64        h
65    }
66
67    #[test]
68    fn test_new_hash_not_duplicate() {
69        let hl = PacketHashlist::new(100);
70        assert!(!hl.is_duplicate(&make_hash(1)));
71    }
72
73    #[test]
74    fn test_added_hash_is_duplicate() {
75        let mut hl = PacketHashlist::new(100);
76        let h = make_hash(1);
77        hl.add(h);
78        assert!(hl.is_duplicate(&h));
79    }
80
81    #[test]
82    fn test_after_rotation_old_hashes_still_detected() {
83        let mut hl = PacketHashlist::new(4); // rotate at > 2
84        let h1 = make_hash(1);
85        let h2 = make_hash(2);
86        let h3 = make_hash(3);
87        hl.add(h1);
88        hl.add(h2);
89        hl.add(h3);
90
91        // Force rotation
92        assert!(hl.maybe_rotate());
93
94        // Old hashes should still be found in previous
95        assert!(hl.is_duplicate(&h1));
96        assert!(hl.is_duplicate(&h2));
97        assert!(hl.is_duplicate(&h3));
98    }
99
100    #[test]
101    fn test_after_second_rotation_oldest_forgotten() {
102        let mut hl = PacketHashlist::new(4); // rotate at > 2
103        let h1 = make_hash(1);
104        let h2 = make_hash(2);
105        let h3 = make_hash(3);
106
107        // Add to first generation
108        hl.add(h1);
109        hl.add(h2);
110        hl.add(h3);
111        hl.maybe_rotate(); // h1,h2,h3 now in previous
112
113        // Add to second generation
114        let h4 = make_hash(4);
115        let h5 = make_hash(5);
116        let h6 = make_hash(6);
117        hl.add(h4);
118        hl.add(h5);
119        hl.add(h6);
120        hl.maybe_rotate(); // h4,h5,h6 now in previous; h1,h2,h3 forgotten
121
122        // First generation should be forgotten
123        assert!(!hl.is_duplicate(&h1));
124        assert!(!hl.is_duplicate(&h2));
125        assert!(!hl.is_duplicate(&h3));
126
127        // Second generation should still be detected
128        assert!(hl.is_duplicate(&h4));
129        assert!(hl.is_duplicate(&h5));
130        assert!(hl.is_duplicate(&h6));
131    }
132
133    #[test]
134    fn test_rotation_triggers_at_threshold() {
135        let mut hl = PacketHashlist::new(6); // rotate at > 3
136        hl.add(make_hash(1));
137        hl.add(make_hash(2));
138        hl.add(make_hash(3));
139        assert!(!hl.maybe_rotate()); // 3 is not > 3
140
141        hl.add(make_hash(4));
142        assert!(hl.maybe_rotate()); // 4 > 3
143    }
144
145    #[test]
146    fn test_len() {
147        let mut hl = PacketHashlist::new(100);
148        assert_eq!(hl.len(), 0);
149
150        hl.add(make_hash(1));
151        hl.add(make_hash(2));
152        assert_eq!(hl.len(), 2);
153    }
154}