1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use shape_ast::ast::{FunctionDef, VarKind};
7use shape_value::ValueWord;
8
9#[derive(Debug, Clone)]
11pub struct Closure {
12 pub function: Arc<FunctionDef>,
14 pub captured_env: CapturedEnvironment,
16}
17
18impl PartialEq for Closure {
19 fn eq(&self, other: &Self) -> bool {
20 Arc::ptr_eq(&self.function, &other.function) && self.captured_env == other.captured_env
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct CapturedEnvironment {
29 pub bindings: HashMap<String, CapturedBinding>,
31 pub parent: Option<Box<CapturedEnvironment>>,
33}
34
35#[derive(Debug, Clone)]
37pub struct CapturedBinding {
38 pub value: ValueWord,
40 pub kind: VarKind,
42 pub is_mutable: bool,
44}
45
46impl PartialEq for CapturedBinding {
47 fn eq(&self, other: &Self) -> bool {
48 self.kind == other.kind
49 && self.is_mutable == other.is_mutable
50 && self.value.vw_equals(&other.value)
51 }
52}
53
54impl PartialEq for CapturedEnvironment {
55 fn eq(&self, other: &Self) -> bool {
56 self.bindings == other.bindings && self.parent == other.parent
57 }
58}
59
60impl Default for CapturedEnvironment {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl CapturedEnvironment {
67 pub fn new() -> Self {
69 Self {
70 bindings: HashMap::new(),
71 parent: None,
72 }
73 }
74
75 pub fn with_parent(parent: CapturedEnvironment) -> Self {
77 Self {
78 bindings: HashMap::new(),
79 parent: Some(Box::new(parent)),
80 }
81 }
82
83 pub fn capture(&mut self, name: String, value: ValueWord, kind: VarKind) {
85 let is_mutable = matches!(kind, VarKind::Var);
86 self.bindings.insert(
87 name,
88 CapturedBinding {
89 value,
90 kind,
91 is_mutable,
92 },
93 );
94 }
95
96 pub fn lookup(&self, name: &str) -> Option<&CapturedBinding> {
98 self.bindings
99 .get(name)
100 .or_else(|| self.parent.as_ref().and_then(|p| p.lookup(name)))
101 }
102
103 pub fn lookup_mut(&mut self, name: &str) -> Option<&mut CapturedBinding> {
105 if self.bindings.contains_key(name) {
106 self.bindings.get_mut(name)
107 } else if let Some(parent) = &mut self.parent {
108 parent.lookup_mut(name)
109 } else {
110 None
111 }
112 }
113
114 pub fn all_captured_names(&self) -> Vec<String> {
116 let mut names: Vec<String> = self.bindings.keys().cloned().collect();
117
118 if let Some(parent) = &self.parent {
119 for name in parent.all_captured_names() {
120 if !names.contains(&name) {
121 names.push(name);
122 }
123 }
124 }
125
126 names
127 }
128}
129
130pub struct EnvironmentAnalyzer {
132 scope_stack: Vec<HashMap<String, bool>>, captured_vars: HashMap<String, usize>, mutated_captures: HashSet<String>,
138 function_scope_level: usize,
143}
144
145impl Default for EnvironmentAnalyzer {
146 fn default() -> Self {
147 Self {
148 scope_stack: vec![HashMap::new()],
149 captured_vars: HashMap::new(),
150 mutated_captures: HashSet::new(),
151 function_scope_level: 1,
152 }
153 }
154}
155
156impl EnvironmentAnalyzer {
157 pub fn new() -> Self {
158 Self {
159 scope_stack: vec![HashMap::new()], captured_vars: HashMap::new(),
161 mutated_captures: HashSet::new(),
162 function_scope_level: 1, }
164 }
165
166 pub fn enter_scope(&mut self) {
168 self.scope_stack.push(HashMap::new());
169 }
170
171 pub fn exit_scope(&mut self) {
173 self.scope_stack.pop();
174 }
175
176 pub fn define_variable(&mut self, name: &str) {
178 if let Some(current_scope) = self.scope_stack.last_mut() {
179 current_scope.insert(name.to_string(), true);
180 }
181 }
182
183 pub fn check_variable_reference(&mut self, name: &str) {
189 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
191 if scope.contains_key(name) {
192 if level < self.function_scope_level {
195 self.captured_vars.insert(name.to_string(), level);
196 }
197 return;
198 }
199 }
200 }
201
202 pub fn mark_capture_mutated(&mut self, name: &str) {
204 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
206 if scope.contains_key(name) {
207 if level < self.function_scope_level {
208 self.captured_vars.insert(name.to_string(), level);
209 self.mutated_captures.insert(name.to_string());
210 }
211 return;
212 }
213 }
214 }
215
216 pub fn get_captured_vars(&self) -> Vec<String> {
218 self.captured_vars.keys().cloned().collect()
219 }
220
221 pub fn get_mutated_captures(&self) -> HashSet<String> {
223 self.mutated_captures.clone()
224 }
225
226 pub fn analyze_function(function: &FunctionDef, outer_scope_vars: &[String]) -> Vec<String> {
228 let mut analyzer = Self::new();
229
230 for var in outer_scope_vars {
232 analyzer.define_variable(var);
233 }
234
235 analyzer.enter_scope();
237 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
240
241 for param in &function.params {
243 for name in param.get_identifiers() {
244 analyzer.define_variable(&name);
245 }
246 }
247
248 for stmt in &function.body {
250 analyzer.analyze_statement(stmt);
251 }
252
253 analyzer.get_captured_vars()
254 }
255
256 pub fn analyze_function_with_mutability(
259 function: &FunctionDef,
260 outer_scope_vars: &[String],
261 ) -> (Vec<String>, HashSet<String>) {
262 let mut analyzer = Self::new();
263
264 for var in outer_scope_vars {
266 analyzer.define_variable(var);
267 }
268
269 analyzer.enter_scope();
271 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
272
273 for param in &function.params {
275 for name in param.get_identifiers() {
276 analyzer.define_variable(&name);
277 }
278 }
279
280 for stmt in &function.body {
282 analyzer.analyze_statement(stmt);
283 }
284
285 (
286 analyzer.get_captured_vars(),
287 analyzer.get_mutated_captures(),
288 )
289 }
290
291 fn analyze_statement(&mut self, stmt: &shape_ast::ast::Statement) {
293 use shape_ast::ast::Statement;
294
295 match stmt {
296 Statement::Return(expr, _) => {
297 if let Some(expr) = expr {
298 self.analyze_expr(expr);
299 }
300 }
301 Statement::Expression(expr, _) => {
302 self.analyze_expr(expr);
303 }
304 Statement::VariableDecl(decl, _) => {
305 if let Some(value) = &decl.value {
307 self.analyze_expr(value);
308 }
309 if let Some(name) = decl.pattern.as_identifier() {
311 self.define_variable(name);
312 } else {
313 for name in decl.pattern.get_identifiers() {
315 self.define_variable(&name);
316 }
317 }
318 }
319 Statement::Assignment(assign, _) => {
320 self.analyze_expr(&assign.value);
321 if let Some(name) = assign.pattern.as_identifier() {
322 self.mark_capture_mutated(name);
324 self.check_variable_reference(name);
325 } else {
326 for name in assign.pattern.get_identifiers() {
328 self.mark_capture_mutated(&name);
329 self.check_variable_reference(&name);
330 }
331 }
332 }
333 Statement::If(if_stmt, _) => {
334 self.analyze_expr(&if_stmt.condition);
335 self.enter_scope();
336 for stmt in &if_stmt.then_body {
337 self.analyze_statement(stmt);
338 }
339 self.exit_scope();
340
341 if let Some(else_body) = &if_stmt.else_body {
342 self.enter_scope();
343 for stmt in else_body {
344 self.analyze_statement(stmt);
345 }
346 self.exit_scope();
347 }
348 }
349 Statement::While(while_loop, _) => {
350 self.analyze_expr(&while_loop.condition);
351 self.enter_scope();
352 for stmt in &while_loop.body {
353 self.analyze_statement(stmt);
354 }
355 self.exit_scope();
356 }
357 Statement::For(for_loop, _) => {
358 self.enter_scope();
359
360 match &for_loop.init {
362 shape_ast::ast::ForInit::ForIn { pattern, iter } => {
363 self.analyze_expr(iter);
364 for name in pattern.get_identifiers() {
366 self.define_variable(&name);
367 }
368 }
369 shape_ast::ast::ForInit::ForC {
370 init: _,
371 condition,
372 update,
373 } => {
374 self.analyze_expr(condition);
377 self.analyze_expr(update);
378 }
379 }
380
381 for stmt in &for_loop.body {
382 self.analyze_statement(stmt);
383 }
384
385 self.exit_scope();
386 }
387 Statement::Break(_) | Statement::Continue(_) => {
388 }
390 Statement::Extend(ext, _) => {
391 for method in &ext.methods {
392 self.enter_scope();
393 self.define_variable("self");
394 for param in &method.params {
395 for name in param.get_identifiers() {
396 self.define_variable(&name);
397 }
398 }
399 for stmt in &method.body {
400 self.analyze_statement(stmt);
401 }
402 self.exit_scope();
403 }
404 }
405 Statement::RemoveTarget(_) => {}
406 Statement::SetParamType { .. }
407 | Statement::SetReturnType { .. }
408 | Statement::SetReturnExpr { .. } => {}
409 Statement::SetParamValue { expression, .. } => {
410 self.analyze_expr(expression);
411 }
412 Statement::ReplaceModuleExpr { expression, .. } => {
413 self.analyze_expr(expression);
414 }
415 Statement::ReplaceBodyExpr { expression, .. } => {
416 self.analyze_expr(expression);
417 }
418 Statement::ReplaceBody { body, .. } => {
419 for stmt in body {
420 self.analyze_statement(stmt);
421 }
422 }
423 }
424 }
425
426 fn analyze_expr(&mut self, expr: &shape_ast::ast::Expr) {
428 use shape_ast::ast::Expr;
429
430 match expr {
431 Expr::Identifier(name, _) => {
432 self.check_variable_reference(name);
433 }
434 Expr::Literal(..)
435 | Expr::DataRef(..)
436 | Expr::DataDateTimeRef(..)
437 | Expr::TimeRef(..)
438 | Expr::PatternRef(..) => {
439 }
441 Expr::DataRelativeAccess {
442 reference,
443 index: _,
444 ..
445 } => {
446 self.analyze_expr(reference);
447 }
449 Expr::BinaryOp { left, right, .. } => {
450 self.analyze_expr(left);
451 self.analyze_expr(right);
452 }
453 Expr::FuzzyComparison { left, right, .. } => {
454 self.analyze_expr(left);
455 self.analyze_expr(right);
456 }
457 Expr::UnaryOp { operand, .. } => {
458 self.analyze_expr(operand);
459 }
460 Expr::FunctionCall { name, args, .. } => {
461 self.check_variable_reference(name);
464 for arg in args {
465 self.analyze_expr(arg);
466 }
467 }
468 Expr::QualifiedFunctionCall {
469 namespace,
470 args,
471 ..
472 } => {
473 self.check_variable_reference(namespace);
474 for arg in args {
475 self.analyze_expr(arg);
476 }
477 }
478 Expr::EnumConstructor { payload, .. } => {
479 use shape_ast::ast::EnumConstructorPayload;
480 match payload {
481 EnumConstructorPayload::Unit => {}
482 EnumConstructorPayload::Tuple(values) => {
483 for value in values {
484 self.analyze_expr(value);
485 }
486 }
487 EnumConstructorPayload::Struct(fields) => {
488 for (_, value) in fields {
489 self.analyze_expr(value);
490 }
491 }
492 }
493 }
494 Expr::PropertyAccess { object, .. } => {
495 self.analyze_expr(object);
496 }
497 Expr::Conditional {
498 condition,
499 then_expr,
500 else_expr,
501 ..
502 } => {
503 self.analyze_expr(condition);
504 self.analyze_expr(then_expr);
505 if let Some(else_e) = else_expr {
506 self.analyze_expr(else_e);
507 }
508 }
509 Expr::Array(elements, _) => {
510 for elem in elements {
511 self.analyze_expr(elem);
512 }
513 }
514 Expr::TableRows(rows, _) => {
515 for row in rows {
516 for elem in row {
517 self.analyze_expr(elem);
518 }
519 }
520 }
521 Expr::ListComprehension(comp, _) => {
522 self.enter_scope();
524
525 for clause in &comp.clauses {
527 for name in clause.pattern.get_identifiers() {
529 self.define_variable(&name);
530 }
531
532 self.analyze_expr(&clause.iterable);
534
535 if let Some(filter) = &clause.filter {
537 self.analyze_expr(filter);
538 }
539 }
540
541 self.analyze_expr(&comp.element);
543
544 self.exit_scope();
545 }
546 Expr::Object(entries, _) => {
547 use shape_ast::ast::ObjectEntry;
548 for entry in entries {
549 match entry {
550 ObjectEntry::Field { value, .. } => self.analyze_expr(value),
551 ObjectEntry::Spread(spread_expr) => self.analyze_expr(spread_expr),
552 }
553 }
554 }
555 Expr::IndexAccess {
556 object,
557 index,
558 end_index,
559 ..
560 } => {
561 self.analyze_expr(object);
562 self.analyze_expr(index);
563 if let Some(end) = end_index {
564 self.analyze_expr(end);
565 }
566 }
567 Expr::Block(block, _) => {
568 self.enter_scope();
569 for item in &block.items {
570 match item {
571 shape_ast::ast::BlockItem::VariableDecl(decl) => {
572 if let Some(value) = &decl.value {
573 self.analyze_expr(value);
574 }
575 if let Some(name) = decl.pattern.as_identifier() {
576 self.define_variable(name);
577 }
578 }
579 shape_ast::ast::BlockItem::Assignment(assign) => {
580 self.analyze_expr(&assign.value);
581 if let Some(name) = assign.pattern.as_identifier() {
582 self.mark_capture_mutated(name);
583 self.check_variable_reference(name);
584 } else {
585 for name in assign.pattern.get_identifiers() {
586 self.mark_capture_mutated(&name);
587 self.check_variable_reference(&name);
588 }
589 }
590 }
591 shape_ast::ast::BlockItem::Statement(stmt) => {
592 self.analyze_statement(stmt);
593 }
594 shape_ast::ast::BlockItem::Expression(expr) => {
595 self.analyze_expr(expr);
596 }
597 }
598 }
599 self.exit_scope();
600 }
601 Expr::TypeAssertion { expr, .. } => {
602 self.analyze_expr(expr);
603 }
604 Expr::InstanceOf { expr, .. } => {
605 self.analyze_expr(expr);
606 }
607 Expr::FunctionExpr {
608 params,
609 return_type: _,
610 body,
611 ..
612 } => {
613 let saved_function_scope_level = self.function_scope_level;
615 self.enter_scope();
616 self.function_scope_level = self.scope_stack.len() - 1;
617
618 for param in params {
619 for name in param.get_identifiers() {
620 self.define_variable(&name);
621 }
622 }
623
624 for stmt in body {
625 self.analyze_statement(stmt);
626 }
627
628 self.exit_scope();
629 self.function_scope_level = saved_function_scope_level;
630
631 self.captured_vars
635 .retain(|_, level| *level < saved_function_scope_level);
636 self.mutated_captures
637 .retain(|name| self.captured_vars.contains_key(name));
638 }
639 Expr::Duration(..) => {
640 }
642
643 Expr::If(if_expr, _) => {
645 self.analyze_expr(&if_expr.condition);
646 self.analyze_expr(&if_expr.then_branch);
647 if let Some(else_branch) = &if_expr.else_branch {
648 self.analyze_expr(else_branch);
649 }
650 }
651
652 Expr::While(while_expr, _) => {
653 self.analyze_expr(&while_expr.condition);
654 self.analyze_expr(&while_expr.body);
655 }
656
657 Expr::For(for_expr, _) => {
658 self.enter_scope();
659 self.analyze_pattern(&for_expr.pattern);
661 self.analyze_expr(&for_expr.iterable);
662 self.analyze_expr(&for_expr.body);
663 self.exit_scope();
664 }
665
666 Expr::Loop(loop_expr, _) => {
667 self.analyze_expr(&loop_expr.body);
668 }
669
670 Expr::Let(let_expr, _) => {
671 if let Some(value) = &let_expr.value {
672 self.analyze_expr(value);
673 }
674 self.enter_scope();
675 self.analyze_pattern(&let_expr.pattern);
676 self.analyze_expr(&let_expr.body);
677 self.exit_scope();
678 }
679
680 Expr::Assign(assign, _) => {
681 self.analyze_expr(&assign.value);
682 self.analyze_expr(&assign.target);
683 }
684
685 Expr::Break(value, _) => {
686 if let Some(val) = value {
687 self.analyze_expr(val);
688 }
689 }
690
691 Expr::Continue(_) => {
692 }
694
695 Expr::Return(value, _) => {
696 if let Some(val) = value {
697 self.analyze_expr(val);
698 }
699 }
700
701 Expr::MethodCall { receiver, args, .. } => {
702 self.analyze_expr(receiver);
703 for arg in args {
704 self.analyze_expr(arg);
705 }
706 }
707
708 Expr::Match(match_expr, _) => {
709 self.analyze_expr(&match_expr.scrutinee);
710 for arm in &match_expr.arms {
711 self.enter_scope();
712 self.analyze_pattern(&arm.pattern);
713 if let Some(guard) = &arm.guard {
714 self.analyze_expr(guard);
715 }
716 self.analyze_expr(&arm.body);
717 self.exit_scope();
718 }
719 }
720
721 Expr::Unit(_) => {
722 }
724
725 Expr::Spread(inner_expr, _) => {
726 self.analyze_expr(inner_expr);
728 }
729
730 Expr::DateTime(..) => {
731 }
733 Expr::Range { start, end, .. } => {
734 if let Some(s) = start {
736 self.analyze_expr(s);
737 }
738 if let Some(e) = end {
739 self.analyze_expr(e);
740 }
741 }
742
743 Expr::TimeframeContext { expr, .. } => {
744 self.analyze_expr(expr);
746 }
747
748 Expr::TryOperator(inner, _) => {
749 self.analyze_expr(inner);
751 }
752 Expr::UsingImpl { expr, .. } => {
753 self.analyze_expr(expr);
754 }
755
756 Expr::Await(inner, _) => {
757 self.analyze_expr(inner);
759 }
760
761 Expr::SimulationCall { params, .. } => {
762 for (_, value_expr) in params {
765 self.analyze_expr(value_expr);
766 }
767 }
768
769 Expr::WindowExpr(_, _) => {
770 }
772
773 Expr::FromQuery(from_query, _) => {
774 self.analyze_expr(&from_query.source);
776 for clause in &from_query.clauses {
777 match clause {
778 shape_ast::ast::QueryClause::Where(pred) => {
779 self.analyze_expr(pred);
780 }
781 shape_ast::ast::QueryClause::OrderBy(specs) => {
782 for spec in specs {
783 self.analyze_expr(&spec.key);
784 }
785 }
786 shape_ast::ast::QueryClause::GroupBy { element, key, .. } => {
787 self.analyze_expr(element);
788 self.analyze_expr(key);
789 }
790 shape_ast::ast::QueryClause::Join {
791 source,
792 left_key,
793 right_key,
794 ..
795 } => {
796 self.analyze_expr(source);
797 self.analyze_expr(left_key);
798 self.analyze_expr(right_key);
799 }
800 shape_ast::ast::QueryClause::Let { value, .. } => {
801 self.analyze_expr(value);
802 }
803 }
804 }
805 self.analyze_expr(&from_query.select);
806 }
807 Expr::StructLiteral { fields, .. } => {
808 for (_, value_expr) in fields {
809 self.analyze_expr(value_expr);
810 }
811 }
812 Expr::Join(join_expr, _) => {
813 for branch in &join_expr.branches {
814 self.analyze_expr(&branch.expr);
815 }
816 }
817 Expr::Annotated { target, .. } => {
818 self.analyze_expr(target);
819 }
820 Expr::AsyncLet(async_let, _) => {
821 self.analyze_expr(&async_let.expr);
822 }
823 Expr::AsyncScope(inner, _) => {
824 self.analyze_expr(inner);
825 }
826 Expr::Comptime(stmts, _) => {
827 for stmt in stmts {
828 self.analyze_statement(stmt);
829 }
830 }
831 Expr::ComptimeFor(cf, _) => {
832 self.analyze_expr(&cf.iterable);
833 for stmt in &cf.body {
834 self.analyze_statement(stmt);
835 }
836 }
837 Expr::Reference { expr: inner, .. } => {
838 self.analyze_expr(inner);
839 }
840 }
841 }
842
843 fn analyze_pattern(&mut self, pattern: &shape_ast::ast::Pattern) {
845 use shape_ast::ast::Pattern;
846
847 match pattern {
848 Pattern::Identifier(name) => {
849 self.define_variable(name);
850 }
851 Pattern::Typed { name, .. } => {
852 self.define_variable(name);
853 }
854 Pattern::Wildcard | Pattern::Literal(_) => {
855 }
857 Pattern::Array(patterns) => {
858 for p in patterns {
859 self.analyze_pattern(p);
860 }
861 }
862 Pattern::Object(fields) => {
863 for (_, p) in fields {
864 self.analyze_pattern(p);
865 }
866 }
867 Pattern::Constructor { fields, .. } => match fields {
868 shape_ast::ast::PatternConstructorFields::Unit => {}
869 shape_ast::ast::PatternConstructorFields::Tuple(patterns) => {
870 for p in patterns {
871 self.analyze_pattern(p);
872 }
873 }
874 shape_ast::ast::PatternConstructorFields::Struct(fields) => {
875 for (_, p) in fields {
876 self.analyze_pattern(p);
877 }
878 }
879 },
880 }
881 }
882}