hash_chain/
map.rs

1use std::{
2    borrow::Borrow,
3    collections::{hash_map::RandomState, HashMap},
4    hash::{BuildHasher, Hash},
5    mem::take,
6    ops::Index,
7};
8
9#[derive(Clone)]
10pub struct ChainMap<K, V, S = RandomState> {
11    pub(crate) maps: Vec<HashMap<K, V, S>>,
12}
13
14impl<K: Hash + Eq, V, S: BuildHasher> ChainMap<K, V, S>
15where
16    K: Hash + Eq,
17    S: BuildHasher,
18{
19    pub fn new(map: HashMap<K, V, S>) -> Self {
20        Self { maps: vec![map] }
21    }
22    /// Inserts a key-value pair into the map.
23    /// If the map did not have this key present, None is returned.
24    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
25        let map = self.maps.last_mut()?;
26        map.insert(key, value)
27    }
28
29    pub fn insert_at(&mut self, idx: usize, key: K, value: V) -> Result<Option<V>, crate::Error> {
30        if let Some(map) = self.maps.get_mut(idx) {
31            Ok(map.insert(key, value))
32        } else {
33            Err(crate::Error::IndexOutOfRange)
34        }
35    }
36
37    /// Returns the key-value pair corresponding to the supplied key.
38    ///
39    /// The supplied key may be any borrowed form of the map's key type, but
40    /// `Hash` and `Eq` on the borrowed form *must* match those for
41    /// the key type.
42    pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
43    where
44        K: Borrow<Q>,
45        Q: Hash + Eq,
46    {
47        for map in self.maps.iter().rev() {
48            if let Some(v) = map.get(key) {
49                return Some(v);
50            }
51        }
52        None
53    }
54    /// Returns a mutable reference to the value corresponding to the key.
55    ///
56    /// The supplied key may be any borrowed form of the map's key type, but
57    /// `Hash` and `Eq` on the borrowed form *must* match those for
58    /// the key type.
59    pub fn get_mut<Q: ?Sized>(&mut self, key: &Q) -> Option<&mut V>
60    where
61        K: Borrow<Q>,
62        Q: Hash + Eq,
63    {
64        for map in self.maps.iter_mut().rev() {
65            if let Some(v) = map.get_mut(key) {
66                return Some(v);
67            }
68        }
69        None
70    }
71
72    pub fn get_before<Q: ?Sized>(&self, idx: usize, key: &Q) -> Option<&V>
73    where
74        K: Borrow<Q>,
75        Q: Hash + Eq,
76    {
77        let iter = if idx >= self.maps.len() {
78            self.maps.iter()
79        } else {
80            self.maps[0..idx].iter()
81        };
82
83        for map in iter.rev() {
84            if let Some(v) = map.get(key) {
85                return Some(v);
86            }
87        }
88        None
89    }
90
91    pub fn get_before_mut<Q: ?Sized>(&mut self, idx: usize, key: &Q) -> Option<&mut V>
92    where
93        K: Borrow<Q>,
94        Q: Hash + Eq,
95    {
96        let iter = if idx >= self.maps.len() {
97            self.maps.iter_mut()
98        } else {
99            self.maps[0..idx].iter_mut()
100        };
101
102        for map in iter.rev() {
103            if let Some(v) = map.get_mut(key) {
104                return Some(v);
105            }
106        }
107        None
108    }
109
110    pub fn new_child_with(&mut self, map: HashMap<K, V, S>) {
111        self.maps.push(map);
112    }
113
114    pub fn last_has<Q: ?Sized>(&self, key: &Q) -> bool
115    where
116        K: Borrow<Q>,
117        Q: Hash + Eq,
118    {
119        self.has_at(self.maps.len() - 1, key)
120    }
121
122    pub fn has_at<Q: ?Sized>(&self, idx: usize, key: &Q) -> bool
123    where
124        K: Borrow<Q>,
125        Q: Hash + Eq,
126    {
127        if let Some(map) = self.maps.get(idx) {
128            map.contains_key(key)
129        } else {
130            false
131        }
132    }
133
134    pub fn child_len(&self) -> usize {
135        self.maps.len()
136    }
137
138    pub fn get_last_index<Q: ?Sized>(&self, key: &Q) -> Option<usize>
139    where
140        K: Borrow<Q>,
141        Q: Hash + Eq,
142    {
143        for (i, map) in self.maps.iter().enumerate().rev() {
144            if map.contains_key(key) {
145                return Some(i);
146            }
147        }
148        None
149    }
150}
151
152impl<K: Hash + Eq, V, S: BuildHasher + Default> ChainMap<K, V, S> {
153    pub fn new_child(&mut self) {
154        self.maps.push(HashMap::default());
155    }
156
157    pub fn remove_child(&mut self) -> Option<HashMap<K, V, S>> {
158        if self.maps.len() == 1 {
159            let ret = take(&mut self.maps[0]);
160            Some(ret)
161        } else {
162            self.maps.pop()
163        }
164    }
165
166    pub fn split_off(&mut self, idx: usize) -> Self {
167        let maps = self.maps.split_off(idx);
168        Self {
169            maps,
170        }
171    }
172
173    pub fn append(&mut self, other: &mut Self) {
174        self.maps.append(&mut other.maps);
175    }
176}
177
178impl<K, V> Default for ChainMap<K, V>
179where
180    K: Hash + Eq,
181{
182    fn default() -> Self {
183        Self {
184            maps: vec![HashMap::new()],
185        }
186    }
187}
188
189impl<K, Q: ?Sized, V, S> Index<&Q> for ChainMap<K, V, S>
190where
191    K: Eq + Hash + Borrow<Q>,
192    Q: Eq + Hash,
193    S: BuildHasher,
194{
195    type Output = V;
196
197    /// Returns a reference to the value corresponding to the supplied key.
198    ///
199    /// # Panics
200    ///
201    /// Panics if the key is not present in the `HashMap`.
202    #[inline]
203    fn index(&self, key: &Q) -> &V {
204        self.get(key).expect("no entry found for key")
205    }
206}
207
208impl<K, V, S> PartialEq for ChainMap<K, V, S>
209where
210    K: Eq + Hash,
211    V: PartialEq,
212    S: std::hash::BuildHasher,
213{
214    fn eq(&self, other: &ChainMap<K, V, S>) -> bool {
215        self.maps == other.maps
216    }
217}
218
219impl<K, V, S> Eq for ChainMap<K, V, S>
220where
221    K: Eq + Hash,
222    V: Eq,
223    S: BuildHasher,
224{
225}
226
227impl<K, V, S> core::fmt::Debug for ChainMap<K, V, S>
228where
229    K: Eq + Hash + core::fmt::Debug,
230    V: core::fmt::Debug,
231    S: BuildHasher,
232{
233    fn fmt(&self, f: &mut core::fmt::Formatter) -> std::fmt::Result {
234        f.debug_struct("ChainMap")
235            .field("maps", &self.maps)
236            .finish()
237    }
238}
239
240#[cfg(test)]
241mod test {
242    use super::*;
243    use std::default::Default;
244
245    #[test]
246    fn initialization() {
247        let mut test_map = HashMap::new();
248        test_map.insert("test", 1);
249        let chain_map = ChainMap::new(test_map);
250
251        assert!(chain_map.maps.len() > 0);
252        assert_eq!(chain_map.maps[0].get("test"), Some(&1));
253    }
254
255    #[test]
256    fn initialization_default() {
257        let chain_map: ChainMap<(), ()> = ChainMap::default();
258
259        assert!(chain_map.maps.len() > 0);
260        assert!(chain_map.maps[0].is_empty());
261    }
262
263    #[test]
264    fn insert() {
265        let mut chain_map = ChainMap::default();
266        assert!(chain_map.insert("test", 1).is_none());
267
268        assert_eq!(chain_map.maps[0].get("test"), Some(&1));
269    }
270
271    #[test]
272    fn insert_at() {
273        let mut chain_map = ChainMap::default();
274        chain_map.insert("banana", "milk");
275        chain_map.new_child();
276
277        chain_map.insert_at(0, "strawberry", "soda").unwrap();
278        assert_eq!(chain_map.maps[0].get("strawberry"), Some(&"soda"));
279        assert_eq!(chain_map.maps[1].get("strawberry"), None);
280    }
281
282    #[test]
283    #[should_panic = "IndexOutOfRange"]
284    fn insert_at_out_of_bounds() {
285        let mut chain_map = ChainMap::default();
286        chain_map.insert("banana", "milk");
287        chain_map.new_child();
288
289        chain_map.insert_at(37, "strawberry", "soda").unwrap();
290    }
291
292    #[test]
293    fn get() {
294        let mut chain_map = ChainMap::default();
295        chain_map.insert("test", 1);
296
297        assert_eq!(chain_map.get(&"test"), Some(&1));
298    }
299
300    #[test]
301    fn get_none() {
302        let chain_map: ChainMap<&str, ()> = ChainMap::default();
303        assert_eq!(chain_map.get(&"test"), None);
304    }
305
306    #[test]
307    fn get_mut() {
308        let mut chain_map = ChainMap::default();
309        chain_map.insert("test", 1);
310
311        let test_value = chain_map.get_mut(&"test");
312        assert_eq!(test_value, Some(&mut 1));
313        *test_value.unwrap() += 1;
314        let changed = chain_map.get(&"test");
315        assert_eq!(changed, Some(&2));
316    }
317
318    #[test]
319    fn get_mut_outer() {
320        let mut chain_map = ChainMap::default();
321        chain_map.insert("outer", 1);
322        chain_map.new_child();
323        chain_map.insert("inner", 2);
324        let ret = chain_map.get_mut("outer").unwrap();
325        *ret += 9000;
326
327        let changed = chain_map.get(&"outer");
328        assert_eq!(changed, Some(&9001));
329    }
330
331    #[test]
332    fn index() {
333        let mut chain_map = ChainMap::default();
334        chain_map.insert("test", 1);
335
336        assert_eq!(chain_map[&"test"], 1);
337    }
338
339    #[test]
340    fn new_child() {
341        let mut chain_map = ChainMap::default();
342        chain_map.insert("test", 1);
343        chain_map.new_child();
344        assert!(chain_map.maps.len() > 1);
345    }
346
347    #[test]
348    fn scopes() {
349        let mut chain_map = ChainMap::default();
350        chain_map.insert("x", 0);
351        chain_map.insert("y", 2);
352        chain_map.new_child();
353        chain_map.insert("x", 1);
354        assert_eq!(chain_map.get("x"), Some(&1));
355        assert_eq!(chain_map.get("y"), Some(&2));
356    }
357
358    #[test]
359    fn remove_child() {
360        let mut chain_map = ChainMap::default();
361        chain_map.insert("x", 0);
362        chain_map.insert("y", 2);
363        chain_map.new_child();
364        chain_map.insert("x", 1);
365        let ret = chain_map.remove_child().unwrap();
366        assert_eq!(ret.get("x"), Some(&1));
367        assert_eq!(chain_map.get("x"), Some(&0));
368    }
369
370    #[test]
371    fn remove_child_length_1() {
372        let mut chain_map = ChainMap::default();
373        chain_map.insert("x", 0);
374        let _ = chain_map.remove_child();
375        assert_eq!(chain_map.get("x"), None);
376        assert!(chain_map.maps.len() == 1);
377    }
378
379    #[test]
380    fn has_at_exists() {
381        let mut chain_map = ChainMap::default();
382        chain_map.insert("x", 0);
383
384        assert!(chain_map.has_at(0, &"x"));
385    }
386
387    #[test]
388    fn has_at_doesnt_exist() {
389        let chain_map: ChainMap<&str, ()> = ChainMap::default();
390
391        assert!(!chain_map.has_at(11, &"x"));
392    }
393
394    #[test]
395    fn last_has_true() {
396        let mut chain_map = ChainMap::default();
397        chain_map.insert("x", 0);
398        chain_map.new_child();
399        chain_map.insert("y", 1);
400
401        assert!(chain_map.last_has(&"y"));
402    }
403
404    #[test]
405    fn last_has_false() {
406        let mut chain_map = ChainMap::default();
407        chain_map.insert("x", 0);
408        chain_map.new_child();
409        chain_map.insert("y", 1);
410
411        assert!(!chain_map.last_has(&"x"));
412    }
413
414    #[test]
415    fn child_len() {
416        let mut chain_map: ChainMap<&str, ()> = ChainMap::default();
417        assert_eq!(chain_map.child_len(), 1);
418
419        for i in 2..100 {
420            chain_map.new_child();
421            assert_eq!(chain_map.child_len(), i);
422        }
423    }
424
425    #[test]
426    fn get_before_exists() {
427        let mut chain_map = ChainMap::default();
428        chain_map.insert("test", 1);
429        chain_map.new_child();
430        chain_map.insert("test", 2);
431
432        assert_eq!(chain_map.get_before(1, &"test"), Some(&1));
433    }
434
435    #[test]
436    fn get_before_mut_exists() {
437        let mut chain_map = ChainMap::default();
438        chain_map.insert("test", 1);
439        chain_map.new_child();
440        chain_map.insert("test", 2);
441
442        let test_value = chain_map.get_before_mut(1, &"test");
443        assert_eq!(test_value, Some(&mut 1));
444        *test_value.unwrap() += 2;
445        let changed = chain_map.get_before(1, &"test");
446        assert_eq!(changed, Some(&3));
447        let child = chain_map.get("test");
448        assert_eq!(child, Some(&2));
449    }
450
451    #[test]
452    fn get_last_index_exists() {
453        let mut chain_map = ChainMap::default();
454        chain_map.insert("test1", 1);
455        chain_map.new_child();
456        chain_map.insert("test2", 2);
457
458        assert_eq!(chain_map.get_last_index("test1"), Some(0));
459        assert_eq!(chain_map.get_last_index("test2"), Some(1));
460    }
461
462    #[test]
463    fn get_last_index_doesnt_exist() {
464        let mut chain_map = ChainMap::default();
465        chain_map.insert("test1", 1);
466        chain_map.new_child();
467        chain_map.insert("test2", 2);
468
469        assert_eq!(chain_map.get_last_index("shmee"), None);
470    }
471
472    #[test]
473    fn custom_hasher() {
474        // note: this test is about type checking
475        // rather than actual assertions
476        use std::hash::BuildHasherDefault;
477        use hashers::oz::DJB2Hasher; 
478        let hm = HashMap::with_hasher(BuildHasherDefault::<DJB2Hasher>::default());
479        let mut cm = ChainMap::new(hm);
480        cm.insert("test1", 1);
481        cm.new_child();
482        cm.insert("test1", 1);
483        cm.remove_child();
484        cm["test1"];
485    }
486}