small_map/
inline.rs

1use core::{
2    hash::{BuildHasher, Hash},
3    iter::FusedIterator,
4    mem::{self, transmute, MaybeUninit},
5    ptr::NonNull,
6};
7
8use crate::{
9    raw::{
10        h2,
11        iter::{RawIntoIter, RawIter},
12        util::{equivalent_key, likely, make_hash, Bucket, InsertSlot, SizedTypeProperties},
13        BitMaskWord, Group, RawIterInner, DELETED, EMPTY,
14    },
15    Equivalent,
16};
17
18#[derive(Clone)]
19pub struct Inline<const N: usize, K, V, S> {
20    raw: RawInline<N, (K, V)>,
21    // Option is for take, S always exists before drop.
22    hash_builder: Option<S>,
23}
24
25struct RawInline<const N: usize, T> {
26    aligned_groups: AlignedGroups<N>,
27    len: usize,
28    data: [MaybeUninit<T>; N],
29}
30
31impl<const N: usize, T: Clone> Clone for RawInline<N, T> {
32    #[inline]
33    fn clone(&self) -> Self {
34        let mut aligned_groups = AlignedGroups {
35            groups: [EMPTY; N],
36            _align: [],
37        };
38        let mut data = unsafe { MaybeUninit::<[MaybeUninit<T>; N]>::uninit().assume_init() };
39        let mut new_idx = 0;
40
41        for i in 0..N {
42            let ctrl = self.aligned_groups.groups[i];
43            // Only copy valid entries (not EMPTY and not DELETED)
44            // EMPTY = 0b1111_1111, DELETED = 0b1000_0000
45            // Valid h2 values have the high bit unset (0x00-0x7F)
46            if ctrl & 0x80 == 0 {
47                aligned_groups.groups[new_idx] = ctrl;
48                data[new_idx] = MaybeUninit::new(unsafe { self.data[i].assume_init_ref().clone() });
49                new_idx += 1;
50            }
51        }
52
53        Self {
54            aligned_groups,
55            len: new_idx,
56            data,
57        }
58    }
59}
60
61#[repr(C)]
62#[derive(Clone, Copy)]
63pub(crate) struct AlignedGroups<const N: usize> {
64    groups: [u8; N],
65    _align: [Group; 0],
66}
67
68impl<const N: usize> AlignedGroups<N> {
69    #[inline]
70    unsafe fn ctrl(&self, index: usize) -> *mut u8 {
71        self.groups.as_ptr().add(index).cast_mut()
72    }
73
74    #[inline]
75    pub(crate) fn as_ptr(&self) -> NonNull<u8> {
76        unsafe { NonNull::new_unchecked(self.groups.as_ptr() as _) }
77    }
78}
79
80impl<const N: usize, T> Drop for RawInline<N, T> {
81    #[inline]
82    fn drop(&mut self) {
83        unsafe { self.drop_elements() }
84    }
85}
86
87impl<const N: usize, T> RawInline<N, T> {
88    #[inline]
89    unsafe fn drop_elements(&mut self) {
90        if T::NEEDS_DROP && self.len != 0 {
91            unsafe {
92                drop(RawIntoIter {
93                    inner: self.raw_iter_inner(),
94                    aligned_groups: (&self.aligned_groups as *const AlignedGroups<N>).read(),
95                    data: (&self.data as *const [MaybeUninit<T>; N]).read(),
96                });
97            }
98        }
99    }
100
101    /// Gets a reference to an element in the table.
102    #[inline]
103    fn get(&self, hash: u64, eq: impl FnMut(&T) -> bool) -> Option<&T> {
104        // Avoid `Option::map` because it bloats LLVM IR.
105        match self.find(hash, eq) {
106            Some(bucket) => Some(unsafe { bucket.as_ref() }),
107            None => None,
108        }
109    }
110
111    /// Gets a mutable reference to an element in the table.
112    #[inline]
113    fn get_mut(&mut self, hash: u64, eq: impl FnMut(&T) -> bool) -> Option<&mut T> {
114        // Avoid `Option::map` because it bloats LLVM IR.
115        match self.find(hash, eq) {
116            Some(bucket) => Some(unsafe { bucket.as_mut() }),
117            None => None,
118        }
119    }
120
121    const UNCHECKED_GROUP: usize = N / Group::WIDTH;
122    const TAIL_MASK: BitMaskWord = Group::LOWEST_MASK[N % Group::WIDTH];
123
124    /// Searches for an element in the table.
125    #[inline]
126    fn find(&self, hash: u64, mut eq: impl FnMut(&T) -> bool) -> Option<Bucket<T>> {
127        unsafe {
128            let h2_hash = h2(hash);
129            let mut probe_pos = 0;
130
131            // Manually expand the loop
132            for _ in 0..Self::UNCHECKED_GROUP {
133                let group = Group::load(self.aligned_groups.ctrl(probe_pos));
134                let matches = group.match_byte(h2_hash);
135                for bit in matches {
136                    let index = probe_pos + bit;
137                    if likely(eq(self.data.get_unchecked(index).assume_init_ref())) {
138                        return Some(self.bucket(index));
139                    }
140                }
141                probe_pos += Group::WIDTH;
142            }
143            if !N.is_multiple_of(Group::WIDTH) {
144                let group = Group::load(self.aligned_groups.ctrl(probe_pos));
145                // Clear invalid tail.
146                let matches = group.match_byte(h2_hash).and(Self::TAIL_MASK);
147                for bit in matches {
148                    let index = probe_pos + bit;
149                    if likely(eq(self.data.get_unchecked(index).assume_init_ref())) {
150                        return Some(self.bucket(index));
151                    }
152                }
153            }
154            None
155        }
156    }
157
158    /// Searches for an element in the table. If the element is not found,
159    /// returns `Err` with the position of a slot where an element with the
160    /// same hash could be inserted.
161    #[inline]
162    fn find_or_find_insert_slot(
163        &mut self,
164        hash: u64,
165        mut eq: impl FnMut(&T) -> bool,
166    ) -> Result<Bucket<T>, InsertSlot> {
167        unsafe {
168            let mut insert_slot = None;
169            let h2_hash = h2(hash);
170            let mut probe_pos = 0;
171
172            // Manually expand the loop
173            for _ in 0..Self::UNCHECKED_GROUP {
174                let group = Group::load(self.aligned_groups.ctrl(probe_pos));
175                let matches = group.match_byte(h2_hash);
176                for bit in matches {
177                    let index = probe_pos + bit;
178                    if likely(eq(self.data.get_unchecked(index).assume_init_ref())) {
179                        return Ok(self.bucket(index));
180                    }
181                }
182
183                // We didn't find the element we were looking for in the group, try to get an
184                // insertion slot from the group if we don't have one yet.
185                if likely(insert_slot.is_none()) {
186                    insert_slot = self.find_insert_slot_in_group(&group, probe_pos);
187                }
188
189                // If there's empty set, we should stop searching next group.
190                if likely(group.match_empty().any_bit_set()) {
191                    break;
192                }
193                probe_pos += Group::WIDTH;
194            }
195            if !N.is_multiple_of(Group::WIDTH) {
196                let group = Group::load(self.aligned_groups.ctrl(probe_pos));
197                let matches = group.match_byte(h2_hash).and(Self::TAIL_MASK);
198                for bit in matches {
199                    let index = probe_pos + bit;
200                    if likely(eq(self.data.get_unchecked(index).assume_init_ref())) {
201                        return Ok(self.bucket(index));
202                    }
203                }
204
205                // We didn't find the element we were looking for in the group, try to get an
206                // insertion slot from the group if we don't have one yet.
207                if likely(insert_slot.is_none()) {
208                    insert_slot = self.find_insert_slot_in_group(&group, probe_pos);
209                }
210            }
211
212            Err(InsertSlot {
213                index: insert_slot.unwrap_unchecked(),
214            })
215        }
216    }
217
218    /// Finds the position to insert something in a group.
219    #[inline]
220    fn find_insert_slot_in_group(&self, group: &Group, probe_seq: usize) -> Option<usize> {
221        let bit = group.match_empty_or_deleted().lowest_set_bit();
222
223        if likely(bit.is_some()) {
224            let n = unsafe { bit.unwrap_unchecked() };
225            return Some(probe_seq + n);
226        }
227        None
228    }
229
230    /// Inserts a new element into the table in the given slot, and returns its
231    /// raw bucket.
232    #[inline]
233    unsafe fn insert_in_slot(&mut self, hash: u64, slot: InsertSlot, value: T) -> Bucket<T> {
234        self.record_item_insert_at(slot.index, hash);
235        let bucket = self.bucket(slot.index);
236        bucket.write(value);
237        bucket
238    }
239
240    /// Inserts a new element into the table in the given slot, and returns its
241    /// raw bucket.
242    #[inline]
243    unsafe fn record_item_insert_at(&mut self, index: usize, hash: u64) {
244        self.set_ctrl_h2(index, hash);
245        self.len += 1;
246    }
247
248    /// Sets a control byte to the hash, and possibly also the replicated control byte at
249    /// the end of the array.
250    #[inline]
251    unsafe fn set_ctrl_h2(&mut self, index: usize, hash: u64) {
252        // SAFETY: The caller must uphold the safety rules for the [`RawTableInner::set_ctrl_h2`]
253        *self.aligned_groups.ctrl(index) = h2(hash);
254    }
255
256    /// Finds and removes an element from the table, returning it.
257    #[inline]
258    fn remove_entry(&mut self, hash: u64, eq: impl FnMut(&T) -> bool) -> Option<T> {
259        // Avoid `Option::map` because it bloats LLVM IR.
260        match self.find(hash, eq) {
261            Some(bucket) => Some(unsafe { self.remove(bucket).0 }),
262            None => None,
263        }
264    }
265
266    /// Removes an element from the table, returning it.
267    #[inline]
268    #[allow(clippy::needless_pass_by_value)]
269    unsafe fn remove(&mut self, item: Bucket<T>) -> (T, InsertSlot) {
270        self.erase_no_drop(&item);
271        (
272            item.read(),
273            InsertSlot {
274                index: self.bucket_index(&item),
275            },
276        )
277    }
278
279    /// Erases an element from the table without dropping it.
280    #[inline]
281    unsafe fn erase_no_drop(&mut self, item: &Bucket<T>) {
282        let index = self.bucket_index(item);
283        self.erase(index);
284    }
285
286    /// Returns the index of a bucket from a `Bucket`.
287    #[inline]
288    unsafe fn bucket_index(&self, bucket: &Bucket<T>) -> usize {
289        bucket.to_base_index(NonNull::new_unchecked(self.data.as_ptr() as _))
290    }
291
292    /// Erases the [`Bucket`]'s control byte at the given index so that it does not
293    /// triggered as full, decreases the `items` of the table and, if it can be done,
294    /// increases `self.growth_left`.
295    #[inline]
296    unsafe fn erase(&mut self, index: usize) {
297        *self.aligned_groups.ctrl(index) = DELETED;
298        self.len -= 1;
299    }
300
301    /// Returns a pointer to an element in the table.
302    #[inline]
303    unsafe fn bucket(&self, index: usize) -> Bucket<T> {
304        Bucket::from_base_index(
305            NonNull::new_unchecked(transmute::<*mut MaybeUninit<T>, *mut T>(
306                self.data.as_ptr().cast_mut(),
307            )),
308            index,
309        )
310    }
311
312    #[inline]
313    unsafe fn raw_iter_inner(&self) -> RawIterInner<T> {
314        let init_group = Group::load_aligned(self.aligned_groups.ctrl(0)).match_full();
315        RawIterInner::new(init_group, self.len)
316    }
317
318    #[inline]
319    fn iter(&self) -> RawIter<'_, N, T> {
320        RawIter {
321            inner: unsafe { self.raw_iter_inner() },
322            aligned_groups: &self.aligned_groups,
323            data: &self.data,
324        }
325    }
326}
327
328impl<const N: usize, K, V> RawInline<N, (K, V)> {
329    #[inline]
330    fn retain<F>(&mut self, f: &mut F)
331    where
332        F: FnMut(&K, &mut V) -> bool,
333    {
334        for i in 0..N {
335            let ctrl = self.aligned_groups.groups[i];
336            // Only process valid entries (not EMPTY and not DELETED)
337            if ctrl & 0x80 == 0 {
338                let (k, v) = unsafe { self.data[i].assume_init_mut() };
339                if !f(k, v) {
340                    unsafe {
341                        *self.aligned_groups.ctrl(i) = DELETED;
342                        core::ptr::drop_in_place(self.data[i].as_mut_ptr());
343                    }
344                    self.len -= 1;
345                }
346            }
347        }
348    }
349}
350
351impl<const N: usize, T> IntoIterator for RawInline<N, T> {
352    type Item = T;
353    type IntoIter = RawIntoIter<N, T>;
354
355    #[inline]
356    fn into_iter(self) -> RawIntoIter<N, T> {
357        let ret = unsafe {
358            RawIntoIter {
359                inner: self.raw_iter_inner(),
360                aligned_groups: (&self.aligned_groups as *const AlignedGroups<N>).read(),
361                data: (&self.data as *const [MaybeUninit<T>; N]).read(),
362            }
363        };
364        mem::forget(self);
365        ret
366    }
367}
368
369pub struct Iter<'a, const N: usize, K, V> {
370    inner: RawIter<'a, N, (K, V)>,
371}
372
373pub struct IntoIter<const N: usize, K, V> {
374    inner: RawIntoIter<N, (K, V)>,
375}
376
377impl<'a, const N: usize, K, V> Iterator for Iter<'a, N, K, V> {
378    type Item = (&'a K, &'a V);
379
380    #[inline]
381    fn next(&mut self) -> Option<(&'a K, &'a V)> {
382        match self.inner.next() {
383            Some(kv) => Some((&kv.0, &kv.1)),
384            None => None,
385        }
386    }
387    #[inline]
388    fn size_hint(&self) -> (usize, Option<usize>) {
389        self.inner.size_hint()
390    }
391}
392
393impl<const N: usize, K, V> Iterator for IntoIter<N, K, V> {
394    type Item = (K, V);
395
396    #[inline]
397    fn next(&mut self) -> Option<(K, V)> {
398        self.inner.next()
399    }
400    #[inline]
401    fn size_hint(&self) -> (usize, Option<usize>) {
402        self.inner.size_hint()
403    }
404}
405impl<'a, const N: usize, K, V> ExactSizeIterator for Iter<'a, N, K, V> {
406    #[inline]
407    fn len(&self) -> usize {
408        self.inner.len()
409    }
410}
411impl<const N: usize, K, V> ExactSizeIterator for IntoIter<N, K, V> {
412    #[inline]
413    fn len(&self) -> usize {
414        self.inner.len()
415    }
416}
417impl<'a, const N: usize, K, V> FusedIterator for Iter<'a, N, K, V> {}
418impl<const N: usize, K, V> FusedIterator for IntoIter<N, K, V> {}
419
420impl<const N: usize, K, V, S> IntoIterator for Inline<N, K, V, S> {
421    type Item = (K, V);
422    type IntoIter = IntoIter<N, K, V>;
423
424    fn into_iter(self) -> Self::IntoIter {
425        IntoIter {
426            inner: self.raw.into_iter(),
427        }
428    }
429}
430
431impl<const N: usize, K, V, S> Inline<N, K, V, S> {
432    #[inline]
433    pub(crate) fn iter(&self) -> Iter<'_, N, K, V> {
434        Iter {
435            inner: self.raw.iter(),
436        }
437    }
438
439    #[inline]
440    pub(crate) const fn new(hash_builder: S) -> Self {
441        assert!(N != 0, "SmallMap cannot be initialized with zero size.");
442        Self {
443            raw: RawInline {
444                aligned_groups: AlignedGroups {
445                    groups: [EMPTY; N],
446                    _align: [],
447                },
448                len: 0,
449                // TODO: use uninit_array when stable
450                data: unsafe { MaybeUninit::<[MaybeUninit<(K, V)>; N]>::uninit().assume_init() },
451            },
452            hash_builder: Some(hash_builder),
453        }
454    }
455
456    #[inline]
457    pub(crate) fn is_empty(&self) -> bool {
458        self.raw.len == 0
459    }
460
461    #[inline]
462    pub(crate) fn is_full(&self) -> bool {
463        self.raw.len == N
464    }
465
466    #[inline]
467    pub(crate) fn len(&self) -> usize {
468        self.raw.len
469    }
470
471    // # Safety
472    // Hasher must exist.
473    #[inline]
474    pub(crate) unsafe fn take_hasher(&mut self) -> S {
475        self.hash_builder.take().unwrap_unchecked()
476    }
477
478    #[inline]
479    fn hash_builder(&self) -> &S {
480        self.hash_builder.as_ref().unwrap()
481    }
482}
483
484impl<const N: usize, K, V, S> Inline<N, K, V, S>
485where
486    K: Eq + Hash,
487    S: BuildHasher,
488{
489    /// Returns a reference to the value corresponding to the key.
490    #[inline]
491    pub(crate) fn get<Q>(&self, k: &Q) -> Option<&V>
492    where
493        Q: ?Sized + Hash + Equivalent<K>,
494    {
495        // Avoid `Option::map` because it bloats LLVM IR.
496        match self.get_inner(k) {
497            Some((_, v)) => Some(v),
498            None => None,
499        }
500    }
501
502    /// Returns a reference to the value corresponding to the key.
503    #[inline]
504    pub(crate) fn get_mut<Q>(&mut self, k: &Q) -> Option<&mut V>
505    where
506        Q: ?Sized + Hash + Equivalent<K>,
507    {
508        // Avoid `Option::map` because it bloats LLVM IR.
509        match self.get_inner_mut(k) {
510            Some((_, v)) => Some(v),
511            None => None,
512        }
513    }
514
515    /// Returns the key-value pair corresponding to the supplied key.
516    #[inline]
517    pub(crate) fn get_key_value<Q>(&self, k: &Q) -> Option<(&K, &V)>
518    where
519        Q: ?Sized + Hash + Equivalent<K>,
520    {
521        // Avoid `Option::map` because it bloats LLVM IR.
522        match self.get_inner(k) {
523            Some((key, value)) => Some((key, value)),
524            None => None,
525        }
526    }
527
528    /// Inserts a key-value pair into the map.
529    #[inline]
530    pub(crate) fn insert(&mut self, k: K, v: V) -> Option<V> {
531        let hash = make_hash::<K, S>(self.hash_builder(), &k);
532        match self.raw.find_or_find_insert_slot(hash, equivalent_key(&k)) {
533            Ok(bucket) => Some(mem::replace(unsafe { &mut bucket.as_mut().1 }, v)),
534            Err(slot) => {
535                unsafe {
536                    self.raw.insert_in_slot(hash, slot, (k, v));
537                }
538                None
539            }
540        }
541    }
542
543    /// Removes a key from the map, returning the stored key and value if the
544    /// key was previously in the map. Keeps the allocated memory for reuse.
545    #[inline]
546    pub(crate) fn remove_entry<Q>(&mut self, k: &Q) -> Option<(K, V)>
547    where
548        Q: ?Sized + Hash + Equivalent<K>,
549    {
550        let hash = make_hash::<Q, S>(self.hash_builder(), k);
551        self.raw.remove_entry(hash, equivalent_key(k))
552    }
553
554    /// Retains only the elements specified by the predicate.
555    #[inline]
556    pub(crate) fn retain<F>(&mut self, f: &mut F)
557    where
558        F: FnMut(&K, &mut V) -> bool,
559    {
560        self.raw.retain(f);
561    }
562
563    #[inline]
564    fn get_inner<Q>(&self, k: &Q) -> Option<&(K, V)>
565    where
566        Q: ?Sized + Hash + Equivalent<K>,
567    {
568        if self.is_empty() {
569            None
570        } else {
571            let hash = make_hash::<Q, S>(self.hash_builder(), k);
572            self.raw.get(hash, equivalent_key(k))
573        }
574    }
575
576    #[inline]
577    fn get_inner_mut<Q>(&mut self, k: &Q) -> Option<&mut (K, V)>
578    where
579        Q: ?Sized + Hash + Equivalent<K>,
580    {
581        if self.is_empty() {
582            None
583        } else {
584            let hash = make_hash::<Q, S>(self.hash_builder(), k);
585            self.raw.get_mut(hash, equivalent_key(k))
586        }
587    }
588}