Skip to main content

tycho_util/transactional/
hashmap.rs

1use std::hash::Hash;
2
3use crate::transactional::Transactional;
4use crate::{FastHashMap, FastHashSet};
5
6pub struct TransactionalHashMap<K, V> {
7    inner: FastHashMap<K, V>,
8    tx: Option<MapTx<K, V>>,
9}
10
11struct MapTx<K, V> {
12    added: FastHashSet<K>,
13    removed: FastHashMap<K, V>,
14}
15
16impl<K: Hash + Eq + Clone, V: Transactional> TransactionalHashMap<K, V> {
17    pub fn new() -> Self {
18        Self {
19            inner: FastHashMap::default(),
20            tx: None,
21        }
22    }
23
24    pub fn insert(&mut self, key: K, value: V) -> bool {
25        let old = self.inner.insert(key.clone(), value);
26        if let Some(tx) = &mut self.tx {
27            match old {
28                Some(old_value) => {
29                    if !tx.added.contains(&key) && !tx.removed.contains_key(&key) {
30                        tx.removed.insert(key, old_value);
31                    }
32                    true
33                }
34                None => {
35                    if !tx.removed.contains_key(&key) {
36                        tx.added.insert(key);
37                    }
38                    false
39                }
40            }
41        } else {
42            old.is_some()
43        }
44    }
45
46    pub fn remove(&mut self, key: &K) -> bool {
47        let Some(value) = self.inner.remove(key) else {
48            return false;
49        };
50        if let Some(tx) = &mut self.tx {
51            if tx.added.remove(key) {
52                return true;
53            }
54            if !tx.removed.contains_key(key) {
55                tx.removed.insert(key.clone(), value);
56            }
57        }
58        true
59    }
60
61    pub fn get(&self, key: &K) -> Option<&V> {
62        self.inner.get(key)
63    }
64
65    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
66        self.inner.get_mut(key)
67    }
68
69    pub fn get_disjoint_mut<const N: usize>(&mut self, keys: [&K; N]) -> [Option<&mut V>; N] {
70        self.inner.get_disjoint_mut(keys)
71    }
72
73    pub fn contains_key(&self, key: &K) -> bool {
74        self.inner.contains_key(key)
75    }
76
77    pub fn len(&self) -> usize {
78        self.inner.len()
79    }
80
81    pub fn is_empty(&self) -> bool {
82        self.inner.is_empty()
83    }
84
85    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
86        self.inner.iter()
87    }
88
89    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&K, &mut V)> {
90        self.inner.iter_mut()
91    }
92
93    pub fn values(&self) -> impl Iterator<Item = &V> {
94        self.inner.values()
95    }
96
97    pub fn values_mut(&mut self) -> impl Iterator<Item = &mut V> {
98        self.inner.values_mut()
99    }
100
101    pub fn keys(&self) -> impl Iterator<Item = &K> {
102        self.inner.keys()
103    }
104}
105
106impl<K: Hash + Eq, V> Default for TransactionalHashMap<K, V> {
107    fn default() -> Self {
108        Self {
109            inner: FastHashMap::default(),
110            tx: None,
111        }
112    }
113}
114
115impl<K: Hash + Eq, V> From<FastHashMap<K, V>> for TransactionalHashMap<K, V> {
116    fn from(inner: FastHashMap<K, V>) -> Self {
117        Self { inner, tx: None }
118    }
119}
120
121impl<K: Hash + Eq + Clone, V: Transactional> Transactional for TransactionalHashMap<K, V> {
122    fn begin(&mut self) {
123        debug_assert!(self.tx.is_none());
124        for v in self.inner.values_mut() {
125            v.begin();
126        }
127        self.tx = Some(MapTx {
128            added: FastHashSet::default(),
129            removed: FastHashMap::default(),
130        });
131    }
132
133    fn commit(&mut self) {
134        self.tx = None;
135        for v in self.inner.values_mut() {
136            if v.in_tx() {
137                v.commit();
138            }
139        }
140    }
141
142    fn rollback(&mut self) {
143        if let Some(tx) = self.tx.take() {
144            for key in tx.added {
145                self.inner.remove(&key);
146            }
147            for (key, value) in tx.removed {
148                self.inner.insert(key, value);
149            }
150        }
151        for v in self.inner.values_mut() {
152            if v.in_tx() {
153                v.rollback();
154            }
155        }
156    }
157
158    fn in_tx(&self) -> bool {
159        self.tx.is_some()
160    }
161}