1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/*!
A simple symbol table struct, used to build a rain AST into a rain IR graph
*/

use std::hash::Hash;
use indexmap::{IndexMap, Equivalent};

/// A simple, generic symbol table
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct SymbolTable<K: Hash + Eq, V> {
    symbols: IndexMap<K, Vec<(V, usize)>>,
    scopes: Vec<Vec<usize>>
}

impl<K: Hash + Eq, V> SymbolTable<K, V> {
    /// Create a new, empty symbol table
    pub fn new() -> SymbolTable<K, V> {
        Self::with_capacity(0)
    }
    /// Create a symbol table with a given capacity
    pub fn with_capacity(n: usize) -> SymbolTable<K, V> {
        SymbolTable {
            symbols: IndexMap::with_capacity(n),
            scopes: vec![Vec::new()]
        }
    }
    /// Get the current depth
    pub fn depth(&self) -> usize { self.scopes.len() - 1 }
    /// Register a given symbol at the current depth, returning the current definition at
    /// the current depth, if any.
    pub fn def(&mut self, key: K, mut value: V) -> Option<V> {
        let depth = self.depth();
        let entry = self.symbols.entry(key);
        let index = entry.index();
        let v = entry.or_insert_with(Vec::new);
        if let Some((old_value, old_depth)) = v.last_mut() {
            if depth == *old_depth {
                std::mem::swap(old_value, &mut value);
                return Some(value)
            }
        }
        v.push((value, depth));
        self.scopes.last_mut().unwrap().push(index);
        None
    }
    /// Get the definition of a current symbol, along with its depth, if any
    pub fn get_full<Q>(&self, key: &Q) -> Option<(&V, usize)>
    where Q: ?Sized + Hash + Equivalent<K> {
        self.symbols.get(key).map(|v| v.last().map(|(v, d)| (v, *d))).flatten()
    }
    /// Get the definition of a current symbol
    pub fn get<Q>(&self, key: &Q) -> Option<&V>
    where Q: ?Sized + Hash + Equivalent<K> {
        self.get_full(key).map(|(v, _)| v)
    }
    /// Mutably get the definition of a current symbol, along with its depth, if any
    pub fn get_full_mut<Q>(&mut self, key: &Q) -> Option<(&mut V, usize)>
    where Q: ?Sized + Hash + Equivalent<K> {
        self.symbols.get_mut(key).map(|v| v.last_mut().map(|(v, d)| (v, *d))).flatten()
    }
    /// Try to mutably get the definition of a current symbol at the current depth
    pub fn try_get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
    where Q: ?Sized + Hash + Equivalent<K> {
        let curr_depth = self.depth();
        if let Some((value, depth)) = self.get_full_mut(key) {
            if depth == curr_depth { Some(value) } else { None }
        } else { None }
    }
    //TODO: get_mut
    /// Get the mutable definition of a current symbol, along with its depth, if any
    /// Jump to a given depth, removing obsolete definitions.
    /// Return the number of keys and definitions removed, as well as keys touched, if any.
    pub fn jump_to_depth(&mut self, depth: usize) {
        let target = depth + 1;
        while target > self.scopes.len() {
            self.scopes.push(Vec::new());
        }
        while self.scopes.len() > target {
            for ix in self.scopes.pop().unwrap() {
                let (_, v) = if let Some(v) = self.symbols.get_index_mut(ix) { v }
                    else { continue };
                v.pop();
            }
        }
    }
    /// Add a level of depth
    pub fn push(&mut self) { self.jump_to_depth(self.depth() + 1); }
    /// Try to remove a level of depth. Does nothing if depth  = 0
    pub fn pop(&mut self) {
        self.jump_to_depth(self.depth().saturating_sub(1))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn two_layer_symbol_table_works() {
        let mut symbols = SymbolTable::<&str, usize>::new();
        symbols.def("x", 3);
        symbols.def("y", 7);
        symbols.push();
        symbols.def("x", 9);
        symbols.def("z", 1);
        assert_eq!(symbols.get_full("x"), Some((&9, 1)));
        assert_eq!(symbols.get_full("y"), Some((&7, 0)));
        assert_eq!(symbols.get_full("z"), Some((&1, 1)));
        assert_eq!(symbols.try_get_mut("x"), Some(&mut 9));
        assert_eq!(symbols.try_get_mut("y"), None);
        assert_eq!(symbols.try_get_mut("z"), Some(&mut 1));
        symbols.pop();
        assert_eq!(symbols.get_full("x"), Some((&3, 0)));
        assert_eq!(symbols.get_full("y"), Some((&7, 0)));
        assert_eq!(symbols.get_full("z"), None);
        assert_eq!(symbols.try_get_mut("x"), Some(&mut 3));
        assert_eq!(symbols.try_get_mut("y"), Some(&mut 7));
        assert_eq!(symbols.try_get_mut("z"), None);
    }
}