use super::action::Action;
use super::error::OperationError;
use super::result::{ApplyResult, PreviewResult};
use crate::constraint::{ConstraintSet, Expression};
use crate::structure::{NodeId, NodeKind, Reference, StructureAst};
#[derive(Debug, Clone)]
pub struct AstEngine {
pub structure: StructureAst,
pub constraints: ConstraintSet,
}
impl AstEngine {
#[must_use]
pub fn new() -> Self {
Self {
structure: StructureAst::new(),
constraints: ConstraintSet::new(),
}
}
pub fn apply(&mut self, action: &Action) -> Result<ApplyResult, OperationError> {
match action {
Action::FillHole { target, fill } => self.fill_hole(*target, fill),
Action::AddConstraint { target, constraint } => {
self.add_constraint_op(*target, constraint)
}
Action::RemoveConstraint { constraint_id } => self.remove_constraint_op(*constraint_id),
Action::ReplaceNode {
target,
replacement,
} => self.replace_node(*target, replacement),
Action::AddSlotElement {
parent,
slot_name,
element,
} => self.add_slot_element(*parent, slot_name, element),
Action::RemoveSlotElement {
parent,
slot_name,
child,
} => self.remove_slot_element(*parent, slot_name, *child),
Action::IntroduceMultiTestCase {
count_var_name,
sum_bound,
} => self.introduce_multi_test_case(count_var_name, sum_bound.as_ref()),
Action::AddSibling { target, element } => self.add_sibling(*target, element),
Action::AddChoiceVariant {
choice,
tag_value,
first_element,
} => self.add_choice_variant(*choice, tag_value, first_element),
}
}
pub fn preview(&self, action: &Action) -> Result<PreviewResult, OperationError> {
let mut clone = self.clone();
let result = clone.apply(action)?;
let new_holes_created = result
.created_nodes
.iter()
.copied()
.filter(|&id| {
clone
.structure
.get(id)
.is_some_and(|n| matches!(n.kind(), NodeKind::Hole { .. }))
})
.collect();
let mut constraints_affected = result.created_constraints;
constraints_affected.extend(result.affected_constraints);
Ok(PreviewResult {
new_holes_created,
constraints_affected,
})
}
}
impl Default for AstEngine {
fn default() -> Self {
Self::new()
}
}
impl AstEngine {
pub(crate) fn resolve_structure_references(&mut self, node_id: NodeId) {
let Some(node) = self.structure.get(node_id) else {
return;
};
let kind = node.kind().clone();
match kind {
NodeKind::Array { name, mut length } => {
Self::resolve_expr_refs(&self.structure, node_id, &mut length);
if let Some(n) = self.structure.get_mut(node_id) {
n.set_kind(NodeKind::Array { name, length });
}
}
NodeKind::Matrix {
name,
mut rows,
mut cols,
} => {
Self::resolve_ref(&self.structure, node_id, &mut rows);
Self::resolve_ref(&self.structure, node_id, &mut cols);
if let Some(n) = self.structure.get_mut(node_id) {
n.set_kind(NodeKind::Matrix { name, rows, cols });
}
}
NodeKind::Repeat {
mut count,
index_var,
body,
} => {
Self::resolve_expr_refs(&self.structure, node_id, &mut count);
if let Some(n) = self.structure.get_mut(node_id) {
n.set_kind(NodeKind::Repeat {
count,
index_var,
body,
});
}
}
_ => {}
}
}
fn resolve_ref(structure: &StructureAst, _owner: NodeId, reference: &mut Reference) {
if let Reference::Unresolved(name) = reference {
if let Some(target_id) = Self::find_node_by_name_static(structure, name.as_str()) {
*reference = Reference::VariableRef(target_id);
}
}
}
fn resolve_expr_refs(structure: &StructureAst, owner: NodeId, expr: &mut Expression) {
match expr {
Expression::Var(reference) => {
Self::resolve_ref(structure, owner, reference);
}
Expression::BinOp { lhs, rhs, .. } => {
Self::resolve_expr_refs(structure, owner, lhs);
Self::resolve_expr_refs(structure, owner, rhs);
}
Expression::Pow { base, exp } => {
Self::resolve_expr_refs(structure, owner, base);
Self::resolve_expr_refs(structure, owner, exp);
}
Expression::FnCall { args, .. } => {
for arg in args {
Self::resolve_expr_refs(structure, owner, arg);
}
}
Expression::Lit(_) => {}
}
}
fn find_node_by_name_static(structure: &StructureAst, name: &str) -> Option<NodeId> {
for node in structure.iter() {
let node_name = match node.kind() {
NodeKind::Scalar { name }
| NodeKind::Array { name, .. }
| NodeKind::Matrix { name, .. } => Some(name.as_str()),
_ => None,
};
if node_name == Some(name) {
return Some(node.id());
}
}
None
}
}