use crate::Term;
use anyhow::{anyhow, Result};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ChrTerm {
Var(String),
Const(String),
Int(i64),
Func(String, Vec<ChrTerm>),
}
impl ChrTerm {
pub fn var(name: &str) -> Self {
Self::Var(name.to_string())
}
pub fn const_(name: &str) -> Self {
Self::Const(name.to_string())
}
pub fn int(n: i64) -> Self {
Self::Int(n)
}
pub fn is_var(&self) -> bool {
matches!(self, Self::Var(_))
}
pub fn var_name(&self) -> Option<&str> {
match self {
Self::Var(n) => Some(n),
_ => None,
}
}
pub fn apply_subst(&self, subst: &Substitution) -> Self {
match self {
Self::Var(v) => subst.get(v).cloned().unwrap_or_else(|| self.clone()),
Self::Const(_) | Self::Int(_) => self.clone(),
Self::Func(name, args) => Self::Func(
name.clone(),
args.iter().map(|a| a.apply_subst(subst)).collect(),
),
}
}
pub fn variables(&self) -> HashSet<String> {
let mut vars = HashSet::new();
self.collect_vars(&mut vars);
vars
}
fn collect_vars(&self, vars: &mut HashSet<String>) {
match self {
Self::Var(v) => {
vars.insert(v.clone());
}
Self::Const(_) | Self::Int(_) => {}
Self::Func(_, args) => {
for arg in args {
arg.collect_vars(vars);
}
}
}
}
}
impl std::fmt::Display for ChrTerm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Var(v) => write!(f, "{}", v),
Self::Const(c) => write!(f, "{}", c),
Self::Int(n) => write!(f, "{}", n),
Self::Func(name, args) => {
write!(f, "{}(", name)?;
for (i, arg) in args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
write!(f, ")")
}
}
}
}
pub type Substitution = HashMap<String, ChrTerm>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Constraint {
pub name: String,
pub args: Vec<ChrTerm>,
pub id: usize,
}
impl Constraint {
pub fn new(name: &str, args: Vec<ChrTerm>) -> Self {
Self {
name: name.to_string(),
args,
id: 0,
}
}
pub fn unary(name: &str, arg: &str) -> Self {
Self::new(name, vec![ChrTerm::var(arg)])
}
pub fn binary(name: &str, arg1: &str, arg2: &str) -> Self {
Self::new(name, vec![ChrTerm::var(arg1), ChrTerm::var(arg2)])
}
pub fn binary_const(name: &str, arg1: &str, arg2: &str) -> Self {
Self::new(name, vec![ChrTerm::const_(arg1), ChrTerm::const_(arg2)])
}
pub fn eq(arg1: &str, arg2: &str) -> Self {
Self::new("=", vec![ChrTerm::var(arg1), ChrTerm::var(arg2)])
}
pub fn neq(arg1: &str, arg2: &str) -> Self {
Self::new("\\=", vec![ChrTerm::var(arg1), ChrTerm::var(arg2)])
}
pub fn leq(arg1: &str, arg2: &str) -> Self {
Self::new("leq", vec![ChrTerm::var(arg1), ChrTerm::var(arg2)])
}
pub fn apply_subst(&self, subst: &Substitution) -> Self {
Self {
name: self.name.clone(),
args: self.args.iter().map(|a| a.apply_subst(subst)).collect(),
id: self.id,
}
}
pub fn variables(&self) -> HashSet<String> {
let mut vars = HashSet::new();
for arg in &self.args {
vars.extend(arg.variables());
}
vars
}
pub fn matches(&self, pattern: &Constraint, subst: &mut Substitution) -> bool {
if self.name != pattern.name || self.args.len() != pattern.args.len() {
return false;
}
for (self_arg, pattern_arg) in self.args.iter().zip(pattern.args.iter()) {
if !Self::term_matches(self_arg, pattern_arg, subst) {
return false;
}
}
true
}
fn term_matches(term: &ChrTerm, pattern: &ChrTerm, subst: &mut Substitution) -> bool {
match (term, pattern) {
(_, ChrTerm::Var(v)) => {
if let Some(bound) = subst.get(v) {
bound == term
} else {
subst.insert(v.clone(), term.clone());
true
}
}
(ChrTerm::Const(c1), ChrTerm::Const(c2)) => c1 == c2,
(ChrTerm::Int(n1), ChrTerm::Int(n2)) => n1 == n2,
(ChrTerm::Func(n1, args1), ChrTerm::Func(n2, args2)) => {
if n1 != n2 || args1.len() != args2.len() {
return false;
}
for (a1, a2) in args1.iter().zip(args2.iter()) {
if !Self::term_matches(a1, a2, subst) {
return false;
}
}
true
}
_ => false,
}
}
}
impl std::fmt::Display for Constraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.args.is_empty() {
write!(f, "{}", self.name)
} else {
write!(f, "{}(", self.name)?;
for (i, arg) in self.args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", arg)?;
}
write!(f, ")")
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChrRuleType {
Simplification,
Propagation,
Simpagation,
}
#[derive(Debug, Clone)]
pub enum Guard {
True,
Equal(ChrTerm, ChrTerm),
NotEqual(ChrTerm, ChrTerm),
LessThan(ChrTerm, ChrTerm),
LessEq(ChrTerm, ChrTerm),
GreaterThan(ChrTerm, ChrTerm),
GreaterEq(ChrTerm, ChrTerm),
And(Vec<Guard>),
Or(Vec<Guard>),
Not(Box<Guard>),
Builtin(String, Vec<ChrTerm>),
}
impl Guard {
pub fn evaluate(&self, subst: &Substitution) -> bool {
match self {
Self::True => true,
Self::Equal(t1, t2) => {
let v1 = t1.apply_subst(subst);
let v2 = t2.apply_subst(subst);
v1 == v2
}
Self::NotEqual(t1, t2) => {
let v1 = t1.apply_subst(subst);
let v2 = t2.apply_subst(subst);
v1 != v2
}
Self::LessThan(t1, t2) => Self::compare_terms(t1, t2, subst, |a, b| a < b),
Self::LessEq(t1, t2) => Self::compare_terms(t1, t2, subst, |a, b| a <= b),
Self::GreaterThan(t1, t2) => Self::compare_terms(t1, t2, subst, |a, b| a > b),
Self::GreaterEq(t1, t2) => Self::compare_terms(t1, t2, subst, |a, b| a >= b),
Self::And(guards) => guards.iter().all(|g| g.evaluate(subst)),
Self::Or(guards) => guards.iter().any(|g| g.evaluate(subst)),
Self::Not(inner) => !inner.evaluate(subst),
Self::Builtin(name, _args) => {
match name.as_str() {
"true" => true,
"false" => false,
_ => true, }
}
}
}
fn compare_terms<F>(t1: &ChrTerm, t2: &ChrTerm, subst: &Substitution, cmp: F) -> bool
where
F: Fn(i64, i64) -> bool,
{
let v1 = t1.apply_subst(subst);
let v2 = t2.apply_subst(subst);
match (&v1, &v2) {
(ChrTerm::Int(n1), ChrTerm::Int(n2)) => cmp(*n1, *n2),
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub struct ChrRule {
pub name: String,
pub rule_type: ChrRuleType,
pub kept_head: Vec<Constraint>,
pub removed_head: Vec<Constraint>,
pub guard: Guard,
pub body: Vec<Constraint>,
pub priority: i32,
}
impl ChrRule {
pub fn simplification(
name: &str,
head: Vec<Constraint>,
guards: Vec<Guard>,
body: Vec<Constraint>,
) -> Self {
Self {
name: name.to_string(),
rule_type: ChrRuleType::Simplification,
kept_head: vec![],
removed_head: head,
guard: if guards.is_empty() {
Guard::True
} else {
Guard::And(guards)
},
body,
priority: 0,
}
}
pub fn propagation(
name: &str,
head: Vec<Constraint>,
guards: Vec<Guard>,
body: Vec<Constraint>,
) -> Self {
Self {
name: name.to_string(),
rule_type: ChrRuleType::Propagation,
kept_head: head,
removed_head: vec![],
guard: if guards.is_empty() {
Guard::True
} else {
Guard::And(guards)
},
body,
priority: 0,
}
}
pub fn simpagation(
name: &str,
kept: Vec<Constraint>,
removed: Vec<Constraint>,
guards: Vec<Guard>,
body: Vec<Constraint>,
) -> Self {
Self {
name: name.to_string(),
rule_type: ChrRuleType::Simpagation,
kept_head: kept,
removed_head: removed,
guard: if guards.is_empty() {
Guard::True
} else {
Guard::And(guards)
},
body,
priority: 0,
}
}
pub fn all_head(&self) -> Vec<&Constraint> {
self.kept_head
.iter()
.chain(self.removed_head.iter())
.collect()
}
pub fn variables(&self) -> HashSet<String> {
let mut vars = HashSet::new();
for c in &self.kept_head {
vars.extend(c.variables());
}
for c in &self.removed_head {
vars.extend(c.variables());
}
for c in &self.body {
vars.extend(c.variables());
}
vars
}
}
impl std::fmt::Display for ChrRule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: ", self.name)?;
if !self.kept_head.is_empty() {
for (i, c) in self.kept_head.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", c)?;
}
if !self.removed_head.is_empty() {
write!(f, " \\ ")?;
}
}
for (i, c) in self.removed_head.iter().enumerate() {
if i > 0 || !self.kept_head.is_empty() {
write!(f, ", ")?;
}
write!(f, "{}", c)?;
}
match self.rule_type {
ChrRuleType::Simplification => write!(f, " <=> ")?,
ChrRuleType::Propagation => write!(f, " ==> ")?,
ChrRuleType::Simpagation => write!(f, " <=> ")?,
}
if self.body.is_empty() {
write!(f, "true")?;
} else {
for (i, c) in self.body.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", c)?;
}
}
Ok(())
}
}
#[derive(Debug, Default)]
struct PropagationHistory {
applied: HashSet<(String, Vec<usize>)>,
}
impl PropagationHistory {
fn new() -> Self {
Self::default()
}
fn has_fired(&self, rule_name: &str, constraint_ids: &[usize]) -> bool {
let key = (rule_name.to_string(), constraint_ids.to_vec());
self.applied.contains(&key)
}
fn record(&mut self, rule_name: &str, constraint_ids: &[usize]) {
let key = (rule_name.to_string(), constraint_ids.to_vec());
self.applied.insert(key);
}
fn clear(&mut self) {
self.applied.clear();
}
}
#[derive(Debug, Default)]
pub struct ConstraintStore {
constraints: Vec<Constraint>,
index: HashMap<String, HashSet<usize>>,
next_id: usize,
removed: HashSet<usize>,
}
impl ConstraintStore {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, mut constraint: Constraint) -> usize {
let id = self.next_id;
self.next_id += 1;
constraint.id = id;
self.index
.entry(constraint.name.clone())
.or_default()
.insert(id);
self.constraints.push(constraint);
id
}
pub fn remove(&mut self, id: usize) {
if let Some(constraint) = self.constraints.iter().find(|c| c.id == id) {
if let Some(ids) = self.index.get_mut(&constraint.name) {
ids.remove(&id);
}
}
self.removed.insert(id);
}
pub fn get(&self, id: usize) -> Option<&Constraint> {
if self.removed.contains(&id) {
return None;
}
self.constraints.iter().find(|c| c.id == id)
}
pub fn all(&self) -> Vec<&Constraint> {
self.constraints
.iter()
.filter(|c| !self.removed.contains(&c.id))
.collect()
}
pub fn by_name(&self, name: &str) -> Vec<&Constraint> {
if let Some(ids) = self.index.get(name) {
ids.iter()
.filter(|id| !self.removed.contains(*id))
.filter_map(|id| self.get(*id))
.collect()
} else {
vec![]
}
}
pub fn contains(&self, constraint: &Constraint) -> bool {
self.all().iter().any(|c| {
c.name == constraint.name
&& c.args.len() == constraint.args.len()
&& c.args
.iter()
.zip(constraint.args.iter())
.all(|(a, b)| a == b)
})
}
pub fn len(&self) -> usize {
self.constraints.len() - self.removed.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.constraints.clear();
self.index.clear();
self.next_id = 0;
self.removed.clear();
}
}
#[derive(Debug, Clone)]
pub struct ChrConfig {
pub max_iterations: usize,
pub use_history: bool,
pub simplify: bool,
}
impl Default for ChrConfig {
fn default() -> Self {
Self {
max_iterations: 10000,
use_history: true,
simplify: true,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ChrStats {
pub rule_applications: usize,
pub propagations: usize,
pub simplifications: usize,
pub iterations: usize,
pub constraints_added: usize,
pub constraints_removed: usize,
}
#[derive(Debug)]
pub struct ChrEngine {
rules: Vec<ChrRule>,
store: ConstraintStore,
history: PropagationHistory,
config: ChrConfig,
stats: ChrStats,
work_queue: VecDeque<usize>,
}
impl Default for ChrEngine {
fn default() -> Self {
Self::new()
}
}
impl ChrEngine {
pub fn new() -> Self {
Self::with_config(ChrConfig::default())
}
pub fn with_config(config: ChrConfig) -> Self {
Self {
rules: Vec::new(),
store: ConstraintStore::new(),
history: PropagationHistory::new(),
config,
stats: ChrStats::default(),
work_queue: VecDeque::new(),
}
}
pub fn add_rule(&mut self, rule: ChrRule) {
self.rules.push(rule);
}
pub fn add_rules(&mut self, rules: Vec<ChrRule>) {
self.rules.extend(rules);
}
pub fn add_constraint(&mut self, constraint: Constraint) {
let id = self.store.add(constraint);
self.work_queue.push_back(id);
self.stats.constraints_added += 1;
}
pub fn add_constraints(&mut self, constraints: Vec<Constraint>) {
for c in constraints {
self.add_constraint(c);
}
}
pub fn constraints(&self) -> Vec<&Constraint> {
self.store.all()
}
pub fn stats(&self) -> &ChrStats {
&self.stats
}
pub fn clear(&mut self) {
self.store.clear();
self.history.clear();
self.stats = ChrStats::default();
self.work_queue.clear();
}
pub fn solve(&mut self) -> Result<Vec<Constraint>> {
self.stats.iterations = 0;
while self.stats.iterations < self.config.max_iterations {
self.stats.iterations += 1;
let mut applied = false;
for rule_idx in 0..self.rules.len() {
if self.try_apply_rule(rule_idx)? {
applied = true;
break; }
}
if !applied {
break;
}
}
Ok(self.store.all().into_iter().cloned().collect())
}
fn try_apply_rule(&mut self, rule_idx: usize) -> Result<bool> {
let rule = self.rules[rule_idx].clone();
let matches = self.find_matching_constraints(&rule)?;
for (subst, matched_ids) in matches {
if rule.rule_type == ChrRuleType::Propagation
&& self.config.use_history
&& self.history.has_fired(&rule.name, &matched_ids)
{
continue;
}
if !rule.guard.evaluate(&subst) {
continue;
}
self.apply_rule(&rule, &subst, &matched_ids)?;
return Ok(true);
}
Ok(false)
}
fn find_matching_constraints(&self, rule: &ChrRule) -> Result<Vec<(Substitution, Vec<usize>)>> {
let all_head: Vec<&Constraint> = rule.all_head();
if all_head.is_empty() {
return Ok(vec![(Substitution::new(), vec![])]);
}
let first = all_head[0];
let candidates = self.store.by_name(&first.name);
let mut results = Vec::new();
for candidate in candidates {
let mut subst = Substitution::new();
if candidate.matches(first, &mut subst) {
let remaining: Vec<&&Constraint> = all_head.iter().skip(1).collect();
let matches = self.match_remaining(&remaining, &subst, &[candidate.id])?;
for (s, ids) in matches {
results.push((s, ids));
}
}
}
Ok(results)
}
fn match_remaining(
&self,
remaining: &[&&Constraint],
subst: &Substitution,
matched_ids: &[usize],
) -> Result<Vec<(Substitution, Vec<usize>)>> {
if remaining.is_empty() {
return Ok(vec![(subst.clone(), matched_ids.to_vec())]);
}
let first = remaining[0];
let applied_first = first.apply_subst(subst);
let candidates = self.store.by_name(&applied_first.name);
let mut results = Vec::new();
for candidate in candidates {
if matched_ids.contains(&candidate.id) {
continue;
}
let mut new_subst = subst.clone();
if candidate.matches(&applied_first, &mut new_subst) {
let mut new_ids = matched_ids.to_vec();
new_ids.push(candidate.id);
let rest: Vec<&&Constraint> = remaining.iter().skip(1).copied().collect();
let sub_matches = self.match_remaining(&rest, &new_subst, &new_ids)?;
results.extend(sub_matches);
}
}
Ok(results)
}
fn apply_rule(
&mut self,
rule: &ChrRule,
subst: &Substitution,
matched_ids: &[usize],
) -> Result<()> {
self.stats.rule_applications += 1;
if rule.rule_type == ChrRuleType::Propagation && self.config.use_history {
self.history.record(&rule.name, matched_ids);
self.stats.propagations += 1;
}
if rule.rule_type == ChrRuleType::Simplification
|| rule.rule_type == ChrRuleType::Simpagation
{
let kept_count = rule.kept_head.len();
for (i, id) in matched_ids.iter().enumerate() {
if rule.rule_type == ChrRuleType::Simpagation && i < kept_count {
continue;
}
self.store.remove(*id);
self.stats.constraints_removed += 1;
}
self.stats.simplifications += 1;
}
for body_constraint in &rule.body {
let new_constraint = body_constraint.apply_subst(subst);
if new_constraint.name == "=" && new_constraint.args.len() == 2 {
self.add_constraint(new_constraint);
} else {
if !self.store.contains(&new_constraint) {
self.add_constraint(new_constraint);
}
}
}
Ok(())
}
pub fn term_from_oxirs(term: &Term) -> ChrTerm {
match term {
Term::Variable(v) => ChrTerm::Var(v.clone()),
Term::Constant(c) => {
if let Ok(n) = c.parse::<i64>() {
ChrTerm::Int(n)
} else {
ChrTerm::Const(c.clone())
}
}
Term::Literal(l) => ChrTerm::Const(l.clone()),
Term::Function { name, args } => ChrTerm::Func(
name.clone(),
args.iter().map(Self::term_from_oxirs).collect(),
),
}
}
pub fn term_to_oxirs(term: &ChrTerm) -> Term {
match term {
ChrTerm::Var(v) => Term::Variable(v.clone()),
ChrTerm::Const(c) => Term::Constant(c.clone()),
ChrTerm::Int(n) => Term::Constant(n.to_string()),
ChrTerm::Func(name, args) => Term::Function {
name: name.clone(),
args: args.iter().map(Self::term_to_oxirs).collect(),
},
}
}
}
pub struct ChrParser;
impl ChrParser {
pub fn parse_rule(input: &str) -> Result<ChrRule> {
let input = input.trim();
let (name, rest) = if let Some(colon_pos) = input.find(':') {
let name = input[..colon_pos].trim();
let rest = input[colon_pos + 1..].trim();
(name.to_string(), rest.to_string())
} else {
("rule".to_string(), input.to_string())
};
let (rule_type, head_str, body_str) = if rest.contains("<=>") {
let parts: Vec<&str> = rest.splitn(2, "<=>").collect();
(
ChrRuleType::Simplification,
parts[0].trim(),
parts.get(1).map(|s| s.trim()).unwrap_or("true"),
)
} else if rest.contains("==>") {
let parts: Vec<&str> = rest.splitn(2, "==>").collect();
(
ChrRuleType::Propagation,
parts[0].trim(),
parts.get(1).map(|s| s.trim()).unwrap_or("true"),
)
} else {
return Err(anyhow!("Invalid CHR rule syntax: missing <=> or ==>"));
};
let (kept_head, removed_head) = if head_str.contains('\\') {
let parts: Vec<&str> = head_str.splitn(2, '\\').collect();
let kept = Self::parse_constraints(parts[0].trim())?;
let removed = Self::parse_constraints(parts.get(1).map(|s| s.trim()).unwrap_or(""))?;
(kept, removed)
} else {
let constraints = Self::parse_constraints(head_str)?;
match rule_type {
ChrRuleType::Simplification => (vec![], constraints),
ChrRuleType::Propagation => (constraints, vec![]),
ChrRuleType::Simpagation => (vec![], constraints),
}
};
let (guard, body_constraints) = if body_str.contains('|') {
let parts: Vec<&str> = body_str.splitn(2, '|').collect();
let guard = Self::parse_guard(parts[0].trim())?;
let body = Self::parse_constraints(parts.get(1).map(|s| s.trim()).unwrap_or("true"))?;
(guard, body)
} else {
(Guard::True, Self::parse_constraints(body_str)?)
};
Ok(ChrRule {
name,
rule_type: if !kept_head.is_empty() && !removed_head.is_empty() {
ChrRuleType::Simpagation
} else {
rule_type
},
kept_head,
removed_head,
guard,
body: body_constraints,
priority: 0,
})
}
fn parse_constraints(input: &str) -> Result<Vec<Constraint>> {
if input.is_empty() || input == "true" {
return Ok(vec![]);
}
let mut constraints = Vec::new();
let mut depth = 0;
let mut start = 0;
for (i, c) in input.char_indices() {
match c {
'(' => depth += 1,
')' => depth -= 1,
',' if depth == 0 => {
let part = input[start..i].trim();
if !part.is_empty() {
constraints.push(Self::parse_constraint(part)?);
}
start = i + 1;
}
_ => {}
}
}
let last = input[start..].trim();
if !last.is_empty() && last != "true" {
constraints.push(Self::parse_constraint(last)?);
}
Ok(constraints)
}
fn parse_constraint(input: &str) -> Result<Constraint> {
let input = input.trim();
if input.contains('=') && !input.contains("\\=") {
let parts: Vec<&str> = input.splitn(2, '=').collect();
return Ok(Constraint::new(
"=",
vec![
Self::parse_term(parts[0].trim())?,
Self::parse_term(parts.get(1).map(|s| s.trim()).unwrap_or(""))?,
],
));
}
if let Some(paren_pos) = input.find('(') {
let name = input[..paren_pos].trim();
let args_str = input[paren_pos + 1..].trim_end_matches(')');
let args = Self::parse_args(args_str)?;
Ok(Constraint::new(name, args))
} else {
Ok(Constraint::new(input, vec![]))
}
}
fn parse_args(input: &str) -> Result<Vec<ChrTerm>> {
if input.is_empty() {
return Ok(vec![]);
}
let mut args = Vec::new();
let mut depth = 0;
let mut start = 0;
for (i, c) in input.char_indices() {
match c {
'(' => depth += 1,
')' => depth -= 1,
',' if depth == 0 => {
let part = input[start..i].trim();
if !part.is_empty() {
args.push(Self::parse_term(part)?);
}
start = i + 1;
}
_ => {}
}
}
let last = input[start..].trim();
if !last.is_empty() {
args.push(Self::parse_term(last)?);
}
Ok(args)
}
fn parse_term(input: &str) -> Result<ChrTerm> {
let input = input.trim();
if input.is_empty() {
return Err(anyhow!("Empty term"));
}
if let Some(paren_pos) = input.find('(') {
let name = input[..paren_pos].trim();
let args_str = input[paren_pos + 1..].trim_end_matches(')');
let args = Self::parse_args(args_str)?;
return Ok(ChrTerm::Func(name.to_string(), args));
}
if let Ok(n) = input.parse::<i64>() {
return Ok(ChrTerm::Int(n));
}
if input
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false)
{
Ok(ChrTerm::Var(input.to_string()))
} else {
Ok(ChrTerm::Const(input.to_string()))
}
}
fn parse_guard(input: &str) -> Result<Guard> {
let input = input.trim();
if input.is_empty() || input == "true" {
return Ok(Guard::True);
}
if input.contains("\\=") {
let parts: Vec<&str> = input.splitn(2, "\\=").collect();
return Ok(Guard::NotEqual(
Self::parse_term(parts[0].trim())?,
Self::parse_term(parts.get(1).map(|s| s.trim()).unwrap_or(""))?,
));
}
if input.contains(">=") {
let parts: Vec<&str> = input.splitn(2, ">=").collect();
return Ok(Guard::GreaterEq(
Self::parse_term(parts[0].trim())?,
Self::parse_term(parts.get(1).map(|s| s.trim()).unwrap_or(""))?,
));
}
if input.contains("<=") {
let parts: Vec<&str> = input.splitn(2, "<=").collect();
return Ok(Guard::LessEq(
Self::parse_term(parts[0].trim())?,
Self::parse_term(parts.get(1).map(|s| s.trim()).unwrap_or(""))?,
));
}
if input.contains('>') {
let parts: Vec<&str> = input.splitn(2, '>').collect();
return Ok(Guard::GreaterThan(
Self::parse_term(parts[0].trim())?,
Self::parse_term(parts.get(1).map(|s| s.trim()).unwrap_or(""))?,
));
}
if input.contains('<') {
let parts: Vec<&str> = input.splitn(2, '<').collect();
return Ok(Guard::LessThan(
Self::parse_term(parts[0].trim())?,
Self::parse_term(parts.get(1).map(|s| s.trim()).unwrap_or(""))?,
));
}
if input.contains('=') {
let parts: Vec<&str> = input.splitn(2, '=').collect();
return Ok(Guard::Equal(
Self::parse_term(parts[0].trim())?,
Self::parse_term(parts.get(1).map(|s| s.trim()).unwrap_or(""))?,
));
}
Ok(Guard::Builtin(input.to_string(), vec![]))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chr_term_creation() {
let var = ChrTerm::var("X");
assert!(var.is_var());
assert_eq!(var.var_name(), Some("X"));
let const_ = ChrTerm::const_("foo");
assert!(!const_.is_var());
let int = ChrTerm::int(42);
assert!(matches!(int, ChrTerm::Int(42)));
}
#[test]
fn test_chr_term_substitution() {
let var = ChrTerm::var("X");
let mut subst = Substitution::new();
subst.insert("X".to_string(), ChrTerm::const_("value"));
let result = var.apply_subst(&subst);
assert_eq!(result, ChrTerm::const_("value"));
}
#[test]
fn test_constraint_creation() {
let c = Constraint::binary("leq", "X", "Y");
assert_eq!(c.name, "leq");
assert_eq!(c.args.len(), 2);
}
#[test]
fn test_constraint_matching() {
let pattern = Constraint::binary("leq", "X", "Y");
let instance = Constraint::new("leq", vec![ChrTerm::const_("a"), ChrTerm::const_("b")]);
let mut subst = Substitution::new();
assert!(instance.matches(&pattern, &mut subst));
assert_eq!(subst.get("X"), Some(&ChrTerm::const_("a")));
assert_eq!(subst.get("Y"), Some(&ChrTerm::const_("b")));
}
#[test]
fn test_chr_rule_simplification() {
let rule = ChrRule::simplification(
"reflexivity",
vec![Constraint::binary("leq", "X", "X")],
vec![],
vec![], );
assert_eq!(rule.rule_type, ChrRuleType::Simplification);
assert!(rule.kept_head.is_empty());
assert_eq!(rule.removed_head.len(), 1);
}
#[test]
fn test_chr_rule_propagation() {
let rule = ChrRule::propagation(
"transitivity",
vec![
Constraint::binary("leq", "X", "Y"),
Constraint::binary("leq", "Y", "Z"),
],
vec![Guard::NotEqual(ChrTerm::var("X"), ChrTerm::var("Z"))],
vec![Constraint::binary("leq", "X", "Z")],
);
assert_eq!(rule.rule_type, ChrRuleType::Propagation);
assert_eq!(rule.kept_head.len(), 2);
}
#[test]
fn test_constraint_store() {
let mut store = ConstraintStore::new();
let id1 = store.add(Constraint::binary("leq", "a", "b"));
let id2 = store.add(Constraint::binary("leq", "b", "c"));
assert_eq!(store.len(), 2);
assert!(store.get(id1).is_some());
assert!(store.get(id2).is_some());
store.remove(id1);
assert_eq!(store.len(), 1);
assert!(store.get(id1).is_none());
}
#[test]
fn test_chr_engine_basic() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ChrEngine::new();
engine.add_rule(ChrRule::simplification(
"reflexivity",
vec![Constraint::binary("leq", "X", "X")],
vec![],
vec![],
));
engine.add_constraint(Constraint::new(
"leq",
vec![ChrTerm::const_("a"), ChrTerm::const_("a")],
));
let result = engine.solve()?;
assert!(result.is_empty());
Ok(())
}
#[test]
fn test_chr_antisymmetry() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ChrEngine::new();
engine.add_rule(ChrRule::simplification(
"antisymmetry",
vec![
Constraint::binary("leq", "X", "Y"),
Constraint::binary("leq", "Y", "X"),
],
vec![],
vec![Constraint::eq("X", "Y")],
));
engine.add_constraint(Constraint::new(
"leq",
vec![ChrTerm::const_("a"), ChrTerm::const_("b")],
));
engine.add_constraint(Constraint::new(
"leq",
vec![ChrTerm::const_("b"), ChrTerm::const_("a")],
));
let result = engine.solve()?;
assert!(result.iter().any(|c| c.name == "="));
Ok(())
}
#[test]
fn test_chr_propagation() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ChrEngine::new();
engine.add_rule(ChrRule::propagation(
"transitivity",
vec![
Constraint::binary("leq", "X", "Y"),
Constraint::binary("leq", "Y", "Z"),
],
vec![],
vec![Constraint::binary("leq", "X", "Z")],
));
engine.add_constraint(Constraint::new(
"leq",
vec![ChrTerm::const_("a"), ChrTerm::const_("b")],
));
engine.add_constraint(Constraint::new(
"leq",
vec![ChrTerm::const_("b"), ChrTerm::const_("c")],
));
let result = engine.solve()?;
assert!(result.iter().any(|c| {
c.name == "leq"
&& c.args.len() == 2
&& c.args[0] == ChrTerm::const_("a")
&& c.args[1] == ChrTerm::const_("c")
}));
Ok(())
}
#[test]
fn test_chr_guard() {
let guard = Guard::Equal(ChrTerm::var("X"), ChrTerm::const_("value"));
let mut subst = Substitution::new();
subst.insert("X".to_string(), ChrTerm::const_("value"));
assert!(guard.evaluate(&subst));
subst.insert("X".to_string(), ChrTerm::const_("other"));
assert!(!guard.evaluate(&subst));
}
#[test]
fn test_chr_parser_simplification() -> Result<(), Box<dyn std::error::Error>> {
let rule = ChrParser::parse_rule("reflexivity: leq(X, X) <=> true")?;
assert_eq!(rule.name, "reflexivity");
assert_eq!(rule.rule_type, ChrRuleType::Simplification);
assert_eq!(rule.removed_head.len(), 1);
Ok(())
}
#[test]
fn test_chr_parser_propagation() -> Result<(), Box<dyn std::error::Error>> {
let rule = ChrParser::parse_rule("trans: leq(X, Y), leq(Y, Z) ==> leq(X, Z)")?;
assert_eq!(rule.name, "trans");
assert_eq!(rule.rule_type, ChrRuleType::Propagation);
assert_eq!(rule.kept_head.len(), 2);
assert_eq!(rule.body.len(), 1);
Ok(())
}
#[test]
fn test_chr_parser_with_guard() -> Result<(), Box<dyn std::error::Error>> {
let rule = ChrParser::parse_rule("idempotence: leq(X, Y), leq(X, Y) <=> true | leq(X, Y)")?;
assert_eq!(rule.name, "idempotence");
Ok(())
}
#[test]
fn test_chr_stats() -> Result<(), Box<dyn std::error::Error>> {
let mut engine = ChrEngine::new();
engine.add_rule(ChrRule::simplification(
"test",
vec![Constraint::binary("test", "X", "X")],
vec![],
vec![],
));
engine.add_constraint(Constraint::new(
"test",
vec![ChrTerm::const_("a"), ChrTerm::const_("a")],
));
engine.solve()?;
assert!(engine.stats().rule_applications > 0);
Ok(())
}
#[test]
fn test_chr_term_display() {
let term = ChrTerm::Func(
"f".to_string(),
vec![ChrTerm::var("X"), ChrTerm::const_("a")],
);
let display = format!("{}", term);
assert!(display.contains("f("));
assert!(display.contains("X"));
assert!(display.contains("a"));
}
#[test]
fn test_chr_constraint_display() {
let c = Constraint::binary("leq", "X", "Y");
let display = format!("{}", c);
assert!(display.contains("leq"));
assert!(display.contains("X"));
assert!(display.contains("Y"));
}
#[test]
fn test_chr_rule_display() {
let rule = ChrRule::simplification(
"test",
vec![Constraint::binary("p", "X", "Y")],
vec![],
vec![Constraint::binary("q", "X", "Y")],
);
let display = format!("{}", rule);
assert!(display.contains("test"));
assert!(display.contains("<=>"));
}
}