Skip to main content

inc_complete/storage/
hashmapped.rs

1use dashmap::DashMap;
2
3use crate::{Cell, storage::StorageFor};
4use std::hash::{BuildHasher, Hash};
5
6use super::Computation;
7
8pub struct HashMapStorage<K, Hasher = rustc_hash::FxBuildHasher>
9where
10    K: Computation + Eq + Hash,
11    Hasher: BuildHasher,
12{
13    key_to_cell: DashMap<K, Cell, Hasher>,
14    cell_to_key: DashMap<Cell, (K, Option<K::Output>), Hasher>,
15}
16
17impl<K, H> Default for HashMapStorage<K, H>
18where
19    K: Computation + Eq + Hash,
20    H: Default + BuildHasher + Clone,
21{
22    fn default() -> Self {
23        Self {
24            key_to_cell: Default::default(),
25            cell_to_key: Default::default(),
26        }
27    }
28}
29
30impl<K, H> StorageFor<K> for HashMapStorage<K, H>
31where
32    K: Clone + Eq + Hash + Computation,
33    K::Output: Eq + Clone,
34    H: BuildHasher + Clone,
35{
36    fn get_cell_for_computation(&self, key: &K) -> Option<Cell> {
37        self.key_to_cell.get(key).map(|value| *value)
38    }
39
40    fn insert_new_cell(&self, cell: Cell, key: K) {
41        self.key_to_cell.insert(key.clone(), cell);
42        self.cell_to_key.insert(cell, (key, None));
43    }
44
45    fn try_get_input(&self, cell: Cell) -> Option<K> {
46        let key_ref = self.cell_to_key.get(&cell)?;
47        Some(key_ref.0.clone())
48    }
49
50    fn get_input(&self, cell: Cell) -> K {
51        self.cell_to_key.get(&cell).unwrap().0.clone()
52    }
53
54    fn get_output(&self, cell: Cell) -> Option<K::Output> {
55        self.cell_to_key.get(&cell).unwrap().1.clone()
56    }
57
58    fn update_output(&self, cell: Cell, new_value: K::Output) -> bool {
59        let mut previous_output = self.cell_to_key.get_mut(&cell).unwrap();
60        let changed = K::ASSUME_CHANGED
61            || previous_output
62                .1
63                .as_ref()
64                .is_none_or(|value| *value != new_value);
65        previous_output.1 = Some(new_value);
66        changed
67    }
68
69    fn gc(&mut self, used_cells: &std::collections::HashSet<Cell>) {
70        // Remove cells that are not in the used set
71        self.cell_to_key.retain(|cell, _| used_cells.contains(cell));
72        self.key_to_cell.retain(|_, cell| used_cells.contains(cell));
73    }
74}
75
76impl<K, H> serde::Serialize for HashMapStorage<K, H>
77where
78    K: serde::Serialize + Computation + Eq + Hash + Clone,
79    K::Output: serde::Serialize + Clone,
80    H: BuildHasher + Clone,
81{
82    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83    where
84        S: serde::Serializer,
85    {
86        let mut cell_to_key_vec: Vec<(Cell, (K, Option<K::Output>))> =
87            Vec::with_capacity(self.cell_to_key.len());
88
89        for kv in self.cell_to_key.iter() {
90            let cell = *kv.key();
91            let (key, value) = kv.value().clone();
92            cell_to_key_vec.push((cell, (key, value)));
93        }
94
95        cell_to_key_vec.serialize(serializer)
96    }
97}
98
99impl<'de, K, H> serde::Deserialize<'de> for HashMapStorage<K, H>
100where
101    K: serde::Deserialize<'de> + Hash + Eq + Computation + Clone,
102    K::Output: serde::Deserialize<'de>,
103    H: Default + BuildHasher + Clone,
104{
105    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
106    where
107        D: serde::Deserializer<'de>,
108    {
109        let cell_to_key_vec: Vec<(Cell, (K, Option<K::Output>))> =
110            serde::Deserialize::deserialize(deserializer)?;
111
112        let key_to_cell = DashMap::default();
113        let cell_to_key = DashMap::default();
114
115        for (cell, (key, value)) in cell_to_key_vec {
116            key_to_cell.insert(key.clone(), cell);
117            cell_to_key.insert(cell, (key, value));
118        }
119
120        Ok(HashMapStorage {
121            cell_to_key,
122            key_to_cell,
123        })
124    }
125}