1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use std::cmp;
use std::collections::HashMap;
use std::hash::Hash;

#[derive(Debug, Clone)]
pub struct StackMap<K, V> {
    mapping: HashMap<K, V>,
    history: Vec<(K, Option<V>)>,
}

impl<K, V> std::default::Default for StackMap<K, V> {
    fn default() -> Self {
        Self {
            mapping: Default::default(),
            history: Default::default(),
        }
    }
}

pub struct StackMapCheckpoint(usize);

impl<K, V> From<HashMap<K, V>> for StackMap<K, V> {
    fn from(value: HashMap<K, V>) -> Self {
        Self {
            mapping: value,
            history: Default::default(),
        }
    }
}

impl<K, V> StackMap<K, V>
where
    K: Clone + cmp::Eq + Hash,
    V: Clone,
{
    pub fn insert(&mut self, key: K, value: V) {
        let previous = self.mapping.insert(key.clone(), value);
        self.history.push((key, previous));
    }

    pub fn checkpoint(&self) -> StackMapCheckpoint {
        StackMapCheckpoint(self.history.len())
    }

    pub fn restore(&mut self, checkpoint: StackMapCheckpoint) {
        while self.history.len() > checkpoint.0 {
            self.pop();
        }
    }

    fn pop(&mut self) {
        if let Some((key, value)) = self.history.pop() {
            if let Some(value) = value {
                self.mapping.insert(key, value);
            } else {
                self.mapping.remove(&key);
            }
        } else {
            panic!("pop called more than push");
        }
    }

    pub fn lookup(&self, key: &K) -> Option<&V> {
        self.mapping.get(key)
    }
}

#[cfg(test)]
mod tests {
    use super::StackMap;

    #[test]
    fn test_checkpoint_restore() {
        let mut map = StackMap::default();
        map.insert("a", 1);
        map.insert("b", 2);
        map.insert("c", 3);
        assert_eq!(map.lookup(&"a"), Some(&1));
        assert_eq!(map.lookup(&"b"), Some(&2));
        assert_eq!(map.lookup(&"c"), Some(&3));
        let checkpoint1 = map.checkpoint();

        map.insert("a", 4);
        map.insert("b", 5);
        map.insert("c", 6);
        assert_eq!(map.lookup(&"a"), Some(&4));
        assert_eq!(map.lookup(&"b"), Some(&5));
        assert_eq!(map.lookup(&"c"), Some(&6));
        let checkpoint2 = map.checkpoint();

        map.insert("a", 7);
        map.insert("b", 8);
        map.insert("c", 9);
        assert_eq!(map.lookup(&"a"), Some(&7));
        assert_eq!(map.lookup(&"b"), Some(&8));
        assert_eq!(map.lookup(&"c"), Some(&9));

        map.restore(checkpoint2);
        assert_eq!(map.lookup(&"a"), Some(&4));
        assert_eq!(map.lookup(&"b"), Some(&5));
        assert_eq!(map.lookup(&"c"), Some(&6));

        map.restore(checkpoint1);
        assert_eq!(map.lookup(&"a"), Some(&1));
        assert_eq!(map.lookup(&"b"), Some(&2));
        assert_eq!(map.lookup(&"c"), Some(&3));
    }
}