1use super::action::Action;
2use super::error::OperationError;
3use super::result::{ApplyResult, PreviewResult};
4use crate::constraint::{ConstraintSet, Expression};
5use crate::structure::{NodeId, NodeKind, Reference, StructureAst};
6
7#[derive(Debug, Clone)]
11pub struct AstEngine {
12 pub structure: StructureAst,
14 pub constraints: ConstraintSet,
16}
17
18impl AstEngine {
19 #[must_use]
21 pub fn new() -> Self {
22 Self {
23 structure: StructureAst::new(),
24 constraints: ConstraintSet::new(),
25 }
26 }
27
28 pub fn apply(&mut self, action: &Action) -> Result<ApplyResult, OperationError> {
33 match action {
34 Action::FillHole { target, fill } => self.fill_hole(*target, fill),
35 Action::AddConstraint { target, constraint } => {
36 self.add_constraint_op(*target, constraint)
37 }
38 Action::RemoveConstraint { constraint_id } => self.remove_constraint_op(*constraint_id),
39 Action::ReplaceNode {
40 target,
41 replacement,
42 } => self.replace_node(*target, replacement),
43 Action::AddSlotElement {
44 parent,
45 slot_name,
46 element,
47 } => self.add_slot_element(*parent, slot_name, element),
48 Action::RemoveSlotElement {
49 parent,
50 slot_name,
51 child,
52 } => self.remove_slot_element(*parent, slot_name, *child),
53 Action::IntroduceMultiTestCase {
54 count_var_name,
55 sum_bound,
56 } => self.introduce_multi_test_case(count_var_name, sum_bound.as_ref()),
57 Action::AddSibling { target, element } => self.add_sibling(*target, element),
58 Action::AddChoiceVariant {
59 choice,
60 tag_value,
61 first_element,
62 } => self.add_choice_variant(*choice, tag_value, first_element),
63 }
64 }
65
66 pub fn preview(&self, action: &Action) -> Result<PreviewResult, OperationError> {
75 let mut clone = self.clone();
76 let result = clone.apply(action)?;
77
78 let new_holes_created = result
80 .created_nodes
81 .iter()
82 .copied()
83 .filter(|&id| {
84 clone
85 .structure
86 .get(id)
87 .is_some_and(|n| matches!(n.kind(), NodeKind::Hole { .. }))
88 })
89 .collect();
90
91 let mut constraints_affected = result.created_constraints;
93 constraints_affected.extend(result.affected_constraints);
94
95 Ok(PreviewResult {
96 new_holes_created,
97 constraints_affected,
98 })
99 }
100}
101
102impl Default for AstEngine {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl AstEngine {
109 pub(crate) fn resolve_structure_references(&mut self, node_id: NodeId) {
114 let Some(node) = self.structure.get(node_id) else {
115 return;
116 };
117 let kind = node.kind().clone();
118 match kind {
119 NodeKind::Array { name, mut length } => {
120 Self::resolve_expr_refs(&self.structure, node_id, &mut length);
121 if let Some(n) = self.structure.get_mut(node_id) {
122 n.set_kind(NodeKind::Array { name, length });
123 }
124 }
125 NodeKind::Matrix {
126 name,
127 mut rows,
128 mut cols,
129 } => {
130 Self::resolve_ref(&self.structure, node_id, &mut rows);
131 Self::resolve_ref(&self.structure, node_id, &mut cols);
132 if let Some(n) = self.structure.get_mut(node_id) {
133 n.set_kind(NodeKind::Matrix { name, rows, cols });
134 }
135 }
136 NodeKind::Repeat {
137 mut count,
138 index_var,
139 body,
140 } => {
141 Self::resolve_expr_refs(&self.structure, node_id, &mut count);
142 if let Some(n) = self.structure.get_mut(node_id) {
143 n.set_kind(NodeKind::Repeat {
144 count,
145 index_var,
146 body,
147 });
148 }
149 }
150 _ => {}
151 }
152 }
153
154 fn resolve_ref(structure: &StructureAst, _owner: NodeId, reference: &mut Reference) {
156 if let Reference::Unresolved(name) = reference {
157 if let Some(target_id) = Self::find_node_by_name_static(structure, name.as_str()) {
158 *reference = Reference::VariableRef(target_id);
159 }
160 }
161 }
162
163 fn resolve_expr_refs(structure: &StructureAst, owner: NodeId, expr: &mut Expression) {
165 match expr {
166 Expression::Var(reference) => {
167 Self::resolve_ref(structure, owner, reference);
168 }
169 Expression::BinOp { lhs, rhs, .. } => {
170 Self::resolve_expr_refs(structure, owner, lhs);
171 Self::resolve_expr_refs(structure, owner, rhs);
172 }
173 Expression::Pow { base, exp } => {
174 Self::resolve_expr_refs(structure, owner, base);
175 Self::resolve_expr_refs(structure, owner, exp);
176 }
177 Expression::FnCall { args, .. } => {
178 for arg in args {
179 Self::resolve_expr_refs(structure, owner, arg);
180 }
181 }
182 Expression::Lit(_) => {}
183 }
184 }
185
186 fn find_node_by_name_static(structure: &StructureAst, name: &str) -> Option<NodeId> {
188 for node in structure.iter() {
189 let node_name = match node.kind() {
190 NodeKind::Scalar { name }
191 | NodeKind::Array { name, .. }
192 | NodeKind::Matrix { name, .. } => Some(name.as_str()),
193 _ => None,
194 };
195 if node_name == Some(name) {
196 return Some(node.id());
197 }
198 }
199 None
200 }
201}