use crate::planning::semantics::{
ArithmeticComputation, ComparisonComputation, DataPath, Expression, ExpressionKind,
LiteralValue, SemanticConversionTarget, ValueKind,
};
use crate::{Error, OperationResult};
use serde::ser::{Serialize, SerializeStruct, Serializer};
use std::fmt;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub enum Constraint {
True,
False,
Comparison {
data: DataPath,
op: ComparisonComputation,
value: Arc<LiteralValue>,
},
Data(DataPath),
And(Box<Constraint>, Box<Constraint>),
Or(Box<Constraint>, Box<Constraint>),
Not(Box<Constraint>),
}
impl Constraint {
pub fn is_true(&self) -> bool {
matches!(self, Constraint::True)
}
pub fn is_false(&self) -> bool {
matches!(self, Constraint::False)
}
pub fn and(self, other: Constraint) -> Constraint {
if self.is_false() || other.is_false() {
return Constraint::False;
}
if self.is_true() {
return other;
}
if other.is_true() {
return self;
}
Constraint::And(Box::new(self), Box::new(other))
}
pub fn or(self, other: Constraint) -> Constraint {
if self.is_true() || other.is_true() {
return Constraint::True;
}
if self.is_false() {
return other;
}
if other.is_false() {
return self;
}
Constraint::Or(Box::new(self), Box::new(other))
}
pub fn not(self) -> Constraint {
match self {
Constraint::True => Constraint::False,
Constraint::False => Constraint::True,
Constraint::Not(inner) => *inner,
other => Constraint::Not(Box::new(other)),
}
}
pub fn simplify(self) -> Result<Constraint, Error> {
let mut atoms: Vec<Constraint> = Vec::new();
if let Some(bexpr) = to_bool_expr(&self, &mut atoms) {
const MAX_ATOMS: usize = 64;
if atoms.len() <= MAX_ATOMS {
let theory = build_numeric_theory_closure(&atoms)?;
let combined = boolean_expression::Expr::and(bexpr, theory);
let simplified = combined.simplify_via_bdd();
return Ok(from_bool_expr(&simplified, &atoms));
}
}
Ok(self)
}
pub fn from_expression(expr: &Expression) -> Result<Constraint, Error> {
use ExpressionKind;
enum WorkItem {
Process(usize),
BuildAnd,
ApplyNot,
}
let mut expr_pool: Vec<Expression> = Vec::new();
let mut work_stack: Vec<WorkItem> = Vec::new();
let mut constraint_stack: Vec<Constraint> = Vec::new();
let root_idx = expr_pool.len();
expr_pool.push(expr.clone());
work_stack.push(WorkItem::Process(root_idx));
while let Some(work) = work_stack.pop() {
match work {
WorkItem::Process(expr_idx) => {
let current_expr = &expr_pool[expr_idx];
let expr_kind = current_expr.kind.clone();
let expr_source = current_expr.source_location.clone();
match expr_kind {
ExpressionKind::Literal(lit) => match &lit.value {
ValueKind::Boolean(bool_val) => {
if *bool_val {
constraint_stack.push(Constraint::True);
} else {
constraint_stack.push(Constraint::False);
}
}
_ => {
return Err(Error::validation(
"Constraint expression must be boolean",
expr_source.clone(),
None::<String>,
));
}
},
ExpressionKind::DataPath(data_path) => {
constraint_stack.push(Constraint::Data(data_path.clone()));
}
ExpressionKind::Comparison(left, op, right) => {
match Self::from_comparison(&left, &op, &right) {
Ok(comparison_constraint) => {
constraint_stack.push(comparison_constraint);
}
Err(e) => return Err(e),
}
}
ExpressionKind::LogicalAnd(left, right) => {
let left_idx = expr_pool.len();
expr_pool.push((*left).clone());
let right_idx = expr_pool.len();
expr_pool.push((*right).clone());
work_stack.push(WorkItem::BuildAnd);
work_stack.push(WorkItem::Process(right_idx));
work_stack.push(WorkItem::Process(left_idx));
}
ExpressionKind::LogicalNegation(inner, _) => {
let inner_idx = expr_pool.len();
expr_pool.push((*inner).clone());
work_stack.push(WorkItem::ApplyNot);
work_stack.push(WorkItem::Process(inner_idx));
}
other => {
return Err(Error::validation(
format!(
"Cannot convert expression kind to constraint: {:?}",
std::mem::discriminant(&other)
),
current_expr.source_location.clone(),
None::<String>,
));
}
}
}
WorkItem::BuildAnd => {
let right = constraint_stack
.pop()
.expect("Internal error: missing right constraint for And");
let left = constraint_stack
.pop()
.expect("Internal error: missing left constraint for And");
constraint_stack.push(left.and(right));
}
WorkItem::ApplyNot => {
let inner = constraint_stack
.pop()
.expect("Internal error: missing constraint for Not");
constraint_stack.push(inner.not());
}
}
}
Ok(constraint_stack
.pop()
.expect("Internal error: no constraint result from expression conversion"))
}
fn from_comparison(
left: &Expression,
op: &ComparisonComputation,
right: &Expression,
) -> Result<Constraint, Error> {
use ExpressionKind;
fn inversion_err_for(
left: &Expression,
message: impl Into<String>,
suggestion: Option<String>,
) -> Error {
Error::inversion(message, left.source_location.clone(), suggestion)
}
if let ExpressionKind::DataPath(data_path) = &left.kind {
if let ExpressionKind::Literal(value) = &right.kind {
return Ok(Constraint::Comparison {
data: data_path.clone(),
op: op.clone(),
value: Arc::new(value.as_ref().clone()),
});
}
}
if let ExpressionKind::Literal(value) = &left.kind {
if let ExpressionKind::DataPath(data_path) = &right.kind {
let flipped_op = flip_comparison_operator(op);
return Ok(Constraint::Comparison {
data: data_path.clone(),
op: flipped_op,
value: Arc::new(value.as_ref().clone()),
});
}
}
if let ExpressionKind::Literal(left_val) = &left.kind {
if let ExpressionKind::Literal(right_val) = &right.kind {
if let Some(result) = evaluate_literal_comparison(left_val, op, right_val) {
return Ok(if result {
Constraint::True
} else {
Constraint::False
});
}
}
}
if op.is_equal() || op.is_not_equal() {
if let ExpressionKind::Comparison(inner_left, inner_op, inner_right) = &left.kind {
if let ExpressionKind::Literal(lit) = &right.kind {
if let ValueKind::Boolean(bool_val) = &lit.value {
let inner_constraint =
Self::from_comparison(inner_left, inner_op, inner_right)?;
let should_negate = if op.is_equal() { !*bool_val } else { *bool_val };
if should_negate {
return Ok(inner_constraint.not());
}
return Ok(inner_constraint);
}
}
}
if let ExpressionKind::Literal(lit) = &left.kind {
if let ValueKind::Boolean(bool_val) = &lit.value {
if let ExpressionKind::Comparison(inner_left, inner_op, inner_right) =
&right.kind
{
let inner_constraint =
Self::from_comparison(inner_left, inner_op, inner_right)?;
let should_negate = if op.is_equal() { !*bool_val } else { *bool_val };
if should_negate {
return Ok(inner_constraint.not());
}
return Ok(inner_constraint);
}
}
}
}
if matches!(&left.kind, ExpressionKind::Veto(_))
|| matches!(&right.kind, ExpressionKind::Veto(_))
{
return Ok(if op.is_not_equal() {
Constraint::True
} else {
Constraint::False
});
}
if let Some(rewritten) = try_rewrite_comparison_to_atomic(left, op, right) {
return Ok(rewritten);
}
Err(inversion_err_for(
left,
format!(
"Cannot invert condition yet: unsupported comparison shape: {:?} {:?} {:?}",
left, op, right
),
Some(
"Try rewriting the unless condition into a simple comparison between a single data and a literal (e.g. x > 10)."
.to_string(),
),
))
}
pub fn collect_data(&self) -> Vec<DataPath> {
let mut paths = Vec::new();
let mut stack = vec![self];
while let Some(constraint) = stack.pop() {
match constraint {
Constraint::True | Constraint::False => {}
Constraint::Comparison { data, .. } => {
paths.push(data.clone());
}
Constraint::Data(data_path) => {
paths.push(data_path.clone());
}
Constraint::And(left, right) | Constraint::Or(left, right) => {
stack.push(left.as_ref());
stack.push(right.as_ref());
}
Constraint::Not(inner) => {
stack.push(inner.as_ref());
}
}
}
paths.sort_by_key(|a| a.to_string());
paths.dedup();
paths
}
}
fn try_rewrite_comparison_to_atomic(
left: &Expression,
op: &ComparisonComputation,
right: &Expression,
) -> Option<Constraint> {
use ExpressionKind;
let left = constant_fold_expression(left).unwrap_or_else(|| left.clone());
let right = constant_fold_expression(right).unwrap_or_else(|| right.clone());
let (left, right) = match (&left.kind, &right.kind) {
(ExpressionKind::UnitConversion(inner, target), ExpressionKind::Literal(_)) => {
if is_monotone_unit_conversion_target(target) {
((**inner).clone(), right.clone())
} else {
(left.clone(), right.clone())
}
}
(ExpressionKind::Literal(_), ExpressionKind::UnitConversion(inner, target)) => {
if is_monotone_unit_conversion_target(target) {
(left.clone(), (**inner).clone())
} else {
(left.clone(), right.clone())
}
}
_ => (left.clone(), right.clone()),
};
let (expr, mut op_norm, lit) = match (&left.kind, &right.kind) {
(ExpressionKind::Literal(l), _) => {
let flipped = flip_comparison_operator(op);
(right.clone(), flipped, l.as_ref().clone())
}
(_, ExpressionKind::Literal(r)) => (left.clone(), op.clone(), r.as_ref().clone()),
_ => return None,
};
let mut data = Vec::new();
collect_data_paths(&expr, &mut data);
data.sort_by_key(|fp| fp.to_string());
data.dedup();
if data.len() != 1 {
return None;
}
let data = data[0].clone();
let (new_op, new_value) = isolate_linear_comparison(&expr, &data, &op_norm, &lit)?;
op_norm = new_op;
Some(Constraint::Comparison {
data,
op: op_norm,
value: Arc::new(new_value),
})
}
fn is_monotone_unit_conversion_target(target: &SemanticConversionTarget) -> bool {
matches!(
target,
SemanticConversionTarget::Duration(_)
| SemanticConversionTarget::ScaleUnit(_)
| SemanticConversionTarget::RatioUnit(_)
)
}
fn collect_data_paths(expr: &Expression, out: &mut Vec<DataPath>) {
use ExpressionKind;
let mut stack: Vec<&Expression> = vec![expr];
while let Some(e) = stack.pop() {
match &e.kind {
ExpressionKind::DataPath(fp) => out.push(fp.clone()),
ExpressionKind::Arithmetic(l, _, r)
| ExpressionKind::Comparison(l, _, r)
| ExpressionKind::LogicalAnd(l, r) => {
stack.push(l.as_ref());
stack.push(r.as_ref());
}
ExpressionKind::LogicalNegation(inner, _)
| ExpressionKind::UnitConversion(inner, _)
| ExpressionKind::MathematicalComputation(_, inner) => {
stack.push(inner.as_ref());
}
ExpressionKind::DateRelative(_, date_expr, tolerance) => {
stack.push(date_expr.as_ref());
if let Some(tol) = tolerance {
stack.push(tol.as_ref());
}
}
ExpressionKind::DateCalendar(_, _, date_expr) => {
stack.push(date_expr.as_ref());
}
ExpressionKind::Literal(_)
| ExpressionKind::Veto(_)
| ExpressionKind::RulePath(_)
| ExpressionKind::Now => {}
}
}
}
fn contains_data(expr: &Expression, target: &DataPath) -> bool {
let mut paths = Vec::new();
collect_data_paths(expr, &mut paths);
paths.iter().any(|fp| fp == target)
}
fn constant_fold_expression(expr: &Expression) -> Option<Expression> {
use ExpressionKind;
match &expr.kind {
ExpressionKind::Literal(_) => Some(expr.clone()),
ExpressionKind::DataPath(_) => None,
ExpressionKind::UnitConversion(inner, target) => {
let folded_inner = constant_fold_expression(inner)?;
if let ExpressionKind::Literal(lit) = &folded_inner.kind {
match crate::computation::convert_unit(lit.as_ref(), target) {
OperationResult::Value(v) => Some(Expression::with_source(
ExpressionKind::Literal(Box::new(v.as_ref().clone())),
expr.source_location.clone(),
)),
_ => None,
}
} else {
None
}
}
ExpressionKind::Arithmetic(left, op, right) => {
let left_folded = constant_fold_expression(left)?;
let right_folded = constant_fold_expression(right)?;
match (&left_folded.kind, &right_folded.kind) {
(ExpressionKind::Literal(l), ExpressionKind::Literal(r)) => {
match crate::computation::arithmetic_operation(l.as_ref(), op, r.as_ref()) {
OperationResult::Value(v) => Some(Expression::with_source(
ExpressionKind::Literal(Box::new(v.as_ref().clone())),
expr.source_location.clone(),
)),
_ => None,
}
}
_ => None,
}
}
_ => None,
}
}
fn flip_inequality(op: &ComparisonComputation) -> ComparisonComputation {
match op {
ComparisonComputation::GreaterThan => ComparisonComputation::LessThan,
ComparisonComputation::GreaterThanOrEqual => ComparisonComputation::LessThanOrEqual,
ComparisonComputation::LessThan => ComparisonComputation::GreaterThan,
ComparisonComputation::LessThanOrEqual => ComparisonComputation::GreaterThanOrEqual,
ComparisonComputation::Is => op.clone(),
ComparisonComputation::IsNot => op.clone(),
}
}
fn isolate_linear_comparison(
expr: &Expression,
unknown: &DataPath,
op: &ComparisonComputation,
bound: &LiteralValue,
) -> Option<(ComparisonComputation, LiteralValue)> {
use ExpressionKind;
match &expr.kind {
ExpressionKind::DataPath(fp) if fp == unknown => Some((op.clone(), bound.clone())),
ExpressionKind::UnitConversion(inner, target)
if is_monotone_unit_conversion_target(target) =>
{
isolate_linear_comparison(inner, unknown, op, bound)
}
ExpressionKind::Arithmetic(left, arithmetic_op, right) => {
let left_contains = contains_data(left, unknown);
let right_contains = contains_data(right, unknown);
if left_contains && right_contains {
return None;
}
if left_contains {
let right_lit = constant_fold_expression(right)?;
let ExpressionKind::Literal(c) = right_lit.kind else {
return None;
};
isolate_through_arithmetic_left(left, arithmetic_op, c.as_ref(), op, bound, unknown)
} else if right_contains {
let left_lit = constant_fold_expression(left)?;
let ExpressionKind::Literal(c) = left_lit.kind else {
return None;
};
isolate_through_arithmetic_right(
right,
arithmetic_op,
c.as_ref(),
op,
bound,
unknown,
)
} else {
None
}
}
_ => None,
}
}
fn isolate_through_arithmetic_left(
inner_with_unknown: &Expression,
operation: &ArithmeticComputation,
constant: &LiteralValue,
op: &ComparisonComputation,
bound: &LiteralValue,
unknown: &DataPath,
) -> Option<(ComparisonComputation, LiteralValue)> {
match operation {
ArithmeticComputation::Add => {
let new_bound = lit_sub(bound, constant)?;
isolate_linear_comparison(inner_with_unknown, unknown, op, &new_bound)
}
ArithmeticComputation::Subtract => {
let new_bound = lit_add(bound, constant)?;
isolate_linear_comparison(inner_with_unknown, unknown, op, &new_bound)
}
ArithmeticComputation::Multiply => {
let c = constant_as_number(constant)?;
if c.is_zero() {
return None;
}
let mut new_op = op.clone();
if c.is_sign_negative() && !op.is_equal() && !op.is_not_equal() {
new_op = flip_inequality(&new_op);
}
let new_bound = lit_div_number(bound, c)?;
isolate_linear_comparison(inner_with_unknown, unknown, &new_op, &new_bound)
}
ArithmeticComputation::Divide => {
let c = constant_as_number(constant)?;
if c.is_zero() {
return None;
}
let mut new_op = op.clone();
if c.is_sign_negative() && !op.is_equal() && !op.is_not_equal() {
new_op = flip_inequality(&new_op);
}
let new_bound = lit_mul_number(bound, c)?;
isolate_linear_comparison(inner_with_unknown, unknown, &new_op, &new_bound)
}
_ => None,
}
}
fn isolate_through_arithmetic_right(
inner_with_unknown: &Expression,
operation: &ArithmeticComputation,
constant: &LiteralValue,
op: &ComparisonComputation,
bound: &LiteralValue,
unknown: &DataPath,
) -> Option<(ComparisonComputation, LiteralValue)> {
match operation {
ArithmeticComputation::Add => {
let new_bound = lit_sub(bound, constant)?;
isolate_linear_comparison(inner_with_unknown, unknown, op, &new_bound)
}
ArithmeticComputation::Subtract => {
let new_bound = lit_sub(constant, bound)?;
let new_op = if op.is_equal() || op.is_not_equal() {
op.clone()
} else {
flip_inequality(op)
};
isolate_linear_comparison(inner_with_unknown, unknown, &new_op, &new_bound)
}
ArithmeticComputation::Multiply => {
let c = constant_as_number(constant)?;
if c.is_zero() {
return None;
}
let mut new_op = op.clone();
if c.is_sign_negative() && !op.is_equal() && !op.is_not_equal() {
new_op = flip_inequality(&new_op);
}
let new_bound = lit_div_number(bound, c)?;
isolate_linear_comparison(inner_with_unknown, unknown, &new_op, &new_bound)
}
ArithmeticComputation::Divide => None,
_ => None,
}
}
fn constant_as_number(lit: &LiteralValue) -> Option<rust_decimal::Decimal> {
match &lit.value {
ValueKind::Number(n) => Some(*n),
_ => None,
}
}
fn lit_add(a: &LiteralValue, b: &LiteralValue) -> Option<LiteralValue> {
match (&a.value, &b.value) {
(ValueKind::Number(la), ValueKind::Number(lb)) => Some(LiteralValue::number_with_type(
*la + *lb,
a.lemma_type.clone(),
)),
(ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub))
if a.lemma_type == b.lemma_type && lua == lub =>
{
Some(LiteralValue::scale_with_type(
*la + *lb,
lua.clone(),
a.lemma_type.clone(),
))
}
(ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub))
if a.lemma_type == b.lemma_type && lua == lub =>
{
Some(LiteralValue::duration_with_type(
*la + *lb,
lua.clone(),
a.lemma_type.clone(),
))
}
_ => None,
}
}
fn lit_sub(a: &LiteralValue, b: &LiteralValue) -> Option<LiteralValue> {
match (&a.value, &b.value) {
(ValueKind::Number(la), ValueKind::Number(lb)) => Some(LiteralValue::number_with_type(
*la - *lb,
a.lemma_type.clone(),
)),
(ValueKind::Scale(la, lua), ValueKind::Scale(lb, lub))
if a.lemma_type == b.lemma_type && lua == lub =>
{
Some(LiteralValue::scale_with_type(
*la - *lb,
lua.clone(),
a.lemma_type.clone(),
))
}
(ValueKind::Duration(la, lua), ValueKind::Duration(lb, lub))
if a.lemma_type == b.lemma_type && lua == lub =>
{
Some(LiteralValue::duration_with_type(
*la - *lb,
lua.clone(),
a.lemma_type.clone(),
))
}
_ => None,
}
}
fn lit_mul_number(a: &LiteralValue, c: rust_decimal::Decimal) -> Option<LiteralValue> {
match &a.value {
ValueKind::Number(n) => Some(LiteralValue::number_with_type(*n * c, a.lemma_type.clone())),
ValueKind::Scale(n, u) => Some(LiteralValue::scale_with_type(
*n * c,
u.clone(),
a.lemma_type.clone(),
)),
ValueKind::Duration(n, u) => Some(LiteralValue::duration_with_type(
*n * c,
u.clone(),
a.lemma_type.clone(),
)),
_ => None,
}
}
fn lit_div_number(a: &LiteralValue, c: rust_decimal::Decimal) -> Option<LiteralValue> {
if c.is_zero() {
return None;
}
match &a.value {
ValueKind::Number(n) => Some(LiteralValue::number_with_type(*n / c, a.lemma_type.clone())),
ValueKind::Scale(n, u) => Some(LiteralValue::scale_with_type(
*n / c,
u.clone(),
a.lemma_type.clone(),
)),
ValueKind::Duration(n, u) => Some(LiteralValue::duration_with_type(
*n / c,
u.clone(),
a.lemma_type.clone(),
)),
_ => None,
}
}
fn build_numeric_theory_closure(
atoms: &[Constraint],
) -> Result<boolean_expression::Expr<usize>, Error> {
use boolean_expression::Expr;
let mut by_data: std::collections::HashMap<DataPath, Vec<usize>> =
std::collections::HashMap::new();
for (idx, atom) in atoms.iter().enumerate() {
if let Constraint::Comparison { data, .. } = atom {
by_data.entry(data.clone()).or_default().push(idx);
}
}
let mut theory = Expr::Const(true);
for idxs in by_data.values() {
for i in 0..idxs.len() {
for j in (i + 1)..idxs.len() {
let a_idx = idxs[i];
let b_idx = idxs[j];
let a = atoms.get(a_idx).unwrap();
let b = atoms.get(b_idx).unwrap();
let (a_dom, b_dom) = match (a, b) {
(
Constraint::Comparison {
op: a_op,
value: a_val,
..
},
Constraint::Comparison {
op: b_op,
value: b_val,
..
},
) => (
crate::inversion::domain::domain_for_comparison_atom(a_op, a_val.as_ref())?,
crate::inversion::domain::domain_for_comparison_atom(b_op, b_val.as_ref())?,
),
_ => continue,
};
if a_dom.is_subset_of(&b_dom) {
theory = Expr::and(
theory,
Expr::or(Expr::not(Expr::Terminal(a_idx)), Expr::Terminal(b_idx)),
);
}
if b_dom.is_subset_of(&a_dom) {
theory = Expr::and(
theory,
Expr::or(Expr::not(Expr::Terminal(b_idx)), Expr::Terminal(a_idx)),
);
}
if a_dom.intersect(&b_dom).is_empty() {
theory = Expr::and(
theory,
Expr::not(Expr::and(Expr::Terminal(a_idx), Expr::Terminal(b_idx))),
);
}
}
}
}
Ok(theory)
}
fn evaluate_literal_comparison(
left: &LiteralValue,
op: &ComparisonComputation,
right: &LiteralValue,
) -> Option<bool> {
match (&left.value, &right.value) {
(ValueKind::Text(l), ValueKind::Text(r)) => {
if op.is_equal() {
Some(l == r)
} else if op.is_not_equal() {
Some(l != r)
} else {
None
}
}
(ValueKind::Boolean(l), ValueKind::Boolean(r)) => {
if op.is_equal() {
Some(l == r)
} else if op.is_not_equal() {
Some(l != r)
} else {
None
}
}
(ValueKind::Number(l), ValueKind::Number(r)) => match op {
ComparisonComputation::Is => Some(l == r),
ComparisonComputation::IsNot => Some(l != r),
ComparisonComputation::LessThan => Some(l < r),
ComparisonComputation::LessThanOrEqual => Some(l <= r),
ComparisonComputation::GreaterThan => Some(l > r),
ComparisonComputation::GreaterThanOrEqual => Some(l >= r),
},
(ValueKind::Ratio(l, _), ValueKind::Ratio(r, _)) => match op {
ComparisonComputation::Is => Some(l == r),
ComparisonComputation::IsNot => Some(l != r),
ComparisonComputation::LessThan => Some(l < r),
ComparisonComputation::LessThanOrEqual => Some(l <= r),
ComparisonComputation::GreaterThan => Some(l > r),
ComparisonComputation::GreaterThanOrEqual => Some(l >= r),
},
_ => None,
}
}
fn flip_comparison_operator(op: &ComparisonComputation) -> ComparisonComputation {
match op {
ComparisonComputation::Is => ComparisonComputation::Is,
ComparisonComputation::IsNot => ComparisonComputation::IsNot,
ComparisonComputation::LessThan => ComparisonComputation::GreaterThan,
ComparisonComputation::LessThanOrEqual => ComparisonComputation::GreaterThanOrEqual,
ComparisonComputation::GreaterThan => ComparisonComputation::LessThan,
ComparisonComputation::GreaterThanOrEqual => ComparisonComputation::LessThanOrEqual,
}
}
fn find_or_add_atom(constraint: &Constraint, atoms: &mut Vec<Constraint>) -> usize {
for (i, atom) in atoms.iter().enumerate() {
if constraints_structurally_equal(atom, constraint) {
return i;
}
}
atoms.push(constraint.clone());
atoms.len() - 1
}
fn to_bool_expr(
constraint: &Constraint,
atoms: &mut Vec<Constraint>,
) -> Option<boolean_expression::Expr<usize>> {
use boolean_expression::Expr;
enum WorkItem {
Visit(Box<Constraint>),
BuildAnd,
BuildOr,
BuildNot,
}
let mut stack = vec![WorkItem::Visit(Box::new(constraint.clone()))];
let mut expr_stack: Vec<Expr<usize>> = Vec::new();
while let Some(work) = stack.pop() {
match work {
WorkItem::Visit(c) => match c.as_ref() {
Constraint::True => expr_stack.push(Expr::Const(true)),
Constraint::False => expr_stack.push(Expr::Const(false)),
Constraint::And(left, right) => {
stack.push(WorkItem::BuildAnd);
stack.push(WorkItem::Visit(right.clone()));
stack.push(WorkItem::Visit(left.clone()));
}
Constraint::Or(left, right) => {
stack.push(WorkItem::BuildOr);
stack.push(WorkItem::Visit(right.clone()));
stack.push(WorkItem::Visit(left.clone()));
}
Constraint::Not(inner) => {
stack.push(WorkItem::BuildNot);
stack.push(WorkItem::Visit(inner.clone()));
}
Constraint::Comparison { .. } | Constraint::Data(_) => {
let idx = find_or_add_atom(c.as_ref(), atoms);
expr_stack.push(Expr::Terminal(idx));
}
},
WorkItem::BuildAnd => {
let right = expr_stack.pop()?;
let left = expr_stack.pop()?;
expr_stack.push(Expr::and(left, right));
}
WorkItem::BuildOr => {
let right = expr_stack.pop()?;
let left = expr_stack.pop()?;
expr_stack.push(Expr::or(left, right));
}
WorkItem::BuildNot => {
let inner = expr_stack.pop()?;
expr_stack.push(Expr::not(inner));
}
}
}
expr_stack.pop()
}
fn constraints_structurally_equal(a: &Constraint, b: &Constraint) -> bool {
match (a, b) {
(Constraint::True, Constraint::True) => true,
(Constraint::False, Constraint::False) => true,
(
Constraint::Comparison {
data: f1,
op: o1,
value: v1,
},
Constraint::Comparison {
data: f2,
op: o2,
value: v2,
},
) => f1 == f2 && o1 == o2 && v1 == v2,
(Constraint::Data(f1), Constraint::Data(f2)) => f1 == f2,
_ => false,
}
}
fn from_bool_expr(bool_expr: &boolean_expression::Expr<usize>, atoms: &[Constraint]) -> Constraint {
use boolean_expression::Expr;
enum Work {
Process(Expr<usize>),
CombineAnd,
CombineOr,
ApplyNot,
}
let mut work_stack = vec![Work::Process(bool_expr.clone())];
let mut constraint_stack: Vec<Constraint> = Vec::new();
while let Some(work) = work_stack.pop() {
match work {
Work::Process(expr) => match expr {
Expr::Const(true) => constraint_stack.push(Constraint::True),
Expr::Const(false) => constraint_stack.push(Constraint::False),
Expr::Terminal(i) => {
constraint_stack.push(atoms.get(i).cloned().unwrap_or_else(|| {
unreachable!(
"BUG: bool_expr terminal index {} out of bounds (atoms len {})",
i,
atoms.len()
)
}));
}
Expr::Not(inner) => {
work_stack.push(Work::ApplyNot);
work_stack.push(Work::Process(*inner));
}
Expr::And(left, right) => {
work_stack.push(Work::CombineAnd);
work_stack.push(Work::Process(*right));
work_stack.push(Work::Process(*left));
}
Expr::Or(left, right) => {
work_stack.push(Work::CombineOr);
work_stack.push(Work::Process(*right));
work_stack.push(Work::Process(*left));
}
},
Work::CombineAnd => {
let right = constraint_stack.pop().unwrap_or_else(|| {
unreachable!("BUG: constraint stack underflow in CombineAnd (right)")
});
let left = constraint_stack.pop().unwrap_or_else(|| {
unreachable!("BUG: constraint stack underflow in CombineAnd (left)")
});
constraint_stack.push(left.and(right));
}
Work::CombineOr => {
let right = constraint_stack.pop().unwrap_or_else(|| {
unreachable!("BUG: constraint stack underflow in CombineOr (right)")
});
let left = constraint_stack.pop().unwrap_or_else(|| {
unreachable!("BUG: constraint stack underflow in CombineOr (left)")
});
constraint_stack.push(left.or(right));
}
Work::ApplyNot => {
let inner = constraint_stack
.pop()
.unwrap_or_else(|| unreachable!("BUG: constraint stack underflow in ApplyNot"));
constraint_stack.push(inner.not());
}
}
}
constraint_stack
.pop()
.unwrap_or_else(|| unreachable!("BUG: constraint stack empty after evaluation"))
}
impl fmt::Display for Constraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Constraint::True => write!(f, "true"),
Constraint::False => write!(f, "false"),
Constraint::Comparison { data, op, value } => {
write!(f, "{} {} {}", data, op, value)
}
Constraint::Data(data_path) => write!(f, "{}", data_path),
Constraint::And(left, right) => {
let left_str = format_with_parens(left, self);
let right_str = format_with_parens(right, self);
write!(f, "{} and {}", left_str, right_str)
}
Constraint::Or(left, right) => {
let left_str = format_with_parens(left, self);
let right_str = format_with_parens(right, self);
write!(f, "{} or {}", left_str, right_str)
}
Constraint::Not(inner) => match inner.as_ref() {
Constraint::And(_, _) | Constraint::Or(_, _) => {
write!(f, "not ({})", inner)
}
_ => write!(f, "not {}", inner),
},
}
}
}
fn format_with_parens(inner: &Constraint, parent: &Constraint) -> String {
let needs_parens = matches!(
(parent, inner),
(Constraint::And(_, _), Constraint::Or(_, _))
);
if needs_parens {
format!("({})", inner)
} else {
inner.to_string()
}
}
impl Serialize for Constraint {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Constraint::True => {
let mut state = serializer.serialize_struct("Constraint", 1)?;
state.serialize_field("type", "true")?;
state.end()
}
Constraint::False => {
let mut state = serializer.serialize_struct("Constraint", 1)?;
state.serialize_field("type", "false")?;
state.end()
}
Constraint::Comparison { data, op, value } => {
let mut state = serializer.serialize_struct("Constraint", 4)?;
state.serialize_field("type", "comparison")?;
state.serialize_field("data", &data.to_string())?;
state.serialize_field("op", &op.to_string())?;
state.serialize_field("value", value)?;
state.end()
}
Constraint::Data(data_path) => {
let mut state = serializer.serialize_struct("Constraint", 2)?;
state.serialize_field("type", "data")?;
state.serialize_field("data", &data_path.to_string())?;
state.end()
}
Constraint::And(left, right) => {
let mut state = serializer.serialize_struct("Constraint", 3)?;
state.serialize_field("type", "and")?;
state.serialize_field("left", left)?;
state.serialize_field("right", right)?;
state.end()
}
Constraint::Or(left, right) => {
let mut state = serializer.serialize_struct("Constraint", 3)?;
state.serialize_field("type", "or")?;
state.serialize_field("left", left)?;
state.serialize_field("right", right)?;
state.end()
}
Constraint::Not(inner) => {
let mut state = serializer.serialize_struct("Constraint", 2)?;
state.serialize_field("type", "not")?;
state.serialize_field("inner", inner)?;
state.end()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal::Decimal;
fn num(n: i64) -> LiteralValue {
LiteralValue::number(Decimal::from(n))
}
fn data(name: &str) -> DataPath {
DataPath::new(vec![], name.to_string())
}
fn comparison(data_name: &str, op: ComparisonComputation, val: i64) -> Constraint {
Constraint::Comparison {
data: data(data_name),
op,
value: Arc::new(num(val)),
}
}
#[test]
fn test_constraint_and_short_circuit() {
let c1 = Constraint::True;
let c2 = Constraint::Data(data("x"));
assert!(matches!(c1.and(c2.clone()), Constraint::Data(_)));
let c3 = Constraint::False;
assert!(matches!(c3.and(c2), Constraint::False));
}
#[test]
fn test_constraint_or_short_circuit() {
let c1 = Constraint::False;
let c2 = Constraint::Data(data("x"));
assert!(matches!(c1.or(c2.clone()), Constraint::Data(_)));
let c3 = Constraint::True;
assert!(matches!(c3.or(c2), Constraint::True));
}
#[test]
fn test_constraint_not_double_negation() {
let c = Constraint::Data(data("x"));
let not_c = c.clone().not();
let not_not_c = not_c.not();
assert_eq!(c, not_not_c);
}
#[test]
fn test_constraint_display_simple() {
let c = Constraint::Comparison {
data: data("age"),
op: ComparisonComputation::GreaterThan,
value: Arc::new(num(18)),
};
assert_eq!(c.to_string(), "age > 18");
}
#[test]
fn test_constraint_display_and() {
let c1 = Constraint::Comparison {
data: data("age"),
op: ComparisonComputation::GreaterThan,
value: Arc::new(num(18)),
};
let c2 = Constraint::Data(data("is_employee"));
let combined = Constraint::And(Box::new(c1), Box::new(c2.not()));
assert_eq!(combined.to_string(), "age > 18 and not is_employee");
}
#[test]
fn test_collect_data() {
let c = Constraint::And(
Box::new(Constraint::Comparison {
data: data("age"),
op: ComparisonComputation::GreaterThan,
value: Arc::new(num(18)),
}),
Box::new(Constraint::Data(data("is_employee"))),
);
let data = c.collect_data();
assert_eq!(data.len(), 2);
}
#[test]
fn test_simplify_tautology() {
let a = comparison("x", ComparisonComputation::GreaterThan, 10);
let b = Constraint::Data(data("flag"));
let expr = a.clone().and(b.clone()).or(a.clone().and(b.not()));
let simplified = expr.simplify().unwrap();
assert_eq!(simplified.to_string(), "x > 10");
}
#[test]
fn test_simplify_contradiction() {
let c1 = comparison("x", ComparisonComputation::Is, 1);
let c2 = comparison("x", ComparisonComputation::Is, 2);
let expr = c1.and(c2);
let simplified = expr.simplify().unwrap();
assert!(
simplified.is_false(),
"Expected contradiction to simplify to false, got: {}",
simplified
);
}
#[test]
fn test_simplify_detects_ordering_implication_contradiction() {
let a = comparison("x", ComparisonComputation::GreaterThan, 5);
let b = comparison("x", ComparisonComputation::GreaterThan, 3);
let expr = a.and(b.not());
let simplified = expr.simplify().unwrap();
assert!(
simplified.is_false(),
"Expected contradiction to simplify to false, got: {}",
simplified
);
}
#[test]
fn test_simplify_detects_neq_contradiction() {
let eq = comparison("x", ComparisonComputation::Is, 5);
let neq = comparison("x", ComparisonComputation::IsNot, 5);
let simplified = eq.and(neq).simplify().unwrap();
assert!(
simplified.is_false(),
"Expected contradiction to simplify to false, got: {}",
simplified
);
}
}