use crate::planning::semantics::{
ArithmeticComputation, ComparisonComputation, DataPath, Expression, ExpressionKind,
LiteralValue, MathematicalComputation, NegationType, RulePath, SemanticConversionTarget,
Source,
};
use crate::planning::{ExecutableRule, ExecutionPlan};
use crate::OperationResult;
use serde::ser::{Serialize, SerializeMap, Serializer};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use super::constraint::Constraint;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct World(HashMap<RulePath, usize>);
impl World {
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn get(&self, rule_path: &RulePath) -> Option<&usize> {
self.0.get(rule_path)
}
pub fn insert(&mut self, rule_path: RulePath, branch_idx: usize) -> Option<usize> {
self.0.insert(rule_path, branch_idx)
}
pub fn iter(&self) -> impl Iterator<Item = (&RulePath, &usize)> {
self.0.iter()
}
}
impl Serialize for World {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(self.0.len()))?;
for (k, v) in &self.0 {
map.serialize_entry(&k.to_string(), v)?;
}
map.end()
}
}
#[derive(Debug, Clone)]
pub(super) struct WorldSolution {
pub world: World,
pub constraint: Constraint,
pub outcome: OperationResult,
}
#[derive(Debug, Clone)]
pub(super) struct WorldArithmeticSolution {
pub world: World,
pub constraint: Constraint,
pub outcome_expression: Expression,
}
#[derive(Debug, Clone)]
pub(super) struct EnumerationResult {
pub literal_solutions: Vec<WorldSolution>,
pub arithmetic_solutions: Vec<WorldArithmeticSolution>,
}
pub(super) struct WorldEnumerator<'a> {
plan: &'a ExecutionPlan,
rules_in_order: Vec<RulePath>,
rule_cache: HashMap<RulePath, &'a ExecutableRule>,
}
impl<'a> WorldEnumerator<'a> {
pub(super) fn new(
plan: &'a ExecutionPlan,
target_rule: &RulePath,
) -> Result<Self, crate::Error> {
let rule_map: HashMap<RulePath, &ExecutableRule> =
plan.rules.iter().map(|r| (r.path.clone(), r)).collect();
let dependent_rules = collect_transitive_dependencies(target_rule, &rule_map)?;
let rules_in_order: Vec<RulePath> = plan
.rules
.iter()
.filter(|r| dependent_rules.contains(&r.path))
.map(|r| r.path.clone())
.collect();
let rule_cache: HashMap<RulePath, &ExecutableRule> = rules_in_order
.iter()
.filter_map(|path| rule_map.get(path).map(|r| (path.clone(), *r)))
.collect();
Ok(Self {
plan,
rules_in_order,
rule_cache,
})
}
pub(super) fn enumerate(
&mut self,
provided_data: &HashSet<DataPath>,
) -> Result<EnumerationResult, crate::Error> {
if self.rules_in_order.is_empty() {
return Ok(EnumerationResult {
literal_solutions: vec![],
arithmetic_solutions: vec![],
});
}
let mut current_worlds: Vec<(World, Constraint)> = vec![(World::new(), Constraint::True)];
for rule_path in &self.rules_in_order.clone() {
let rule_node = match self.rule_cache.get(rule_path) {
Some(node) => *node,
None => continue,
};
let mut next_worlds = Vec::new();
for (world, accumulated_constraint) in current_worlds {
for (branch_idx, branch) in rule_node.branches.iter().enumerate() {
let mut new_world = world.clone();
new_world.insert(rule_path.clone(), branch_idx);
let mut branch_constraint = if let Some(ref condition) = branch.condition {
let substituted_condition = substitute_rules_in_expression(
&Arc::new(condition.clone()),
&new_world,
self.plan,
)?;
let hydrated_condition = hydrate_data_in_expression(
&Arc::new(substituted_condition),
self.plan,
provided_data,
)?;
Constraint::from_expression(&hydrated_condition)?
} else {
Constraint::True
};
for later_branch in rule_node.branches.iter().skip(branch_idx + 1) {
if let Some(ref later_condition) = later_branch.condition {
let substituted_later = substitute_rules_in_expression(
&Arc::new(later_condition.clone()),
&new_world,
self.plan,
)?;
let hydrated_later = hydrate_data_in_expression(
&Arc::new(substituted_later),
self.plan,
provided_data,
)?;
let later_constraint = Constraint::from_expression(&hydrated_later)?;
branch_constraint = branch_constraint.and(later_constraint.not());
}
}
let combined_constraint = accumulated_constraint.clone().and(branch_constraint);
let simplified = combined_constraint.simplify()?;
if !simplified.is_false() {
next_worlds.push((new_world, simplified));
}
}
}
current_worlds = next_worlds;
if current_worlds.is_empty() {
break;
}
}
let target_rule_path = self
.rules_in_order
.last()
.unwrap_or_else(|| unreachable!("BUG: no rules in order for world enumeration"));
let mut literal_solutions = Vec::new();
let mut arithmetic_solutions = Vec::new();
for (world, constraint) in current_worlds {
if let Some(&branch_idx) = world.get(target_rule_path) {
if let Some(rule_node) = self.rule_cache.get(target_rule_path) {
if branch_idx < rule_node.branches.len() {
let branch = &rule_node.branches[branch_idx];
let substituted_result = substitute_rules_in_expression(
&Arc::new(branch.result.clone()),
&world,
self.plan,
)?;
let hydrated_result = hydrate_data_in_expression(
&Arc::new(substituted_result),
self.plan,
provided_data,
)?;
let folded_result = try_constant_fold_expression(&hydrated_result)
.unwrap_or(hydrated_result.clone());
if let Some(outcome) = extract_outcome(&folded_result) {
literal_solutions.push(WorldSolution {
world,
constraint,
outcome,
});
} else if is_boolean_expression(&folded_result) {
let (true_solutions, false_solutions) =
create_boolean_expression_solutions(
world,
constraint,
&folded_result,
)?;
literal_solutions.extend(true_solutions);
literal_solutions.extend(false_solutions);
} else if is_arithmetic_expression(&folded_result) {
arithmetic_solutions.push(WorldArithmeticSolution {
world,
constraint,
outcome_expression: folded_result,
});
}
}
}
}
}
Ok(EnumerationResult {
literal_solutions,
arithmetic_solutions,
})
}
}
fn collect_transitive_dependencies(
target_rule: &RulePath,
rule_map: &HashMap<RulePath, &ExecutableRule>,
) -> Result<HashSet<RulePath>, crate::Error> {
let mut result = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(target_rule.clone());
result.insert(target_rule.clone());
while let Some(rule_path) = queue.pop_front() {
if let Some(rule) = rule_map.get(&rule_path) {
let dependencies = extract_rule_dependencies(rule);
for dependency in dependencies {
if result.insert(dependency.clone()) {
queue.push_back(dependency);
}
}
}
}
Ok(result)
}
fn extract_rule_dependencies(rule: &ExecutableRule) -> HashSet<RulePath> {
let mut deps = HashSet::new();
for branch in &rule.branches {
if let Some(ref condition) = branch.condition {
extract_rule_paths_from_expression(condition, &mut deps);
}
extract_rule_paths_from_expression(&branch.result, &mut deps);
}
deps
}
fn extract_rule_paths_from_expression(expr: &Expression, paths: &mut HashSet<RulePath>) {
match &expr.kind {
ExpressionKind::RulePath(rp) => {
paths.insert(rp.clone());
}
ExpressionKind::Arithmetic(left, _, right)
| ExpressionKind::Comparison(left, _, right)
| ExpressionKind::LogicalAnd(left, right) => {
extract_rule_paths_from_expression(left, paths);
extract_rule_paths_from_expression(right, paths);
}
ExpressionKind::LogicalNegation(inner, _)
| ExpressionKind::UnitConversion(inner, _)
| ExpressionKind::MathematicalComputation(_, inner) => {
extract_rule_paths_from_expression(inner, paths);
}
ExpressionKind::DateRelative(_, date_expr, tolerance) => {
extract_rule_paths_from_expression(date_expr, paths);
if let Some(tol) = tolerance {
extract_rule_paths_from_expression(tol, paths);
}
}
ExpressionKind::DateCalendar(_, _, date_expr) => {
extract_rule_paths_from_expression(date_expr, paths);
}
ExpressionKind::Literal(_)
| ExpressionKind::DataPath(_)
| ExpressionKind::Veto(_)
| ExpressionKind::Now => {}
}
}
fn substitute_rules_in_expression(
expr: &Arc<Expression>,
world: &World,
plan: &ExecutionPlan,
) -> Result<Expression, crate::Error> {
enum WorkItem {
Process(usize),
BuildArithmetic(ArithmeticComputation, Option<Source>),
BuildComparison(ComparisonComputation, Option<Source>),
BuildLogicalAnd(Option<Source>),
BuildLogicalNegation(NegationType, Option<Source>),
BuildUnitConversion(SemanticConversionTarget, Option<Source>),
BuildMathematicalComputation(MathematicalComputation, Option<Source>),
PopVisitedRules,
}
let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
let mut work_stack: Vec<WorkItem> = Vec::new();
let mut result_pool: Vec<Expression> = Vec::new();
let mut visited_rules_stack: Vec<HashSet<RulePath>> = vec![HashSet::new()];
let root_idx = expr_pool.len();
expr_pool.push(Arc::clone(expr));
work_stack.push(WorkItem::Process(root_idx));
while let Some(work) = work_stack.pop() {
match work {
WorkItem::Process(expr_idx) => {
let e = &expr_pool[expr_idx];
let source_loc = e.source_location.clone();
match &e.kind {
ExpressionKind::RulePath(rule_path) => {
let visited = visited_rules_stack.last().expect("visited_rules_stack should never be empty when processing RulePath expressions");
if visited.contains(rule_path) {
unreachable!(
"BUG: circular rule reference detected during substitution: {}",
rule_path
);
}
if let Some(&branch_idx) = world.get(rule_path) {
if let Some(rule) = plan.get_rule_by_path(rule_path) {
if branch_idx < rule.branches.len() {
let branch = &rule.branches[branch_idx];
let mut new_visited = visited.clone();
new_visited.insert(rule_path.clone());
visited_rules_stack.push(new_visited);
let sub_expr_idx = expr_pool.len();
expr_pool.push(Arc::new(branch.result.clone()));
work_stack.push(WorkItem::PopVisitedRules);
work_stack.push(WorkItem::Process(sub_expr_idx));
continue;
}
}
}
result_pool.push(Expression::with_source(
ExpressionKind::RulePath(rule_path.clone()),
source_loc,
));
}
ExpressionKind::Arithmetic(left, op, right) => {
let op_clone = op.clone();
let left_arc = Arc::clone(left);
let right_arc = Arc::clone(right);
let left_idx = expr_pool.len();
expr_pool.push(left_arc);
let right_idx = expr_pool.len();
expr_pool.push(right_arc);
work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
work_stack.push(WorkItem::Process(right_idx));
work_stack.push(WorkItem::Process(left_idx));
}
ExpressionKind::Comparison(left, op, right) => {
let op_clone = op.clone();
let left_arc = Arc::clone(left);
let right_arc = Arc::clone(right);
let left_idx = expr_pool.len();
expr_pool.push(left_arc);
let right_idx = expr_pool.len();
expr_pool.push(right_arc);
work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
work_stack.push(WorkItem::Process(right_idx));
work_stack.push(WorkItem::Process(left_idx));
}
ExpressionKind::LogicalAnd(left, right) => {
let left_arc = Arc::clone(left);
let right_arc = Arc::clone(right);
let left_idx = expr_pool.len();
expr_pool.push(left_arc);
let right_idx = expr_pool.len();
expr_pool.push(right_arc);
work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
work_stack.push(WorkItem::Process(right_idx));
work_stack.push(WorkItem::Process(left_idx));
}
ExpressionKind::LogicalNegation(inner, neg_type) => {
let neg_type_clone = neg_type.clone();
let inner_arc = Arc::clone(inner);
let inner_idx = expr_pool.len();
expr_pool.push(inner_arc);
work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
work_stack.push(WorkItem::Process(inner_idx));
}
ExpressionKind::UnitConversion(inner, unit) => {
let unit_clone = unit.clone();
let inner_arc = Arc::clone(inner);
let inner_idx = expr_pool.len();
expr_pool.push(inner_arc);
work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
work_stack.push(WorkItem::Process(inner_idx));
}
ExpressionKind::MathematicalComputation(func, inner) => {
let func_clone = func.clone();
let inner_arc = Arc::clone(inner);
let inner_idx = expr_pool.len();
expr_pool.push(inner_arc);
work_stack.push(WorkItem::BuildMathematicalComputation(
func_clone, source_loc,
));
work_stack.push(WorkItem::Process(inner_idx));
}
ExpressionKind::Literal(lit) => {
result_pool.push(Expression::with_source(
ExpressionKind::Literal(lit.clone()),
source_loc,
));
}
ExpressionKind::DataPath(data_path) => {
result_pool.push(Expression::with_source(
ExpressionKind::DataPath(data_path.clone()),
source_loc,
));
}
ExpressionKind::Veto(veto) => {
result_pool.push(Expression::with_source(
ExpressionKind::Veto(veto.clone()),
source_loc,
));
}
ExpressionKind::Now => {
result_pool.push(Expression::with_source(ExpressionKind::Now, source_loc));
}
ExpressionKind::DateRelative(..) | ExpressionKind::DateCalendar(..) => {
result_pool.push(Expression::with_source(e.kind.clone(), source_loc));
}
}
}
WorkItem::BuildArithmetic(op, source_loc) => {
let right = result_pool.pop().unwrap_or_else(|| {
unreachable!(
"BUG: missing right expression for Arithmetic during inversion hydration"
)
});
let left = result_pool.pop().unwrap_or_else(|| {
unreachable!(
"BUG: missing left expression for Arithmetic during inversion hydration"
)
});
result_pool.push(Expression::with_source(
ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
source_loc,
));
}
WorkItem::BuildComparison(op, source_loc) => {
let right = result_pool.pop().unwrap_or_else(|| {
unreachable!(
"BUG: missing right expression for Comparison during inversion hydration"
)
});
let left = result_pool.pop().unwrap_or_else(|| {
unreachable!(
"BUG: missing left expression for Comparison during inversion hydration"
)
});
result_pool.push(Expression::with_source(
ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
source_loc,
));
}
WorkItem::BuildLogicalAnd(source_loc) => {
let right = result_pool.pop().unwrap_or_else(|| {
unreachable!(
"BUG: missing right expression for LogicalAnd during inversion hydration"
)
});
let left = result_pool.pop().unwrap_or_else(|| {
unreachable!(
"BUG: missing left expression for LogicalAnd during inversion hydration"
)
});
result_pool.push(Expression::with_source(
ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
source_loc,
));
}
WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
let inner = result_pool
.pop()
.expect("Internal error: missing expression for LogicalNegation");
result_pool.push(Expression::with_source(
ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
source_loc,
));
}
WorkItem::BuildUnitConversion(unit, source_loc) => {
let inner = result_pool
.pop()
.expect("Internal error: missing expression for UnitConversion");
result_pool.push(Expression::with_source(
ExpressionKind::UnitConversion(Arc::new(inner), unit),
source_loc,
));
}
WorkItem::BuildMathematicalComputation(func, source_loc) => {
let inner = result_pool
.pop()
.expect("Internal error: missing expression for MathematicalComputation");
result_pool.push(Expression::with_source(
ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
source_loc,
));
}
WorkItem::PopVisitedRules => {
visited_rules_stack.pop();
}
}
}
Ok(result_pool
.pop()
.unwrap_or_else(|| unreachable!("BUG: no result from substitution")))
}
fn hydrate_data_in_expression(
expr: &Arc<Expression>,
plan: &ExecutionPlan,
provided_data: &HashSet<DataPath>,
) -> Result<Expression, crate::Error> {
enum WorkItem {
Process(usize),
BuildArithmetic(ArithmeticComputation, Option<Source>),
BuildComparison(ComparisonComputation, Option<Source>),
BuildLogicalAnd(Option<Source>),
BuildLogicalNegation(NegationType, Option<Source>),
BuildUnitConversion(SemanticConversionTarget, Option<Source>),
BuildMathematicalComputation(MathematicalComputation, Option<Source>),
}
let mut expr_pool: Vec<Arc<Expression>> = Vec::new();
let mut work_stack: Vec<WorkItem> = Vec::new();
let mut result_pool: Vec<Expression> = Vec::new();
let root_idx = expr_pool.len();
expr_pool.push(Arc::clone(expr));
work_stack.push(WorkItem::Process(root_idx));
while let Some(work) = work_stack.pop() {
match work {
WorkItem::Process(expr_idx) => {
let (source_loc, expr_kind_ref) = {
let e = &expr_pool[expr_idx];
(e.source_location.clone(), &e.kind)
};
match expr_kind_ref {
ExpressionKind::DataPath(data_path) => {
if provided_data.contains(data_path) {
if let Some(lit) = plan.data.get(data_path).and_then(|d| d.value()) {
result_pool.push(Expression::with_source(
ExpressionKind::Literal(Box::new(lit.clone())),
source_loc,
));
continue;
}
}
result_pool.push(Expression::with_source(
ExpressionKind::DataPath(data_path.clone()),
source_loc,
));
}
ExpressionKind::Arithmetic(left, op, right) => {
let op_clone = op.clone();
let left_arc = Arc::clone(left);
let right_arc = Arc::clone(right);
let left_idx = expr_pool.len();
expr_pool.push(left_arc);
let right_idx = expr_pool.len();
expr_pool.push(right_arc);
work_stack.push(WorkItem::BuildArithmetic(op_clone, source_loc));
work_stack.push(WorkItem::Process(right_idx));
work_stack.push(WorkItem::Process(left_idx));
}
ExpressionKind::Comparison(left, op, right) => {
let op_clone = op.clone();
let left_arc = Arc::clone(left);
let right_arc = Arc::clone(right);
let left_idx = expr_pool.len();
expr_pool.push(left_arc);
let right_idx = expr_pool.len();
expr_pool.push(right_arc);
work_stack.push(WorkItem::BuildComparison(op_clone, source_loc));
work_stack.push(WorkItem::Process(right_idx));
work_stack.push(WorkItem::Process(left_idx));
}
ExpressionKind::LogicalAnd(left, right) => {
let left_arc = Arc::clone(left);
let right_arc = Arc::clone(right);
let left_idx = expr_pool.len();
expr_pool.push(left_arc);
let right_idx = expr_pool.len();
expr_pool.push(right_arc);
work_stack.push(WorkItem::BuildLogicalAnd(source_loc));
work_stack.push(WorkItem::Process(right_idx));
work_stack.push(WorkItem::Process(left_idx));
}
ExpressionKind::LogicalNegation(inner, neg_type) => {
let neg_type_clone = neg_type.clone();
let inner_arc = Arc::clone(inner);
let inner_idx = expr_pool.len();
expr_pool.push(inner_arc);
work_stack.push(WorkItem::BuildLogicalNegation(neg_type_clone, source_loc));
work_stack.push(WorkItem::Process(inner_idx));
}
ExpressionKind::UnitConversion(inner, unit) => {
let unit_clone = unit.clone();
let inner_arc = Arc::clone(inner);
let inner_idx = expr_pool.len();
expr_pool.push(inner_arc);
work_stack.push(WorkItem::BuildUnitConversion(unit_clone, source_loc));
work_stack.push(WorkItem::Process(inner_idx));
}
ExpressionKind::MathematicalComputation(func, inner) => {
let func_clone = func.clone();
let inner_arc = Arc::clone(inner);
let inner_idx = expr_pool.len();
expr_pool.push(inner_arc);
work_stack.push(WorkItem::BuildMathematicalComputation(
func_clone, source_loc,
));
work_stack.push(WorkItem::Process(inner_idx));
}
ExpressionKind::Literal(lit) => {
result_pool.push(Expression::with_source(
ExpressionKind::Literal(lit.clone()),
source_loc,
));
}
ExpressionKind::RulePath(rule_path) => {
result_pool.push(Expression::with_source(
ExpressionKind::RulePath(rule_path.clone()),
source_loc,
));
}
ExpressionKind::Veto(veto) => {
result_pool.push(Expression::with_source(
ExpressionKind::Veto(veto.clone()),
source_loc,
));
}
ExpressionKind::Now => {
result_pool.push(Expression::with_source(ExpressionKind::Now, source_loc));
}
ExpressionKind::DateRelative(..) | ExpressionKind::DateCalendar(..) => {
result_pool
.push(Expression::with_source(expr_kind_ref.clone(), source_loc));
}
}
}
WorkItem::BuildArithmetic(op, source_loc) => {
let right = result_pool.pop().unwrap_or_else(|| {
unreachable!("BUG: missing right expression for Arithmetic")
});
let left = result_pool
.pop()
.unwrap_or_else(|| unreachable!("BUG: missing left expression for Arithmetic"));
result_pool.push(Expression::with_source(
ExpressionKind::Arithmetic(Arc::new(left), op, Arc::new(right)),
source_loc,
));
}
WorkItem::BuildComparison(op, source_loc) => {
let right = result_pool.pop().unwrap_or_else(|| {
unreachable!("BUG: missing right expression for Comparison")
});
let left = result_pool
.pop()
.unwrap_or_else(|| unreachable!("BUG: missing left expression for Comparison"));
result_pool.push(Expression::with_source(
ExpressionKind::Comparison(Arc::new(left), op, Arc::new(right)),
source_loc,
));
}
WorkItem::BuildLogicalAnd(source_loc) => {
let right = result_pool.pop().unwrap_or_else(|| {
unreachable!("BUG: missing right expression for LogicalAnd")
});
let left = result_pool
.pop()
.unwrap_or_else(|| unreachable!("BUG: missing left expression for LogicalAnd"));
result_pool.push(Expression::with_source(
ExpressionKind::LogicalAnd(Arc::new(left), Arc::new(right)),
source_loc,
));
}
WorkItem::BuildLogicalNegation(neg_type, source_loc) => {
let inner = result_pool
.pop()
.expect("Internal error: missing expression for LogicalNegation");
result_pool.push(Expression::with_source(
ExpressionKind::LogicalNegation(Arc::new(inner), neg_type),
source_loc,
));
}
WorkItem::BuildUnitConversion(unit, source_loc) => {
let inner = result_pool
.pop()
.expect("Internal error: missing expression for UnitConversion");
result_pool.push(Expression::with_source(
ExpressionKind::UnitConversion(Arc::new(inner), unit),
source_loc,
));
}
WorkItem::BuildMathematicalComputation(func, source_loc) => {
let inner = result_pool
.pop()
.expect("Internal error: missing expression for MathematicalComputation");
result_pool.push(Expression::with_source(
ExpressionKind::MathematicalComputation(func, Arc::new(inner)),
source_loc,
));
}
}
}
Ok(result_pool
.pop()
.expect("Internal error: no result from hydration"))
}
fn extract_outcome(expr: &Expression) -> Option<OperationResult> {
match &expr.kind {
ExpressionKind::Literal(lit) => {
Some(OperationResult::Value(Box::new(lit.as_ref().clone())))
}
ExpressionKind::Veto(ve) => Some(OperationResult::Veto(
crate::evaluation::operations::VetoType::UserDefined {
message: ve.message.clone(),
},
)),
_ => None,
}
}
fn is_boolean_expression(expr: &Expression) -> bool {
matches!(
&expr.kind,
ExpressionKind::Comparison(_, _, _)
| ExpressionKind::LogicalAnd(_, _)
| ExpressionKind::LogicalNegation(_, _)
)
}
fn is_arithmetic_expression(expr: &Expression) -> bool {
match &expr.kind {
ExpressionKind::Arithmetic(_, _, _) => true,
ExpressionKind::MathematicalComputation(_, _) => true,
ExpressionKind::UnitConversion(inner, _) => is_arithmetic_expression(inner),
ExpressionKind::DataPath(_) => true, _ => false,
}
}
fn create_boolean_expression_solutions(
world: World,
base_constraint: Constraint,
boolean_expr: &Expression,
) -> Result<(Vec<WorldSolution>, Vec<WorldSolution>), crate::Error> {
let expr_constraint = Constraint::from_expression(boolean_expr)?;
let true_constraint = base_constraint.clone().and(expr_constraint.clone());
let simplified_true = true_constraint.simplify()?;
let true_solutions = if !simplified_true.is_false() {
vec![WorldSolution {
world: world.clone(),
constraint: simplified_true,
outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(true))),
}]
} else {
vec![]
};
let false_constraint = base_constraint.and(expr_constraint.not());
let simplified_false = false_constraint.simplify()?;
let false_solutions = if !simplified_false.is_false() {
vec![WorldSolution {
world,
constraint: simplified_false,
outcome: OperationResult::Value(Box::new(LiteralValue::from_bool(false))),
}]
} else {
vec![]
};
Ok((true_solutions, false_solutions))
}
pub(crate) fn try_constant_fold_expression(expr: &Expression) -> Option<Expression> {
match &expr.kind {
ExpressionKind::Literal(_) => Some(expr.clone()),
ExpressionKind::Arithmetic(left, op, right) => {
let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
(&left_folded.kind, &right_folded.kind)
{
if let Some(result) = evaluate_arithmetic(left_val.as_ref(), op, right_val.as_ref())
{
return Some(Expression::with_source(
ExpressionKind::Literal(Box::new(result)),
expr.source_location.clone(),
));
}
}
Some(Expression::with_source(
ExpressionKind::Arithmetic(
Arc::new(left_folded),
op.clone(),
Arc::new(right_folded),
),
expr.source_location.clone(),
))
}
ExpressionKind::Comparison(left, op, right) => {
let left_folded = try_constant_fold_expression(left).unwrap_or((**left).clone());
let right_folded = try_constant_fold_expression(right).unwrap_or((**right).clone());
if let (ExpressionKind::Literal(ref left_val), ExpressionKind::Literal(ref right_val)) =
(&left_folded.kind, &right_folded.kind)
{
if let Some(result) = evaluate_comparison(left_val.as_ref(), op, right_val.as_ref())
{
return Some(Expression::with_source(
ExpressionKind::Literal(Box::new(LiteralValue::from_bool(result))),
expr.source_location.clone(),
));
}
}
Some(Expression::with_source(
ExpressionKind::Comparison(
Arc::new(left_folded),
op.clone(),
Arc::new(right_folded),
),
expr.source_location.clone(),
))
}
_ => None,
}
}
fn evaluate_arithmetic(
left: &LiteralValue,
op: &ArithmeticComputation,
right: &LiteralValue,
) -> Option<LiteralValue> {
use crate::computation::arithmetic_operation;
match arithmetic_operation(left, op, right) {
OperationResult::Value(lit) => Some(lit.as_ref().clone()),
OperationResult::Veto(_) => None,
}
}
fn evaluate_comparison(
left: &LiteralValue,
op: &ComparisonComputation,
right: &LiteralValue,
) -> Option<bool> {
use crate::computation::comparison_operation;
use crate::planning::semantics::ValueKind;
match comparison_operation(left, op, right) {
OperationResult::Value(lit) => match &lit.value {
ValueKind::Boolean(b) => Some(*b),
_ => None,
},
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::planning::semantics::ValueKind;
use rust_decimal::Decimal;
fn literal_expr(val: LiteralValue) -> Expression {
Expression::with_source(ExpressionKind::Literal(Box::new(val)), None)
}
fn num(n: i64) -> LiteralValue {
LiteralValue::number(Decimal::from(n))
}
#[test]
fn test_world_new() {
let world = World::new();
assert!(world.0.is_empty());
}
#[test]
fn test_world_insert_and_get() {
let mut world = World::new();
let rule_path = RulePath {
segments: vec![],
rule: "test_rule".to_string(),
};
world.insert(rule_path.clone(), 2);
assert_eq!(world.get(&rule_path), Some(&2));
}
#[test]
fn test_constant_fold_arithmetic() {
let left = literal_expr(num(10));
let right = literal_expr(num(5));
let expr = Expression::with_source(
ExpressionKind::Arithmetic(Arc::new(left), ArithmeticComputation::Add, Arc::new(right)),
None,
);
let folded = try_constant_fold_expression(&expr).unwrap();
if let ExpressionKind::Literal(lit) = &folded.kind {
if let ValueKind::Number(n) = &lit.value {
assert_eq!(*n, Decimal::from(15));
} else {
panic!("Expected literal number");
}
} else {
panic!("Expected literal number");
}
}
#[test]
fn test_constant_fold_comparison() {
let left = literal_expr(num(10));
let right = literal_expr(num(5));
let expr = Expression::with_source(
ExpressionKind::Comparison(
Arc::new(left),
ComparisonComputation::GreaterThan,
Arc::new(right),
),
None,
);
let folded = try_constant_fold_expression(&expr).unwrap();
if let ExpressionKind::Literal(lit) = &folded.kind {
if let ValueKind::Boolean(b) = &lit.value {
assert!(*b);
} else {
panic!("Expected literal boolean");
}
} else {
panic!("Expected literal boolean");
}
}
}