hugr_model/v0/scope/symbol.rs
1use std::{borrow::Cow, hash::BuildHasherDefault};
2
3use fxhash::FxHasher;
4use indexmap::IndexMap;
5use thiserror::Error;
6
7use crate::v0::table::{NodeId, RegionId};
8
9type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
10
11/// Symbol binding table that keeps track of symbol resolution and scoping.
12///
13/// Nodes may introduce a symbol so that other parts of the IR can refer to the
14/// node. Symbols have an associated name and are scoped via regions. A symbol
15/// can shadow another symbol with the same name from an outer region, but
16/// within any single region each symbol name must be unique.
17///
18/// When a symbol is referred to directly by the id of the node, the symbol must
19/// be in scope at the point of reference as if the reference was by name. This
20/// guarantees that transformations between directly indexed and named formats
21/// are always valid.
22///
23/// # Examples
24///
25/// ```
26/// # pub use hugr_model::v0::table::{NodeId, RegionId};
27/// # pub use hugr_model::v0::scope::SymbolTable;
28/// let mut symbols = SymbolTable::new();
29/// symbols.enter(RegionId(0));
30/// symbols.insert("foo", NodeId(0)).unwrap();
31/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0));
32/// symbols.enter(RegionId(1));
33/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0));
34/// symbols.insert("foo", NodeId(1)).unwrap();
35/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(1));
36/// assert!(!symbols.is_visible(NodeId(0)));
37/// symbols.exit();
38/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0));
39/// assert!(symbols.is_visible(NodeId(0)));
40/// assert!(!symbols.is_visible(NodeId(1)));
41/// ```
42#[derive(Debug, Clone, Default)]
43pub struct SymbolTable<'a> {
44 symbols: FxIndexMap<&'a str, BindingIndex>,
45 bindings: FxIndexMap<NodeId, Binding>,
46 scopes: FxIndexMap<RegionId, Scope>,
47}
48
49impl<'a> SymbolTable<'a> {
50 /// Create a new symbol table.
51 pub fn new() -> Self {
52 Self::default()
53 }
54
55 /// Enter a new scope for the given region.
56 pub fn enter(&mut self, region: RegionId) {
57 self.scopes.insert(
58 region,
59 Scope {
60 binding_stack: self.bindings.len(),
61 },
62 );
63 }
64
65 /// Exit a previously entered scope.
66 ///
67 /// # Panics
68 ///
69 /// Panics if there are no remaining open scopes.
70 pub fn exit(&mut self) {
71 let (_, scope) = self.scopes.pop().unwrap();
72
73 for _ in scope.binding_stack..self.bindings.len() {
74 let (_, binding) = self.bindings.pop().unwrap();
75
76 if let Some(shadows) = binding.shadows {
77 self.symbols[binding.symbol_index] = shadows;
78 } else {
79 let last = self.symbols.pop();
80 debug_assert_eq!(last.unwrap().1, self.bindings.len());
81 }
82 }
83 }
84
85 /// Insert a new symbol into the current scope.
86 ///
87 /// # Errors
88 ///
89 /// Returns an error if the symbol is already defined in the current scope.
90 /// In the case of an error the table remains unchanged.
91 ///
92 /// # Panics
93 ///
94 /// Panics if there is no current scope.
95 pub fn insert(&mut self, name: &'a str, node: NodeId) -> Result<(), DuplicateSymbolError> {
96 let scope_depth = self.scopes.len() as u16 - 1;
97 let (symbol_index, shadowed) = self.symbols.insert_full(name, self.bindings.len());
98
99 if let Some(shadowed) = shadowed {
100 let (shadowed_node, shadowed_binding) = self.bindings.get_index(shadowed).unwrap();
101 if shadowed_binding.scope_depth == scope_depth {
102 self.symbols.insert(name, shadowed);
103 return Err(DuplicateSymbolError(name.into(), node, *shadowed_node));
104 }
105 }
106
107 self.bindings.insert(
108 node,
109 Binding {
110 scope_depth,
111 shadows: shadowed,
112 symbol_index,
113 },
114 );
115
116 Ok(())
117 }
118
119 /// Check whether a symbol is currently visible in the current scope.
120 pub fn is_visible(&self, node: NodeId) -> bool {
121 let Some(binding) = self.bindings.get(&node) else {
122 return false;
123 };
124
125 // Check that the symbol has not been shadowed at this point.
126 self.symbols[binding.symbol_index] == binding.symbol_index
127 }
128
129 /// Tries to resolve a symbol name in the current scope.
130 pub fn resolve(&self, name: &'a str) -> Result<NodeId, UnknownSymbolError> {
131 let index = *self
132 .symbols
133 .get(name)
134 .ok_or(UnknownSymbolError(name.into()))?;
135
136 // NOTE: The unwrap is safe because the `symbols` map
137 // points to valid indices in the `bindings` map.
138 let (node, _) = self.bindings.get_index(index).unwrap();
139 Ok(*node)
140 }
141
142 /// Returns the depth of the given region, if it corresponds to a currently open scope.
143 pub fn region_to_depth(&self, region: RegionId) -> Option<ScopeDepth> {
144 Some(self.scopes.get_index_of(®ion)? as _)
145 }
146
147 /// Returns the region corresponding to the scope at the given depth.
148 pub fn depth_to_region(&self, depth: ScopeDepth) -> Option<RegionId> {
149 let (region, _) = self.scopes.get_index(depth as _)?;
150 Some(*region)
151 }
152
153 /// Resets the symbol table to its initial state while maintaining its
154 /// allocated memory.
155 pub fn clear(&mut self) {
156 self.symbols.clear();
157 self.bindings.clear();
158 self.scopes.clear();
159 }
160}
161
162#[derive(Debug, Clone, Copy)]
163struct Binding {
164 /// The depth of the scope in which this binding is defined.
165 scope_depth: ScopeDepth,
166
167 /// The index of the binding that is shadowed by this one, if any.
168 shadows: Option<BindingIndex>,
169
170 /// The index of this binding's symbol in the symbol table.
171 ///
172 /// The symbol table always points to the currently visible binding for a
173 /// symbol. Therefore this index is only valid if this binding is not shadowed.
174 /// In particular, we detect shadowing by checking if the entry in the symbol
175 /// table at this index does indeed point to this binding.
176 symbol_index: SymbolIndex,
177}
178
179#[derive(Debug, Clone, Copy)]
180struct Scope {
181 /// The length of the `bindings` stack when this scope was entered.
182 binding_stack: usize,
183}
184
185type BindingIndex = usize;
186type SymbolIndex = usize;
187
188pub type ScopeDepth = u16;
189
190/// Error that occurs when trying to resolve an unknown symbol.
191#[derive(Debug, Clone, Error)]
192#[error("symbol name `{0}` not found in this scope")]
193pub struct UnknownSymbolError<'a>(pub Cow<'a, str>);
194
195/// Error that occurs when trying to introduce a symbol that is already defined in the current scope.
196#[derive(Debug, Clone, Error)]
197#[error("symbol `{0}` is already defined in this scope")]
198pub struct DuplicateSymbolError<'a>(pub Cow<'a, str>, pub NodeId, pub NodeId);