Skip to main content

cp_ast_core/operation/
engine.rs

1use super::action::Action;
2use super::error::OperationError;
3use super::result::{ApplyResult, PreviewResult};
4use crate::constraint::{ConstraintSet, Expression};
5use crate::structure::{NodeId, NodeKind, Reference, StructureAst};
6
7/// The main AST engine that owns both Structure and Constraint data.
8///
9/// Provides `apply()` to execute actions and `preview()` to dry-run them.
10#[derive(Debug, Clone)]
11pub struct AstEngine {
12    /// The structure AST.
13    pub structure: StructureAst,
14    /// The constraint set.
15    pub constraints: ConstraintSet,
16}
17
18impl AstEngine {
19    /// Create a new engine with empty structure and constraints.
20    #[must_use]
21    pub fn new() -> Self {
22        Self {
23            structure: StructureAst::new(),
24            constraints: ConstraintSet::new(),
25        }
26    }
27
28    /// Apply an action to the AST, returning the result or an error.
29    ///
30    /// # Errors
31    /// Returns `OperationError` if the action cannot be applied.
32    pub fn apply(&mut self, action: &Action) -> Result<ApplyResult, OperationError> {
33        match action {
34            Action::FillHole { target, fill } => self.fill_hole(*target, fill),
35            Action::AddConstraint { target, constraint } => {
36                self.add_constraint_op(*target, constraint)
37            }
38            Action::RemoveConstraint { constraint_id } => self.remove_constraint_op(*constraint_id),
39            Action::ReplaceNode {
40                target,
41                replacement,
42            } => self.replace_node(*target, replacement),
43            Action::AddSlotElement {
44                parent,
45                slot_name,
46                element,
47            } => self.add_slot_element(*parent, slot_name, element),
48            Action::RemoveSlotElement {
49                parent,
50                slot_name,
51                child,
52            } => self.remove_slot_element(*parent, slot_name, *child),
53            Action::IntroduceMultiTestCase {
54                count_var_name,
55                sum_bound,
56            } => self.introduce_multi_test_case(count_var_name, sum_bound.as_ref()),
57            Action::AddSibling { target, element } => self.add_sibling(*target, element),
58            Action::AddChoiceVariant {
59                choice,
60                tag_value,
61                first_element,
62            } => self.add_choice_variant(*choice, tag_value, first_element),
63        }
64    }
65
66    /// Preview an action without applying it (dry-run).
67    ///
68    /// Clones `self`, applies the action on the clone, and derives what
69    /// *would* happen — new holes created and constraints affected —
70    /// without mutating the original engine.
71    ///
72    /// # Errors
73    /// Returns `OperationError` if the action is invalid.
74    pub fn preview(&self, action: &Action) -> Result<PreviewResult, OperationError> {
75        let mut clone = self.clone();
76        let result = clone.apply(action)?;
77
78        // Holes created: nodes that were created AND are Hole kind in the clone.
79        let new_holes_created = result
80            .created_nodes
81            .iter()
82            .copied()
83            .filter(|&id| {
84                clone
85                    .structure
86                    .get(id)
87                    .is_some_and(|n| matches!(n.kind(), NodeKind::Hole { .. }))
88            })
89            .collect();
90
91        // Constraints affected: union of created + affected from ApplyResult.
92        let mut constraints_affected = result.created_constraints;
93        constraints_affected.extend(result.affected_constraints);
94
95        Ok(PreviewResult {
96            new_holes_created,
97            constraints_affected,
98        })
99    }
100}
101
102impl Default for AstEngine {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl AstEngine {
109    /// Resolve `Unresolved` variable name references in a structure node's expressions.
110    ///
111    /// Looks up names like "N" in the structure and replaces them with `VariableRef(node_id)`.
112    /// Handles expressions in Array length, Repeat count, and references in Matrix rows/cols.
113    pub(crate) fn resolve_structure_references(&mut self, node_id: NodeId) {
114        let Some(node) = self.structure.get(node_id) else {
115            return;
116        };
117        let kind = node.kind().clone();
118        match kind {
119            NodeKind::Array { name, mut length } => {
120                Self::resolve_expr_refs(&self.structure, node_id, &mut length);
121                if let Some(n) = self.structure.get_mut(node_id) {
122                    n.set_kind(NodeKind::Array { name, length });
123                }
124            }
125            NodeKind::Matrix {
126                name,
127                mut rows,
128                mut cols,
129            } => {
130                Self::resolve_ref(&self.structure, node_id, &mut rows);
131                Self::resolve_ref(&self.structure, node_id, &mut cols);
132                if let Some(n) = self.structure.get_mut(node_id) {
133                    n.set_kind(NodeKind::Matrix { name, rows, cols });
134                }
135            }
136            NodeKind::Repeat {
137                mut count,
138                index_var,
139                body,
140            } => {
141                Self::resolve_expr_refs(&self.structure, node_id, &mut count);
142                if let Some(n) = self.structure.get_mut(node_id) {
143                    n.set_kind(NodeKind::Repeat {
144                        count,
145                        index_var,
146                        body,
147                    });
148                }
149            }
150            _ => {}
151        }
152    }
153
154    /// Resolve Unresolved names in a `Reference` against the structure.
155    fn resolve_ref(structure: &StructureAst, _owner: NodeId, reference: &mut Reference) {
156        if let Reference::Unresolved(name) = reference {
157            if let Some(target_id) = Self::find_node_by_name_static(structure, name.as_str()) {
158                *reference = Reference::VariableRef(target_id);
159            }
160        }
161    }
162
163    /// Resolve Unresolved names in an `Expression` against the structure.
164    fn resolve_expr_refs(structure: &StructureAst, owner: NodeId, expr: &mut Expression) {
165        match expr {
166            Expression::Var(reference) => {
167                Self::resolve_ref(structure, owner, reference);
168            }
169            Expression::BinOp { lhs, rhs, .. } => {
170                Self::resolve_expr_refs(structure, owner, lhs);
171                Self::resolve_expr_refs(structure, owner, rhs);
172            }
173            Expression::Pow { base, exp } => {
174                Self::resolve_expr_refs(structure, owner, base);
175                Self::resolve_expr_refs(structure, owner, exp);
176            }
177            Expression::FnCall { args, .. } => {
178                for arg in args {
179                    Self::resolve_expr_refs(structure, owner, arg);
180                }
181            }
182            Expression::Lit(_) => {}
183        }
184    }
185
186    /// Find a structure node by its variable name.
187    fn find_node_by_name_static(structure: &StructureAst, name: &str) -> Option<NodeId> {
188        for node in structure.iter() {
189            let node_name = match node.kind() {
190                NodeKind::Scalar { name }
191                | NodeKind::Array { name, .. }
192                | NodeKind::Matrix { name, .. } => Some(name.as_str()),
193                _ => None,
194            };
195            if node_name == Some(name) {
196                return Some(node.id());
197            }
198        }
199        None
200    }
201}