cao_lang/collections/
hash_map.rs

1#[cfg(feature = "serde")]
2mod serde_impl;
3
4#[cfg(test)]
5mod tests;
6
7use std::{
8    alloc::Layout,
9    borrow::Borrow,
10    hash::{Hash, Hasher},
11    mem::swap,
12    ptr::NonNull,
13};
14
15use crate::alloc::{Allocator, SysAllocator};
16
17pub(crate) const MAX_LOAD: f32 = 0.7;
18
19type ArrayTriplet<K, V> = (NonNull<u8>, NonNull<K>, NonNull<V>);
20
21/// Hash map implemented for Cao-Lang
22pub struct CaoHashMap<K, V, A: Allocator = SysAllocator> {
23    /// beginning of the data, and the hash buffer
24    /// layout:
25    /// [hash hash hash][key key key][value value value]
26    data: NonNull<u8>,
27    /// begin of the keys array
28    keys: NonNull<K>,
29    /// begin of the values array
30    values: NonNull<V>,
31
32    count: usize,
33    capacity: usize,
34
35    alloc: A,
36}
37
38unsafe impl<K, V, A: Allocator + Send> Send for CaoHashMap<K, V, A> {}
39unsafe impl<K, V, A: Allocator + Send> Sync for CaoHashMap<K, V, A> {}
40
41impl<K, V, A> std::fmt::Debug for CaoHashMap<K, V, A>
42where
43    K: std::fmt::Debug,
44    V: std::fmt::Debug,
45    A: Allocator,
46{
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        let mut state = f.debug_map();
49        for (k, v) in self.iter() {
50            state.entry(k, v);
51        }
52        state.finish()
53    }
54}
55
56impl<K, V, A> Clone for CaoHashMap<K, V, A>
57where
58    K: Clone + Eq + Hash,
59    V: Clone,
60    A: Allocator + Clone,
61{
62    fn clone(&self) -> Self {
63        let mut result = CaoHashMap::with_capacity_in(self.capacity, self.alloc.clone()).unwrap();
64
65        // TODO: could use insert with hint
66        // or better yet, memcpy hashes, then clone the occupied entries
67        for (k, v) in self.iter() {
68            result.insert(k.clone(), v.clone()).unwrap();
69        }
70        result
71    }
72}
73
74impl<K, V, A: Allocator + Default> Default for CaoHashMap<K, V, A> {
75    fn default() -> Self {
76        CaoHashMap::with_capacity_in(0, A::default()).unwrap()
77    }
78}
79
80pub struct Entry<'a, K, V> {
81    hash: u64,
82    key: K,
83    pl: EntryPayload<'a, K, V>,
84}
85
86enum EntryPayload<'a, K, V> {
87    Occupied(&'a mut V),
88    Vacant {
89        hash: &'a mut u64,
90        key: *mut K,
91        value: *mut V,
92        count: &'a mut usize,
93    },
94}
95
96impl<'a, K, V> Entry<'a, K, V> {
97    pub fn or_insert_with<F: FnOnce() -> V>(self, fun: F) -> &'a mut V {
98        match self.pl {
99            EntryPayload::Occupied(res) => res,
100            EntryPayload::Vacant {
101                hash,
102                key,
103                value,
104                count,
105            } => {
106                *hash = self.hash;
107                unsafe {
108                    std::ptr::write(key, self.key);
109                    std::ptr::write(value, fun());
110                    *count += 1;
111                    &mut *value
112                }
113            }
114        }
115    }
116}
117
118#[derive(Debug, Clone, thiserror::Error)]
119pub enum MapError {
120    #[error("Failed to allocate memory {0}")]
121    AllocError(crate::alloc::AllocError),
122}
123
124impl<K, V, A: Allocator> Drop for CaoHashMap<K, V, A> {
125    fn drop(&mut self) {
126        self.clear();
127        let (layout, _) = Self::layout(self.capacity);
128        unsafe {
129            self.alloc.dealloc(self.data, layout);
130        }
131    }
132}
133
134impl<K, V, A: Allocator> CaoHashMap<K, V, A> {
135    pub fn len(&self) -> usize {
136        self.count
137    }
138
139    pub fn is_empty(&self) -> bool {
140        self.count == 0
141    }
142
143    pub fn with_capacity_in(capacity: usize, alloc: A) -> Result<Self, MapError> {
144        let capacity = capacity.max(1);
145        let (data, keys, values) = unsafe { Self::alloc_storage(&alloc, capacity)? };
146        let mut result = Self {
147            data,
148            keys,
149            values,
150            count: 0,
151            capacity,
152            alloc,
153        };
154        result.zero_hashes();
155        Ok(result)
156    }
157
158    /// # Safety
159    /// Caller must ensure that the hashes are zeroed
160    unsafe fn alloc_storage(alloc: &A, cap: usize) -> Result<ArrayTriplet<K, V>, MapError> {
161        let (layout, [ko, vo]) = Self::layout(cap);
162        let data = alloc.alloc(layout).map_err(MapError::AllocError)?;
163        let keys = data.as_ptr().add(ko).cast();
164        let values = data.as_ptr().add(vo).cast();
165        Ok((
166            data,
167            NonNull::new_unchecked(keys),
168            NonNull::new_unchecked(values),
169        ))
170    }
171
172    fn layout(cap: usize) -> (Layout, [usize; 2]) {
173        let hash_layout = Layout::array::<u64>(cap).unwrap();
174        let keys_layout = Layout::array::<K>(cap).unwrap();
175        let values_layout = Layout::array::<V>(cap).unwrap();
176
177        let (result, keys_offset) = hash_layout.extend(keys_layout).unwrap();
178        let (result, vals_offset) = result.extend(values_layout).unwrap();
179
180        (result, [keys_offset, vals_offset])
181    }
182
183    pub fn clear(&mut self) {
184        let handles = self.data.cast::<u64>().as_ptr();
185        let keys = self.keys.as_ptr();
186        let values = self.values.as_ptr();
187
188        unsafe {
189            clear_arrays(handles, keys, values, self.capacity);
190        }
191
192        self.count = 0;
193    }
194
195    pub fn insert(&mut self, key: K, value: V) -> Result<u64, MapError>
196    where
197        K: Eq + Hash,
198    {
199        let h = hash(&key);
200        unsafe { self.insert_with_hint(h, key, value).map(|_| h) }
201    }
202
203    /// # Safety
204    /// Caller must ensure that the hash is correct for the key
205    pub unsafe fn insert_with_hint(&mut self, h: u64, key: K, value: V) -> Result<(), MapError>
206    where
207        K: Eq,
208    {
209        debug_assert!(h != 0, "Bad handle, 0 values are reserved");
210
211        // find the bucket
212        let hashes = self.hashes();
213        let keys = self.keys.as_ptr();
214        let values = self.values.as_ptr();
215
216        let i = self.find_ind(h, &key);
217        if hashes[i] != 0 {
218            debug_assert_eq!(hashes[i], h);
219            // delete the old entry
220            if std::mem::needs_drop::<K>() {
221                std::ptr::drop_in_place(keys.add(i));
222            }
223            if std::mem::needs_drop::<V>() {
224                std::ptr::drop_in_place(values.add(i));
225            }
226        } else {
227            self.hashes_mut()[i] = h;
228            self.count += 1;
229        }
230        std::ptr::write(keys.add(i), key);
231        std::ptr::write(values.add(i), value);
232        // delaying grow so that no grow is triggered if the key overrides an existing value
233        if Self::needs_grow(self.count, self.capacity) {
234            self.grow()?;
235        }
236        Ok(())
237    }
238
239    fn needs_grow(count: usize, capacity: usize) -> bool {
240        count as f32 > capacity as f32 * MAX_LOAD
241    }
242
243    pub fn reserve(&mut self, additional_cap: usize) -> Result<(), MapError>
244    where
245        K: Eq,
246    {
247        unsafe { self.adjust_capacity(self.capacity + additional_cap) }
248    }
249
250    fn grow(&mut self) -> Result<(), MapError>
251    where
252        K: Eq,
253    {
254        let new_cap = (self.capacity.max(2) * 3) / 2;
255        debug_assert!(new_cap > self.capacity);
256        unsafe { self.adjust_capacity(new_cap) }
257    }
258
259    unsafe fn adjust_capacity(&mut self, capacity: usize) -> Result<(), MapError>
260    where
261        K: Eq,
262    {
263        let (mut data, mut keys, mut values) = Self::alloc_storage(&self.alloc, capacity)?;
264        swap(&mut self.data, &mut data);
265        swap(&mut self.keys, &mut keys);
266        swap(&mut self.values, &mut values);
267        let capacity = std::mem::replace(&mut self.capacity, capacity);
268        self.zero_hashes();
269        let count = std::mem::replace(&mut self.count, 0); // insert will increment count
270                                                           // copy over the existing values
271        for i in 0..capacity {
272            let hash = *data.as_ptr().cast::<u64>().add(i);
273            if hash != 0 {
274                let key = std::ptr::read(keys.as_ptr().add(i));
275                let val = std::ptr::read(values.as_ptr().add(i));
276                self.insert_with_hint(hash, key, val)?;
277            }
278        }
279
280        assert_eq!(
281            count, self.count,
282            "Internal error: moving the values after realloc resulted in inconsistent count"
283        );
284
285        // free up the old storage
286        let (layout, _) = Self::layout(capacity);
287        self.alloc.dealloc(data, layout);
288
289        Ok(())
290    }
291
292    pub fn remove<Q: ?Sized>(&mut self, key: &Q) -> Option<V>
293    where
294        K: Borrow<Q>,
295        Q: Eq + Hash,
296    {
297        let hash = hash(key);
298        unsafe { self.remove_with_hint(hash, key) }
299    }
300
301    /// # Safety
302    ///
303    /// Hash must be produced from the key
304    pub unsafe fn remove_with_hint<Q: ?Sized>(&mut self, hash: u64, key: &Q) -> Option<V>
305    where
306        K: Borrow<Q>,
307        Q: Eq,
308    {
309        let i = self.find_ind(hash, key);
310        if self.hashes()[i] != 0 {
311            if std::mem::needs_drop::<K>() {
312                std::ptr::drop_in_place(self.keys.as_ptr().add(i));
313            }
314
315            let result = std::ptr::read(self.values.as_ptr().add(i));
316            self.hashes_mut()[i] = 0;
317
318            // if the consecutive buckets are not empty, move them back, so lookups dont fail
319            // and they aren't in their optimal position
320            //
321            let mut i = i; // track the last empty slot
322            let mut j = (i + 1) % self.capacity();
323            while self.hashes()[j] != 0 {
324                // if the jth item is not in its optimal bucket, then move it back to the empty
325                // slot
326                if (self.hashes()[j] % self.capacity() as u64) != j as u64 {
327                    self.hashes_mut()[i] = self.hashes()[j];
328                    std::ptr::swap(self.keys.as_ptr().add(i), self.keys.as_ptr().add(j));
329                    std::ptr::swap(self.values.as_ptr().add(i), self.values.as_ptr().add(j));
330                    i = j;
331                }
332                j = (j + 1) % self.capacity();
333            }
334
335            return Some(result);
336        }
337        None
338    }
339
340    pub fn contains<Q: ?Sized>(&self, key: &Q) -> bool
341    where
342        K: Borrow<Q>,
343        Q: Eq + Hash,
344    {
345        let hash = hash(key);
346        unsafe { self.contains_with_hint(hash, key) }
347    }
348
349    /// # Safety
350    ///
351    /// Hash must be produced from the key
352    pub unsafe fn contains_with_hint<Q: ?Sized>(&self, h: u64, k: &Q) -> bool
353    where
354        K: Borrow<Q>,
355        Q: Eq + Hash,
356    {
357        let i = self.find_ind(h, k);
358        self.hashes()[i] != 0
359    }
360
361    pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
362    where
363        K: Borrow<Q>,
364        Q: Eq + Hash,
365    {
366        let hash = hash(key);
367        unsafe { self.get_with_hint(hash, key) }
368    }
369
370    /// # Safety
371    ///
372    /// Hash must be produced from the key
373    pub unsafe fn get_with_hint<Q: ?Sized>(&self, h: u64, k: &Q) -> Option<&V>
374    where
375        K: Borrow<Q>,
376        Q: Eq,
377    {
378        let i = self.find_ind(h, k);
379        if self.hashes()[i] != 0 {
380            Some(&*self.values.as_ptr().add(i))
381        } else {
382            None
383        }
384    }
385
386    pub fn get_mut<Q: ?Sized>(&mut self, key: &Q) -> Option<&mut V>
387    where
388        K: Borrow<Q>,
389        Q: Eq + Hash,
390    {
391        let hash = hash(key);
392        unsafe { self.get_with_hint_mut(hash, key) }
393    }
394
395    /// # Safety
396    ///
397    /// Hash must be produced from the key
398    pub unsafe fn get_with_hint_mut<Q: ?Sized>(&mut self, h: u64, k: &Q) -> Option<&mut V>
399    where
400        K: Borrow<Q>,
401        Q: Eq + Hash,
402    {
403        let i = self.find_ind(h, k);
404        if self.hashes()[i] != 0 {
405            Some(&mut *self.values.as_ptr().add(i))
406        } else {
407            None
408        }
409    }
410
411    fn find_ind<Q: ?Sized>(&self, needle: u64, k: &Q) -> usize
412    where
413        K: Borrow<Q>,
414        Q: Eq,
415    {
416        let len = self.capacity;
417
418        // improve uniformity via fibonacci hashing
419        // in wasm sizeof usize is 4, so multiply our already 32 bit hash
420        let mut ind = (needle.wrapping_mul(2654435769) as usize) % len;
421        let hashes = self.hashes();
422        let keys = self.keys.as_ptr();
423        loop {
424            unsafe {
425                debug_assert!(ind < len);
426                let h = hashes[ind];
427                if h == 0 || (h == needle && (*keys.add(ind)).borrow() == k) {
428                    return ind;
429                }
430            }
431            ind = (ind + 1) % len;
432        }
433    }
434
435    fn hashes(&self) -> &[u64] {
436        unsafe { std::slice::from_raw_parts(self.data.as_ptr().cast(), self.capacity) }
437    }
438
439    fn hashes_mut(&mut self) -> &mut [u64] {
440        unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr().cast(), self.capacity) }
441    }
442
443    /// Zero-out the hash buffer
444    ///
445    /// Call this function after a fresh alloc of the data buffer
446    fn zero_hashes(&mut self) {
447        self.hashes_mut().fill(0u64);
448    }
449
450    /// This method eagerly allocated new buffers, if inserting via the entry
451    /// would grow the buffer beyong its max load
452    pub fn entry(&mut self, key: K) -> Result<Entry<K, V>, MapError>
453    where
454        K: Eq + Hash,
455    {
456        let hash = hash(&key);
457        let i = self.find_ind(hash, &key);
458        let pl;
459        if self.hashes()[i] != 0 {
460            pl = EntryPayload::Occupied(unsafe { &mut *self.values.as_ptr().add(i) });
461        } else {
462            // if it would need to grow on insert, then allocate the new buffer now
463            if Self::needs_grow(self.count + 1, self.capacity) {
464                self.grow()?;
465            }
466            unsafe {
467                pl = EntryPayload::Vacant {
468                    hash: &mut *self.data.cast::<u64>().as_ptr().add(i),
469                    key: self.keys.as_ptr().add(i),
470                    value: self.values.as_ptr().add(i),
471                    count: &mut self.count,
472                }
473            }
474        }
475        Ok(Entry { hash, key, pl })
476    }
477
478    pub fn capacity(&self) -> usize {
479        self.capacity
480    }
481
482    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
483        (0..self.capacity)
484            .filter(|i| self.hashes()[*i] != 0)
485            .map(|i| unsafe { (&*self.keys.as_ptr().add(i), &*self.values.as_ptr().add(i)) })
486    }
487
488    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&K, &mut V)> {
489        (0..self.capacity)
490            .filter(|i| self.hashes()[*i] != 0)
491            .map(|i| unsafe {
492                (
493                    &*self.keys.as_ptr().add(i),
494                    &mut *self.values.as_ptr().add(i),
495                )
496            })
497    }
498}
499
500struct CaoHasher(u64);
501impl Default for CaoHasher {
502    fn default() -> Self {
503        Self(2166136261)
504    }
505}
506
507impl Hasher for CaoHasher {
508    fn finish(&self) -> u64 {
509        self.0
510    }
511
512    fn write(&mut self, bytes: &[u8]) {
513        const MASK: u64 = u32::MAX as u64;
514        let mut hash = self.0;
515        for byte in bytes {
516            hash ^= *byte as u64;
517            hash &= MASK;
518            hash *= 16777619;
519        }
520        self.0 = hash & MASK;
521    }
522}
523
524fn hash<T: ?Sized + Hash>(t: &T) -> u64 {
525    let mut hasher = CaoHasher::default();
526    t.hash(&mut hasher);
527    let result = hasher.finish();
528    debug_assert_ne!(result, 0, "0 hash is reserved");
529    result
530}
531
532/// # Safety
533///
534/// Must be called with valid arrays in a CaoHashMap
535unsafe fn clear_arrays<K, V>(handles: *mut u64, keys: *mut K, values: *mut V, count: usize) {
536    for i in 0..count {
537        if (*handles.add(i)) != 0 {
538            *handles.add(i) = 0;
539            if std::mem::needs_drop::<K>() {
540                std::ptr::drop_in_place(keys.add(i));
541            }
542            if std::mem::needs_drop::<V>() {
543                std::ptr::drop_in_place(values.add(i));
544            }
545        }
546    }
547}