Skip to main content

compact_dict/
dict.rs

1#[path = "./keys_container.rs"]
2mod keys_container;
3use keys_container::KeysContainer;
4
5use std::convert::{TryFrom, TryInto};
6use std::simd::prelude::*;
7
8#[path = "./ahash.rs"]
9pub mod ahash;
10// use ahash::{StrHash, FxStrHash, MojoAHashStrHash, AHashStrHash};
11
12/// Open-addressing dictionary with linear probing and 1-based slot_to_index like your Mojo Dict.
13/// - V: Copy (to mirror Copyable & Movable)
14/// - H: BuildHasher/StrHash (default aHash RandomState)
15/// - KC: key-count integer (u32 default)
16/// - KO: key-offset integer for KeysContainer (u32 default)
17#[cfg_attr(feature = "rkyv", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))]
18pub struct Dict<
19    V: Copy,
20    H: ahash::StrHash = ahash::MojoAHashStrHash,
21    KC: TryInto<usize> + From<u8> + From<u16> + TryFrom<u32> + TryFrom<usize> + Copy + PartialEq = u32,
22    KO: TryFrom<usize> + Copy + TryInto<usize> = u32,
23    const DESTRUCTIVE: bool = true,
24    const CACHING_HASHES: bool = true,
25> {
26    keys: KeysContainer<KO>,
27    key_hashes: Option<Vec<KC>>, // present if CACHING_HASHES
28    values: Vec<V>,
29    slot_to_index: Vec<KC>,        // 0 = empty, else index+1
30    deleted_mask: Option<Vec<u8>>, // bit per key index if DESTRUCTIVE
31    count: usize,                  // active (non-deleted) entries
32    capacity: usize,               // power of two, >= 8
33    hasher: H,
34}
35
36#[allow(dead_code)]
37impl<
38    V: Copy,
39    H: ahash::StrHash + Default,
40    KC: TryInto<usize> + From<u8> + From<u16> + TryFrom<u32> + TryFrom<usize> + Copy + PartialEq,
41    KO: TryFrom<usize> + Copy + TryInto<usize>,
42    const DESTRUCTIVE: bool,
43    const CACHING_HASHES: bool,
44> Dict<V, H, KC, KO, DESTRUCTIVE, CACHING_HASHES>
45{
46    pub fn new(capacity: usize) -> Self {
47
48        let capacity = capacity.max(8).next_power_of_two();
49
50        let slot_to_index = vec![
51            KC::try_from(0usize)
52                .ok()
53                .expect("1 usize -> KeyEndType conversion failed");
54            capacity
55        ];
56        let key_hashes = if CACHING_HASHES {
57            Some(vec![
58                KC::try_from(0usize)
59                    .ok()
60                    .expect("3 usize -> KeyEndType conversion failed");
61                capacity
62            ])
63        } else {
64            None
65        };
66
67        let deleted_mask = if DESTRUCTIVE {
68            // one bit per key index; we size by capacity/8 like your code (mask for keys)
69            Some(vec![0u8; capacity >> 3])
70        } else {
71            None
72        };
73
74        Self {
75            keys: KeysContainer::<KO>::new(capacity),
76            key_hashes,
77            values: Vec::with_capacity(capacity),
78            slot_to_index,
79            deleted_mask,
80            count: 0,
81            capacity,
82            hasher: H::default(),
83        }
84    }
85
86    #[inline]
87    pub fn len(&self) -> usize {
88        self.count
89    }
90
91    #[inline]
92    pub fn contains(&self, key: &str) -> bool {
93        self.find_key_index(key) != 0
94    }
95
96    #[inline]
97    fn load_slot(&self, slot: usize) -> usize {
98        self.slot_to_index[slot]
99            .try_into()
100            .ok()
101            .expect("KeyEndType -> usize conversion failed")
102    }
103
104    #[inline]
105    fn store_slot(&mut self, slot: usize, val: usize) {
106        self.slot_to_index[slot] = KC::try_from(val)
107            .ok()
108            .expect("4 usize -> KeyEndType conversion failed");
109    }
110
111    #[inline(always)]
112    fn probe_simd_32(&self, key_hash_truncated: usize, slot: usize, _modulo_mask: usize) -> Option<usize> {
113        if !CACHING_HASHES || std::mem::size_of::<KC>() != 4 {
114            return None; // Only support u32 hashes for now
115        }
116        
117        // We want to load 16 elements (64 bytes) of u32 from the key_hashes array from slot
118        // and compare them with the truncated hash.
119        let hashes = self.key_hashes.as_ref().unwrap();
120        let target: Simd<u32, 16> = Simd::splat(key_hash_truncated as u32);
121        
122        let current_slot = slot;
123        
124        // Ensure we don't read past the end of the array, modulo takes care of wrapping.
125        // We'll just do a single lane sweep.
126        if current_slot + 16 <= self.capacity {
127            let chunk = &hashes[current_slot..current_slot+16];
128            // Read through unsafe pointer because KC might vary but we verified it's 4 bytes
129            let ptr: *const u32 = chunk.as_ptr() as *const u32;
130            let loaded: Simd<u32, 16> = unsafe { Simd::from_slice(std::slice::from_raw_parts(ptr, 16)) };
131            
132            let cmp = loaded.simd_eq(target);
133            if cmp.any() {
134                // Find first match
135                return Some(current_slot + cmp.first_set().unwrap());
136            }
137        }
138        None
139    }
140    
141    #[inline]
142    fn is_deleted(&self, index: usize) -> bool {
143        if !DESTRUCTIVE {
144            return false;
145        }
146        let dm = self.deleted_mask.as_ref().unwrap();
147        let byte = index >> 3;
148        let bit = index & 7;
149        (dm[byte] & (1 << bit)) != 0
150    }
151
152    #[inline]
153    fn set_deleted(&mut self, index: usize) {
154        if !DESTRUCTIVE {
155            return;
156        }
157        let dm = self.deleted_mask.as_mut().unwrap();
158        let byte = index >> 3;
159        let bit = index & 7;
160        dm[byte] |= 1 << bit;
161    }
162
163    #[inline]
164    fn clear_deleted(&mut self, index: usize) {
165        if !DESTRUCTIVE {
166            return;
167        }
168        let dm = self.deleted_mask.as_mut().unwrap();
169        let byte = index >> 3;
170        let bit = index & 7;
171        dm[byte] &= !(1 << bit);
172    }
173
174    fn maybe_rehash(&mut self) {
175        // Mojo used: if self.count / self.capacity >= 0.87 -> rehash
176        // We'll emulate with >= 87% load factor.
177        if self.count * 100 >= self.capacity * 87 {
178            self.rehash();
179        }
180    }
181
182    fn rehash(&mut self) {
183        let old_cap = self.capacity;
184        let old_slots = std::mem::take(&mut self.slot_to_index);
185        let old_hashes = if CACHING_HASHES {
186            std::mem::take(&mut self.key_hashes)
187        } else {
188            None
189        };
190
191        self.capacity <<= 1;
192        self.slot_to_index = vec![
193            KC::try_from(0usize)
194                .ok()
195                .expect("5 usize -> KeyEndType conversion failed");
196            self.capacity
197        ];
198
199        if DESTRUCTIVE {
200            let mut new_mask = vec![0u8; self.capacity >> 3];
201            if let Some(old_mask) = self.deleted_mask.as_ref() {
202                let to_copy = old_mask.len().min(new_mask.len());
203                new_mask[..to_copy].copy_from_slice(&old_mask[..to_copy]);
204            }
205            self.deleted_mask = Some(new_mask);
206        }
207
208        let modulo_mask = self.capacity - 1;
209        let mut new_hashes = if CACHING_HASHES {
210            Some(vec![
211                KC::try_from(0usize)
212                    .ok()
213                    .expect("6 usize -> KeyEndType conversion failed");
214                self.capacity
215            ])
216        } else {
217            None
218        };
219
220        for i in 0..old_cap {
221            let key_index: usize = old_slots[i]
222                .try_into()
223                .ok()
224                .expect("KeyEndType -> usize conversion failed");
225            if key_index == 0 {
226                continue;
227            }
228
229            let idx0 = key_index - 1;
230            let k = self.keys.get(idx0).unwrap();
231
232            // pull cached (already truncated) or recompute and truncate
233            let key_hash_truncated: usize = if CACHING_HASHES {
234                old_hashes.as_ref().unwrap()[i]
235                    .try_into()
236                    .ok()
237                    .expect("KC -> usize conversion failed")
238            } else {
239                let h = self.hasher.hash(k);
240                (h as usize) & ((1 << (std::mem::size_of::<KC>() * 8)) - 1)
241            };
242
243            let mut slot = key_hash_truncated & modulo_mask;
244            loop {
245                if self.load_slot(slot) == 0 {
246                    self.store_slot(slot, key_index);
247                    if CACHING_HASHES {
248                        new_hashes.as_mut().unwrap()[slot] = KC::try_from(key_hash_truncated)
249                            .ok()
250                            .expect("7 usize -> KeyEndType conversion failed");
251                    }
252                    break;
253                }
254                slot = (slot + 1) & modulo_mask;
255            }
256        }
257
258        if CACHING_HASHES {
259            self.key_hashes = new_hashes;
260        }
261        let _ = old_hashes;
262    }
263
264    pub fn put(&mut self, key: &str, value: V) {
265
266        self.maybe_rehash();
267
268        let key_hash_u64 = self.hasher.hash(key);
269        let key_hash_truncated =
270            (key_hash_u64 as usize) & ((1 << (std::mem::size_of::<KC>() * 8)) - 1);
271
272        let modulo_mask = self.capacity - 1;
273        let mut slot = key_hash_truncated & modulo_mask;
274
275        loop {
276            let key_index = self.load_slot(slot);
277            if key_index == 0 {
278                // insert fresh
279                self.keys.add(key);
280                if CACHING_HASHES {
281                    self.key_hashes.as_mut().unwrap()[slot] = KC::try_from(key_hash_truncated)
282                        .ok()
283                        .expect("8 usize -> KeyEndType conversion failed");
284                }
285                self.values.push(value);
286                self.store_slot(slot, self.keys.len()); // 1-based
287                self.count += 1;
288                return;
289            }
290
291            // collision path
292            if CACHING_HASHES {
293                let other_hash: usize = self.key_hashes.as_ref().unwrap()[slot]
294                    .try_into()
295                    .ok()
296                    .expect("KC -> usize conversion failed");
297                if other_hash == key_hash_truncated {
298                    let other_key = self.keys.get(key_index - 1).unwrap();
299                    if other_key == key {
300                        // replace value
301                        let idx0 = key_index - 1;
302                        self.values[idx0] = value;
303                        if DESTRUCTIVE && self.is_deleted(idx0) {
304                            self.count += 1;
305                            self.clear_deleted(idx0);
306                        }
307                        return;
308                    }
309                }
310            } else {
311                let other_key = self.keys.get(key_index - 1).unwrap();
312                if other_key == key {
313                    let idx0 = key_index - 1;
314                    self.values[idx0] = value;
315                    if DESTRUCTIVE && self.is_deleted(idx0) {
316                        self.count += 1;
317                        self.clear_deleted(idx0);
318                    }
319                    return;
320                }
321            }
322
323            slot = (slot + 1) & modulo_mask;
324        }
325    }
326
327    pub fn get_or(&self, key: &str, default: V) -> V {
328        let key_index = self.find_key_index(key);
329        if key_index == 0 {
330            return default;
331        }
332        if DESTRUCTIVE {
333            if self.is_deleted(key_index - 1) {
334                return default;
335            }
336        }
337        self.values[key_index - 1]
338    }
339
340    // pub fn calc(&mut self, key: &str, f: impl Fn(V) -> V) {
341    //     let key_index = self.find_key_index(key);
342    //     if key_index != 0 {
343    //         let idx0 = key_index - 1;
344    //         self.values[idx0] = f(self.values[idx0]);
345    //     }
346    // }
347
348    // pub fn delete(&mut self, key: &str) {
349    //     if !DESTRUCTIVE {
350    //         return;
351    //     }
352    //     let key_index = self.find_key_index(key);
353    //     if key_index == 0 {
354    //         return;
355    //     }
356    //     let idx0 = key_index - 1;
357    //     if !self.is_deleted(idx0) {
358    //         self.count -= 1;
359    //     }
360    //     self.set_deleted(idx0);
361    // }
362
363    // pub fn upsert(&mut self, key: &str, update: impl Fn(Option<V>) -> V) {
364    //     let mut key_index = self.find_key_index(key);
365    //     if key_index == 0 {
366    //         let v = update(None);
367    //         self.put(key, v);
368    //         return;
369    //     }
370    //     key_index -= 1;
371    //
372    //     if DESTRUCTIVE && self.is_deleted(key_index) {
373    //         self.values[key_index] = update(None);
374    //         return;
375    //     }
376    //     self.values[key_index] = update(Some(self.values[key_index]));
377    // }
378
379    pub fn clear(&mut self) {
380        self.values.clear();
381        self.keys.clear();
382        for x in &mut self.slot_to_index {
383            *x = KC::try_from(0usize)
384                .ok()
385                .expect("9 usize -> KeyEndType conversion failed");
386        }
387        if DESTRUCTIVE {
388            for b in self.deleted_mask.as_mut().unwrap().iter_mut() {
389                *b = 0;
390            }
391        }
392        self.count = 0;
393    }
394
395    #[inline]
396    fn find_key_index(&self, key: &str) -> usize {
397        let key_hash_u64 = self.hasher.hash(key);
398        let key_hash_truncated = (key_hash_u64 as usize) & ((1 << (std::mem::size_of::<KC>() * 8)) - 1);
399        let modulo_mask = self.capacity - 1;
400        let mut slot = key_hash_truncated & modulo_mask;
401
402        // Fast SIMD path for cached u32 hashes
403        if CACHING_HASHES && std::mem::size_of::<KC>() == 4 {
404            let target: Simd<u32, 16> = Simd::splat(key_hash_truncated as u32);
405            let empty_target: Simd<u32, 16> = Simd::splat(0);
406            let hashes = self.key_hashes.as_ref().unwrap();
407            
408            let mut current_slot = slot;
409            loop {
410                // Read from cache to check for 0 quickly
411                let key_index = self.load_slot(current_slot);
412                if key_index == 0 {
413                    return 0; // Empty slot found, key doesn't exist
414                }
415
416                if current_slot + 16 <= self.capacity {
417                    let chunk = &hashes[current_slot..current_slot+16];
418                    let ptr: *const u32 = chunk.as_ptr() as *const u32;
419                    let loaded: Simd<u32, 16> = unsafe { Simd::from_slice(std::slice::from_raw_parts(ptr, 16)) };
420                    
421                    let cmp = loaded.simd_eq(target);
422                    let empty_cmp = loaded.simd_eq(empty_target);
423                    
424                    let match_mask = cmp.to_bitmask();
425                    let empty_mask = empty_cmp.to_bitmask();
426
427                    // If we have matches, we must check them all in order
428                    if match_mask != 0 {
429                        let mut remaining_matches = match_mask;
430                        while remaining_matches != 0 {
431                            let match_idx = remaining_matches.trailing_zeros();
432                            
433                            // If we see an empty slot BEFORE this match, the chain is broken
434                            if empty_mask != 0 && empty_mask.trailing_zeros() < match_idx {
435                                // Double check if it's truly an empty slot (index == 0)
436                                let empty_idx = empty_mask.trailing_zeros();
437                                if self.load_slot(current_slot + empty_idx as usize) == 0 {
438                                    return 0;
439                                }
440                            }
441
442                            let match_slot = current_slot + match_idx as usize;
443                            let candidate_index = self.load_slot(match_slot);
444                            if candidate_index != 0 {
445                                let other_key = self.keys.get(candidate_index - 1).unwrap();
446                                if other_key == key {
447                                    return candidate_index;
448                                }
449                            }
450                            
451                            // Clear this bit and continue checking other matches
452                            remaining_matches &= !(1 << match_idx);
453                        }
454                    }
455
456                    // No matches in this 16-lane, or false positives. Can we skip ahead?
457                    if empty_mask != 0 {
458                        // Yes! There is an empty slot in this block.
459                        // We must process up to that empty slot to ensure correctness.
460                        let empty_idx = empty_mask.trailing_zeros();
461                        if self.load_slot(current_slot + empty_idx as usize) == 0 {
462                            // We hit a true empty slot, end of chain
463                            return 0;
464                        }
465                        // It was an empty hash (0) but index != 0 (maybe deleted but rehashing?)
466                        // We just fall through to the scalar advance
467                    } else {
468                        // Block is completely full with no spaces, we can advance by 16 safely!
469                        current_slot = (current_slot + 16) & modulo_mask;
470                        continue;
471                    }
472                }
473
474                // Standard fallback linear probe check for edges or when SIMD cannot jump
475                let other_hash: usize = hashes[current_slot]
476                    .try_into()
477                    .ok()
478                    .expect("KC -> usize conversion failed");
479                    
480                if other_hash == key_hash_truncated {
481                    let other_key = self.keys.get(key_index - 1).unwrap();
482                    if other_key == key {
483                        return key_index;
484                    }
485                }
486
487                current_slot = (current_slot + 1) & modulo_mask;
488            }
489        }
490
491        // --- Original scalar fallback path ---
492        loop {
493            let key_index = self.load_slot(slot);
494            if key_index == 0 {
495                return 0;
496            }
497
498            if CACHING_HASHES {
499                let other_hash: usize = self.key_hashes.as_ref().unwrap()[slot]
500                    .try_into()
501                    .ok()
502                    .expect("KC -> usize conversion failed");
503                if other_hash == key_hash_truncated {
504                    let other_key = self.keys.get(key_index - 1).unwrap();
505                    if other_key == key {
506                        return key_index;
507                    }
508                }
509            } else {
510                let other_key = self.keys.get(key_index - 1).unwrap();
511                if other_key == key {
512                    return key_index;
513                }
514            }
515
516            slot = (slot + 1) & modulo_mask;
517        }
518    }
519
520    /// Debug print similar to your `debug()` method.
521    pub fn debug(&self) {
522        println!("Dict count: {} and capacity: {}", self.count, self.capacity);
523        println!("KeyMap:");
524        for i in 0..self.capacity {
525            print!(
526                "{}{}",
527                self.slot_to_index[i]
528                    .try_into()
529                    .ok()
530                    .expect("KC -> usize conversion failed"),
531                if i + 1 < self.capacity { ", " } else { "\n" }
532            );
533        }
534        println!("Keys:");
535        print!("({})[", self.keys.len());
536        for i in 0..self.keys.len() {
537            if i > 0 {
538                print!(", ");
539            }
540            print!("{}", self.keys.get(i).unwrap());
541        }
542        println!("]");
543        if CACHING_HASHES {
544            println!("KeyHashes:");
545            for i in 0..self.capacity {
546                let v = if self.load_slot(i) > 0 {
547                    (self.key_hashes.as_ref().unwrap()[i]
548                        .try_into()
549                        .ok()
550                        .expect("KC -> usize conversion failed")) as usize
551                } else {
552                    0
553                };
554                print!("{}{}", v, if i + 1 < self.capacity { ", " } else { "\n" });
555            }
556        }
557    }
558}
559
560#[cfg(feature = "rkyv")]
561impl<
562    V: Copy + rkyv::Archive,
563    H: ahash::StrHash + rkyv::Archive + Default,
564    KC: TryInto<usize> + From<u8> + From<u16> + TryFrom<u32> + TryFrom<usize> + Copy + PartialEq + rkyv::Archive,
565    KO: TryFrom<usize> + Copy + TryInto<usize> + rkyv::Archive,
566    const DESTRUCTIVE: bool,
567    const CACHING_HASHES: bool,
568> ArchivedDict<V, H, KC, KO, DESTRUCTIVE, CACHING_HASHES>
569where
570    V::Archived: Copy + Into<V>,
571    KC::Archived: Copy,
572    usize: TryFrom<KC::Archived>,
573    KO::Archived: Copy,
574    usize: TryFrom<KO::Archived>,
575{
576    pub fn get_or(&self, key: &str, default: V) -> V {
577        self.get(key).unwrap_or(default)
578    }
579
580    pub fn get(&self, key: &str) -> Option<V> {
581        let key_index = self.find_key_index(key);
582        if key_index == 0 {
583            return None;
584        }
585        if DESTRUCTIVE {
586            if self.is_deleted(key_index - 1) {
587                return None;
588            }
589        }
590        Some(self.values[key_index - 1].into())
591    }
592
593    pub fn contains(&self, key: &str) -> bool {
594        let key_index = self.find_key_index(key);
595        if key_index == 0 {
596            return false;
597        }
598        if DESTRUCTIVE {
599            if self.is_deleted(key_index - 1) {
600                return false;
601            }
602        }
603        true
604    }
605
606    #[inline]
607    fn load_slot(&self, slot: usize) -> usize {
608        usize::try_from(self.slot_to_index[slot]).unwrap_or(0)
609    }
610
611    #[inline]
612    fn is_deleted(&self, index: usize) -> bool {
613        if !DESTRUCTIVE {
614            return false;
615        }
616        if let Some(dm) = self.deleted_mask.as_ref() {
617            let byte = index >> 3;
618            let bit = index & 7;
619            (dm[byte] & (1 << bit)) != 0
620        } else {
621            false
622        }
623    }
624
625    #[inline]
626    fn find_key_index(&self, key: &str) -> usize {
627        let key_hash_u64 = H::default().hash(key);
628        let key_hash_truncated = (key_hash_u64 as usize) & ((1 << (core::mem::size_of::<KC>() * 8)) - 1);
629        let capacity: usize = (self.capacity).try_into().unwrap_or(0);
630        let modulo_mask = capacity - 1;
631        let mut slot = key_hash_truncated & modulo_mask;
632
633        loop {
634            let key_index = self.load_slot(slot);
635            if key_index == 0 {
636                return 0;
637            }
638
639            if CACHING_HASHES {
640                if let Some(hashes) = self.key_hashes.as_ref() {
641                    let other_hash: usize = usize::try_from(hashes[slot]).unwrap_or(0);
642                    if other_hash == key_hash_truncated {
643                        if let Some(other_key) = self.keys.get(key_index - 1) {
644                            if other_key == key {
645                                return key_index;
646                            }
647                        }
648                    }
649                }
650            } else {
651                if let Some(other_key) = self.keys.get(key_index - 1) {
652                    if other_key == key {
653                        return key_index;
654                    }
655                }
656            }
657
658            slot = (slot + 1) & modulo_mask;
659        }
660    }
661}