use super::engine::AstEngine;
use super::error::OperationError;
use super::fill_hole::parse_length_expr;
use super::result::ApplyResult;
use super::types::{ConstraintDef, ConstraintDefKind, VarType};
use crate::constraint::Expression;
use crate::constraint::{Constraint, ConstraintId, ExpectedType};
use crate::structure::{NodeId, NodeKind, Reference, StructureAst};
impl AstEngine {
pub(crate) fn add_constraint_op(
&mut self,
target: NodeId,
constraint_def: &ConstraintDef,
) -> Result<ApplyResult, OperationError> {
if !self.structure.contains(target) {
return Err(OperationError::NodeNotFound { node: target });
}
let mut constraint = convert_def_to_constraint(target, &constraint_def.kind);
resolve_constraint_references(&self.structure, &mut constraint);
let cid = self.constraints.add(Some(target), constraint);
Ok(ApplyResult {
created_nodes: vec![],
removed_nodes: vec![],
created_constraints: vec![cid],
affected_constraints: vec![],
})
}
pub(crate) fn remove_constraint_op(
&mut self,
constraint_id: ConstraintId,
) -> Result<ApplyResult, OperationError> {
if self.constraints.get(constraint_id).is_none() {
return Err(OperationError::InvalidOperation {
action: "RemoveConstraint".to_owned(),
reason: format!("Constraint {constraint_id:?} not found"),
});
}
self.constraints.remove(constraint_id);
Ok(ApplyResult {
created_nodes: vec![],
removed_nodes: vec![],
created_constraints: vec![],
affected_constraints: vec![constraint_id],
})
}
}
fn convert_def_to_constraint(target: NodeId, kind: &ConstraintDefKind) -> Constraint {
let target_ref = Reference::VariableRef(target);
match kind {
ConstraintDefKind::Range { lower, upper } => Constraint::Range {
target: target_ref,
lower: parse_expression(lower),
upper: parse_expression(upper),
},
ConstraintDefKind::TypeDecl { typ } => Constraint::TypeDecl {
target: target_ref,
expected: match typ {
VarType::Int => ExpectedType::Int,
VarType::Str => ExpectedType::Str,
VarType::Char => ExpectedType::Char,
},
},
ConstraintDefKind::Relation { op, rhs } => Constraint::Relation {
lhs: Expression::Var(target_ref),
op: *op,
rhs: parse_expression(rhs),
},
ConstraintDefKind::Distinct => Constraint::Distinct {
elements: target_ref,
unit: crate::constraint::DistinctUnit::Element,
},
ConstraintDefKind::Sorted { order } => Constraint::Sorted {
elements: target_ref,
order: *order,
},
ConstraintDefKind::Property { tag } => Constraint::Property {
target: target_ref,
tag: crate::constraint::PropertyTag::Custom(tag.clone()),
},
ConstraintDefKind::SumBound { upper, .. } => Constraint::SumBound {
variable: target_ref,
upper: parse_expression(upper),
},
ConstraintDefKind::CharSet { charset } => Constraint::CharSet {
target: target_ref,
charset: charset.clone(),
},
ConstraintDefKind::StringLength { min, max } => Constraint::StringLength {
target: target_ref,
min: parse_expression(min),
max: parse_expression(max),
},
ConstraintDefKind::Guarantee { description } => Constraint::Guarantee {
description: description.clone(),
predicate: None,
},
}
}
pub(super) fn parse_expression(s: &str) -> Expression {
parse_length_expr(s)
}
fn resolve_constraint_references(structure: &StructureAst, constraint: &mut Constraint) {
match constraint {
Constraint::Range { lower, upper, .. } => {
resolve_expression_references(structure, lower);
resolve_expression_references(structure, upper);
}
Constraint::SumBound { upper, .. } => {
resolve_expression_references(structure, upper);
}
Constraint::Relation { lhs, rhs, .. } => {
resolve_expression_references(structure, lhs);
resolve_expression_references(structure, rhs);
}
Constraint::StringLength { min, max, .. } => {
resolve_expression_references(structure, min);
resolve_expression_references(structure, max);
}
_ => {}
}
}
fn resolve_expression_references(structure: &StructureAst, expr: &mut Expression) {
match expr {
Expression::Var(Reference::Unresolved(name)) => {
if let Some(node_id) = find_node_by_name(structure, name.as_str()) {
*expr = Expression::Var(Reference::VariableRef(node_id));
}
}
Expression::BinOp { lhs, rhs, .. } => {
resolve_expression_references(structure, lhs);
resolve_expression_references(structure, rhs);
}
Expression::Pow { base, exp } => {
resolve_expression_references(structure, base);
resolve_expression_references(structure, exp);
}
Expression::FnCall { args, .. } => {
for arg in args {
resolve_expression_references(structure, arg);
}
}
_ => {}
}
}
fn find_node_by_name(structure: &StructureAst, name: &str) -> Option<NodeId> {
for node in structure.iter() {
let node_name = match node.kind() {
NodeKind::Scalar { name }
| NodeKind::Array { name, .. }
| NodeKind::Matrix { name, .. } => Some(name.as_str()),
_ => None,
};
if node_name == Some(name) {
return Some(node.id());
}
}
None
}