1use std::cmp;
2use std::collections::HashMap;
3use std::hash::Hash;
4
5#[derive(Debug, Clone)]
6pub struct StackMap<K, V> {
7    mapping: HashMap<K, V>,
8    history: Vec<(K, Option<V>)>,
9}
10
11impl<K, V> std::default::Default for StackMap<K, V> {
12    fn default() -> Self {
13        Self {
14            mapping: Default::default(),
15            history: Default::default(),
16        }
17    }
18}
19
20pub struct StackMapCheckpoint(usize);
21
22impl<K, V> From<HashMap<K, V>> for StackMap<K, V> {
23    fn from(value: HashMap<K, V>) -> Self {
24        Self {
25            mapping: value,
26            history: Default::default(),
27        }
28    }
29}
30
31impl<K, V> StackMap<K, V>
32where
33    K: Clone + cmp::Eq + Hash,
34    V: Clone,
35{
36    pub fn insert(&mut self, key: K, value: V) {
37        let previous = self.mapping.insert(key.clone(), value);
38        self.history.push((key, previous));
39    }
40
41    pub fn checkpoint(&self) -> StackMapCheckpoint {
42        StackMapCheckpoint(self.history.len())
43    }
44
45    pub fn restore(&mut self, checkpoint: StackMapCheckpoint) {
46        while self.history.len() > checkpoint.0 {
47            self.pop();
48        }
49    }
50
51    fn pop(&mut self) {
52        if let Some((key, value)) = self.history.pop() {
53            if let Some(value) = value {
54                self.mapping.insert(key, value);
55            } else {
56                self.mapping.remove(&key);
57            }
58        } else {
59            panic!("pop called more than push");
60        }
61    }
62
63    pub fn lookup(&self, key: &K) -> Option<&V> {
64        self.mapping.get(key)
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::StackMap;
71
72    #[test]
73    fn test_checkpoint_restore() {
74        let mut map = StackMap::default();
75        map.insert("a", 1);
76        map.insert("b", 2);
77        map.insert("c", 3);
78        assert_eq!(map.lookup(&"a"), Some(&1));
79        assert_eq!(map.lookup(&"b"), Some(&2));
80        assert_eq!(map.lookup(&"c"), Some(&3));
81        let checkpoint1 = map.checkpoint();
82
83        map.insert("a", 4);
84        map.insert("b", 5);
85        map.insert("c", 6);
86        assert_eq!(map.lookup(&"a"), Some(&4));
87        assert_eq!(map.lookup(&"b"), Some(&5));
88        assert_eq!(map.lookup(&"c"), Some(&6));
89        let checkpoint2 = map.checkpoint();
90
91        map.insert("a", 7);
92        map.insert("b", 8);
93        map.insert("c", 9);
94        assert_eq!(map.lookup(&"a"), Some(&7));
95        assert_eq!(map.lookup(&"b"), Some(&8));
96        assert_eq!(map.lookup(&"c"), Some(&9));
97
98        map.restore(checkpoint2);
99        assert_eq!(map.lookup(&"a"), Some(&4));
100        assert_eq!(map.lookup(&"b"), Some(&5));
101        assert_eq!(map.lookup(&"c"), Some(&6));
102
103        map.restore(checkpoint1);
104        assert_eq!(map.lookup(&"a"), Some(&1));
105        assert_eq!(map.lookup(&"b"), Some(&2));
106        assert_eq!(map.lookup(&"c"), Some(&3));
107    }
108}