1#![no_std]
8
9use core::{
10 hash::{BuildHasher, Hash, Hasher},
11 marker::PhantomData,
12};
13
14use hashbrown::{
15 hash_map::DefaultHashBuilder,
16 raw::{RawIntoIter, RawIter, RawTable},
17};
18
19#[derive(Clone)]
21pub struct KeyedSet<T, Extractor, S = DefaultHashBuilder> {
22 inner: hashbrown::raw::RawTable<T>,
23 hash_builder: S,
24 extractor: Extractor,
25}
26
27impl<T, Extractor: Default, S: Default> Default for KeyedSet<T, Extractor, S> {
28 fn default() -> Self {
29 Self {
30 inner: Default::default(),
31 hash_builder: Default::default(),
32 extractor: Default::default(),
33 }
34 }
35}
36
37impl<'a, T, Extractor, S> IntoIterator for &'a KeyedSet<T, Extractor, S> {
38 type Item = &'a T;
39 type IntoIter = Iter<'a, T>;
40 fn into_iter(self) -> Self::IntoIter {
41 self.iter()
42 }
43}
44impl<'a, T, Extractor, S> IntoIterator for &'a mut KeyedSet<T, Extractor, S> {
45 type Item = &'a mut T;
46 type IntoIter = IterMut<'a, T>;
47 fn into_iter(self) -> Self::IntoIter {
48 self.iter_mut()
49 }
50}
51pub trait KeyExtractor<'a, T> {
53 type Key: Hash;
55 fn extract(&self, from: &'a T) -> Self::Key;
57}
58impl<'a, T: 'a, U: Hash, F: Fn(&'a T) -> U> KeyExtractor<'a, T> for F {
59 type Key = U;
60 fn extract(&self, from: &'a T) -> Self::Key {
61 self(from)
62 }
63}
64impl<'a, T: 'a + Hash> KeyExtractor<'a, T> for () {
65 type Key = &'a T;
66 fn extract(&self, from: &'a T) -> Self::Key {
67 from
68 }
69}
70impl<T, Extractor> KeyedSet<T, Extractor>
71where
72 Extractor: for<'a> KeyExtractor<'a, T>,
73 for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
74{
75 pub fn new(extractor: Extractor) -> Self {
77 Self {
78 inner: Default::default(),
79 hash_builder: Default::default(),
80 extractor,
81 }
82 }
83}
84
85impl<T: core::fmt::Debug, Extractor, S> core::fmt::Debug for KeyedSet<T, Extractor, S> {
86 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
87 write!(f, "KeyedSet {{")?;
88 for v in self.iter() {
89 write!(f, "{:?}, ", v)?;
90 }
91 write!(f, "}}")
92 }
93}
94
95#[allow(clippy::manual_hash_one)]
96impl<T, Extractor, S> KeyedSet<T, Extractor, S>
97where
98 Extractor: for<'a> KeyExtractor<'a, T>,
99 for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
100 S: BuildHasher,
101{
102 pub fn insert(&mut self, value: T) -> Option<T>
104 where
105 for<'a, 'b> <Extractor as KeyExtractor<'a, T>>::Key:
106 PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
107 {
108 let key = self.extractor.extract(&value);
109 let mut hasher = self.hash_builder.build_hasher();
110 key.hash(&mut hasher);
111 let hash = hasher.finish();
112 match self
113 .inner
114 .get_mut(hash, |i| self.extractor.extract(i).eq(&key))
115 {
116 Some(bucket) => {
117 core::mem::drop(key);
118 Some(core::mem::replace(bucket, value))
119 }
120 None => {
121 core::mem::drop(key);
122 let hasher = make_hasher(&self.hash_builder, &self.extractor);
123 self.inner.insert(hash, value, hasher);
124 None
125 }
126 }
127 }
128 pub fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
130 where
131 K: core::hash::Hash,
132 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
133 {
134 <Self as IEntry<T, Extractor, S, DefaultBorrower>>::entry(self, key)
135 }
136 pub fn write(&mut self, value: T) -> &mut T
138 where
139 for<'a, 'b> <Extractor as KeyExtractor<'a, T>>::Key:
140 PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
141 {
142 let key = self.extractor.extract(&value);
143 let mut hasher = self.hash_builder.build_hasher();
144 key.hash(&mut hasher);
145 let hash = hasher.finish();
146 match self
147 .inner
148 .get_mut(hash, |i| self.extractor.extract(i).eq(&key))
149 {
150 Some(bucket) => {
151 core::mem::drop(key);
152 *bucket = value;
153 unsafe { core::mem::transmute(bucket) }
154 }
155 None => {
156 core::mem::drop(key);
157 let hasher = make_hasher(&self.hash_builder, &self.extractor);
158 let bucket = self.inner.insert(hash, value, hasher);
159 unsafe { &mut *bucket.as_ptr() }
160 }
161 }
162 }
163 pub fn get<K>(&self, key: &K) -> Option<&T>
165 where
166 K: core::hash::Hash,
167 for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash + PartialEq<K>,
168 {
169 let mut hasher = self.hash_builder.build_hasher();
170 key.hash(&mut hasher);
171 let hash = hasher.finish();
172 self.inner.get(hash, |i| self.extractor.extract(i).eq(key))
173 }
174 pub fn get_mut<'a, K>(&'a mut self, key: &'a K) -> Option<KeyedSetGuard<'a, K, T, Extractor>>
178 where
179 K: core::hash::Hash,
180 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
181 {
182 let mut hasher = self.hash_builder.build_hasher();
183 key.hash(&mut hasher);
184 let hash = hasher.finish();
185 self.inner
186 .get_mut(hash, |i| self.extractor.extract(i).eq(key))
187 .map(|guarded| KeyedSetGuard {
188 guarded,
189 key,
190 extractor: &self.extractor,
191 })
192 }
193 pub unsafe fn get_mut_unguarded<'a, K>(&'a mut self, key: &K) -> Option<&'a mut T>
198 where
199 K: core::hash::Hash,
200 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
201 {
202 let mut hasher = self.hash_builder.build_hasher();
203 key.hash(&mut hasher);
204 let hash = hasher.finish();
205 self.inner
206 .get_mut(hash, |i| self.extractor.extract(i).eq(key))
207 }
208 pub fn remove<K>(&mut self, key: &K) -> Option<T>
210 where
211 K: core::hash::Hash,
212 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
213 {
214 let mut hasher = self.hash_builder.build_hasher();
215 key.hash(&mut hasher);
216 let hash = hasher.finish();
217 self.inner
218 .remove_entry(hash, |i| self.extractor.extract(i).eq(key))
219 }
220 pub fn drain_where<F: FnMut(&mut T) -> bool>(&mut self, predicate: F) -> DrainFilter<T, F> {
224 DrainFilter {
225 predicate,
226 iter: unsafe { self.inner.iter() },
227 table: &mut self.inner,
228 }
229 }
230 pub fn drain(&mut self) -> Drain<T> {
234 Drain {
235 iter: unsafe { self.inner.iter() },
236 table: &mut self.inner,
237 }
238 }
239}
240pub struct Drain<'a, T> {
242 iter: RawIter<T>,
243 table: &'a mut RawTable<T>,
244}
245
246impl<'a, T> Drop for Drain<'a, T> {
247 fn drop(&mut self) {
248 for _ in self {}
249 }
250}
251
252impl<'a, T> Iterator for Drain<'a, T> {
253 type Item = T;
254 fn next(&mut self) -> Option<Self::Item> {
255 Some(unsafe { self.table.remove(self.iter.next()?).0 })
256 }
257}
258pub struct DrainFilter<'a, T, F: FnMut(&mut T) -> bool> {
260 predicate: F,
261 iter: RawIter<T>,
262 table: &'a mut RawTable<T>,
263}
264
265impl<'a, T, F: FnMut(&mut T) -> bool> Drop for DrainFilter<'a, T, F> {
266 fn drop(&mut self) {
267 for _ in self {}
268 }
269}
270
271impl<'a, T, F: FnMut(&mut T) -> bool> Iterator for DrainFilter<'a, T, F> {
272 type Item = T;
273 fn next(&mut self) -> Option<Self::Item> {
274 unsafe {
275 for item in &mut self.iter {
276 if (self.predicate)(item.as_mut()) {
277 return Some(self.table.remove(item).0);
278 }
279 }
280 }
281 None
282 }
283}
284pub trait IEntry<T, Extractor, S, Borrower = DefaultBorrower>
286where
287 Extractor: for<'a> KeyExtractor<'a, T>,
288 for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
289 S: BuildHasher,
290{
291 fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
293 where
294 Borrower: IBorrower<K>,
295 <Borrower as IBorrower<K>>::Borrowed: core::hash::Hash,
296 for<'z> <Extractor as KeyExtractor<'z, T>>::Key:
297 core::hash::Hash + PartialEq<<Borrower as IBorrower<K>>::Borrowed>;
298}
299impl<T, Extractor, S, Borrower> IEntry<T, Extractor, S, Borrower> for KeyedSet<T, Extractor, S>
300where
301 Extractor: for<'a> KeyExtractor<'a, T>,
302 for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
303 S: BuildHasher,
304{
305 fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
306 where
307 Borrower: IBorrower<K>,
308 <Borrower as IBorrower<K>>::Borrowed: core::hash::Hash,
309 for<'z> <Extractor as KeyExtractor<'z, T>>::Key:
310 core::hash::Hash + PartialEq<<Borrower as IBorrower<K>>::Borrowed>,
311 {
312 match unsafe { self.get_mut_unguarded(Borrower::borrow(&key)) } {
313 Some(entry) => Entry::OccupiedEntry(unsafe { core::mem::transmute(entry) }),
314 None => Entry::Vacant(VacantEntry { set: self, key }),
315 }
316 }
317}
318pub struct DefaultBorrower;
320pub trait IBorrower<T> {
322 type Borrowed;
324 fn borrow(value: &T) -> &Self::Borrowed;
326}
327impl<T> IBorrower<T> for DefaultBorrower {
328 type Borrowed = T;
329
330 fn borrow(value: &T) -> &Self::Borrowed {
331 value
332 }
333}
334impl<T, Extractor, S> KeyedSet<T, Extractor, S> {
335 pub fn iter(&self) -> Iter<T> {
337 Iter {
338 inner: unsafe { self.inner.iter() },
339 marker: PhantomData,
340 }
341 }
342 pub fn iter_mut(&mut self) -> IterMut<T> {
344 IterMut {
345 inner: unsafe { self.inner.iter() },
346 marker: PhantomData,
347 }
348 }
349 pub fn len(&self) -> usize {
351 self.inner.len()
352 }
353 pub fn is_empty(&self) -> bool {
355 self.inner.is_empty()
356 }
357}
358pub struct KeyedSetGuard<'a, K, T, Extractor>
360where
361 Extractor: for<'z> KeyExtractor<'z, T>,
362 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
363{
364 guarded: &'a mut T,
365 key: &'a K,
366 extractor: &'a Extractor,
367}
368impl<'a, K, T, Extractor> core::ops::Deref for KeyedSetGuard<'a, K, T, Extractor>
369where
370 Extractor: for<'z> KeyExtractor<'z, T>,
371 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
372{
373 type Target = T;
374
375 fn deref(&self) -> &Self::Target {
376 self.guarded
377 }
378}
379impl<'a, K, T, Extractor> core::ops::DerefMut for KeyedSetGuard<'a, K, T, Extractor>
380where
381 Extractor: for<'z> KeyExtractor<'z, T>,
382 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
383{
384 fn deref_mut(&mut self) -> &mut Self::Target {
385 self.guarded
386 }
387}
388impl<'a, K, T, Extractor> Drop for KeyedSetGuard<'a, K, T, Extractor>
389where
390 Extractor: for<'z> KeyExtractor<'z, T>,
391 for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
392{
393 fn drop(&mut self) {
394 if !self.extractor.extract(&*self.guarded).eq(self.key) {
395 panic!("KeyedSetGuard dropped with new value that would change the key, breaking the internal table's invariants.")
396 }
397 }
398}
399
400pub struct IntoIter<T>(RawIntoIter<T>);
402
403impl<T> ExactSizeIterator for IntoIter<T> {
404 fn len(&self) -> usize {
405 self.0.len()
406 }
407}
408impl<T> Iterator for IntoIter<T> {
409 type Item = T;
410 fn next(&mut self) -> Option<Self::Item> {
411 self.0.next()
412 }
413}
414
415pub struct Iter<'a, T> {
417 inner: RawIter<T>,
418 marker: PhantomData<&'a ()>,
419}
420impl<'a, T: 'a> Iterator for Iter<'a, T> {
421 type Item = &'a T;
422 fn next(&mut self) -> Option<Self::Item> {
423 self.inner.next().map(|b| unsafe { b.as_ref() })
424 }
425}
426impl<'a, T: 'a> ExactSizeIterator for Iter<'a, T> {
427 fn len(&self) -> usize {
428 self.inner.len()
429 }
430}
431pub struct IterMut<'a, T> {
433 inner: RawIter<T>,
434 marker: PhantomData<&'a mut ()>,
435}
436impl<'a, T: 'a> Iterator for IterMut<'a, T> {
437 type Item = &'a mut T;
438 fn next(&mut self) -> Option<Self::Item> {
439 self.inner.next().map(|b| unsafe { b.as_mut() })
440 }
441}
442impl<'a, T: 'a> ExactSizeIterator for IterMut<'a, T> {
443 fn len(&self) -> usize {
444 self.inner.len()
445 }
446}
447
448pub struct VacantEntry<'a, T: 'a, Extractor, K, S> {
450 pub set: &'a mut KeyedSet<T, Extractor, S>,
452 pub key: K,
454}
455pub enum Entry<'a, T, Extractor, K, S = DefaultHashBuilder> {
457 Vacant(VacantEntry<'a, T, Extractor, K, S>),
459 OccupiedEntry(&'a mut T),
461}
462
463impl<'a, T: 'a, Extractor, S, K> Entry<'a, T, Extractor, K, S>
464where
465 S: BuildHasher,
466 for<'z> Extractor: KeyExtractor<'z, T>,
467 for<'z, 'b> <Extractor as KeyExtractor<'z, T>>::Key:
468 PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
469{
470 pub fn get_or_insert_with(self, f: impl FnOnce(K) -> T) -> &'a mut T {
472 match self {
473 Entry::Vacant(entry) => entry.insert_with(f),
474 Entry::OccupiedEntry(entry) => entry,
475 }
476 }
477 pub fn get_or_insert_with_into(self) -> &'a mut T
479 where
480 K: Into<T>,
481 {
482 self.get_or_insert_with(|k| k.into())
483 }
484}
485impl<'a, K, T, Extractor, S> VacantEntry<'a, T, Extractor, K, S>
486where
487 S: BuildHasher,
488 for<'z> Extractor: KeyExtractor<'z, T>,
489 for<'z, 'b> <Extractor as KeyExtractor<'z, T>>::Key:
490 PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
491{
492 pub fn insert_with<F: FnOnce(K) -> T>(self, f: F) -> &'a mut T {
494 self.set.write(f(self.key))
495 }
496}
497
498#[allow(clippy::manual_hash_one)]
499fn make_hasher<'a, S: BuildHasher, Extractor, T>(
500 hash_builder: &'a S,
501 extractor: &'a Extractor,
502) -> impl Fn(&T) -> u64 + 'a
503where
504 Extractor: for<'b> KeyExtractor<'b, T>,
505 for<'b> <Extractor as KeyExtractor<'b, T>>::Key: core::hash::Hash,
506{
507 move |value| {
508 let key = extractor.extract(value);
509 let mut hasher = hash_builder.build_hasher();
510 key.hash(&mut hasher);
511 hasher.finish()
512 }
513}
514
515#[test]
516fn test() {
517 let mut set = KeyedSet::new(|value: &(u64, u64)| value.0);
518 assert_eq!(set.len(), 0);
519 set.insert((0, 0));
520 assert_eq!(set.insert((0, 1)), Some((0, 0)));
521 assert_eq!(set.len(), 1);
522 assert_eq!(set.get(&0), Some(&(0, 1)));
523 assert!(set.get(&1).is_none());
524 assert_eq!(*set.entry(12).get_or_insert_with(|k| (k, k)), (12, 12));
525}