bradis/
linked_hash_set.rs

1use crate::db::KeyRef;
2use std::{
3    cmp::{Eq, PartialEq},
4    hash::{Hash, Hasher},
5    marker::PhantomData,
6    ptr::NonNull,
7};
8
9use hashbrown::{Equivalent, HashSet};
10
11type Link<T> = Option<NonNull<Node<T>>>;
12
13/// This is one node in a linked list for embedding in a hash table.
14#[derive(Debug)]
15struct Node<T> {
16    next: Link<T>,
17    prev: Link<T>,
18    value: T,
19}
20
21#[derive(Debug)]
22struct NodePointer<T>(NonNull<Node<T>>);
23
24unsafe impl<T: Send> Send for NodePointer<T> {}
25
26impl<T: PartialEq> PartialEq for NodePointer<T> {
27    fn eq(&self, other: &Self) -> bool {
28        unsafe { self.0.as_ref().value == other.0.as_ref().value }
29    }
30}
31
32impl<T: Eq> Eq for NodePointer<T> {}
33
34impl<T: Hash> Hash for NodePointer<T> {
35    fn hash<H: Hasher>(&self, state: &mut H) {
36        unsafe {
37            self.0.as_ref().value.hash(state);
38        }
39    }
40}
41
42#[derive(Eq, Hash, PartialEq)]
43struct Wrapper<'a, T: ?Sized>(&'a T);
44
45impl<Q, T> Equivalent<NodePointer<T>> for Wrapper<'_, Q>
46where
47    Q: KeyRef<T> + ?Sized,
48{
49    fn equivalent(&self, key: &NodePointer<T>) -> bool {
50        unsafe { self.0.equivalent(&key.0.as_ref().value) }
51    }
52}
53
54/// There are several instances in which we need an ordered list of elements with constant time
55/// membership and removal operations. For instance, a list of subscribers to a particular PUBSUB
56/// key. A linked list embedded in a hash table is a pretty good solution.
57pub struct LinkedHashSet<T> {
58    front: Link<T>,
59    back: Link<T>,
60    set: HashSet<NodePointer<T>>,
61}
62
63impl<T: Eq + Hash + std::fmt::Debug> std::fmt::Debug for LinkedHashSet<T> {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_set().entries(self.iter()).finish()?;
66        Ok(())
67    }
68}
69
70impl<T> Drop for LinkedHashSet<T> {
71    fn drop(&mut self) {
72        for node in self.set.drain() {
73            unsafe { drop(Box::from_raw(node.0.as_ptr())) };
74        }
75    }
76}
77
78unsafe impl<T: Send> Send for LinkedHashSet<T> {}
79
80impl<T: Eq + Hash> Default for LinkedHashSet<T> {
81    fn default() -> Self {
82        LinkedHashSet {
83            front: None,
84            back: None,
85            set: HashSet::default(),
86        }
87    }
88}
89
90impl<T: Clone + Eq + Hash> Clone for LinkedHashSet<T> {
91    fn clone(&self) -> Self {
92        let mut set = LinkedHashSet::new();
93        for t in self.iter() {
94            set.insert_back(t.clone());
95        }
96        set
97    }
98}
99
100impl<T: Eq + Hash> LinkedHashSet<T> {
101    pub fn new() -> Self {
102        LinkedHashSet::default()
103    }
104
105    /// Is this set empty?
106    pub fn is_empty(&self) -> bool {
107        self.set.is_empty()
108    }
109
110    /// The number of elements in the set
111    pub fn len(&self) -> usize {
112        self.set.len()
113    }
114
115    /// Insert an element into the set at the back of the list
116    pub fn insert_back(&mut self, value: T) {
117        if self.set.contains(&Wrapper(&value)) {
118            return;
119        }
120
121        let node = Box::leak(Box::new(Node {
122            prev: self.back,
123            next: None,
124            value,
125        }))
126        .into();
127
128        // Update the back of the list
129        if let Some(mut back) = self.back {
130            unsafe { back.as_mut() }.next = Some(node);
131        }
132        self.back = Some(node);
133
134        // Update the front of the list
135        if self.front.is_none() {
136            self.front = Some(node);
137        }
138
139        self.set.insert(NodePointer(node));
140    }
141
142    /// Remove an element from the set
143    pub fn remove<Q>(&mut self, value: &Q) -> Option<T>
144    where
145        Q: KeyRef<T> + ?Sized,
146    {
147        let node = self.set.take(&Wrapper(value))?;
148        let node = *unsafe { Box::from_raw(node.0.as_ptr()) };
149
150        let next = node.next;
151        let prev = node.prev;
152
153        // Update the previous node
154        if let Some(mut prev) = prev {
155            unsafe { prev.as_mut() }.next = next;
156        } else {
157            self.front = next;
158        }
159
160        // Update the next node
161        if let Some(mut next) = next {
162            unsafe { next.as_mut() }.prev = prev;
163        } else {
164            self.back = prev;
165        }
166
167        Some(node.value)
168    }
169
170    /// The front element
171    pub fn front(&self) -> Option<&T> {
172        self.front.map(|node| &unsafe { node.as_ref() }.value)
173    }
174
175    /// The back element
176    #[cfg(test)]
177    pub fn back(&self) -> Option<&T> {
178        self.back.map(|node| &unsafe { node.as_ref() }.value)
179    }
180
181    /// An iterator over the elements of the set
182    pub fn iter(&self) -> impl Iterator<Item = &T> {
183        Iter {
184            next: self.front,
185            phantom: PhantomData,
186        }
187    }
188}
189
190struct Iter<'a, T> {
191    next: Link<T>,
192    phantom: PhantomData<&'a T>,
193}
194
195impl<'a, T: 'a> Iterator for Iter<'a, T> {
196    type Item = &'a T;
197
198    fn next(&mut self) -> Option<Self::Item> {
199        let node = self.next?;
200        let node = unsafe { node.as_ref() };
201        self.next = node.next;
202        Some(&node.value)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn insert_twice() {
212        let mut set: LinkedHashSet<i64> = LinkedHashSet::new();
213        set.insert_back(1);
214        set.insert_back(1);
215        let items: Vec<_> = set.iter().collect();
216        assert_eq!(items, vec![&1]);
217    }
218
219    #[test]
220    fn insert_back() {
221        let mut set: LinkedHashSet<i64> = LinkedHashSet::new();
222        assert_eq!(set.front(), None);
223        assert_eq!(set.back(), None);
224
225        set.insert_back(1);
226        let items: Vec<_> = set.iter().collect();
227        assert_eq!(items, vec![&1]);
228        assert_eq!(set.len(), 1);
229        assert_eq!(set.front(), Some(&1));
230        assert_eq!(set.back(), Some(&1));
231
232        set.insert_back(2);
233        let items: Vec<_> = set.iter().collect();
234        assert_eq!(items, vec![&1, &2]);
235        assert_eq!(set.len(), 2);
236        assert_eq!(set.front(), Some(&1));
237        assert_eq!(set.back(), Some(&2));
238
239        set.insert_back(3);
240        let items: Vec<_> = set.iter().collect();
241        assert_eq!(items, vec![&1, &2, &3]);
242        assert_eq!(set.len(), 3);
243        assert_eq!(set.front(), Some(&1));
244        assert_eq!(set.back(), Some(&3));
245
246        set.remove(&2);
247        let items: Vec<&i64> = set.iter().collect();
248        assert_eq!(items, vec![&1, &3]);
249        assert_eq!(set.len(), 2);
250        assert_eq!(set.front(), Some(&1));
251        assert_eq!(set.back(), Some(&3));
252
253        set.remove(&1);
254        let items: Vec<&i64> = set.iter().collect();
255        assert_eq!(items, vec![&3]);
256        assert_eq!(set.len(), 1);
257        assert_eq!(set.front(), Some(&3));
258        assert_eq!(set.back(), Some(&3));
259
260        set.remove(&3);
261
262        assert_eq!(set.iter().count(), 0);
263        assert_eq!(set.len(), 0);
264        assert_eq!(set.front(), None);
265        assert_eq!(set.back(), None);
266    }
267
268    #[test]
269    fn borrow() {
270        let mut set: LinkedHashSet<Vec<u8>> = LinkedHashSet::new();
271        set.insert_back(b"foo".to_vec());
272        assert_eq!(set.len(), 1);
273
274        set.remove(&b"foo"[..]);
275        assert!(set.is_empty());
276    }
277}