use super::expression::Expression;
use crate::errors::{Result, RuleEngineError};
use crate::types::Value;
use crate::Facts;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Bindings {
bindings: HashMap<String, Value>,
}
impl Bindings {
pub fn new() -> Self {
Self {
bindings: HashMap::new(),
}
}
pub fn bind(&mut self, var_name: String, value: Value) -> Result<()> {
if let Some(existing) = self.bindings.get(&var_name) {
if existing != &value {
return Err(RuleEngineError::ExecutionError(format!(
"Variable binding conflict: {} is already bound to {:?}, cannot rebind to {:?}",
var_name, existing, value
)));
}
} else {
self.bindings.insert(var_name, value);
}
Ok(())
}
pub fn get(&self, var_name: &str) -> Option<&Value> {
self.bindings.get(var_name)
}
pub fn is_bound(&self, var_name: &str) -> bool {
self.bindings.contains_key(var_name)
}
pub fn merge(&mut self, other: &Bindings) -> Result<()> {
for (var, val) in &other.bindings {
self.bind(var.clone(), val.clone())?;
}
Ok(())
}
pub fn as_map(&self) -> &HashMap<String, Value> {
&self.bindings
}
pub fn len(&self) -> usize {
self.bindings.len()
}
pub fn is_empty(&self) -> bool {
self.bindings.is_empty()
}
pub fn clear(&mut self) {
self.bindings.clear();
}
pub fn from_map(map: HashMap<String, Value>) -> Self {
Self { bindings: map }
}
pub fn into_map(self) -> HashMap<String, Value> {
self.bindings
}
pub fn to_map(&self) -> HashMap<String, Value> {
self.bindings.clone()
}
}
impl Default for Bindings {
fn default() -> Self {
Self::new()
}
}
pub struct Unifier;
impl Unifier {
pub fn unify(left: &Expression, right: &Expression, bindings: &mut Bindings) -> Result<bool> {
match (left, right) {
(Expression::Variable(var), expr) => {
if let Some(bound_value) = bindings.get(var) {
Self::unify(&Expression::Literal(bound_value.clone()), expr, bindings)
} else {
if let Some(value) = Self::expression_to_value(expr, bindings)? {
bindings.bind(var.clone(), value)?;
Ok(true)
} else {
Ok(false)
}
}
}
(expr, Expression::Variable(var)) => {
Self::unify(&Expression::Variable(var.clone()), expr, bindings)
}
(Expression::Literal(v1), Expression::Literal(v2)) => Ok(v1 == v2),
(Expression::Field(f1), Expression::Field(f2)) => Ok(f1 == f2),
(
Expression::Comparison {
left: l1,
operator: op1,
right: r1,
},
Expression::Comparison {
left: l2,
operator: op2,
right: r2,
},
) => {
if op1 != op2 {
return Ok(false);
}
let left_match = Self::unify(l1, l2, bindings)?;
let right_match = Self::unify(r1, r2, bindings)?;
Ok(left_match && right_match)
}
(
Expression::And {
left: l1,
right: r1,
},
Expression::And {
left: l2,
right: r2,
},
) => {
let left_match = Self::unify(l1, l2, bindings)?;
let right_match = Self::unify(r1, r2, bindings)?;
Ok(left_match && right_match)
}
(
Expression::Or {
left: l1,
right: r1,
},
Expression::Or {
left: l2,
right: r2,
},
) => {
let left_match = Self::unify(l1, l2, bindings)?;
let right_match = Self::unify(r1, r2, bindings)?;
Ok(left_match && right_match)
}
(Expression::Not(e1), Expression::Not(e2)) => Self::unify(e1, e2, bindings),
_ => Ok(false),
}
}
pub fn match_expression(
expr: &Expression,
facts: &Facts,
bindings: &mut Bindings,
) -> Result<bool> {
match expr {
Expression::Variable(var) => {
if !bindings.is_bound(var) {
return Ok(false);
}
Ok(true)
}
Expression::Field(field_name) => {
Ok(facts.get(field_name).is_some())
}
Expression::Literal(_) => {
Ok(true)
}
Expression::Comparison {
left,
operator,
right,
} => {
let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
let result = match operator {
crate::types::Operator::Equal => left_val == right_val,
crate::types::Operator::NotEqual => left_val != right_val,
crate::types::Operator::GreaterThan => {
Self::compare_values(&left_val, &right_val)? > 0
}
crate::types::Operator::LessThan => {
Self::compare_values(&left_val, &right_val)? < 0
}
crate::types::Operator::GreaterThanOrEqual => {
Self::compare_values(&left_val, &right_val)? >= 0
}
crate::types::Operator::LessThanOrEqual => {
Self::compare_values(&left_val, &right_val)? <= 0
}
_ => {
return Err(RuleEngineError::ExecutionError(format!(
"Unsupported operator: {:?}",
operator
)));
}
};
Ok(result)
}
Expression::And { left, right } => {
let left_match = Self::match_expression(left, facts, bindings)?;
if !left_match {
return Ok(false);
}
Self::match_expression(right, facts, bindings)
}
Expression::Or { left, right } => {
let left_match = Self::match_expression(left, facts, bindings)?;
if left_match {
return Ok(true);
}
Self::match_expression(right, facts, bindings)
}
Expression::Not(expr) => {
let result = Self::match_expression(expr, facts, bindings)?;
Ok(!result)
}
}
}
pub fn evaluate_with_bindings(
expr: &Expression,
facts: &Facts,
bindings: &Bindings,
) -> Result<Value> {
match expr {
Expression::Variable(var) => bindings.get(var).cloned().ok_or_else(|| {
RuleEngineError::ExecutionError(format!("Unbound variable: {}", var))
}),
Expression::Field(field) => facts.get(field).ok_or_else(|| {
RuleEngineError::ExecutionError(format!("Field not found: {}", field))
}),
Expression::Literal(val) => Ok(val.clone()),
Expression::Comparison {
left,
operator,
right,
} => {
let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
let result = match operator {
crate::types::Operator::Equal => left_val == right_val,
crate::types::Operator::NotEqual => left_val != right_val,
crate::types::Operator::GreaterThan => {
Self::compare_values(&left_val, &right_val)? > 0
}
crate::types::Operator::LessThan => {
Self::compare_values(&left_val, &right_val)? < 0
}
crate::types::Operator::GreaterThanOrEqual => {
Self::compare_values(&left_val, &right_val)? >= 0
}
crate::types::Operator::LessThanOrEqual => {
Self::compare_values(&left_val, &right_val)? <= 0
}
_ => {
return Err(RuleEngineError::ExecutionError(format!(
"Unsupported operator: {:?}",
operator
)));
}
};
Ok(Value::Boolean(result))
}
Expression::And { left, right } => {
let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
if !left_val.to_bool() {
return Ok(Value::Boolean(false));
}
let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
Ok(Value::Boolean(right_val.to_bool()))
}
Expression::Or { left, right } => {
let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
if left_val.to_bool() {
return Ok(Value::Boolean(true));
}
let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
Ok(Value::Boolean(right_val.to_bool()))
}
Expression::Not(expr) => {
let value = Self::evaluate_with_bindings(expr, facts, bindings)?;
Ok(Value::Boolean(!value.to_bool()))
}
}
}
fn expression_to_value(expr: &Expression, bindings: &Bindings) -> Result<Option<Value>> {
match expr {
Expression::Literal(val) => Ok(Some(val.clone())),
Expression::Variable(var) => Ok(bindings.get(var).cloned()),
_ => Ok(None), }
}
fn compare_values(left: &Value, right: &Value) -> Result<i32> {
match (left, right) {
(Value::Number(a), Value::Number(b)) => {
if a < b {
Ok(-1)
} else if a > b {
Ok(1)
} else {
Ok(0)
}
}
(Value::String(a), Value::String(b)) => Ok(a.cmp(b) as i32),
(Value::Boolean(a), Value::Boolean(b)) => Ok(a.cmp(b) as i32),
_ => Err(RuleEngineError::ExecutionError(format!(
"Cannot compare values: {:?} and {:?}",
left, right
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Operator;
#[test]
fn test_bindings_basic() {
let mut bindings = Bindings::new();
assert!(bindings.is_empty());
assert_eq!(bindings.len(), 0);
bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
assert!(!bindings.is_empty());
assert_eq!(bindings.len(), 1);
assert!(bindings.is_bound("X"));
assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
}
#[test]
fn test_bindings_conflict() {
let mut bindings = Bindings::new();
bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
assert!(bindings.bind("X".to_string(), Value::Number(42.0)).is_ok());
assert!(bindings
.bind("X".to_string(), Value::Number(100.0))
.is_err());
}
#[test]
fn test_bindings_merge() {
let mut bindings1 = Bindings::new();
let mut bindings2 = Bindings::new();
bindings1
.bind("X".to_string(), Value::Number(42.0))
.unwrap();
bindings2
.bind("Y".to_string(), Value::String("hello".to_string()))
.unwrap();
bindings1.merge(&bindings2).unwrap();
assert_eq!(bindings1.len(), 2);
assert_eq!(bindings1.get("X"), Some(&Value::Number(42.0)));
assert_eq!(
bindings1.get("Y"),
Some(&Value::String("hello".to_string()))
);
}
#[test]
fn test_bindings_merge_conflict() {
let mut bindings1 = Bindings::new();
let mut bindings2 = Bindings::new();
bindings1
.bind("X".to_string(), Value::Number(42.0))
.unwrap();
bindings2
.bind("X".to_string(), Value::Number(100.0))
.unwrap();
assert!(bindings1.merge(&bindings2).is_err());
}
#[test]
fn test_unify_variable_with_literal() {
let mut bindings = Bindings::new();
let var = Expression::Variable("X".to_string());
let lit = Expression::Literal(Value::Number(42.0));
let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
assert!(result);
assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
}
#[test]
fn test_unify_bound_variable() {
let mut bindings = Bindings::new();
bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
let var = Expression::Variable("X".to_string());
let lit = Expression::Literal(Value::Number(42.0));
let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
assert!(result);
let lit2 = Expression::Literal(Value::Number(100.0));
let result2 = Unifier::unify(&var, &lit2, &mut bindings);
assert!(result2.is_err() || !result2.unwrap());
}
#[test]
fn test_unify_two_literals() {
let mut bindings = Bindings::new();
let lit1 = Expression::Literal(Value::Number(42.0));
let lit2 = Expression::Literal(Value::Number(42.0));
let lit3 = Expression::Literal(Value::Number(100.0));
assert!(Unifier::unify(&lit1, &lit2, &mut bindings).unwrap());
assert!(!Unifier::unify(&lit1, &lit3, &mut bindings).unwrap());
}
#[test]
fn test_match_expression_simple() {
let facts = Facts::new();
facts.set("User.IsVIP", Value::Boolean(true));
let mut bindings = Bindings::new();
let expr = Expression::Comparison {
left: Box::new(Expression::Field("User.IsVIP".to_string())),
operator: Operator::Equal,
right: Box::new(Expression::Literal(Value::Boolean(true))),
};
let result = Unifier::match_expression(&expr, &facts, &mut bindings).unwrap();
assert!(result);
}
#[test]
fn test_evaluate_with_bindings() {
let facts = Facts::new();
facts.set("Order.Amount", Value::Number(100.0));
let mut bindings = Bindings::new();
bindings.bind("X".to_string(), Value::Number(50.0)).unwrap();
let var_expr = Expression::Variable("X".to_string());
let result = Unifier::evaluate_with_bindings(&var_expr, &facts, &bindings).unwrap();
assert_eq!(result, Value::Number(50.0));
let field_expr = Expression::Field("Order.Amount".to_string());
let result = Unifier::evaluate_with_bindings(&field_expr, &facts, &bindings).unwrap();
assert_eq!(result, Value::Number(100.0));
}
#[test]
fn test_compare_values() {
assert_eq!(
Unifier::compare_values(&Value::Number(10.0), &Value::Number(20.0)).unwrap(),
-1
);
assert_eq!(
Unifier::compare_values(&Value::Number(20.0), &Value::Number(10.0)).unwrap(),
1
);
assert_eq!(
Unifier::compare_values(&Value::Number(10.0), &Value::Number(10.0)).unwrap(),
0
);
}
}