sokoban/
hash_table.rs

1use crate::node_allocator::{
2    FromSlice, NodeAllocator, NodeAllocatorMap, NodeField, ZeroCopy, SENTINEL,
3};
4use bytemuck::{Pod, Zeroable};
5use std::collections::hash_map::DefaultHasher;
6use std::hash::Hasher;
7use std::{
8    hash::Hash,
9    ops::{Index, IndexMut},
10};
11
12#[repr(C)]
13#[derive(Default, Copy, Clone)]
14pub struct HashNode<
15    K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
16    V: Default + Copy + Clone + Pod + Zeroable,
17> {
18    pub key: K,
19    pub value: V,
20}
21
22unsafe impl<
23        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
24        V: Default + Copy + Clone + Pod + Zeroable,
25    > Zeroable for HashNode<K, V>
26{
27}
28unsafe impl<
29        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
30        V: Default + Copy + Clone + Pod + Zeroable,
31    > Pod for HashNode<K, V>
32{
33}
34
35impl<
36        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
37        V: Default + Copy + Clone + Pod + Zeroable,
38    > HashNode<K, V>
39{
40    pub fn new(key: K, value: V) -> Self {
41        Self { key, value }
42    }
43}
44
45#[repr(C)]
46#[derive(Copy, Clone)]
47pub struct HashTable<
48    K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
49    V: Default + Copy + Clone + Pod + Zeroable,
50    const NUM_BUCKETS: usize,
51    const MAX_SIZE: usize,
52> {
53    pub buckets: [u32; NUM_BUCKETS],
54    pub allocator: NodeAllocator<HashNode<K, V>, MAX_SIZE, 4>,
55}
56
57unsafe impl<
58        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
59        V: Default + Copy + Clone + Pod + Zeroable,
60        const NUM_BUCKETS: usize,
61        const MAX_SIZE: usize,
62    > Zeroable for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
63{
64}
65unsafe impl<
66        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
67        V: Default + Copy + Clone + Pod + Zeroable,
68        const NUM_BUCKETS: usize,
69        const MAX_SIZE: usize,
70    > Pod for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
71{
72}
73
74impl<
75        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
76        V: Default + Copy + Clone + Pod + Zeroable,
77        const NUM_BUCKETS: usize,
78        const MAX_SIZE: usize,
79    > ZeroCopy for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
80{
81}
82
83impl<
84        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
85        V: Default + Copy + Clone + Pod + Zeroable,
86        const NUM_BUCKETS: usize,
87        const MAX_SIZE: usize,
88    > Default for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
89{
90    fn default() -> Self {
91        Self::assert_proper_alignment();
92        HashTable {
93            buckets: [SENTINEL; NUM_BUCKETS],
94            allocator: NodeAllocator::<HashNode<K, V>, MAX_SIZE, 4>::default(),
95        }
96    }
97}
98
99impl<
100        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
101        V: Default + Copy + Clone + Pod + Zeroable,
102        const NUM_BUCKETS: usize,
103        const MAX_SIZE: usize,
104    > NodeAllocatorMap<K, V> for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
105{
106    fn insert(&mut self, key: K, value: V) -> Option<u32> {
107        self._insert(key, value)
108    }
109
110    fn remove(&mut self, key: &K) -> Option<V> {
111        self._remove(key)
112    }
113
114    fn contains(&self, key: &K) -> bool {
115        self.get(key).is_some()
116    }
117
118    fn get(&self, key: &K) -> Option<&V> {
119        let mut hasher = DefaultHasher::new();
120        key.hash(&mut hasher);
121        let bucket_index = hasher.finish() as usize % NUM_BUCKETS;
122        let mut curr_node = self.buckets[bucket_index];
123        while curr_node != SENTINEL {
124            let node = self.get_node(curr_node);
125            if node.key == *key {
126                return Some(&node.value);
127            } else {
128                curr_node = self.get_next(curr_node);
129            }
130        }
131        None
132    }
133
134    fn get_mut(&mut self, key: &K) -> Option<&mut V> {
135        let mut hasher = DefaultHasher::new();
136        key.hash(&mut hasher);
137        let bucket_index = hasher.finish() as usize % NUM_BUCKETS;
138        let head = self.buckets[bucket_index];
139        let mut curr_node = head;
140        while curr_node != SENTINEL {
141            let node = self.get_node(curr_node);
142            if node.key == *key {
143                // If get_mut is called, we move the matched node to the front of the queue
144                let prev = self.get_prev(curr_node);
145                let next = self.get_next(curr_node);
146                if curr_node != head {
147                    self.allocator
148                        .clear_register(curr_node, NodeField::Left as u32);
149                    self.allocator.connect(
150                        prev,
151                        next,
152                        NodeField::Right as u32,
153                        NodeField::Left as u32,
154                    );
155                    self.allocator.connect(
156                        curr_node,
157                        head,
158                        NodeField::Right as u32,
159                        NodeField::Left as u32,
160                    );
161                }
162                self.buckets[bucket_index] = curr_node;
163                return Some(&mut self.get_node_mut(curr_node).value);
164            } else {
165                curr_node = self.get_next(curr_node);
166            }
167        }
168        None
169    }
170
171    fn size(&self) -> usize {
172        self.allocator.size as usize
173    }
174
175    fn len(&self) -> usize {
176        self.allocator.size as usize
177    }
178
179    fn capacity(&self) -> usize {
180        MAX_SIZE
181    }
182
183    fn iter(&self) -> Box<dyn DoubleEndedIterator<Item = (&K, &V)> + '_> {
184        Box::new(self._iter())
185    }
186
187    fn iter_mut(&mut self) -> Box<dyn DoubleEndedIterator<Item = (&K, &mut V)> + '_> {
188        Box::new(self._iter_mut())
189    }
190}
191
192impl<
193        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
194        V: Default + Copy + Clone + Pod + Zeroable,
195        const NUM_BUCKETS: usize,
196        const MAX_SIZE: usize,
197    > FromSlice for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
198{
199    fn new_from_slice(slice: &mut [u8]) -> &mut Self {
200        Self::assert_proper_alignment();
201        let tab = Self::load_mut_bytes(slice).unwrap();
202        tab.initialize();
203        tab
204    }
205}
206
207impl<
208        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
209        V: Default + Copy + Clone + Pod + Zeroable,
210        const NUM_BUCKETS: usize,
211        const MAX_SIZE: usize,
212    > HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
213{
214    fn assert_proper_alignment() {
215        assert!(NUM_BUCKETS % 2 == 0);
216    }
217
218    pub fn initialize(&mut self) {
219        self.allocator.initialize();
220    }
221
222    pub fn new() -> Self {
223        Self::default()
224    }
225
226    pub fn get_next(&self, index: u32) -> u32 {
227        self.allocator.get_register(index, NodeField::Right as u32)
228    }
229
230    pub fn get_prev(&self, index: u32) -> u32 {
231        self.allocator.get_register(index, NodeField::Left as u32)
232    }
233
234    pub fn get_node(&self, index: u32) -> &HashNode<K, V> {
235        self.allocator.get(index).get_value()
236    }
237
238    pub fn get_node_mut(&mut self, index: u32) -> &mut HashNode<K, V> {
239        self.allocator.get_mut(index).get_value_mut()
240    }
241
242    fn _insert(&mut self, key: K, value: V) -> Option<u32> {
243        let mut hasher = DefaultHasher::new();
244        key.hash(&mut hasher);
245        let bucket_index = hasher.finish() as usize % NUM_BUCKETS;
246        let head = self.buckets[bucket_index];
247        let mut curr_node = head;
248        while curr_node != SENTINEL {
249            let node = self.get_node(curr_node);
250            if node.key == key {
251                self.get_node_mut(curr_node).value = value;
252                return Some(curr_node);
253            } else {
254                curr_node = self.get_next(curr_node);
255            }
256        }
257        if self.len() >= self.capacity() {
258            return None;
259        }
260        let node_index = self.allocator.add_node(HashNode::new(key, value));
261        self.buckets[bucket_index] = node_index;
262        if head != SENTINEL {
263            self.allocator.connect(
264                node_index,
265                head,
266                NodeField::Right as u32,
267                NodeField::Left as u32,
268            );
269        }
270        Some(node_index)
271    }
272
273    pub fn _remove(&mut self, key: &K) -> Option<V> {
274        let mut hasher = DefaultHasher::new();
275        key.hash(&mut hasher);
276        let bucket_index = hasher.finish() as usize % NUM_BUCKETS;
277        let head = self.buckets[bucket_index];
278        let mut curr_node = self.buckets[bucket_index];
279        while curr_node != SENTINEL {
280            let node = self.get_node(curr_node);
281            if node.key == *key {
282                let val = node.value;
283                let prev = self.get_prev(curr_node);
284                let next = self.get_next(curr_node);
285                self.allocator
286                    .clear_register(curr_node, NodeField::Left as u32);
287                self.allocator
288                    .clear_register(curr_node, NodeField::Right as u32);
289                self.allocator.remove_node(curr_node);
290                if head == curr_node {
291                    assert!(prev == SENTINEL);
292                    self.buckets[bucket_index] = next;
293                }
294                self.allocator
295                    .connect(prev, next, NodeField::Right as u32, NodeField::Left as u32);
296                return Some(val);
297            } else {
298                curr_node = self.get_next(curr_node);
299            }
300        }
301        None
302    }
303
304    pub fn contains(&self, key: &K) -> bool {
305        let mut hasher = DefaultHasher::new();
306        key.hash(&mut hasher);
307        let bucket_index = hasher.finish() as usize % NUM_BUCKETS;
308        let mut curr_node = self.buckets[bucket_index];
309        while curr_node != SENTINEL {
310            let node = self.get_node(curr_node);
311            if node.key == *key {
312                return true;
313            } else {
314                curr_node = self.get_next(curr_node);
315            }
316        }
317        false
318    }
319
320    pub fn get_addr(&self, key: &K) -> u32 {
321        let mut hasher = DefaultHasher::new();
322        key.hash(&mut hasher);
323        let bucket_index = hasher.finish() as usize % NUM_BUCKETS;
324        let mut curr_node = self.buckets[bucket_index];
325        while curr_node != SENTINEL {
326            let node = self.get_node(curr_node);
327            if node.key == *key {
328                return curr_node;
329            } else {
330                curr_node = self.get_next(curr_node);
331            }
332        }
333        SENTINEL
334    }
335
336    fn _iter(&self) -> HashTableIterator<'_, K, V, NUM_BUCKETS, MAX_SIZE> {
337        HashTableIterator::<K, V, NUM_BUCKETS, MAX_SIZE> {
338            ht: self,
339            bucket: 0,
340            node: self.buckets[0],
341        }
342    }
343
344    fn _iter_mut(&mut self) -> HashTableIteratorMut<'_, K, V, NUM_BUCKETS, MAX_SIZE> {
345        let node = self.buckets[0];
346        HashTableIteratorMut::<K, V, NUM_BUCKETS, MAX_SIZE> {
347            ht: self,
348            bucket: 0,
349            node,
350        }
351    }
352}
353
354impl<
355        'a,
356        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
357        V: Default + Copy + Clone + Pod + Zeroable,
358        const NUM_BUCKETS: usize,
359        const MAX_SIZE: usize,
360    > IntoIterator for &'a HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
361{
362    type Item = (&'a K, &'a V);
363    type IntoIter = HashTableIterator<'a, K, V, NUM_BUCKETS, MAX_SIZE>;
364
365    fn into_iter(self) -> Self::IntoIter {
366        self._iter()
367    }
368}
369
370impl<
371        'a,
372        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
373        V: Default + Copy + Clone + Pod + Zeroable,
374        const NUM_BUCKETS: usize,
375        const MAX_SIZE: usize,
376    > IntoIterator for &'a mut HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
377{
378    type Item = (&'a K, &'a mut V);
379    type IntoIter = HashTableIteratorMut<'a, K, V, NUM_BUCKETS, MAX_SIZE>;
380
381    fn into_iter(self) -> Self::IntoIter {
382        self._iter_mut()
383    }
384}
385
386pub struct HashTableIterator<
387    'a,
388    K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
389    V: Default + Copy + Clone + Pod + Zeroable,
390    const NUM_BUCKETS: usize,
391    const MAX_SIZE: usize,
392> {
393    ht: &'a HashTable<K, V, NUM_BUCKETS, MAX_SIZE>,
394    bucket: usize,
395    node: u32,
396}
397
398impl<
399        'a,
400        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
401        V: Default + Copy + Clone + Pod + Zeroable,
402        const NUM_BUCKETS: usize,
403        const MAX_SIZE: usize,
404    > Iterator for HashTableIterator<'a, K, V, NUM_BUCKETS, MAX_SIZE>
405{
406    type Item = (&'a K, &'a V);
407
408    fn next(&mut self) -> Option<Self::Item> {
409        if self.bucket < NUM_BUCKETS {
410            while self.node == SENTINEL {
411                self.bucket += 1;
412                if self.bucket == NUM_BUCKETS {
413                    return None;
414                }
415                let head = self.ht.buckets[self.bucket];
416                self.node = head;
417            }
418            let node = self.ht.get_node(self.node);
419            self.node = self.ht.get_next(self.node);
420            Some((&node.key, &node.value))
421        } else {
422            None
423        }
424    }
425}
426
427impl<
428        'a,
429        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
430        V: Default + Copy + Clone + Pod + Zeroable,
431        const NUM_BUCKETS: usize,
432        const MAX_SIZE: usize,
433    > DoubleEndedIterator for HashTableIterator<'a, K, V, NUM_BUCKETS, MAX_SIZE>
434{
435    fn next_back(&mut self) -> Option<Self::Item> {
436        None
437    }
438}
439
440pub struct HashTableIteratorMut<
441    'a,
442    K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
443    V: Default + Copy + Clone + Pod + Zeroable,
444    const NUM_BUCKETS: usize,
445    const MAX_SIZE: usize,
446> {
447    ht: &'a mut HashTable<K, V, NUM_BUCKETS, MAX_SIZE>,
448    bucket: usize,
449    node: u32,
450}
451
452impl<
453        'a,
454        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
455        V: Default + Copy + Clone + Pod + Zeroable,
456        const NUM_BUCKETS: usize,
457        const MAX_SIZE: usize,
458    > Iterator for HashTableIteratorMut<'a, K, V, NUM_BUCKETS, MAX_SIZE>
459{
460    type Item = (&'a K, &'a mut V);
461
462    fn next(&mut self) -> Option<Self::Item> {
463        if self.bucket < NUM_BUCKETS {
464            while self.node == SENTINEL {
465                self.bucket += 1;
466                if self.bucket == NUM_BUCKETS {
467                    return None;
468                }
469                let head = self.ht.buckets[self.bucket];
470                self.node = head;
471            }
472            let ptr = self.node;
473            self.node = self.ht.get_next(self.node);
474            // TODO: How does one remove this unsafe?
475            unsafe {
476                let node =
477                    (*self.ht.allocator.nodes.as_mut_ptr().add((ptr - 1) as usize)).get_value_mut();
478                Some((&node.key, &mut node.value))
479            }
480        } else {
481            None
482        }
483    }
484}
485
486impl<
487        'a,
488        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
489        V: Default + Copy + Clone + Pod + Zeroable,
490        const NUM_BUCKETS: usize,
491        const MAX_SIZE: usize,
492    > DoubleEndedIterator for HashTableIteratorMut<'a, K, V, NUM_BUCKETS, MAX_SIZE>
493{
494    fn next_back(&mut self) -> Option<Self::Item> {
495        None
496    }
497}
498
499impl<
500        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
501        V: Default + Copy + Clone + Pod + Zeroable,
502        const NUM_BUCKETS: usize,
503        const MAX_SIZE: usize,
504    > Index<&K> for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
505{
506    type Output = V;
507
508    fn index(&self, index: &K) -> &Self::Output {
509        self.get(index).unwrap()
510    }
511}
512
513impl<
514        K: Hash + PartialEq + Copy + Clone + Default + Pod + Zeroable,
515        V: Default + Copy + Clone + Pod + Zeroable,
516        const NUM_BUCKETS: usize,
517        const MAX_SIZE: usize,
518    > IndexMut<&K> for HashTable<K, V, NUM_BUCKETS, MAX_SIZE>
519{
520    fn index_mut(&mut self, index: &K) -> &mut Self::Output {
521        self.get_mut(index).unwrap()
522    }
523}