use std::collections::{HashMap, HashSet};
use super::segment::Segment;
use crate::parser::error::ParseError;
use crate::parser::{
Arithmetic, ArithmeticOperator, AtomArg, ComparisonOperator, ConstType, Factor, FlowLogRule,
HeadArg, Predicate,
};
pub(crate) fn desugar_equality_assignments(
segments: &mut [Segment],
raw_facts: &mut Vec<FlowLogRule>,
) -> Result<(), ParseError> {
for seg in segments.iter_mut() {
let rules: &mut Vec<FlowLogRule> = match seg {
Segment::Plain(rules) => rules,
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules_mut(),
};
let mut kept = Vec::with_capacity(rules.len());
for mut rule in std::mem::take(rules) {
if desugar_rule(&mut rule)? {
raw_facts.push(rule);
} else {
kept.push(rule);
}
}
*rules = kept;
}
Ok(())
}
fn desugar_rule(rule: &mut FlowLogRule) -> Result<bool, ParseError> {
let mut bound: HashSet<String> = HashSet::new();
for pred in rule.rhs() {
if let Predicate::PositiveAtom(atom) = pred {
for arg in atom.arguments() {
if let AtomArg::Var(v) = arg {
bound.insert(v.clone());
}
}
}
}
let mut assignment_idx: HashSet<usize> = HashSet::new();
let mut order: Vec<(String, Arithmetic)> = Vec::new();
loop {
let mut progressed = false;
for (i, pred) in rule.rhs().iter().enumerate() {
if assignment_idx.contains(&i) {
continue;
}
let Predicate::Compare(expr) = pred else {
continue;
};
if *expr.operator() != ComparisonOperator::Equal {
continue;
}
if let Some((var, value)) = as_assignment(expr.left(), expr.right(), &bound) {
bound.insert(var.clone());
assignment_idx.insert(i);
order.push((var, value));
progressed = true;
}
}
if !progressed {
break;
}
}
if order.is_empty() {
return Ok(false);
}
let mut resolved: HashMap<String, Arithmetic> = HashMap::new();
let mut resolved_order: Vec<String> = Vec::new();
for (var, mut value) in order {
for prior in &resolved_order {
subst_arith(&mut value, prior, &resolved[prior]);
}
resolved_order.push(var.clone());
resolved.insert(var, value);
}
for var in &resolved_order {
subst_head(rule, var, &resolved[var]);
}
let mut new_rhs: Vec<Predicate> = Vec::with_capacity(rule.rhs().len());
for (i, pred) in rule.rhs_mut().iter_mut().enumerate() {
if assignment_idx.contains(&i) {
continue;
}
for var in &resolved_order {
subst_predicate(pred, var, &resolved[var])?;
}
new_rhs.push(pred.clone());
}
rule.set_rhs(new_rhs);
if !rule.rhs().is_empty() {
return Ok(false);
}
let span = rule.head().span();
for arg in rule.head_mut().head_arguments_mut() {
match arg {
HeadArg::Arith(a) if a.is_const() => {}
HeadArg::Arith(a) => {
let folded = fold_const_int(a).ok_or(ParseError::GroundRuleNotConst { span })?;
*a = Arithmetic::new(Factor::Const(ConstType::Int(folded)), vec![]);
}
HeadArg::Var(_) | HeadArg::Aggregation(_) => {
return Err(ParseError::GroundRuleNotConst { span });
}
}
}
Ok(true)
}
fn fold_const_int(a: &Arithmetic) -> Option<i64> {
fn factor_value(f: &Factor) -> Option<i64> {
match f {
Factor::Const(ConstType::Int(v)) => Some(*v),
Factor::Group(inner) => fold_const_int(inner),
_ => None,
}
}
let mut acc = factor_value(a.init())?;
for (op, f) in a.rest() {
let v = factor_value(f)?;
acc = match op {
ArithmeticOperator::Plus => acc.checked_add(v)?,
ArithmeticOperator::Minus => acc.checked_sub(v)?,
ArithmeticOperator::Multiply => acc.checked_mul(v)?,
ArithmeticOperator::Divide => acc.checked_div(v)?,
ArithmeticOperator::Modulo => acc.checked_rem(v)?,
};
}
Some(acc)
}
fn as_assignment(
lhs: &Arithmetic,
rhs: &Arithmetic,
bound: &HashSet<String>,
) -> Option<(String, Arithmetic)> {
let try_side =
|var_side: &Arithmetic, value_side: &Arithmetic| -> Option<(String, Arithmetic)> {
if !var_side.is_var() {
return None;
}
let var = var_side.vars().into_iter().next()?.clone();
if bound.contains(&var) {
return None;
}
let value_vars = value_side.vars();
if value_vars.iter().any(|v| **v == var) {
return None;
}
if value_vars.iter().all(|v| bound.contains(*v)) {
Some((var, value_side.clone()))
} else {
None
}
};
try_side(lhs, rhs).or_else(|| try_side(rhs, lhs))
}
fn value_factor(value: &Arithmetic) -> Factor {
if value.rest().is_empty() {
value.init().clone()
} else {
Factor::Group(Box::new(value.clone()))
}
}
fn subst_head(rule: &mut FlowLogRule, var: &str, value: &Arithmetic) {
for arg in rule.head_mut().head_arguments_mut() {
match arg {
HeadArg::Var(v) if v == var => {
*arg = if value.is_var() {
HeadArg::Var(value.vars()[0].clone())
} else {
HeadArg::Arith(value.clone())
};
}
HeadArg::Var(_) => {}
HeadArg::Arith(a) => subst_arith(a, var, value),
HeadArg::Aggregation(agg) => subst_arith(agg.arithmetic_mut(), var, value),
}
}
}
fn subst_predicate(pred: &mut Predicate, var: &str, value: &Arithmetic) -> Result<(), ParseError> {
match pred {
Predicate::PositiveAtom(_) => {}
Predicate::NegativeAtom(atom) => {
let span = atom.span();
for arg in atom.arguments_mut() {
if let AtomArg::Var(v) = arg
&& v == var
{
*arg = atom_arg_from_value(value).ok_or_else(|| {
ParseError::AssignmentVarInNegation {
span,
var: var.to_string(),
}
})?;
}
}
}
Predicate::Compare(expr) => {
subst_arith(expr.left_mut(), var, value);
subst_arith(expr.right_mut(), var, value);
}
Predicate::FnCall(fc) => {
for arg in fc.args_mut() {
subst_arith(arg, var, value);
}
}
}
Ok(())
}
fn atom_arg_from_value(value: &Arithmetic) -> Option<AtomArg> {
if !value.rest().is_empty() {
return None;
}
match value.init() {
Factor::Var(v) => Some(AtomArg::Var(v.clone())),
Factor::Const(c) => Some(AtomArg::Const(c.clone())),
_ => None,
}
}
fn subst_arith(arith: &mut Arithmetic, var: &str, value: &Arithmetic) {
subst_factor(arith.init_mut(), var, value);
for (_, factor) in arith.rest_mut() {
subst_factor(factor, var, value);
}
}
fn subst_factor(factor: &mut Factor, var: &str, value: &Arithmetic) {
match factor {
Factor::Var(v) if v == var => *factor = value_factor(value),
Factor::Var(_) | Factor::Const(_) => {}
Factor::FnCall(fc) => {
for arg in fc.args_mut() {
subst_arith(arg, var, value);
}
}
Factor::Builtin(bc) => {
for arg in bc.args_mut() {
subst_arith(arg, var, value);
}
}
Factor::Cast(c) => subst_factor(c.inner_mut(), var, value),
Factor::Group(a) => subst_arith(a, var, value),
}
}