1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use shape_ast::ast::{FunctionDef, VarKind};
7use shape_value::ValueWord;
8
9#[derive(Debug, Clone)]
11pub struct Closure {
12 pub function: Arc<FunctionDef>,
14 pub captured_env: CapturedEnvironment,
16}
17
18impl PartialEq for Closure {
19 fn eq(&self, other: &Self) -> bool {
20 Arc::ptr_eq(&self.function, &other.function) && self.captured_env == other.captured_env
23 }
24}
25
26#[derive(Debug, Clone)]
28pub struct CapturedEnvironment {
29 pub bindings: HashMap<String, CapturedBinding>,
31 pub parent: Option<Box<CapturedEnvironment>>,
33}
34
35#[derive(Debug, Clone)]
37pub struct CapturedBinding {
38 pub value: ValueWord,
40 pub kind: VarKind,
42 pub is_mutable: bool,
44}
45
46impl PartialEq for CapturedBinding {
47 fn eq(&self, other: &Self) -> bool {
48 self.kind == other.kind
49 && self.is_mutable == other.is_mutable
50 && self.value.vw_equals(&other.value)
51 }
52}
53
54impl PartialEq for CapturedEnvironment {
55 fn eq(&self, other: &Self) -> bool {
56 self.bindings == other.bindings && self.parent == other.parent
57 }
58}
59
60impl Default for CapturedEnvironment {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl CapturedEnvironment {
67 pub fn new() -> Self {
69 Self {
70 bindings: HashMap::new(),
71 parent: None,
72 }
73 }
74
75 pub fn with_parent(parent: CapturedEnvironment) -> Self {
77 Self {
78 bindings: HashMap::new(),
79 parent: Some(Box::new(parent)),
80 }
81 }
82
83 pub fn capture(&mut self, name: String, value: ValueWord, kind: VarKind) {
85 let is_mutable = matches!(kind, VarKind::Var);
86 self.bindings.insert(
87 name,
88 CapturedBinding {
89 value,
90 kind,
91 is_mutable,
92 },
93 );
94 }
95
96 pub fn lookup(&self, name: &str) -> Option<&CapturedBinding> {
98 self.bindings
99 .get(name)
100 .or_else(|| self.parent.as_ref().and_then(|p| p.lookup(name)))
101 }
102
103 pub fn lookup_mut(&mut self, name: &str) -> Option<&mut CapturedBinding> {
105 if self.bindings.contains_key(name) {
106 self.bindings.get_mut(name)
107 } else if let Some(parent) = &mut self.parent {
108 parent.lookup_mut(name)
109 } else {
110 None
111 }
112 }
113
114 pub fn all_captured_names(&self) -> Vec<String> {
116 let mut names: Vec<String> = self.bindings.keys().cloned().collect();
117
118 if let Some(parent) = &self.parent {
119 for name in parent.all_captured_names() {
120 if !names.contains(&name) {
121 names.push(name);
122 }
123 }
124 }
125
126 names
127 }
128}
129
130pub struct EnvironmentAnalyzer {
132 scope_stack: Vec<HashMap<String, bool>>, captured_vars: HashMap<String, usize>, mutated_captures: HashSet<String>,
138 function_scope_level: usize,
143}
144
145impl Default for EnvironmentAnalyzer {
146 fn default() -> Self {
147 Self {
148 scope_stack: vec![HashMap::new()],
149 captured_vars: HashMap::new(),
150 mutated_captures: HashSet::new(),
151 function_scope_level: 1,
152 }
153 }
154}
155
156impl EnvironmentAnalyzer {
157 pub fn new() -> Self {
158 Self {
159 scope_stack: vec![HashMap::new()], captured_vars: HashMap::new(),
161 mutated_captures: HashSet::new(),
162 function_scope_level: 1, }
164 }
165
166 pub fn enter_scope(&mut self) {
168 self.scope_stack.push(HashMap::new());
169 }
170
171 pub fn exit_scope(&mut self) {
173 self.scope_stack.pop();
174 }
175
176 pub fn define_variable(&mut self, name: &str) {
178 if let Some(current_scope) = self.scope_stack.last_mut() {
179 current_scope.insert(name.to_string(), true);
180 }
181 }
182
183 pub fn check_variable_reference(&mut self, name: &str) {
189 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
191 if scope.contains_key(name) {
192 if level < self.function_scope_level {
195 self.captured_vars.insert(name.to_string(), level);
196 }
197 return;
198 }
199 }
200 }
201
202 pub fn mark_capture_mutated(&mut self, name: &str) {
204 for (level, scope) in self.scope_stack.iter().enumerate().rev() {
206 if scope.contains_key(name) {
207 if level < self.function_scope_level {
208 self.captured_vars.insert(name.to_string(), level);
209 self.mutated_captures.insert(name.to_string());
210 }
211 return;
212 }
213 }
214 }
215
216 pub fn get_captured_vars(&self) -> Vec<String> {
218 self.captured_vars.keys().cloned().collect()
219 }
220
221 pub fn get_mutated_captures(&self) -> HashSet<String> {
223 self.mutated_captures.clone()
224 }
225
226 pub fn analyze_function(function: &FunctionDef, outer_scope_vars: &[String]) -> Vec<String> {
228 let mut analyzer = Self::new();
229
230 for var in outer_scope_vars {
232 analyzer.define_variable(var);
233 }
234
235 analyzer.enter_scope();
237 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
240
241 for param in &function.params {
243 for name in param.get_identifiers() {
244 analyzer.define_variable(&name);
245 }
246 }
247
248 for stmt in &function.body {
250 analyzer.analyze_statement(stmt);
251 }
252
253 analyzer.get_captured_vars()
254 }
255
256 pub fn analyze_function_with_mutability(
259 function: &FunctionDef,
260 outer_scope_vars: &[String],
261 ) -> (Vec<String>, HashSet<String>) {
262 let mut analyzer = Self::new();
263
264 for var in outer_scope_vars {
266 analyzer.define_variable(var);
267 }
268
269 analyzer.enter_scope();
271 analyzer.function_scope_level = analyzer.scope_stack.len() - 1;
272
273 for param in &function.params {
275 for name in param.get_identifiers() {
276 analyzer.define_variable(&name);
277 }
278 }
279
280 for stmt in &function.body {
282 analyzer.analyze_statement(stmt);
283 }
284
285 (
286 analyzer.get_captured_vars(),
287 analyzer.get_mutated_captures(),
288 )
289 }
290
291 fn analyze_statement(&mut self, stmt: &shape_ast::ast::Statement) {
293 use shape_ast::ast::Statement;
294
295 match stmt {
296 Statement::Return(expr, _) => {
297 if let Some(expr) = expr {
298 self.analyze_expr(expr);
299 }
300 }
301 Statement::Expression(expr, _) => {
302 self.analyze_expr(expr);
303 }
304 Statement::VariableDecl(decl, _) => {
305 if let Some(value) = &decl.value {
307 self.analyze_expr(value);
308 }
309 if let Some(name) = decl.pattern.as_identifier() {
311 self.define_variable(name);
312 } else {
313 for name in decl.pattern.get_identifiers() {
315 self.define_variable(&name);
316 }
317 }
318 }
319 Statement::Assignment(assign, _) => {
320 self.analyze_expr(&assign.value);
321 if let Some(name) = assign.pattern.as_identifier() {
322 self.mark_capture_mutated(name);
324 self.check_variable_reference(name);
325 } else {
326 for name in assign.pattern.get_identifiers() {
328 self.mark_capture_mutated(&name);
329 self.check_variable_reference(&name);
330 }
331 }
332 }
333 Statement::If(if_stmt, _) => {
334 self.analyze_expr(&if_stmt.condition);
335 self.enter_scope();
336 for stmt in &if_stmt.then_body {
337 self.analyze_statement(stmt);
338 }
339 self.exit_scope();
340
341 if let Some(else_body) = &if_stmt.else_body {
342 self.enter_scope();
343 for stmt in else_body {
344 self.analyze_statement(stmt);
345 }
346 self.exit_scope();
347 }
348 }
349 Statement::While(while_loop, _) => {
350 self.analyze_expr(&while_loop.condition);
351 self.enter_scope();
352 for stmt in &while_loop.body {
353 self.analyze_statement(stmt);
354 }
355 self.exit_scope();
356 }
357 Statement::For(for_loop, _) => {
358 self.enter_scope();
359
360 match &for_loop.init {
362 shape_ast::ast::ForInit::ForIn { pattern, iter } => {
363 self.analyze_expr(iter);
364 for name in pattern.get_identifiers() {
366 self.define_variable(&name);
367 }
368 }
369 shape_ast::ast::ForInit::ForC {
370 init: _,
371 condition,
372 update,
373 } => {
374 self.analyze_expr(condition);
377 self.analyze_expr(update);
378 }
379 }
380
381 for stmt in &for_loop.body {
382 self.analyze_statement(stmt);
383 }
384
385 self.exit_scope();
386 }
387 Statement::Break(_) | Statement::Continue(_) => {
388 }
390 Statement::Extend(ext, _) => {
391 for method in &ext.methods {
392 self.enter_scope();
393 self.define_variable("self");
394 for param in &method.params {
395 for name in param.get_identifiers() {
396 self.define_variable(&name);
397 }
398 }
399 for stmt in &method.body {
400 self.analyze_statement(stmt);
401 }
402 self.exit_scope();
403 }
404 }
405 Statement::RemoveTarget(_) => {}
406 Statement::SetParamType { .. }
407 | Statement::SetReturnType { .. }
408 | Statement::SetReturnExpr { .. } => {}
409 Statement::ReplaceModuleExpr { expression, .. } => {
410 self.analyze_expr(expression);
411 }
412 Statement::ReplaceBodyExpr { expression, .. } => {
413 self.analyze_expr(expression);
414 }
415 Statement::ReplaceBody { body, .. } => {
416 for stmt in body {
417 self.analyze_statement(stmt);
418 }
419 }
420 }
421 }
422
423 fn analyze_expr(&mut self, expr: &shape_ast::ast::Expr) {
425 use shape_ast::ast::Expr;
426
427 match expr {
428 Expr::Identifier(name, _) => {
429 self.check_variable_reference(name);
430 }
431 Expr::Literal(..)
432 | Expr::DataRef(..)
433 | Expr::DataDateTimeRef(..)
434 | Expr::TimeRef(..)
435 | Expr::PatternRef(..) => {
436 }
438 Expr::DataRelativeAccess {
439 reference,
440 index: _,
441 ..
442 } => {
443 self.analyze_expr(reference);
444 }
446 Expr::BinaryOp { left, right, .. } => {
447 self.analyze_expr(left);
448 self.analyze_expr(right);
449 }
450 Expr::FuzzyComparison { left, right, .. } => {
451 self.analyze_expr(left);
452 self.analyze_expr(right);
453 }
454 Expr::UnaryOp { operand, .. } => {
455 self.analyze_expr(operand);
456 }
457 Expr::FunctionCall { name, args, .. } => {
458 self.check_variable_reference(name);
461 for arg in args {
462 self.analyze_expr(arg);
463 }
464 }
465 Expr::EnumConstructor { payload, .. } => {
466 use shape_ast::ast::EnumConstructorPayload;
467 match payload {
468 EnumConstructorPayload::Unit => {}
469 EnumConstructorPayload::Tuple(values) => {
470 for value in values {
471 self.analyze_expr(value);
472 }
473 }
474 EnumConstructorPayload::Struct(fields) => {
475 for (_, value) in fields {
476 self.analyze_expr(value);
477 }
478 }
479 }
480 }
481 Expr::PropertyAccess { object, .. } => {
482 self.analyze_expr(object);
483 }
484 Expr::Conditional {
485 condition,
486 then_expr,
487 else_expr,
488 ..
489 } => {
490 self.analyze_expr(condition);
491 self.analyze_expr(then_expr);
492 if let Some(else_e) = else_expr {
493 self.analyze_expr(else_e);
494 }
495 }
496 Expr::Array(elements, _) => {
497 for elem in elements {
498 self.analyze_expr(elem);
499 }
500 }
501 Expr::TableRows(rows, _) => {
502 for row in rows {
503 for elem in row {
504 self.analyze_expr(elem);
505 }
506 }
507 }
508 Expr::ListComprehension(comp, _) => {
509 self.enter_scope();
511
512 for clause in &comp.clauses {
514 for name in clause.pattern.get_identifiers() {
516 self.define_variable(&name);
517 }
518
519 self.analyze_expr(&clause.iterable);
521
522 if let Some(filter) = &clause.filter {
524 self.analyze_expr(filter);
525 }
526 }
527
528 self.analyze_expr(&comp.element);
530
531 self.exit_scope();
532 }
533 Expr::Object(entries, _) => {
534 use shape_ast::ast::ObjectEntry;
535 for entry in entries {
536 match entry {
537 ObjectEntry::Field { value, .. } => self.analyze_expr(value),
538 ObjectEntry::Spread(spread_expr) => self.analyze_expr(spread_expr),
539 }
540 }
541 }
542 Expr::IndexAccess {
543 object,
544 index,
545 end_index,
546 ..
547 } => {
548 self.analyze_expr(object);
549 self.analyze_expr(index);
550 if let Some(end) = end_index {
551 self.analyze_expr(end);
552 }
553 }
554 Expr::Block(block, _) => {
555 self.enter_scope();
556 for item in &block.items {
557 match item {
558 shape_ast::ast::BlockItem::VariableDecl(decl) => {
559 if let Some(value) = &decl.value {
560 self.analyze_expr(value);
561 }
562 if let Some(name) = decl.pattern.as_identifier() {
563 self.define_variable(name);
564 }
565 }
566 shape_ast::ast::BlockItem::Assignment(assign) => {
567 self.analyze_expr(&assign.value);
568 if let Some(name) = assign.pattern.as_identifier() {
569 self.mark_capture_mutated(name);
570 self.check_variable_reference(name);
571 } else {
572 for name in assign.pattern.get_identifiers() {
573 self.mark_capture_mutated(&name);
574 self.check_variable_reference(&name);
575 }
576 }
577 }
578 shape_ast::ast::BlockItem::Statement(stmt) => {
579 self.analyze_statement(stmt);
580 }
581 shape_ast::ast::BlockItem::Expression(expr) => {
582 self.analyze_expr(expr);
583 }
584 }
585 }
586 self.exit_scope();
587 }
588 Expr::TypeAssertion { expr, .. } => {
589 self.analyze_expr(expr);
590 }
591 Expr::InstanceOf { expr, .. } => {
592 self.analyze_expr(expr);
593 }
594 Expr::FunctionExpr {
595 params,
596 return_type: _,
597 body,
598 ..
599 } => {
600 let saved_function_scope_level = self.function_scope_level;
602 self.enter_scope();
603 self.function_scope_level = self.scope_stack.len() - 1;
604
605 for param in params {
606 for name in param.get_identifiers() {
607 self.define_variable(&name);
608 }
609 }
610
611 for stmt in body {
612 self.analyze_statement(stmt);
613 }
614
615 self.exit_scope();
616 self.function_scope_level = saved_function_scope_level;
617
618 self.captured_vars
622 .retain(|_, level| *level < saved_function_scope_level);
623 self.mutated_captures
624 .retain(|name| self.captured_vars.contains_key(name));
625 }
626 Expr::Duration(..) => {
627 }
629
630 Expr::If(if_expr, _) => {
632 self.analyze_expr(&if_expr.condition);
633 self.analyze_expr(&if_expr.then_branch);
634 if let Some(else_branch) = &if_expr.else_branch {
635 self.analyze_expr(else_branch);
636 }
637 }
638
639 Expr::While(while_expr, _) => {
640 self.analyze_expr(&while_expr.condition);
641 self.analyze_expr(&while_expr.body);
642 }
643
644 Expr::For(for_expr, _) => {
645 self.enter_scope();
646 self.analyze_pattern(&for_expr.pattern);
648 self.analyze_expr(&for_expr.iterable);
649 self.analyze_expr(&for_expr.body);
650 self.exit_scope();
651 }
652
653 Expr::Loop(loop_expr, _) => {
654 self.analyze_expr(&loop_expr.body);
655 }
656
657 Expr::Let(let_expr, _) => {
658 if let Some(value) = &let_expr.value {
659 self.analyze_expr(value);
660 }
661 self.enter_scope();
662 self.analyze_pattern(&let_expr.pattern);
663 self.analyze_expr(&let_expr.body);
664 self.exit_scope();
665 }
666
667 Expr::Assign(assign, _) => {
668 self.analyze_expr(&assign.value);
669 self.analyze_expr(&assign.target);
670 }
671
672 Expr::Break(value, _) => {
673 if let Some(val) = value {
674 self.analyze_expr(val);
675 }
676 }
677
678 Expr::Continue(_) => {
679 }
681
682 Expr::Return(value, _) => {
683 if let Some(val) = value {
684 self.analyze_expr(val);
685 }
686 }
687
688 Expr::MethodCall { receiver, args, .. } => {
689 self.analyze_expr(receiver);
690 for arg in args {
691 self.analyze_expr(arg);
692 }
693 }
694
695 Expr::Match(match_expr, _) => {
696 self.analyze_expr(&match_expr.scrutinee);
697 for arm in &match_expr.arms {
698 self.enter_scope();
699 self.analyze_pattern(&arm.pattern);
700 if let Some(guard) = &arm.guard {
701 self.analyze_expr(guard);
702 }
703 self.analyze_expr(&arm.body);
704 self.exit_scope();
705 }
706 }
707
708 Expr::Unit(_) => {
709 }
711
712 Expr::Spread(inner_expr, _) => {
713 self.analyze_expr(inner_expr);
715 }
716
717 Expr::DateTime(..) => {
718 }
720 Expr::Range { start, end, .. } => {
721 if let Some(s) = start {
723 self.analyze_expr(s);
724 }
725 if let Some(e) = end {
726 self.analyze_expr(e);
727 }
728 }
729
730 Expr::TimeframeContext { expr, .. } => {
731 self.analyze_expr(expr);
733 }
734
735 Expr::TryOperator(inner, _) => {
736 self.analyze_expr(inner);
738 }
739 Expr::UsingImpl { expr, .. } => {
740 self.analyze_expr(expr);
741 }
742
743 Expr::Await(inner, _) => {
744 self.analyze_expr(inner);
746 }
747
748 Expr::SimulationCall { params, .. } => {
749 for (_, value_expr) in params {
752 self.analyze_expr(value_expr);
753 }
754 }
755
756 Expr::WindowExpr(_, _) => {
757 }
759
760 Expr::FromQuery(from_query, _) => {
761 self.analyze_expr(&from_query.source);
763 for clause in &from_query.clauses {
764 match clause {
765 shape_ast::ast::QueryClause::Where(pred) => {
766 self.analyze_expr(pred);
767 }
768 shape_ast::ast::QueryClause::OrderBy(specs) => {
769 for spec in specs {
770 self.analyze_expr(&spec.key);
771 }
772 }
773 shape_ast::ast::QueryClause::GroupBy { element, key, .. } => {
774 self.analyze_expr(element);
775 self.analyze_expr(key);
776 }
777 shape_ast::ast::QueryClause::Join {
778 source,
779 left_key,
780 right_key,
781 ..
782 } => {
783 self.analyze_expr(source);
784 self.analyze_expr(left_key);
785 self.analyze_expr(right_key);
786 }
787 shape_ast::ast::QueryClause::Let { value, .. } => {
788 self.analyze_expr(value);
789 }
790 }
791 }
792 self.analyze_expr(&from_query.select);
793 }
794 Expr::StructLiteral { fields, .. } => {
795 for (_, value_expr) in fields {
796 self.analyze_expr(value_expr);
797 }
798 }
799 Expr::Join(join_expr, _) => {
800 for branch in &join_expr.branches {
801 self.analyze_expr(&branch.expr);
802 }
803 }
804 Expr::Annotated { target, .. } => {
805 self.analyze_expr(target);
806 }
807 Expr::AsyncLet(async_let, _) => {
808 self.analyze_expr(&async_let.expr);
809 }
810 Expr::AsyncScope(inner, _) => {
811 self.analyze_expr(inner);
812 }
813 Expr::Comptime(stmts, _) => {
814 for stmt in stmts {
815 self.analyze_statement(stmt);
816 }
817 }
818 Expr::ComptimeFor(cf, _) => {
819 self.analyze_expr(&cf.iterable);
820 for stmt in &cf.body {
821 self.analyze_statement(stmt);
822 }
823 }
824 Expr::Reference { expr: inner, .. } => {
825 self.analyze_expr(inner);
826 }
827 }
828 }
829
830 fn analyze_pattern(&mut self, pattern: &shape_ast::ast::Pattern) {
832 use shape_ast::ast::Pattern;
833
834 match pattern {
835 Pattern::Identifier(name) => {
836 self.define_variable(name);
837 }
838 Pattern::Typed { name, .. } => {
839 self.define_variable(name);
840 }
841 Pattern::Wildcard | Pattern::Literal(_) => {
842 }
844 Pattern::Array(patterns) => {
845 for p in patterns {
846 self.analyze_pattern(p);
847 }
848 }
849 Pattern::Object(fields) => {
850 for (_, p) in fields {
851 self.analyze_pattern(p);
852 }
853 }
854 Pattern::Constructor { fields, .. } => match fields {
855 shape_ast::ast::PatternConstructorFields::Unit => {}
856 shape_ast::ast::PatternConstructorFields::Tuple(patterns) => {
857 for p in patterns {
858 self.analyze_pattern(p);
859 }
860 }
861 shape_ast::ast::PatternConstructorFields::Struct(fields) => {
862 for (_, p) in fields {
863 self.analyze_pattern(p);
864 }
865 }
866 },
867 }
868 }
869}