#![forbid(unsafe_code)]
use std::fmt;
use itertools::Itertools;
use merc_data::DataVariable;
use merc_data::DataVariableRef;
use merc_data::is_data_variable;
use crate::Rule;
use crate::utilities::DataPosition;
use crate::utilities::DataPositionIndexed;
use crate::utilities::DataPositionIterator;
#[derive(Hash, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
pub struct EquivalenceClass {
pub variable: DataVariable,
pub positions: Vec<DataPosition>,
}
pub fn derive_equivalence_classes(rule: &Rule) -> Vec<EquivalenceClass> {
let mut var_equivalences = vec![];
for (term, pos) in DataPositionIterator::new(rule.lhs.copy()) {
if is_data_variable(&term) {
update_equivalences(&mut var_equivalences, &DataVariableRef::from(term), pos);
}
}
var_equivalences.retain(|x| x.positions.len() > 1);
var_equivalences
}
pub fn check_equivalence_classes<'a, T, P>(term: &'a P, eqs: &[EquivalenceClass]) -> bool
where
P: DataPositionIndexed<'a, Target<'a> = T> + 'a,
T: PartialEq,
{
eqs.iter().all(|ec| {
debug_assert!(
ec.positions.len() >= 2,
"An equivalence class must contain at least two positions"
);
let mut iter_pos = ec.positions.iter();
let first = iter_pos.next().unwrap();
iter_pos.all(|other_pos| term.get_data_position(first) == term.get_data_position(other_pos))
})
}
fn update_equivalences(ve: &mut Vec<EquivalenceClass>, variable: &DataVariableRef<'_>, pos: DataPosition) {
if ve.iter().any(|ec| ec.variable.copy() == *variable) {
for ec in ve.iter_mut() {
if ec.variable.copy() == *variable && !ec.positions.iter().any(|x| x == &pos) {
ec.positions.push(pos);
break;
}
}
} else {
ve.push(EquivalenceClass {
variable: variable.protect(),
positions: vec![pos],
});
}
}
impl fmt::Display for EquivalenceClass {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{{ {} }}", self.variable, self.positions.iter().format(", "))
}
}
#[cfg(test)]
mod tests {
use merc_data::DataExpression;
use merc_data::DataVariable;
use crate::test_utility::create_rewrite_rule;
use super::*;
#[test]
fn test_derive_equivalence_classes() {
let eq: Vec<EquivalenceClass> =
derive_equivalence_classes(&create_rewrite_rule("f(x, h(x))", "result", &["x"]).unwrap());
assert_eq!(
eq,
vec![EquivalenceClass {
variable: DataVariable::new("x").into(),
positions: vec![DataPosition::new(&[1]), DataPosition::new(&[2, 1])]
},],
"The resulting config stack is not as expected"
);
let expression = DataExpression::from_string("f(a(b), h(a(b)))").unwrap();
assert!(
check_equivalence_classes(&expression, &eq),
"The equivalence classes are not checked correctly, equivalences: {:?} and term {}",
&eq,
&expression
);
}
}