1use kyu_common::id::TableId;
7use kyu_common::{KyuError, KyuResult};
8use kyu_types::LogicalType;
9use smol_str::SmolStr;
10
11#[derive(Clone, Debug)]
13pub struct VariableInfo {
14 pub index: u32,
15 pub data_type: LogicalType,
16 pub table_id: Option<TableId>,
18 pub name: SmolStr,
19}
20
21pub struct BinderScope {
26 frames: Vec<ScopeFrame>,
27 next_index: u32,
28}
29
30struct ScopeFrame {
31 variables: Vec<(SmolStr, VariableInfo)>,
32}
33
34impl Default for BinderScope {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl BinderScope {
41 pub fn new() -> Self {
42 Self {
43 frames: vec![ScopeFrame {
44 variables: Vec::new(),
45 }],
46 next_index: 0,
47 }
48 }
49
50 pub fn push_frame(&mut self) {
52 self.frames.push(ScopeFrame {
53 variables: Vec::new(),
54 });
55 }
56
57 pub fn pop_frame(&mut self) {
59 if self.frames.len() > 1 {
60 self.frames.pop();
61 }
62 }
63
64 pub fn define(
68 &mut self,
69 name: &str,
70 data_type: LogicalType,
71 table_id: Option<TableId>,
72 ) -> KyuResult<VariableInfo> {
73 let lower = SmolStr::new(name.to_lowercase());
74
75 let frame = self.frames.last().unwrap();
77 if frame.variables.iter().any(|(n, _)| *n == lower) {
78 return Err(KyuError::Binder(format!(
79 "variable '{name}' already defined in this scope"
80 )));
81 }
82
83 let index = self.next_index;
84 self.next_index += 1;
85
86 let info = VariableInfo {
87 index,
88 data_type,
89 table_id,
90 name: lower.clone(),
91 };
92
93 let frame = self.frames.last_mut().unwrap();
94 frame.variables.push((lower, info.clone()));
95
96 Ok(info)
97 }
98
99 pub fn resolve(&self, name: &str) -> Option<&VariableInfo> {
101 let lower = name.to_lowercase();
102 for frame in self.frames.iter().rev() {
103 for (n, info) in &frame.variables {
104 if n.as_str() == lower {
105 return Some(info);
106 }
107 }
108 }
109 None
110 }
111
112 pub fn num_variables(&self) -> u32 {
114 self.next_index
115 }
116
117 pub fn current_variables(&self) -> &[(SmolStr, VariableInfo)] {
119 &self.frames.last().unwrap().variables
120 }
121
122 pub fn new_from_projection(&mut self, projected: Vec<(SmolStr, LogicalType)>) {
127 if self.frames.len() > 1 {
129 self.frames.pop();
130 } else {
131 self.frames.last_mut().unwrap().variables.clear();
132 }
133
134 let mut new_frame = ScopeFrame {
136 variables: Vec::with_capacity(projected.len()),
137 };
138 for (name, data_type) in projected {
139 let lower = SmolStr::new(name.to_lowercase());
140 let index = self.next_index;
141 self.next_index += 1;
142 new_frame.variables.push((
143 lower.clone(),
144 VariableInfo {
145 index,
146 data_type,
147 table_id: None,
148 name: lower,
149 },
150 ));
151 }
152 self.frames.push(new_frame);
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn define_and_resolve() {
162 let mut scope = BinderScope::new();
163 let info = scope.define("x", LogicalType::Int64, None).unwrap();
164 assert_eq!(info.index, 0);
165 assert_eq!(info.data_type, LogicalType::Int64);
166
167 let resolved = scope.resolve("x").unwrap();
168 assert_eq!(resolved.index, 0);
169 }
170
171 #[test]
172 fn case_insensitive_resolve() {
173 let mut scope = BinderScope::new();
174 scope
175 .define("Person", LogicalType::Node, Some(TableId(1)))
176 .unwrap();
177
178 assert!(scope.resolve("person").is_some());
179 assert!(scope.resolve("PERSON").is_some());
180 assert!(scope.resolve("Person").is_some());
181 }
182
183 #[test]
184 fn duplicate_in_same_frame_errors() {
185 let mut scope = BinderScope::new();
186 scope.define("x", LogicalType::Int64, None).unwrap();
187 assert!(scope.define("x", LogicalType::String, None).is_err());
188 }
189
190 #[test]
191 fn sequential_indices() {
192 let mut scope = BinderScope::new();
193 let a = scope.define("a", LogicalType::Int64, None).unwrap();
194 let b = scope.define("b", LogicalType::String, None).unwrap();
195 assert_eq!(a.index, 0);
196 assert_eq!(b.index, 1);
197 assert_eq!(scope.num_variables(), 2);
198 }
199
200 #[test]
201 fn push_pop_frame() {
202 let mut scope = BinderScope::new();
203 scope.define("outer", LogicalType::Int64, None).unwrap();
204
205 scope.push_frame();
206 scope.define("inner", LogicalType::String, None).unwrap();
207
208 assert!(scope.resolve("outer").is_some());
210 assert!(scope.resolve("inner").is_some());
211
212 scope.pop_frame();
213 assert!(scope.resolve("outer").is_some());
215 assert!(scope.resolve("inner").is_none());
216 }
217
218 #[test]
219 fn inner_frame_shadows_outer() {
220 let mut scope = BinderScope::new();
221 scope.define("x", LogicalType::Int64, None).unwrap();
222
223 scope.push_frame();
224 scope.define("x", LogicalType::String, None).unwrap();
225
226 let info = scope.resolve("x").unwrap();
227 assert_eq!(info.data_type, LogicalType::String);
228 }
229
230 #[test]
231 fn resolve_not_found() {
232 let scope = BinderScope::new();
233 assert!(scope.resolve("nonexistent").is_none());
234 }
235
236 #[test]
237 fn current_variables() {
238 let mut scope = BinderScope::new();
239 scope.define("a", LogicalType::Int64, None).unwrap();
240 scope.define("b", LogicalType::String, None).unwrap();
241
242 let vars = scope.current_variables();
243 assert_eq!(vars.len(), 2);
244 assert_eq!(vars[0].0.as_str(), "a");
245 assert_eq!(vars[1].0.as_str(), "b");
246 }
247
248 #[test]
249 fn new_from_projection() {
250 let mut scope = BinderScope::new();
251 scope.define("old_var", LogicalType::Int64, None).unwrap();
252
253 scope.new_from_projection(vec![
254 (SmolStr::new("name"), LogicalType::String),
255 (SmolStr::new("age"), LogicalType::Int64),
256 ]);
257
258 assert!(scope.resolve("old_var").is_none());
260
261 assert!(scope.resolve("name").is_some());
263 assert!(scope.resolve("age").is_some());
264 assert_eq!(scope.current_variables().len(), 2);
265 }
266
267 #[test]
268 fn table_id_preserved() {
269 let mut scope = BinderScope::new();
270 let info = scope
271 .define("p", LogicalType::Node, Some(TableId(42)))
272 .unwrap();
273 assert_eq!(info.table_id, Some(TableId(42)));
274
275 let resolved = scope.resolve("p").unwrap();
276 assert_eq!(resolved.table_id, Some(TableId(42)));
277 }
278}