1use std::cmp;
2use std::collections::HashMap;
3use std::hash::Hash;
4
5#[derive(Debug, Clone)]
6pub struct StackMap<K, V> {
7 mapping: HashMap<K, V>,
8 history: Vec<(K, Option<V>)>,
9}
10
11impl<K, V> std::default::Default for StackMap<K, V> {
12 fn default() -> Self {
13 Self {
14 mapping: Default::default(),
15 history: Default::default(),
16 }
17 }
18}
19
20pub struct StackMapCheckpoint(usize);
21
22impl<K, V> From<HashMap<K, V>> for StackMap<K, V> {
23 fn from(value: HashMap<K, V>) -> Self {
24 Self {
25 mapping: value,
26 history: Default::default(),
27 }
28 }
29}
30
31impl<K, V> StackMap<K, V>
32where
33 K: Clone + cmp::Eq + Hash,
34 V: Clone,
35{
36 pub fn insert(&mut self, key: K, value: V) {
37 let previous = self.mapping.insert(key.clone(), value);
38 self.history.push((key, previous));
39 }
40
41 pub fn checkpoint(&self) -> StackMapCheckpoint {
42 StackMapCheckpoint(self.history.len())
43 }
44
45 pub fn restore(&mut self, checkpoint: StackMapCheckpoint) {
46 while self.history.len() > checkpoint.0 {
47 self.pop();
48 }
49 }
50
51 fn pop(&mut self) {
52 if let Some((key, value)) = self.history.pop() {
53 if let Some(value) = value {
54 self.mapping.insert(key, value);
55 } else {
56 self.mapping.remove(&key);
57 }
58 } else {
59 panic!("pop called more than push");
60 }
61 }
62
63 pub fn lookup(&self, key: &K) -> Option<&V> {
64 self.mapping.get(key)
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::StackMap;
71
72 #[test]
73 fn test_checkpoint_restore() {
74 let mut map = StackMap::default();
75 map.insert("a", 1);
76 map.insert("b", 2);
77 map.insert("c", 3);
78 assert_eq!(map.lookup(&"a"), Some(&1));
79 assert_eq!(map.lookup(&"b"), Some(&2));
80 assert_eq!(map.lookup(&"c"), Some(&3));
81 let checkpoint1 = map.checkpoint();
82
83 map.insert("a", 4);
84 map.insert("b", 5);
85 map.insert("c", 6);
86 assert_eq!(map.lookup(&"a"), Some(&4));
87 assert_eq!(map.lookup(&"b"), Some(&5));
88 assert_eq!(map.lookup(&"c"), Some(&6));
89 let checkpoint2 = map.checkpoint();
90
91 map.insert("a", 7);
92 map.insert("b", 8);
93 map.insert("c", 9);
94 assert_eq!(map.lookup(&"a"), Some(&7));
95 assert_eq!(map.lookup(&"b"), Some(&8));
96 assert_eq!(map.lookup(&"c"), Some(&9));
97
98 map.restore(checkpoint2);
99 assert_eq!(map.lookup(&"a"), Some(&4));
100 assert_eq!(map.lookup(&"b"), Some(&5));
101 assert_eq!(map.lookup(&"c"), Some(&6));
102
103 map.restore(checkpoint1);
104 assert_eq!(map.lookup(&"a"), Some(&1));
105 assert_eq!(map.lookup(&"b"), Some(&2));
106 assert_eq!(map.lookup(&"c"), Some(&3));
107 }
108}