keyed_set/
lib.rs

1//! # Keyed Set: a hashbrown-based HashSet that indexes based on projections of its elements.
2//! Ever wanted a `HashMap<K, V>`, but where `V` actually contains `K` (or at least can be projected to it)?
3//! Well this is it.
4//!
5//! The easiest way to define a projection is through a closure that you pass at construction, but you may also define your own key extractors as ZSTs that implement `Default` to gain a `Default` constructor for your Keyed Sets.
6
7#![no_std]
8
9use core::{
10    hash::{BuildHasher, Hash, Hasher},
11    marker::PhantomData,
12};
13
14use hashbrown::{
15    hash_map::DefaultHashBuilder,
16    raw::{RawIntoIter, RawIter, RawTable},
17};
18
19/// A `HashMap<K, V>` where `K` is a part of `V`
20#[derive(Clone)]
21pub struct KeyedSet<T, Extractor, S = DefaultHashBuilder> {
22    inner: hashbrown::raw::RawTable<T>,
23    hash_builder: S,
24    extractor: Extractor,
25}
26
27impl<T, Extractor: Default, S: Default> Default for KeyedSet<T, Extractor, S> {
28    fn default() -> Self {
29        Self {
30            inner: Default::default(),
31            hash_builder: Default::default(),
32            extractor: Default::default(),
33        }
34    }
35}
36
37impl<'a, T, Extractor, S> IntoIterator for &'a KeyedSet<T, Extractor, S> {
38    type Item = &'a T;
39    type IntoIter = Iter<'a, T>;
40    fn into_iter(self) -> Self::IntoIter {
41        self.iter()
42    }
43}
44impl<'a, T, Extractor, S> IntoIterator for &'a mut KeyedSet<T, Extractor, S> {
45    type Item = &'a mut T;
46    type IntoIter = IterMut<'a, T>;
47    fn into_iter(self) -> Self::IntoIter {
48        self.iter_mut()
49    }
50}
51/// Extracts the key from the value, allowing [`KeyedSet`] to obtain its values' keys.
52pub trait KeyExtractor<'a, T> {
53    /// The type of the key extracted by the extractor.
54    type Key: Hash;
55    /// Extracts the key from the value, allowing [`KeyedSet`] to obtain its values' keys.
56    fn extract(&self, from: &'a T) -> Self::Key;
57}
58impl<'a, T: 'a, U: Hash, F: Fn(&'a T) -> U> KeyExtractor<'a, T> for F {
59    type Key = U;
60    fn extract(&self, from: &'a T) -> Self::Key {
61        self(from)
62    }
63}
64impl<'a, T: 'a + Hash> KeyExtractor<'a, T> for () {
65    type Key = &'a T;
66    fn extract(&self, from: &'a T) -> Self::Key {
67        from
68    }
69}
70impl<T, Extractor> KeyedSet<T, Extractor>
71where
72    Extractor: for<'a> KeyExtractor<'a, T>,
73    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
74{
75    /// Construct a new map where the key is extracted from the value using `extractor`.`
76    pub fn new(extractor: Extractor) -> Self {
77        Self {
78            inner: Default::default(),
79            hash_builder: Default::default(),
80            extractor,
81        }
82    }
83}
84
85impl<T: core::fmt::Debug, Extractor, S> core::fmt::Debug for KeyedSet<T, Extractor, S> {
86    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
87        write!(f, "KeyedSet {{")?;
88        for v in self.iter() {
89            write!(f, "{:?}, ", v)?;
90        }
91        write!(f, "}}")
92    }
93}
94
95#[allow(clippy::manual_hash_one)]
96impl<T, Extractor, S> KeyedSet<T, Extractor, S>
97where
98    Extractor: for<'a> KeyExtractor<'a, T>,
99    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
100    S: BuildHasher,
101{
102    /// Inserts a value into the map.
103    pub fn insert(&mut self, value: T) -> Option<T>
104    where
105        for<'a, 'b> <Extractor as KeyExtractor<'a, T>>::Key:
106            PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
107    {
108        let key = self.extractor.extract(&value);
109        let mut hasher = self.hash_builder.build_hasher();
110        key.hash(&mut hasher);
111        let hash = hasher.finish();
112        match self
113            .inner
114            .get_mut(hash, |i| self.extractor.extract(i).eq(&key))
115        {
116            Some(bucket) => {
117                core::mem::drop(key);
118                Some(core::mem::replace(bucket, value))
119            }
120            None => {
121                core::mem::drop(key);
122                let hasher = make_hasher(&self.hash_builder, &self.extractor);
123                self.inner.insert(hash, value, hasher);
124                None
125            }
126        }
127    }
128    /// Obtain an entry in the map, allowing mutable access to the value associated to that key if it exists.
129    pub fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
130    where
131        K: core::hash::Hash,
132        for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
133    {
134        <Self as IEntry<T, Extractor, S, DefaultBorrower>>::entry(self, key)
135    }
136    /// Similar to [`KeyedSet::insert`], but returns a mutable reference to the inserted value instead of the previous value.
137    pub fn write(&mut self, value: T) -> &mut T
138    where
139        for<'a, 'b> <Extractor as KeyExtractor<'a, T>>::Key:
140            PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
141    {
142        let key = self.extractor.extract(&value);
143        let mut hasher = self.hash_builder.build_hasher();
144        key.hash(&mut hasher);
145        let hash = hasher.finish();
146        match self
147            .inner
148            .get_mut(hash, |i| self.extractor.extract(i).eq(&key))
149        {
150            Some(bucket) => {
151                core::mem::drop(key);
152                *bucket = value;
153                unsafe { core::mem::transmute(bucket) }
154            }
155            None => {
156                core::mem::drop(key);
157                let hasher = make_hasher(&self.hash_builder, &self.extractor);
158                let bucket = self.inner.insert(hash, value, hasher);
159                unsafe { &mut *bucket.as_ptr() }
160            }
161        }
162    }
163    /// Access the value associated to the key immutably.
164    pub fn get<K>(&self, key: &K) -> Option<&T>
165    where
166        K: core::hash::Hash,
167        for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash + PartialEq<K>,
168    {
169        let mut hasher = self.hash_builder.build_hasher();
170        key.hash(&mut hasher);
171        let hash = hasher.finish();
172        self.inner.get(hash, |i| self.extractor.extract(i).eq(key))
173    }
174    /// Access the value associated to the key mutably.
175    ///
176    /// The returned [`KeyedSetGuard`] will panic on drop if the value is modified in a way that modifies its key.
177    pub fn get_mut<'a, K>(&'a mut self, key: &'a K) -> Option<KeyedSetGuard<'a, K, T, Extractor>>
178    where
179        K: core::hash::Hash,
180        for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
181    {
182        let mut hasher = self.hash_builder.build_hasher();
183        key.hash(&mut hasher);
184        let hash = hasher.finish();
185        self.inner
186            .get_mut(hash, |i| self.extractor.extract(i).eq(key))
187            .map(|guarded| KeyedSetGuard {
188                guarded,
189                key,
190                extractor: &self.extractor,
191            })
192    }
193    /// Access the value associated to the key mutably.
194    ///
195    /// # Safety
196    /// Mutating the value in a way that mutates its key may lead to undefined behaviour.
197    pub unsafe fn get_mut_unguarded<'a, K>(&'a mut self, key: &K) -> Option<&'a mut T>
198    where
199        K: core::hash::Hash,
200        for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
201    {
202        let mut hasher = self.hash_builder.build_hasher();
203        key.hash(&mut hasher);
204        let hash = hasher.finish();
205        self.inner
206            .get_mut(hash, |i| self.extractor.extract(i).eq(key))
207    }
208    /// Remove the value associated to the key, returning it if it exists.
209    pub fn remove<K>(&mut self, key: &K) -> Option<T>
210    where
211        K: core::hash::Hash,
212        for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
213    {
214        let mut hasher = self.hash_builder.build_hasher();
215        key.hash(&mut hasher);
216        let hash = hasher.finish();
217        self.inner
218            .remove_entry(hash, |i| self.extractor.extract(i).eq(key))
219    }
220    /// Returns an iterator that drains elements that match the provided predicate, while removing them from the set.
221    ///
222    /// Note that [`DrainFilter`] WILL iterate fully on drop, ensuring that all elements matching your predicate are always removed, even if you fail to iterate.
223    pub fn drain_where<F: FnMut(&mut T) -> bool>(&mut self, predicate: F) -> DrainFilter<T, F> {
224        DrainFilter {
225            predicate,
226            iter: unsafe { self.inner.iter() },
227            table: &mut self.inner,
228        }
229    }
230    /// Returns an iterator that drains elements from the collection, without affecting the collection's capacity.
231    ///
232    /// Note that [`Drain`] WILL iterate fully on drop, ensuring that all elements are indeed removed, even if you fail to iterate.
233    pub fn drain(&mut self) -> Drain<T> {
234        Drain {
235            iter: unsafe { self.inner.iter() },
236            table: &mut self.inner,
237        }
238    }
239}
240/// An iterator over a [`KeyedSet`] that steals the values from it.
241pub struct Drain<'a, T> {
242    iter: RawIter<T>,
243    table: &'a mut RawTable<T>,
244}
245
246impl<'a, T> Drop for Drain<'a, T> {
247    fn drop(&mut self) {
248        for _ in self {}
249    }
250}
251
252impl<'a, T> Iterator for Drain<'a, T> {
253    type Item = T;
254    fn next(&mut self) -> Option<Self::Item> {
255        Some(unsafe { self.table.remove(self.iter.next()?).0 })
256    }
257}
258/// An iterator over a [`KeyedSet`] that only steals values that match a given predicate.
259pub struct DrainFilter<'a, T, F: FnMut(&mut T) -> bool> {
260    predicate: F,
261    iter: RawIter<T>,
262    table: &'a mut RawTable<T>,
263}
264
265impl<'a, T, F: FnMut(&mut T) -> bool> Drop for DrainFilter<'a, T, F> {
266    fn drop(&mut self) {
267        for _ in self {}
268    }
269}
270
271impl<'a, T, F: FnMut(&mut T) -> bool> Iterator for DrainFilter<'a, T, F> {
272    type Item = T;
273    fn next(&mut self) -> Option<Self::Item> {
274        unsafe {
275            for item in &mut self.iter {
276                if (self.predicate)(item.as_mut()) {
277                    return Some(self.table.remove(item).0);
278                }
279            }
280        }
281        None
282    }
283}
284/// The trait magic that allows [`KeyedSet::entry`] to work.
285pub trait IEntry<T, Extractor, S, Borrower = DefaultBorrower>
286where
287    Extractor: for<'a> KeyExtractor<'a, T>,
288    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
289    S: BuildHasher,
290{
291    /// Access the entry for `key`.
292    fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
293    where
294        Borrower: IBorrower<K>,
295        <Borrower as IBorrower<K>>::Borrowed: core::hash::Hash,
296        for<'z> <Extractor as KeyExtractor<'z, T>>::Key:
297            core::hash::Hash + PartialEq<<Borrower as IBorrower<K>>::Borrowed>;
298}
299impl<T, Extractor, S, Borrower> IEntry<T, Extractor, S, Borrower> for KeyedSet<T, Extractor, S>
300where
301    Extractor: for<'a> KeyExtractor<'a, T>,
302    for<'a> <Extractor as KeyExtractor<'a, T>>::Key: core::hash::Hash,
303    S: BuildHasher,
304{
305    fn entry<'a, K>(&'a mut self, key: K) -> Entry<'a, T, Extractor, K, S>
306    where
307        Borrower: IBorrower<K>,
308        <Borrower as IBorrower<K>>::Borrowed: core::hash::Hash,
309        for<'z> <Extractor as KeyExtractor<'z, T>>::Key:
310            core::hash::Hash + PartialEq<<Borrower as IBorrower<K>>::Borrowed>,
311    {
312        match unsafe { self.get_mut_unguarded(Borrower::borrow(&key)) } {
313            Some(entry) => Entry::OccupiedEntry(unsafe { core::mem::transmute(entry) }),
314            None => Entry::Vacant(VacantEntry { set: self, key }),
315        }
316    }
317}
318/// The default way to borrow a value.
319pub struct DefaultBorrower;
320/// Allows defining alternatives to [`core::ops::Deref`]
321pub trait IBorrower<T> {
322    /// The borrow target.
323    type Borrowed;
324    /// Borrows a value in its borrowed representation.
325    fn borrow(value: &T) -> &Self::Borrowed;
326}
327impl<T> IBorrower<T> for DefaultBorrower {
328    type Borrowed = T;
329
330    fn borrow(value: &T) -> &Self::Borrowed {
331        value
332    }
333}
334impl<T, Extractor, S> KeyedSet<T, Extractor, S> {
335    /// Iterate over the [`KeyedSet`]'s values immutably.
336    pub fn iter(&self) -> Iter<T> {
337        Iter {
338            inner: unsafe { self.inner.iter() },
339            marker: PhantomData,
340        }
341    }
342    /// Iterate over the [`KeyedSet`]'s values mutably.
343    pub fn iter_mut(&mut self) -> IterMut<T> {
344        IterMut {
345            inner: unsafe { self.inner.iter() },
346            marker: PhantomData,
347        }
348    }
349    /// Returns the number of elements in the [`KeyedSet`]
350    pub fn len(&self) -> usize {
351        self.inner.len()
352    }
353    /// Returns `true` if the [`KeyedSet`] is empty.
354    pub fn is_empty(&self) -> bool {
355        self.inner.is_empty()
356    }
357}
358/// A guard that allows mutating a value, but which panics if the new value once dropped doesn't have the same key.
359pub struct KeyedSetGuard<'a, K, T, Extractor>
360where
361    Extractor: for<'z> KeyExtractor<'z, T>,
362    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
363{
364    guarded: &'a mut T,
365    key: &'a K,
366    extractor: &'a Extractor,
367}
368impl<'a, K, T, Extractor> core::ops::Deref for KeyedSetGuard<'a, K, T, Extractor>
369where
370    Extractor: for<'z> KeyExtractor<'z, T>,
371    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
372{
373    type Target = T;
374
375    fn deref(&self) -> &Self::Target {
376        self.guarded
377    }
378}
379impl<'a, K, T, Extractor> core::ops::DerefMut for KeyedSetGuard<'a, K, T, Extractor>
380where
381    Extractor: for<'z> KeyExtractor<'z, T>,
382    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
383{
384    fn deref_mut(&mut self) -> &mut Self::Target {
385        self.guarded
386    }
387}
388impl<'a, K, T, Extractor> Drop for KeyedSetGuard<'a, K, T, Extractor>
389where
390    Extractor: for<'z> KeyExtractor<'z, T>,
391    for<'z> <Extractor as KeyExtractor<'z, T>>::Key: core::hash::Hash + PartialEq<K>,
392{
393    fn drop(&mut self) {
394        if !self.extractor.extract(&*self.guarded).eq(self.key) {
395            panic!("KeyedSetGuard dropped with new value that would change the key, breaking the internal table's invariants.")
396        }
397    }
398}
399
400/// An iterator over the [`KeyedSet`] by value.
401pub struct IntoIter<T>(RawIntoIter<T>);
402
403impl<T> ExactSizeIterator for IntoIter<T> {
404    fn len(&self) -> usize {
405        self.0.len()
406    }
407}
408impl<T> Iterator for IntoIter<T> {
409    type Item = T;
410    fn next(&mut self) -> Option<Self::Item> {
411        self.0.next()
412    }
413}
414
415/// An iterator over the [`KeyedSet`] by reference.
416pub struct Iter<'a, T> {
417    inner: RawIter<T>,
418    marker: PhantomData<&'a ()>,
419}
420impl<'a, T: 'a> Iterator for Iter<'a, T> {
421    type Item = &'a T;
422    fn next(&mut self) -> Option<Self::Item> {
423        self.inner.next().map(|b| unsafe { b.as_ref() })
424    }
425}
426impl<'a, T: 'a> ExactSizeIterator for Iter<'a, T> {
427    fn len(&self) -> usize {
428        self.inner.len()
429    }
430}
431/// An iterator over the [`KeyedSet`] by mutable reference.
432pub struct IterMut<'a, T> {
433    inner: RawIter<T>,
434    marker: PhantomData<&'a mut ()>,
435}
436impl<'a, T: 'a> Iterator for IterMut<'a, T> {
437    type Item = &'a mut T;
438    fn next(&mut self) -> Option<Self::Item> {
439        self.inner.next().map(|b| unsafe { b.as_mut() })
440    }
441}
442impl<'a, T: 'a> ExactSizeIterator for IterMut<'a, T> {
443    fn len(&self) -> usize {
444        self.inner.len()
445    }
446}
447
448/// A vacant entry into a [`KeyedSet`]
449pub struct VacantEntry<'a, T: 'a, Extractor, K, S> {
450    /// The inner set
451    pub set: &'a mut KeyedSet<T, Extractor, S>,
452    /// The key fort he entry.
453    pub key: K,
454}
455/// An entry into a [`KeyedSet`], allowing in-place modification of the value associated with the key if it exists.
456pub enum Entry<'a, T, Extractor, K, S = DefaultHashBuilder> {
457    /// The key was not yet present in the [`KeyedSet`].
458    Vacant(VacantEntry<'a, T, Extractor, K, S>),
459    /// The key was already present in the [`KeyedSet`].
460    OccupiedEntry(&'a mut T),
461}
462
463impl<'a, T: 'a, Extractor, S, K> Entry<'a, T, Extractor, K, S>
464where
465    S: BuildHasher,
466    for<'z> Extractor: KeyExtractor<'z, T>,
467    for<'z, 'b> <Extractor as KeyExtractor<'z, T>>::Key:
468        PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
469{
470    /// Get a mutable reference to the value if present, or assign a value constructed by `f` if it wasn't.
471    pub fn get_or_insert_with(self, f: impl FnOnce(K) -> T) -> &'a mut T {
472        match self {
473            Entry::Vacant(entry) => entry.insert_with(f),
474            Entry::OccupiedEntry(entry) => entry,
475        }
476    }
477    /// A shortcut for `entry.get_or_insert_with(Into::into)`
478    pub fn get_or_insert_with_into(self) -> &'a mut T
479    where
480        K: Into<T>,
481    {
482        self.get_or_insert_with(|k| k.into())
483    }
484}
485impl<'a, K, T, Extractor, S> VacantEntry<'a, T, Extractor, K, S>
486where
487    S: BuildHasher,
488    for<'z> Extractor: KeyExtractor<'z, T>,
489    for<'z, 'b> <Extractor as KeyExtractor<'z, T>>::Key:
490        PartialEq<<Extractor as KeyExtractor<'b, T>>::Key>,
491{
492    /// Inserts a value constructed from the entry's key using `f`.
493    pub fn insert_with<F: FnOnce(K) -> T>(self, f: F) -> &'a mut T {
494        self.set.write(f(self.key))
495    }
496}
497
498#[allow(clippy::manual_hash_one)]
499fn make_hasher<'a, S: BuildHasher, Extractor, T>(
500    hash_builder: &'a S,
501    extractor: &'a Extractor,
502) -> impl Fn(&T) -> u64 + 'a
503where
504    Extractor: for<'b> KeyExtractor<'b, T>,
505    for<'b> <Extractor as KeyExtractor<'b, T>>::Key: core::hash::Hash,
506{
507    move |value| {
508        let key = extractor.extract(value);
509        let mut hasher = hash_builder.build_hasher();
510        key.hash(&mut hasher);
511        hasher.finish()
512    }
513}
514
515#[test]
516fn test() {
517    let mut set = KeyedSet::new(|value: &(u64, u64)| value.0);
518    assert_eq!(set.len(), 0);
519    set.insert((0, 0));
520    assert_eq!(set.insert((0, 1)), Some((0, 0)));
521    assert_eq!(set.len(), 1);
522    assert_eq!(set.get(&0), Some(&(0, 1)));
523    assert!(set.get(&1).is_none());
524    assert_eq!(*set.entry(12).get_or_insert_with(|k| (k, k)), (12, 12));
525}