use std::collections::HashMap;
use std::fmt;
use tensorlogic_ir::TLExpr;
#[derive(Debug, Clone, PartialEq)]
pub enum TLType {
Bool,
Numeric,
Relation(usize),
Set,
Fuzzy,
Probabilistic,
Var(usize),
Unknown,
}
impl TLType {
pub fn is_ground(&self) -> bool {
!matches!(self, TLType::Var(_))
}
pub fn display_name(&self) -> &'static str {
match self {
TLType::Bool => "Bool",
TLType::Numeric => "Numeric",
TLType::Relation(_) => "Relation",
TLType::Set => "Set",
TLType::Fuzzy => "Fuzzy",
TLType::Probabilistic => "Probabilistic",
TLType::Var(_) => "Var",
TLType::Unknown => "Unknown",
}
}
}
impl fmt::Display for TLType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TLType::Relation(n) => write!(f, "Relation({})", n),
TLType::Var(id) => write!(f, "Var({})", id),
other => write!(f, "{}", other.display_name()),
}
}
}
#[derive(Debug, Default)]
pub struct TyVarCounter {
next: usize,
}
impl TyVarCounter {
pub fn new() -> Self {
Self::default()
}
pub fn fresh(&mut self) -> TLType {
let id = self.next;
self.next += 1;
TLType::Var(id)
}
}
#[derive(Debug, Default, Clone)]
pub struct Substitution {
map: HashMap<usize, TLType>,
}
impl Substitution {
pub fn new() -> Self {
Self::default()
}
pub fn bind(&mut self, var: usize, ty: TLType) {
self.map.insert(var, ty);
}
pub fn lookup(&self, var: usize) -> Option<&TLType> {
self.map.get(&var)
}
pub fn apply(&self, ty: &TLType) -> TLType {
match ty {
TLType::Var(id) => {
let mut current_id = *id;
let mut visited = Vec::new(); loop {
if visited.contains(¤t_id) {
return TLType::Var(current_id);
}
visited.push(current_id);
match self.map.get(¤t_id) {
None => return TLType::Var(current_id),
Some(TLType::Var(next_id)) => {
current_id = *next_id;
}
Some(other) => return other.clone(),
}
}
}
other => other.clone(),
}
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
#[derive(Debug)]
pub enum TypeInferError {
UnificationFailed { expected: String, got: String },
UnboundVariable(String),
OccursCheck(usize, String),
ArityMismatch {
name: String,
expected: usize,
got: usize,
},
}
impl fmt::Display for TypeInferError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TypeInferError::UnificationFailed { expected, got } => {
write!(f, "type mismatch: expected {}, got {}", expected, got)
}
TypeInferError::UnboundVariable(name) => {
write!(f, "unbound variable: {}", name)
}
TypeInferError::OccursCheck(id, ty) => {
write!(
f,
"occurs check failed: type variable Var({}) occurs in {}",
id, ty
)
}
TypeInferError::ArityMismatch {
name,
expected,
got,
} => {
write!(
f,
"arity mismatch for '{}': expected {} args, got {}",
name, expected, got
)
}
}
}
}
impl std::error::Error for TypeInferError {}
#[derive(Debug, Clone, Default)]
pub struct TypeEnv {
bindings: HashMap<String, TLType>,
}
impl TypeEnv {
pub fn new() -> Self {
Self::default()
}
pub fn with(mut self, var: impl Into<String>, ty: TLType) -> Self {
self.bindings.insert(var.into(), ty);
self
}
pub fn bind(&mut self, var: impl Into<String>, ty: TLType) {
self.bindings.insert(var.into(), ty);
}
pub fn lookup(&self, var: &str) -> Option<&TLType> {
self.bindings.get(var)
}
pub fn extend(&self, var: impl Into<String>, ty: TLType) -> TypeEnv {
let mut new_env = self.clone();
new_env.bind(var, ty);
new_env
}
}
#[derive(Debug, Clone)]
pub struct TypedExpr {
pub expr: TLExpr,
pub ty: TLType,
pub children: Vec<TypedExpr>,
}
pub struct TypeInferResult {
pub typed_expr: TypedExpr,
pub subst: Substitution,
pub inferred_vars: usize,
pub errors: Vec<TypeInferError>,
}
fn occurs_in(id: usize, ty: &TLType, subst: &Substitution) -> bool {
let resolved = subst.apply(ty);
match resolved {
TLType::Var(other_id) => other_id == id,
_ => false,
}
}
pub fn unify(t1: &TLType, t2: &TLType, subst: &mut Substitution) -> Result<(), TypeInferError> {
let a = subst.apply(t1);
let b = subst.apply(t2);
match (&a, &b) {
(TLType::Bool, TLType::Bool)
| (TLType::Numeric, TLType::Numeric)
| (TLType::Set, TLType::Set)
| (TLType::Fuzzy, TLType::Fuzzy)
| (TLType::Probabilistic, TLType::Probabilistic)
| (TLType::Unknown, TLType::Unknown) => Ok(()),
(TLType::Relation(n), TLType::Relation(m)) => {
if n == m {
Ok(())
} else {
Err(TypeInferError::UnificationFailed {
expected: format!("Relation({})", n),
got: format!("Relation({})", m),
})
}
}
(TLType::Var(id1), TLType::Var(id2)) if id1 == id2 => Ok(()),
(TLType::Var(id), other) => {
if occurs_in(*id, other, subst) {
Err(TypeInferError::OccursCheck(*id, other.to_string()))
} else {
subst.bind(*id, other.clone());
Ok(())
}
}
(other, TLType::Var(id)) => {
if occurs_in(*id, other, subst) {
Err(TypeInferError::OccursCheck(*id, other.to_string()))
} else {
subst.bind(*id, other.clone());
Ok(())
}
}
(lhs, rhs) => Err(TypeInferError::UnificationFailed {
expected: lhs.to_string(),
got: rhs.to_string(),
}),
}
}
pub fn infer(
expr: &TLExpr,
env: &TypeEnv,
subst: &mut Substitution,
counter: &mut TyVarCounter,
) -> Result<TLType, TypeInferError> {
match expr {
TLExpr::Constant(_) => Ok(TLType::Numeric),
TLExpr::Pred { name: _, args } => {
if args.is_empty() {
Ok(TLType::Bool)
} else {
Ok(TLType::Relation(args.len()))
}
}
TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
let tl = infer(l, env, subst, counter)?;
unify(&tl, &TLType::Bool, subst)?;
let tr = infer(r, env, subst, counter)?;
unify(&tr, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::Not(inner) => {
let ti = infer(inner, env, subst, counter)?;
unify(&ti, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::ForAll {
var: _,
domain: _,
body,
}
| TLExpr::Exists {
var: _,
domain: _,
body,
} => {
let tb = infer(body, env, subst, counter)?;
unify(&tb, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::SoftForAll {
var: _,
domain: _,
body,
temperature: _,
}
| TLExpr::SoftExists {
var: _,
domain: _,
body,
temperature: _,
} => {
let _ = infer(body, env, subst, counter)?;
Ok(TLType::Fuzzy)
}
TLExpr::Add(l, r)
| TLExpr::Sub(l, r)
| TLExpr::Mul(l, r)
| TLExpr::Div(l, r)
| TLExpr::Pow(l, r)
| TLExpr::Mod(l, r)
| TLExpr::Min(l, r)
| TLExpr::Max(l, r) => {
let tl = infer(l, env, subst, counter)?;
unify(&tl, &TLType::Numeric, subst)?;
let tr = infer(r, env, subst, counter)?;
unify(&tr, &TLType::Numeric, subst)?;
Ok(TLType::Numeric)
}
TLExpr::Abs(e)
| TLExpr::Floor(e)
| TLExpr::Ceil(e)
| TLExpr::Round(e)
| TLExpr::Sqrt(e)
| TLExpr::Exp(e)
| TLExpr::Log(e)
| TLExpr::Sin(e)
| TLExpr::Cos(e)
| TLExpr::Tan(e) => {
let te = infer(e, env, subst, counter)?;
unify(&te, &TLType::Numeric, subst)?;
Ok(TLType::Numeric)
}
TLExpr::Eq(l, r)
| TLExpr::Lt(l, r)
| TLExpr::Gt(l, r)
| TLExpr::Lte(l, r)
| TLExpr::Gte(l, r) => {
let tl = infer(l, env, subst, counter)?;
let tr = infer(r, env, subst, counter)?;
let fresh = counter.fresh();
unify(&tl, &fresh, subst)?;
unify(&tr, &fresh, subst)?;
Ok(TLType::Bool)
}
TLExpr::IfThenElse {
condition,
then_branch,
else_branch,
} => {
let tc = infer(condition, env, subst, counter)?;
unify(&tc, &TLType::Bool, subst)?;
let tt = infer(then_branch, env, subst, counter)?;
let te = infer(else_branch, env, subst, counter)?;
unify(&tt, &te, subst)?;
Ok(subst.apply(&tt))
}
TLExpr::Let { var, value, body } => {
let tv = infer(value, env, subst, counter)?;
let extended = env.extend(var.clone(), tv);
infer(body, &extended, subst, counter)
}
TLExpr::TNorm {
kind: _,
left,
right,
}
| TLExpr::TCoNorm {
kind: _,
left,
right,
} => {
let _ = infer(left, env, subst, counter)?;
let _ = infer(right, env, subst, counter)?;
Ok(TLType::Fuzzy)
}
TLExpr::FuzzyNot { kind: _, expr } => {
let _ = infer(expr, env, subst, counter)?;
Ok(TLType::Fuzzy)
}
TLExpr::FuzzyImplication {
kind: _,
premise,
conclusion,
} => {
let _ = infer(premise, env, subst, counter)?;
let _ = infer(conclusion, env, subst, counter)?;
Ok(TLType::Fuzzy)
}
TLExpr::WeightedRule { weight: _, rule } => {
let _ = infer(rule, env, subst, counter)?;
Ok(TLType::Probabilistic)
}
TLExpr::ProbabilisticChoice { alternatives } => {
for (_, alt_expr) in alternatives {
let _ = infer(alt_expr, env, subst, counter)?;
}
Ok(TLType::Probabilistic)
}
TLExpr::Next(inner)
| TLExpr::Eventually(inner)
| TLExpr::Always(inner)
| TLExpr::Box(inner)
| TLExpr::Diamond(inner) => {
let ti = infer(inner, env, subst, counter)?;
unify(&ti, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::Until { before, after }
| TLExpr::WeakUntil { before, after }
| TLExpr::Release {
released: before,
releaser: after,
}
| TLExpr::StrongRelease {
released: before,
releaser: after,
} => {
let tb = infer(before, env, subst, counter)?;
unify(&tb, &TLType::Bool, subst)?;
let ta = infer(after, env, subst, counter)?;
unify(&ta, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::Score(inner) => infer(inner, env, subst, counter),
TLExpr::Aggregate {
op: _,
var: _,
domain: _,
body,
group_by: _,
} => {
let _ = infer(body, env, subst, counter)?;
Ok(TLType::Numeric)
}
TLExpr::SetUnion { left, right }
| TLExpr::SetIntersection { left, right }
| TLExpr::SetDifference { left, right } => {
let tl = infer(left, env, subst, counter)?;
unify(&tl, &TLType::Set, subst)?;
let tr = infer(right, env, subst, counter)?;
unify(&tr, &TLType::Set, subst)?;
Ok(TLType::Set)
}
TLExpr::SetCardinality { set } => {
let ts = infer(set, env, subst, counter)?;
unify(&ts, &TLType::Set, subst)?;
Ok(TLType::Numeric)
}
TLExpr::EmptySet => Ok(TLType::Set),
TLExpr::SetComprehension {
var: _,
domain: _,
condition,
} => {
let tc = infer(condition, env, subst, counter)?;
unify(&tc, &TLType::Bool, subst)?;
Ok(TLType::Set)
}
TLExpr::SetMembership { element, set } => {
let _ = infer(element, env, subst, counter)?;
let ts = infer(set, env, subst, counter)?;
unify(&ts, &TLType::Set, subst)?;
Ok(TLType::Bool)
}
TLExpr::CountingExists {
var: _,
domain: _,
body,
min_count: _,
}
| TLExpr::CountingForAll {
var: _,
domain: _,
body,
min_count: _,
}
| TLExpr::ExactCount {
var: _,
domain: _,
body,
count: _,
}
| TLExpr::Majority {
var: _,
domain: _,
body,
} => {
let tb = infer(body, env, subst, counter)?;
unify(&tb, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::LeastFixpoint { var: _, body } | TLExpr::GreatestFixpoint { var: _, body } => {
let tb = infer(body, env, subst, counter)?;
Ok(tb)
}
TLExpr::Lambda {
var: _,
var_type: _,
body,
} => {
infer(body, env, subst, counter)
}
TLExpr::Apply { function, argument } => {
let _ = infer(function, env, subst, counter)?;
let _ = infer(argument, env, subst, counter)?;
Ok(TLType::Unknown)
}
TLExpr::Nominal { name: _ } => Ok(TLType::Bool),
TLExpr::At {
nominal: _,
formula,
} => {
let tf = infer(formula, env, subst, counter)?;
unify(&tf, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
let tf = infer(formula, env, subst, counter)?;
unify(&tf, &TLType::Bool, subst)?;
Ok(TLType::Bool)
}
TLExpr::AllDifferent { variables: _ } => Ok(TLType::Bool),
TLExpr::GlobalCardinality {
variables: _,
values,
min_occurrences: _,
max_occurrences: _,
} => {
for v in values {
let _ = infer(v, env, subst, counter)?;
}
Ok(TLType::Bool)
}
TLExpr::Abducible { name: _, cost: _ } => Ok(TLType::Bool),
TLExpr::Explain { formula } => infer(formula, env, subst, counter),
TLExpr::SymbolLiteral(_) => Ok(TLType::Unknown),
TLExpr::Match { scrutinee, arms } => {
let _scrutinee_ty = infer(scrutinee, env, subst, counter)?;
let mut result_ty: Option<TLType> = None;
for (_, body) in arms {
let body_ty = infer(body, env, subst, counter)?;
match &result_ty {
None => result_ty = Some(body_ty),
Some(rt) => {
unify(rt, &body_ty, subst)?;
}
}
}
Ok(result_ty.unwrap_or(TLType::Bool))
}
}
}
pub fn infer_type(expr: &TLExpr, env: &TypeEnv) -> TypeInferResult {
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let mut errors: Vec<TypeInferError> = Vec::new();
let ty = match infer(expr, env, &mut subst, &mut counter) {
Ok(t) => t,
Err(e) => {
errors.push(e);
TLType::Unknown
}
};
let inferred_vars = subst.map.values().filter(|v| v.is_ground()).count();
let annotated = annotate_with(expr, env, &subst, &mut TyVarCounter::new(), &mut errors);
let root_ty = subst.apply(&ty);
let typed_expr = TypedExpr {
expr: expr.clone(),
ty: root_ty,
children: annotated.children,
};
TypeInferResult {
typed_expr,
subst,
inferred_vars,
errors,
}
}
pub fn annotate(expr: &TLExpr, env: &TypeEnv) -> Result<TypedExpr, TypeInferError> {
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let mut dummy_errors = Vec::new();
let typed = annotate_with(expr, env, &subst, &mut counter, &mut dummy_errors);
let ty = infer(expr, env, &mut subst, &mut counter)?;
let resolved_ty = subst.apply(&ty);
Ok(TypedExpr {
expr: typed.expr,
ty: resolved_ty,
children: typed.children,
})
}
fn annotate_with(
expr: &TLExpr,
env: &TypeEnv,
subst: &Substitution,
counter: &mut TyVarCounter,
errors: &mut Vec<TypeInferError>,
) -> TypedExpr {
let mut local_subst = subst.clone();
let ty = match infer(expr, env, &mut local_subst, counter) {
Ok(t) => local_subst.apply(&t),
Err(e) => {
errors.push(e);
TLType::Unknown
}
};
let children = collect_children(expr)
.into_iter()
.map(|child| annotate_with(child, env, &local_subst, counter, errors))
.collect();
TypedExpr {
expr: expr.clone(),
ty,
children,
}
}
fn collect_children(expr: &TLExpr) -> Vec<&TLExpr> {
match expr {
TLExpr::And(l, r)
| TLExpr::Or(l, r)
| TLExpr::Imply(l, r)
| TLExpr::Add(l, r)
| TLExpr::Sub(l, r)
| TLExpr::Mul(l, r)
| TLExpr::Div(l, r)
| TLExpr::Pow(l, r)
| TLExpr::Mod(l, r)
| TLExpr::Min(l, r)
| TLExpr::Max(l, r)
| TLExpr::Eq(l, r)
| TLExpr::Lt(l, r)
| TLExpr::Gt(l, r)
| TLExpr::Lte(l, r)
| TLExpr::Gte(l, r) => vec![l.as_ref(), r.as_ref()],
TLExpr::Not(e)
| TLExpr::Score(e)
| TLExpr::Abs(e)
| TLExpr::Floor(e)
| TLExpr::Ceil(e)
| TLExpr::Round(e)
| TLExpr::Sqrt(e)
| TLExpr::Exp(e)
| TLExpr::Log(e)
| TLExpr::Sin(e)
| TLExpr::Cos(e)
| TLExpr::Tan(e)
| TLExpr::Next(e)
| TLExpr::Eventually(e)
| TLExpr::Always(e)
| TLExpr::Box(e)
| TLExpr::Diamond(e)
| TLExpr::WeightedRule { rule: e, .. }
| TLExpr::FuzzyNot { expr: e, .. }
| TLExpr::LeastFixpoint { body: e, .. }
| TLExpr::GreatestFixpoint { body: e, .. }
| TLExpr::Lambda { body: e, .. }
| TLExpr::SetCardinality { set: e }
| TLExpr::Somewhere { formula: e }
| TLExpr::Everywhere { formula: e }
| TLExpr::Explain { formula: e }
| TLExpr::At { formula: e, .. } => vec![e.as_ref()],
TLExpr::ForAll { body, .. }
| TLExpr::Exists { body, .. }
| TLExpr::SoftForAll { body, .. }
| TLExpr::SoftExists { body, .. }
| TLExpr::Aggregate { body, .. }
| TLExpr::CountingExists { body, .. }
| TLExpr::CountingForAll { body, .. }
| TLExpr::ExactCount { body, .. }
| TLExpr::Majority { body, .. }
| TLExpr::SetComprehension {
condition: body, ..
} => vec![body.as_ref()],
TLExpr::IfThenElse {
condition,
then_branch,
else_branch,
} => vec![
condition.as_ref(),
then_branch.as_ref(),
else_branch.as_ref(),
],
TLExpr::Let { value, body, .. } => vec![value.as_ref(), body.as_ref()],
TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
vec![left.as_ref(), right.as_ref()]
}
TLExpr::FuzzyImplication {
premise,
conclusion,
..
} => vec![premise.as_ref(), conclusion.as_ref()],
TLExpr::Until { before, after }
| TLExpr::WeakUntil { before, after }
| TLExpr::Release {
released: before,
releaser: after,
}
| TLExpr::StrongRelease {
released: before,
releaser: after,
} => vec![before.as_ref(), after.as_ref()],
TLExpr::SetUnion { left, right }
| TLExpr::SetIntersection { left, right }
| TLExpr::SetDifference { left, right } => vec![left.as_ref(), right.as_ref()],
TLExpr::SetMembership { element, set } => vec![element.as_ref(), set.as_ref()],
TLExpr::Apply { function, argument } => vec![function.as_ref(), argument.as_ref()],
TLExpr::ProbabilisticChoice { alternatives } => {
alternatives.iter().map(|(_, e)| e).collect()
}
TLExpr::GlobalCardinality { values, .. } => values.iter().collect(),
TLExpr::Constant(_)
| TLExpr::Pred { .. }
| TLExpr::EmptySet
| TLExpr::AllDifferent { .. }
| TLExpr::Abducible { .. }
| TLExpr::Nominal { .. }
| TLExpr::SymbolLiteral(_) => vec![],
TLExpr::Match { scrutinee, arms } => {
let mut children = vec![scrutinee.as_ref()];
children.extend(arms.iter().map(|(_, b)| b.as_ref()));
children
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TCoNormKind, TLExpr, TNormKind};
fn prop(name: &str) -> TLExpr {
TLExpr::pred(name, vec![])
}
#[test]
fn test_constant_is_numeric() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let ty = infer(&TLExpr::Constant(42.0), &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Numeric);
}
#[test]
fn test_zero_arity_pred_is_bool() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let ty = infer(&prop("p"), &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Bool);
}
#[test]
fn test_binary_pred_is_relation2() {
use tensorlogic_ir::Term;
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Relation(2));
}
#[test]
fn test_and_bool_bool_is_bool() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::and(prop("p"), prop("q"));
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Bool);
}
#[test]
fn test_add_numeric_numeric_is_numeric() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Numeric);
}
#[test]
fn test_not_bool_is_bool() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::negate(prop("p"));
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Bool);
}
#[test]
fn test_forall_bool_body_is_bool() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::forall("x", "Entity", prop("P"));
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Bool);
}
#[test]
fn test_soft_exists_is_fuzzy() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::SoftExists {
var: "x".into(),
domain: "D".into(),
body: Box::new(prop("P")),
temperature: 1.0,
};
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Fuzzy);
}
#[test]
fn test_tnorm_is_fuzzy() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::TNorm {
kind: TNormKind::Product,
left: Box::new(TLExpr::Constant(0.7)),
right: Box::new(TLExpr::Constant(0.3)),
};
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Fuzzy);
}
#[test]
fn test_probabilistic_choice_is_probabilistic() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::ProbabilisticChoice {
alternatives: vec![(0.6, prop("A")), (0.4, prop("B"))],
};
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Probabilistic);
}
#[test]
fn test_eq_numeric_is_bool() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::Eq(
Box::new(TLExpr::Constant(1.0)),
Box::new(TLExpr::Constant(1.0)),
);
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Bool);
}
#[test]
fn test_ifthenelse_condition_is_bool() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::IfThenElse {
condition: Box::new(prop("cond")),
then_branch: Box::new(TLExpr::Constant(1.0)),
else_branch: Box::new(TLExpr::Constant(0.0)),
};
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Numeric);
}
#[test]
fn test_ifthenelse_numeric_condition_fails() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::IfThenElse {
condition: Box::new(TLExpr::Constant(1.0)), then_branch: Box::new(TLExpr::Constant(2.0)),
else_branch: Box::new(TLExpr::Constant(3.0)),
};
let result = infer(&expr, &env, &mut subst, &mut counter);
assert!(result.is_err(), "expected type error for Numeric condition");
}
#[test]
fn test_let_binding_extends_env() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::Let {
var: "x".into(),
value: Box::new(TLExpr::Constant(1.0)),
body: Box::new(TLExpr::and(prop("p"), prop("q"))),
};
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Bool);
}
#[test]
fn test_unify_bool_bool_ok() {
let mut subst = Substitution::new();
assert!(unify(&TLType::Bool, &TLType::Bool, &mut subst).is_ok());
}
#[test]
fn test_unify_bool_numeric_fails() {
let mut subst = Substitution::new();
let result = unify(&TLType::Bool, &TLType::Numeric, &mut subst);
assert!(result.is_err());
match result.unwrap_err() {
TypeInferError::UnificationFailed { .. } => {}
e => panic!("expected UnificationFailed, got {:?}", e),
}
}
#[test]
fn test_unify_var_resolves() {
let mut subst = Substitution::new();
unify(&TLType::Var(0), &TLType::Bool, &mut subst).unwrap();
let resolved = subst.apply(&TLType::Var(0));
assert_eq!(resolved, TLType::Bool);
}
#[test]
fn test_occurs_check_no_infinite_loop() {
let mut subst = Substitution::new();
let result = unify(&TLType::Var(0), &TLType::Relation(0), &mut subst);
assert!(result.is_ok(), "Var(0) vs Relation(0) should unify fine");
let resolved = subst.apply(&TLType::Var(0));
assert_eq!(resolved, TLType::Relation(0));
}
#[test]
fn test_infer_type_result() {
let expr = TLExpr::and(prop("p"), prop("q"));
let result = infer_type(&expr, &TypeEnv::new());
assert_eq!(result.typed_expr.ty, TLType::Bool);
assert!(result.errors.is_empty());
}
#[test]
fn test_annotate_root_type() {
let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
let typed = annotate(&expr, &TypeEnv::new()).unwrap();
assert_eq!(typed.ty, TLType::Numeric);
}
#[test]
fn test_is_ground_false_for_var() {
assert!(!TLType::Var(99).is_ground());
assert!(TLType::Bool.is_ground());
assert!(TLType::Numeric.is_ground());
assert!(TLType::Relation(3).is_ground());
}
#[test]
fn test_type_env_extend_non_destructive() {
let env = TypeEnv::new().with("x", TLType::Bool);
let extended = env.extend("y", TLType::Numeric);
assert!(env.lookup("x").is_some());
assert!(env.lookup("y").is_none());
assert!(extended.lookup("x").is_some());
assert!(extended.lookup("y").is_some());
}
#[test]
fn test_substitution_apply_chases_chain() {
let mut subst = Substitution::new();
subst.bind(0, TLType::Var(1));
subst.bind(1, TLType::Var(2));
subst.bind(2, TLType::Numeric);
let resolved = subst.apply(&TLType::Var(0));
assert_eq!(resolved, TLType::Numeric);
}
#[test]
fn test_nested_and_or_not_is_bool() {
let expr = TLExpr::and(TLExpr::or(prop("p"), prop("q")), TLExpr::negate(prop("r")));
let result = infer_type(&expr, &TypeEnv::new());
assert_eq!(result.typed_expr.ty, TLType::Bool);
assert!(result.errors.is_empty());
}
#[test]
fn test_tconorm_is_fuzzy() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::TCoNorm {
kind: TCoNormKind::Maximum,
left: Box::new(TLExpr::Constant(0.2)),
right: Box::new(TLExpr::Constant(0.8)),
};
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Fuzzy);
}
#[test]
fn test_weighted_rule_is_probabilistic() {
let env = TypeEnv::new();
let mut subst = Substitution::new();
let mut counter = TyVarCounter::new();
let expr = TLExpr::WeightedRule {
weight: 0.9,
rule: Box::new(TLExpr::imply(prop("A"), prop("B"))),
};
let ty = infer(&expr, &env, &mut subst, &mut counter).unwrap();
assert_eq!(ty, TLType::Probabilistic);
}
}