1use std::{
2 borrow::Borrow,
3 collections::{hash_map::RandomState, HashMap},
4 hash::{BuildHasher, Hash},
5 mem::take,
6 ops::Index,
7};
8
9#[derive(Clone)]
10pub struct ChainMap<K, V, S = RandomState> {
11 pub(crate) maps: Vec<HashMap<K, V, S>>,
12}
13
14impl<K: Hash + Eq, V, S: BuildHasher> ChainMap<K, V, S>
15where
16 K: Hash + Eq,
17 S: BuildHasher,
18{
19 pub fn new(map: HashMap<K, V, S>) -> Self {
20 Self { maps: vec![map] }
21 }
22 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
25 let map = self.maps.last_mut()?;
26 map.insert(key, value)
27 }
28
29 pub fn insert_at(&mut self, idx: usize, key: K, value: V) -> Result<Option<V>, crate::Error> {
30 if let Some(map) = self.maps.get_mut(idx) {
31 Ok(map.insert(key, value))
32 } else {
33 Err(crate::Error::IndexOutOfRange)
34 }
35 }
36
37 pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
43 where
44 K: Borrow<Q>,
45 Q: Hash + Eq,
46 {
47 for map in self.maps.iter().rev() {
48 if let Some(v) = map.get(key) {
49 return Some(v);
50 }
51 }
52 None
53 }
54 pub fn get_mut<Q: ?Sized>(&mut self, key: &Q) -> Option<&mut V>
60 where
61 K: Borrow<Q>,
62 Q: Hash + Eq,
63 {
64 for map in self.maps.iter_mut().rev() {
65 if let Some(v) = map.get_mut(key) {
66 return Some(v);
67 }
68 }
69 None
70 }
71
72 pub fn get_before<Q: ?Sized>(&self, idx: usize, key: &Q) -> Option<&V>
73 where
74 K: Borrow<Q>,
75 Q: Hash + Eq,
76 {
77 let iter = if idx >= self.maps.len() {
78 self.maps.iter()
79 } else {
80 self.maps[0..idx].iter()
81 };
82
83 for map in iter.rev() {
84 if let Some(v) = map.get(key) {
85 return Some(v);
86 }
87 }
88 None
89 }
90
91 pub fn get_before_mut<Q: ?Sized>(&mut self, idx: usize, key: &Q) -> Option<&mut V>
92 where
93 K: Borrow<Q>,
94 Q: Hash + Eq,
95 {
96 let iter = if idx >= self.maps.len() {
97 self.maps.iter_mut()
98 } else {
99 self.maps[0..idx].iter_mut()
100 };
101
102 for map in iter.rev() {
103 if let Some(v) = map.get_mut(key) {
104 return Some(v);
105 }
106 }
107 None
108 }
109
110 pub fn new_child_with(&mut self, map: HashMap<K, V, S>) {
111 self.maps.push(map);
112 }
113
114 pub fn last_has<Q: ?Sized>(&self, key: &Q) -> bool
115 where
116 K: Borrow<Q>,
117 Q: Hash + Eq,
118 {
119 self.has_at(self.maps.len() - 1, key)
120 }
121
122 pub fn has_at<Q: ?Sized>(&self, idx: usize, key: &Q) -> bool
123 where
124 K: Borrow<Q>,
125 Q: Hash + Eq,
126 {
127 if let Some(map) = self.maps.get(idx) {
128 map.contains_key(key)
129 } else {
130 false
131 }
132 }
133
134 pub fn child_len(&self) -> usize {
135 self.maps.len()
136 }
137
138 pub fn get_last_index<Q: ?Sized>(&self, key: &Q) -> Option<usize>
139 where
140 K: Borrow<Q>,
141 Q: Hash + Eq,
142 {
143 for (i, map) in self.maps.iter().enumerate().rev() {
144 if map.contains_key(key) {
145 return Some(i);
146 }
147 }
148 None
149 }
150}
151
152impl<K: Hash + Eq, V, S: BuildHasher + Default> ChainMap<K, V, S> {
153 pub fn new_child(&mut self) {
154 self.maps.push(HashMap::default());
155 }
156
157 pub fn remove_child(&mut self) -> Option<HashMap<K, V, S>> {
158 if self.maps.len() == 1 {
159 let ret = take(&mut self.maps[0]);
160 Some(ret)
161 } else {
162 self.maps.pop()
163 }
164 }
165
166 pub fn split_off(&mut self, idx: usize) -> Self {
167 let maps = self.maps.split_off(idx);
168 Self {
169 maps,
170 }
171 }
172
173 pub fn append(&mut self, other: &mut Self) {
174 self.maps.append(&mut other.maps);
175 }
176}
177
178impl<K, V> Default for ChainMap<K, V>
179where
180 K: Hash + Eq,
181{
182 fn default() -> Self {
183 Self {
184 maps: vec![HashMap::new()],
185 }
186 }
187}
188
189impl<K, Q: ?Sized, V, S> Index<&Q> for ChainMap<K, V, S>
190where
191 K: Eq + Hash + Borrow<Q>,
192 Q: Eq + Hash,
193 S: BuildHasher,
194{
195 type Output = V;
196
197 #[inline]
203 fn index(&self, key: &Q) -> &V {
204 self.get(key).expect("no entry found for key")
205 }
206}
207
208impl<K, V, S> PartialEq for ChainMap<K, V, S>
209where
210 K: Eq + Hash,
211 V: PartialEq,
212 S: std::hash::BuildHasher,
213{
214 fn eq(&self, other: &ChainMap<K, V, S>) -> bool {
215 self.maps == other.maps
216 }
217}
218
219impl<K, V, S> Eq for ChainMap<K, V, S>
220where
221 K: Eq + Hash,
222 V: Eq,
223 S: BuildHasher,
224{
225}
226
227impl<K, V, S> core::fmt::Debug for ChainMap<K, V, S>
228where
229 K: Eq + Hash + core::fmt::Debug,
230 V: core::fmt::Debug,
231 S: BuildHasher,
232{
233 fn fmt(&self, f: &mut core::fmt::Formatter) -> std::fmt::Result {
234 f.debug_struct("ChainMap")
235 .field("maps", &self.maps)
236 .finish()
237 }
238}
239
240#[cfg(test)]
241mod test {
242 use super::*;
243 use std::default::Default;
244
245 #[test]
246 fn initialization() {
247 let mut test_map = HashMap::new();
248 test_map.insert("test", 1);
249 let chain_map = ChainMap::new(test_map);
250
251 assert!(chain_map.maps.len() > 0);
252 assert_eq!(chain_map.maps[0].get("test"), Some(&1));
253 }
254
255 #[test]
256 fn initialization_default() {
257 let chain_map: ChainMap<(), ()> = ChainMap::default();
258
259 assert!(chain_map.maps.len() > 0);
260 assert!(chain_map.maps[0].is_empty());
261 }
262
263 #[test]
264 fn insert() {
265 let mut chain_map = ChainMap::default();
266 assert!(chain_map.insert("test", 1).is_none());
267
268 assert_eq!(chain_map.maps[0].get("test"), Some(&1));
269 }
270
271 #[test]
272 fn insert_at() {
273 let mut chain_map = ChainMap::default();
274 chain_map.insert("banana", "milk");
275 chain_map.new_child();
276
277 chain_map.insert_at(0, "strawberry", "soda").unwrap();
278 assert_eq!(chain_map.maps[0].get("strawberry"), Some(&"soda"));
279 assert_eq!(chain_map.maps[1].get("strawberry"), None);
280 }
281
282 #[test]
283 #[should_panic = "IndexOutOfRange"]
284 fn insert_at_out_of_bounds() {
285 let mut chain_map = ChainMap::default();
286 chain_map.insert("banana", "milk");
287 chain_map.new_child();
288
289 chain_map.insert_at(37, "strawberry", "soda").unwrap();
290 }
291
292 #[test]
293 fn get() {
294 let mut chain_map = ChainMap::default();
295 chain_map.insert("test", 1);
296
297 assert_eq!(chain_map.get(&"test"), Some(&1));
298 }
299
300 #[test]
301 fn get_none() {
302 let chain_map: ChainMap<&str, ()> = ChainMap::default();
303 assert_eq!(chain_map.get(&"test"), None);
304 }
305
306 #[test]
307 fn get_mut() {
308 let mut chain_map = ChainMap::default();
309 chain_map.insert("test", 1);
310
311 let test_value = chain_map.get_mut(&"test");
312 assert_eq!(test_value, Some(&mut 1));
313 *test_value.unwrap() += 1;
314 let changed = chain_map.get(&"test");
315 assert_eq!(changed, Some(&2));
316 }
317
318 #[test]
319 fn get_mut_outer() {
320 let mut chain_map = ChainMap::default();
321 chain_map.insert("outer", 1);
322 chain_map.new_child();
323 chain_map.insert("inner", 2);
324 let ret = chain_map.get_mut("outer").unwrap();
325 *ret += 9000;
326
327 let changed = chain_map.get(&"outer");
328 assert_eq!(changed, Some(&9001));
329 }
330
331 #[test]
332 fn index() {
333 let mut chain_map = ChainMap::default();
334 chain_map.insert("test", 1);
335
336 assert_eq!(chain_map[&"test"], 1);
337 }
338
339 #[test]
340 fn new_child() {
341 let mut chain_map = ChainMap::default();
342 chain_map.insert("test", 1);
343 chain_map.new_child();
344 assert!(chain_map.maps.len() > 1);
345 }
346
347 #[test]
348 fn scopes() {
349 let mut chain_map = ChainMap::default();
350 chain_map.insert("x", 0);
351 chain_map.insert("y", 2);
352 chain_map.new_child();
353 chain_map.insert("x", 1);
354 assert_eq!(chain_map.get("x"), Some(&1));
355 assert_eq!(chain_map.get("y"), Some(&2));
356 }
357
358 #[test]
359 fn remove_child() {
360 let mut chain_map = ChainMap::default();
361 chain_map.insert("x", 0);
362 chain_map.insert("y", 2);
363 chain_map.new_child();
364 chain_map.insert("x", 1);
365 let ret = chain_map.remove_child().unwrap();
366 assert_eq!(ret.get("x"), Some(&1));
367 assert_eq!(chain_map.get("x"), Some(&0));
368 }
369
370 #[test]
371 fn remove_child_length_1() {
372 let mut chain_map = ChainMap::default();
373 chain_map.insert("x", 0);
374 let _ = chain_map.remove_child();
375 assert_eq!(chain_map.get("x"), None);
376 assert!(chain_map.maps.len() == 1);
377 }
378
379 #[test]
380 fn has_at_exists() {
381 let mut chain_map = ChainMap::default();
382 chain_map.insert("x", 0);
383
384 assert!(chain_map.has_at(0, &"x"));
385 }
386
387 #[test]
388 fn has_at_doesnt_exist() {
389 let chain_map: ChainMap<&str, ()> = ChainMap::default();
390
391 assert!(!chain_map.has_at(11, &"x"));
392 }
393
394 #[test]
395 fn last_has_true() {
396 let mut chain_map = ChainMap::default();
397 chain_map.insert("x", 0);
398 chain_map.new_child();
399 chain_map.insert("y", 1);
400
401 assert!(chain_map.last_has(&"y"));
402 }
403
404 #[test]
405 fn last_has_false() {
406 let mut chain_map = ChainMap::default();
407 chain_map.insert("x", 0);
408 chain_map.new_child();
409 chain_map.insert("y", 1);
410
411 assert!(!chain_map.last_has(&"x"));
412 }
413
414 #[test]
415 fn child_len() {
416 let mut chain_map: ChainMap<&str, ()> = ChainMap::default();
417 assert_eq!(chain_map.child_len(), 1);
418
419 for i in 2..100 {
420 chain_map.new_child();
421 assert_eq!(chain_map.child_len(), i);
422 }
423 }
424
425 #[test]
426 fn get_before_exists() {
427 let mut chain_map = ChainMap::default();
428 chain_map.insert("test", 1);
429 chain_map.new_child();
430 chain_map.insert("test", 2);
431
432 assert_eq!(chain_map.get_before(1, &"test"), Some(&1));
433 }
434
435 #[test]
436 fn get_before_mut_exists() {
437 let mut chain_map = ChainMap::default();
438 chain_map.insert("test", 1);
439 chain_map.new_child();
440 chain_map.insert("test", 2);
441
442 let test_value = chain_map.get_before_mut(1, &"test");
443 assert_eq!(test_value, Some(&mut 1));
444 *test_value.unwrap() += 2;
445 let changed = chain_map.get_before(1, &"test");
446 assert_eq!(changed, Some(&3));
447 let child = chain_map.get("test");
448 assert_eq!(child, Some(&2));
449 }
450
451 #[test]
452 fn get_last_index_exists() {
453 let mut chain_map = ChainMap::default();
454 chain_map.insert("test1", 1);
455 chain_map.new_child();
456 chain_map.insert("test2", 2);
457
458 assert_eq!(chain_map.get_last_index("test1"), Some(0));
459 assert_eq!(chain_map.get_last_index("test2"), Some(1));
460 }
461
462 #[test]
463 fn get_last_index_doesnt_exist() {
464 let mut chain_map = ChainMap::default();
465 chain_map.insert("test1", 1);
466 chain_map.new_child();
467 chain_map.insert("test2", 2);
468
469 assert_eq!(chain_map.get_last_index("shmee"), None);
470 }
471
472 #[test]
473 fn custom_hasher() {
474 use std::hash::BuildHasherDefault;
477 use hashers::oz::DJB2Hasher;
478 let hm = HashMap::with_hasher(BuildHasherDefault::<DJB2Hasher>::default());
479 let mut cm = ChainMap::new(hm);
480 cm.insert("test1", 1);
481 cm.new_child();
482 cm.insert("test1", 1);
483 cm.remove_child();
484 cm["test1"];
485 }
486}