hugr_model/v0/scope/symbol.rs
1use std::{borrow::Cow, hash::BuildHasherDefault};
2
3use indexmap::IndexMap;
4use rustc_hash::FxHasher;
5use thiserror::Error;
6
7use crate::v0::table::{NodeId, RegionId};
8
9type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<FxHasher>>;
10
11/// Symbol binding table that keeps track of symbol resolution and scoping.
12///
13/// Nodes may introduce a symbol so that other parts of the IR can refer to the
14/// node. Symbols have an associated name and are scoped via regions. A symbol
15/// can shadow another symbol with the same name from an outer region, but
16/// within any single region each symbol name must be unique.
17///
18/// When a symbol is referred to directly by the id of the node, the symbol must
19/// be in scope at the point of reference as if the reference was by name. This
20/// guarantees that transformations between directly indexed and named formats
21/// are always valid.
22///
23/// # Examples
24///
25/// ```
26/// # pub use hugr_model::v0::table::{NodeId, RegionId};
27/// # pub use hugr_model::v0::scope::SymbolTable;
28/// let mut symbols = SymbolTable::new();
29/// symbols.enter(RegionId(0));
30/// symbols.insert("foo", NodeId(0)).unwrap();
31/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0));
32/// symbols.enter(RegionId(1));
33/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0));
34/// symbols.insert("foo", NodeId(1)).unwrap();
35/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(1));
36/// assert!(!symbols.is_visible(NodeId(0)));
37/// symbols.exit();
38/// assert_eq!(symbols.resolve("foo").unwrap(), NodeId(0));
39/// assert!(symbols.is_visible(NodeId(0)));
40/// assert!(!symbols.is_visible(NodeId(1)));
41/// ```
42#[derive(Debug, Clone, Default)]
43pub struct SymbolTable<'a> {
44 symbols: FxIndexMap<&'a str, BindingIndex>,
45 bindings: FxIndexMap<NodeId, Binding>,
46 scopes: FxIndexMap<RegionId, Scope>,
47}
48
49impl<'a> SymbolTable<'a> {
50 /// Create a new symbol table.
51 #[must_use]
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 /// Enter a new scope for the given region.
57 pub fn enter(&mut self, region: RegionId) {
58 self.scopes.insert(
59 region,
60 Scope {
61 binding_stack: self.bindings.len(),
62 },
63 );
64 }
65
66 /// Exit a previously entered scope.
67 ///
68 /// # Panics
69 ///
70 /// Panics if there are no remaining open scopes.
71 pub fn exit(&mut self) {
72 let (_, scope) = self.scopes.pop().unwrap();
73
74 for _ in scope.binding_stack..self.bindings.len() {
75 let (_, binding) = self.bindings.pop().unwrap();
76
77 if let Some(shadows) = binding.shadows {
78 self.symbols[binding.symbol_index] = shadows;
79 } else {
80 let last = self.symbols.pop();
81 debug_assert_eq!(last.unwrap().1, self.bindings.len());
82 }
83 }
84 }
85
86 /// Insert a new symbol into the current scope.
87 ///
88 /// # Errors
89 ///
90 /// Returns an error if the symbol is already defined in the current scope.
91 /// In the case of an error the table remains unchanged.
92 ///
93 /// # Panics
94 ///
95 /// Panics if there is no current scope.
96 pub fn insert(&mut self, name: &'a str, node: NodeId) -> Result<(), DuplicateSymbolError<'_>> {
97 let scope_depth = self.scopes.len() as u16 - 1;
98 let (symbol_index, shadowed) = self.symbols.insert_full(name, self.bindings.len());
99
100 if let Some(shadowed) = shadowed {
101 let (shadowed_node, shadowed_binding) = self.bindings.get_index(shadowed).unwrap();
102 if shadowed_binding.scope_depth == scope_depth {
103 self.symbols.insert(name, shadowed);
104 return Err(DuplicateSymbolError(name.into(), node, *shadowed_node));
105 }
106 }
107
108 self.bindings.insert(
109 node,
110 Binding {
111 scope_depth,
112 shadows: shadowed,
113 symbol_index,
114 },
115 );
116
117 Ok(())
118 }
119
120 /// Check whether a symbol is currently visible in the current scope.
121 #[must_use]
122 pub fn is_visible(&self, node: NodeId) -> bool {
123 let Some(binding) = self.bindings.get(&node) else {
124 return false;
125 };
126
127 // Check that the symbol has not been shadowed at this point.
128 self.symbols[binding.symbol_index] == binding.symbol_index
129 }
130
131 /// Tries to resolve a symbol name in the current scope.
132 pub fn resolve(&self, name: &'a str) -> Result<NodeId, UnknownSymbolError<'_>> {
133 let index = *self
134 .symbols
135 .get(name)
136 .ok_or(UnknownSymbolError(name.into()))?;
137
138 // NOTE: The unwrap is safe because the `symbols` map
139 // points to valid indices in the `bindings` map.
140 let (node, _) = self.bindings.get_index(index).unwrap();
141 Ok(*node)
142 }
143
144 /// Returns the depth of the given region, if it corresponds to a currently open scope.
145 #[must_use]
146 pub fn region_to_depth(&self, region: RegionId) -> Option<ScopeDepth> {
147 Some(self.scopes.get_index_of(®ion)? as _)
148 }
149
150 /// Returns the region corresponding to the scope at the given depth.
151 #[must_use]
152 pub fn depth_to_region(&self, depth: ScopeDepth) -> Option<RegionId> {
153 let (region, _) = self.scopes.get_index(depth as _)?;
154 Some(*region)
155 }
156
157 /// Resets the symbol table to its initial state while maintaining its
158 /// allocated memory.
159 pub fn clear(&mut self) {
160 self.symbols.clear();
161 self.bindings.clear();
162 self.scopes.clear();
163 }
164}
165
166#[derive(Debug, Clone, Copy)]
167struct Binding {
168 /// The depth of the scope in which this binding is defined.
169 scope_depth: ScopeDepth,
170
171 /// The index of the binding that is shadowed by this one, if any.
172 shadows: Option<BindingIndex>,
173
174 /// The index of this binding's symbol in the symbol table.
175 ///
176 /// The symbol table always points to the currently visible binding for a
177 /// symbol. Therefore this index is only valid if this binding is not shadowed.
178 /// In particular, we detect shadowing by checking if the entry in the symbol
179 /// table at this index does indeed point to this binding.
180 symbol_index: SymbolIndex,
181}
182
183#[derive(Debug, Clone, Copy)]
184struct Scope {
185 /// The length of the `bindings` stack when this scope was entered.
186 binding_stack: usize,
187}
188
189type BindingIndex = usize;
190type SymbolIndex = usize;
191
192pub type ScopeDepth = u16;
193
194/// Error that occurs when trying to resolve an unknown symbol.
195#[derive(Debug, Clone, Error)]
196#[error("symbol name `{0}` not found in this scope")]
197pub struct UnknownSymbolError<'a>(pub Cow<'a, str>);
198
199/// Error that occurs when trying to introduce a symbol that is already defined in the current scope.
200#[derive(Debug, Clone, Error)]
201#[error("symbol `{0}` is already defined in this scope")]
202pub struct DuplicateSymbolError<'a>(pub Cow<'a, str>, pub NodeId, pub NodeId);