1use super::unionfind::UnionFind;
8use hashbrown::raw::{Bucket, RawTable};
9use std::hash::{Hash, Hasher};
10use std::marker::PhantomData;
11
12pub trait CtxEq<V1: ?Sized, V2: ?Sized> {
24 fn ctx_eq(&self, a: &V1, b: &V2, uf: &mut UnionFind) -> bool;
27}
28
29pub trait CtxHash<Value: ?Sized>: CtxEq<Value, Value> {
31 fn ctx_hash(&self, value: &Value, uf: &mut UnionFind) -> u64;
34}
35
36#[derive(Default)]
39pub struct NullCtx;
40
41impl<V: Eq + Hash> CtxEq<V, V> for NullCtx {
42 fn ctx_eq(&self, a: &V, b: &V, _: &mut UnionFind) -> bool {
43 a.eq(b)
44 }
45}
46impl<V: Eq + Hash> CtxHash<V> for NullCtx {
47 fn ctx_hash(&self, value: &V, _: &mut UnionFind) -> u64 {
48 let mut state = fxhash::FxHasher::default();
49 value.hash(&mut state);
50 state.finish()
51 }
52}
53
54struct BucketData<K, V> {
63 hash: u32,
64 k: K,
65 v: V,
66}
67
68pub struct CtxHashMap<K, V> {
70 raw: RawTable<BucketData<K, V>>,
71}
72
73impl<K, V> CtxHashMap<K, V> {
74 pub fn new() -> Self {
76 Self {
77 raw: RawTable::new(),
78 }
79 }
80
81 pub fn with_capacity(capacity: usize) -> Self {
84 Self {
85 raw: RawTable::with_capacity(capacity),
86 }
87 }
88}
89
90impl<K, V> CtxHashMap<K, V> {
91 pub fn insert<Ctx: CtxEq<K, K> + CtxHash<K>>(
94 &mut self,
95 k: K,
96 v: V,
97 ctx: &Ctx,
98 uf: &mut UnionFind,
99 ) -> Option<V> {
100 let hash = ctx.ctx_hash(&k, uf) as u32;
101 match self.raw.find(hash as u64, |bucket| {
102 hash == bucket.hash && ctx.ctx_eq(&bucket.k, &k, uf)
103 }) {
104 Some(bucket) => {
105 let data = unsafe { bucket.as_mut() };
106 Some(std::mem::replace(&mut data.v, v))
107 }
108 None => {
109 let data = BucketData { hash, k, v };
110 self.raw
111 .insert_entry(hash as u64, data, |bucket| bucket.hash as u64);
112 None
113 }
114 }
115 }
116
117 pub fn get<'a, Q, Ctx: CtxEq<K, Q> + CtxHash<Q> + CtxHash<K>>(
119 &'a self,
120 k: &Q,
121 ctx: &Ctx,
122 uf: &mut UnionFind,
123 ) -> Option<&'a V> {
124 let hash = ctx.ctx_hash(k, uf) as u32;
125 self.raw
126 .find(hash as u64, |bucket| {
127 hash == bucket.hash && ctx.ctx_eq(&bucket.k, k, uf)
128 })
129 .map(|bucket| {
130 let data = unsafe { bucket.as_ref() };
131 &data.v
132 })
133 }
134
135 #[inline(always)]
138 pub fn entry<'a, Ctx: CtxEq<K, K> + CtxHash<K>>(
139 &'a mut self,
140 k: K,
141 ctx: &'a Ctx,
142 uf: &mut UnionFind,
143 ) -> Entry<'a, K, V> {
144 let hash = ctx.ctx_hash(&k, uf) as u32;
145 match self.raw.find(hash as u64, |bucket| {
146 hash == bucket.hash && ctx.ctx_eq(&bucket.k, &k, uf)
147 }) {
148 Some(bucket) => Entry::Occupied(OccupiedEntry {
149 bucket,
150 _phantom: PhantomData,
151 }),
152 None => Entry::Vacant(VacantEntry {
153 raw: &mut self.raw,
154 hash,
155 key: k,
156 }),
157 }
158 }
159}
160
161pub enum Entry<'a, K: 'a, V> {
163 Occupied(OccupiedEntry<'a, K, V>),
164 Vacant(VacantEntry<'a, K, V>),
165}
166
167pub struct OccupiedEntry<'a, K, V> {
169 bucket: Bucket<BucketData<K, V>>,
170 _phantom: PhantomData<&'a ()>,
171}
172
173impl<'a, K: 'a, V> OccupiedEntry<'a, K, V> {
174 pub fn get(&self) -> &'a V {
176 let bucket = unsafe { self.bucket.as_ref() };
177 &bucket.v
178 }
179}
180
181pub struct VacantEntry<'a, K, V> {
183 raw: &'a mut RawTable<BucketData<K, V>>,
184 hash: u32,
185 key: K,
186}
187
188impl<'a, K, V> VacantEntry<'a, K, V> {
189 pub fn insert(self, v: V) -> &'a V {
191 let bucket = self.raw.insert(
192 self.hash as u64,
193 BucketData {
194 hash: self.hash,
195 k: self.key,
196 v,
197 },
198 |bucket| bucket.hash as u64,
199 );
200 let data = unsafe { bucket.as_ref() };
201 &data.v
202 }
203}
204
205#[cfg(test)]
206mod test {
207 use super::*;
208 use std::hash::Hash;
209
210 #[derive(Clone, Copy, Debug)]
211 struct Key {
212 index: u32,
213 }
214 struct Ctx {
215 vals: &'static [&'static str],
216 }
217 impl CtxEq<Key, Key> for Ctx {
218 fn ctx_eq(&self, a: &Key, b: &Key, _: &mut UnionFind) -> bool {
219 self.vals[a.index as usize].eq(self.vals[b.index as usize])
220 }
221 }
222 impl CtxHash<Key> for Ctx {
223 fn ctx_hash(&self, value: &Key, _: &mut UnionFind) -> u64 {
224 let mut state = fxhash::FxHasher::default();
225 self.vals[value.index as usize].hash(&mut state);
226 state.finish()
227 }
228 }
229
230 #[test]
231 fn test_basic() {
232 let ctx = Ctx {
233 vals: &["a", "b", "a"],
234 };
235 let mut uf = UnionFind::new();
236
237 let k0 = Key { index: 0 };
238 let k1 = Key { index: 1 };
239 let k2 = Key { index: 2 };
240
241 assert!(ctx.ctx_eq(&k0, &k2, &mut uf));
242 assert!(!ctx.ctx_eq(&k0, &k1, &mut uf));
243 assert!(!ctx.ctx_eq(&k2, &k1, &mut uf));
244
245 let mut map: CtxHashMap<Key, u64> = CtxHashMap::new();
246 assert_eq!(map.insert(k0, 42, &ctx, &mut uf), None);
247 assert_eq!(map.insert(k2, 84, &ctx, &mut uf), Some(42));
248 assert_eq!(map.get(&k1, &ctx, &mut uf), None);
249 assert_eq!(*map.get(&k0, &ctx, &mut uf).unwrap(), 84);
250 }
251
252 #[test]
253 fn test_entry() {
254 let mut ctx = Ctx {
255 vals: &["a", "b", "a"],
256 };
257 let mut uf = UnionFind::new();
258
259 let k0 = Key { index: 0 };
260 let k1 = Key { index: 1 };
261 let k2 = Key { index: 2 };
262
263 let mut map: CtxHashMap<Key, u64> = CtxHashMap::new();
264 match map.entry(k0, &mut ctx, &mut uf) {
265 Entry::Vacant(v) => {
266 v.insert(1);
267 }
268 _ => panic!(),
269 }
270 match map.entry(k1, &mut ctx, &mut uf) {
271 Entry::Vacant(_) => {}
272 Entry::Occupied(_) => panic!(),
273 }
274 match map.entry(k2, &mut ctx, &mut uf) {
275 Entry::Occupied(o) => {
276 assert_eq!(*o.get(), 1);
277 }
278 _ => panic!(),
279 }
280 }
281}