bradis/
linked_hash_set.rs1use crate::db::KeyRef;
2use std::{
3 cmp::{Eq, PartialEq},
4 hash::{Hash, Hasher},
5 marker::PhantomData,
6 ptr::NonNull,
7};
8
9use hashbrown::{Equivalent, HashSet};
10
11type Link<T> = Option<NonNull<Node<T>>>;
12
13#[derive(Debug)]
15struct Node<T> {
16 next: Link<T>,
17 prev: Link<T>,
18 value: T,
19}
20
21#[derive(Debug)]
22struct NodePointer<T>(NonNull<Node<T>>);
23
24unsafe impl<T: Send> Send for NodePointer<T> {}
25
26impl<T: PartialEq> PartialEq for NodePointer<T> {
27 fn eq(&self, other: &Self) -> bool {
28 unsafe { self.0.as_ref().value == other.0.as_ref().value }
29 }
30}
31
32impl<T: Eq> Eq for NodePointer<T> {}
33
34impl<T: Hash> Hash for NodePointer<T> {
35 fn hash<H: Hasher>(&self, state: &mut H) {
36 unsafe {
37 self.0.as_ref().value.hash(state);
38 }
39 }
40}
41
42#[derive(Eq, Hash, PartialEq)]
43struct Wrapper<'a, T: ?Sized>(&'a T);
44
45impl<Q, T> Equivalent<NodePointer<T>> for Wrapper<'_, Q>
46where
47 Q: KeyRef<T> + ?Sized,
48{
49 fn equivalent(&self, key: &NodePointer<T>) -> bool {
50 unsafe { self.0.equivalent(&key.0.as_ref().value) }
51 }
52}
53
54pub struct LinkedHashSet<T> {
58 front: Link<T>,
59 back: Link<T>,
60 set: HashSet<NodePointer<T>>,
61}
62
63impl<T: Eq + Hash + std::fmt::Debug> std::fmt::Debug for LinkedHashSet<T> {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_set().entries(self.iter()).finish()?;
66 Ok(())
67 }
68}
69
70impl<T> Drop for LinkedHashSet<T> {
71 fn drop(&mut self) {
72 for node in self.set.drain() {
73 unsafe { drop(Box::from_raw(node.0.as_ptr())) };
74 }
75 }
76}
77
78unsafe impl<T: Send> Send for LinkedHashSet<T> {}
79
80impl<T: Eq + Hash> Default for LinkedHashSet<T> {
81 fn default() -> Self {
82 LinkedHashSet {
83 front: None,
84 back: None,
85 set: HashSet::default(),
86 }
87 }
88}
89
90impl<T: Clone + Eq + Hash> Clone for LinkedHashSet<T> {
91 fn clone(&self) -> Self {
92 let mut set = LinkedHashSet::new();
93 for t in self.iter() {
94 set.insert_back(t.clone());
95 }
96 set
97 }
98}
99
100impl<T: Eq + Hash> LinkedHashSet<T> {
101 pub fn new() -> Self {
102 LinkedHashSet::default()
103 }
104
105 pub fn is_empty(&self) -> bool {
107 self.set.is_empty()
108 }
109
110 pub fn len(&self) -> usize {
112 self.set.len()
113 }
114
115 pub fn insert_back(&mut self, value: T) {
117 if self.set.contains(&Wrapper(&value)) {
118 return;
119 }
120
121 let node = Box::leak(Box::new(Node {
122 prev: self.back,
123 next: None,
124 value,
125 }))
126 .into();
127
128 if let Some(mut back) = self.back {
130 unsafe { back.as_mut() }.next = Some(node);
131 }
132 self.back = Some(node);
133
134 if self.front.is_none() {
136 self.front = Some(node);
137 }
138
139 self.set.insert(NodePointer(node));
140 }
141
142 pub fn remove<Q>(&mut self, value: &Q) -> Option<T>
144 where
145 Q: KeyRef<T> + ?Sized,
146 {
147 let node = self.set.take(&Wrapper(value))?;
148 let node = *unsafe { Box::from_raw(node.0.as_ptr()) };
149
150 let next = node.next;
151 let prev = node.prev;
152
153 if let Some(mut prev) = prev {
155 unsafe { prev.as_mut() }.next = next;
156 } else {
157 self.front = next;
158 }
159
160 if let Some(mut next) = next {
162 unsafe { next.as_mut() }.prev = prev;
163 } else {
164 self.back = prev;
165 }
166
167 Some(node.value)
168 }
169
170 pub fn front(&self) -> Option<&T> {
172 self.front.map(|node| &unsafe { node.as_ref() }.value)
173 }
174
175 #[cfg(test)]
177 pub fn back(&self) -> Option<&T> {
178 self.back.map(|node| &unsafe { node.as_ref() }.value)
179 }
180
181 pub fn iter(&self) -> impl Iterator<Item = &T> {
183 Iter {
184 next: self.front,
185 phantom: PhantomData,
186 }
187 }
188}
189
190struct Iter<'a, T> {
191 next: Link<T>,
192 phantom: PhantomData<&'a T>,
193}
194
195impl<'a, T: 'a> Iterator for Iter<'a, T> {
196 type Item = &'a T;
197
198 fn next(&mut self) -> Option<Self::Item> {
199 let node = self.next?;
200 let node = unsafe { node.as_ref() };
201 self.next = node.next;
202 Some(&node.value)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn insert_twice() {
212 let mut set: LinkedHashSet<i64> = LinkedHashSet::new();
213 set.insert_back(1);
214 set.insert_back(1);
215 let items: Vec<_> = set.iter().collect();
216 assert_eq!(items, vec![&1]);
217 }
218
219 #[test]
220 fn insert_back() {
221 let mut set: LinkedHashSet<i64> = LinkedHashSet::new();
222 assert_eq!(set.front(), None);
223 assert_eq!(set.back(), None);
224
225 set.insert_back(1);
226 let items: Vec<_> = set.iter().collect();
227 assert_eq!(items, vec![&1]);
228 assert_eq!(set.len(), 1);
229 assert_eq!(set.front(), Some(&1));
230 assert_eq!(set.back(), Some(&1));
231
232 set.insert_back(2);
233 let items: Vec<_> = set.iter().collect();
234 assert_eq!(items, vec![&1, &2]);
235 assert_eq!(set.len(), 2);
236 assert_eq!(set.front(), Some(&1));
237 assert_eq!(set.back(), Some(&2));
238
239 set.insert_back(3);
240 let items: Vec<_> = set.iter().collect();
241 assert_eq!(items, vec![&1, &2, &3]);
242 assert_eq!(set.len(), 3);
243 assert_eq!(set.front(), Some(&1));
244 assert_eq!(set.back(), Some(&3));
245
246 set.remove(&2);
247 let items: Vec<&i64> = set.iter().collect();
248 assert_eq!(items, vec![&1, &3]);
249 assert_eq!(set.len(), 2);
250 assert_eq!(set.front(), Some(&1));
251 assert_eq!(set.back(), Some(&3));
252
253 set.remove(&1);
254 let items: Vec<&i64> = set.iter().collect();
255 assert_eq!(items, vec![&3]);
256 assert_eq!(set.len(), 1);
257 assert_eq!(set.front(), Some(&3));
258 assert_eq!(set.back(), Some(&3));
259
260 set.remove(&3);
261
262 assert_eq!(set.iter().count(), 0);
263 assert_eq!(set.len(), 0);
264 assert_eq!(set.front(), None);
265 assert_eq!(set.back(), None);
266 }
267
268 #[test]
269 fn borrow() {
270 let mut set: LinkedHashSet<Vec<u8>> = LinkedHashSet::new();
271 set.insert_back(b"foo".to_vec());
272 assert_eq!(set.len(), 1);
273
274 set.remove(&b"foo"[..]);
275 assert!(set.is_empty());
276 }
277}