use std::collections::HashMap;
use tensorlogic_ir::TLExpr;
#[derive(Debug, Clone, PartialEq)]
pub enum Instruction {
PushNum(f64),
PushBool(bool),
PushSym(String),
Pop,
Dup,
Add,
Sub,
Mul,
Div,
Pow,
Mod,
Neg,
Abs,
Sqrt,
Exp,
Log,
Min,
Max,
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
And,
Or,
Not,
JumpIfFalse(usize),
JumpIfTrue(usize),
Jump(usize),
LoadVar(String),
StoreVar(String),
TNorm,
TCoNorm,
FuzzyNot,
Halt,
}
#[derive(Debug, Clone)]
pub struct BytecodeProgram {
pub instructions: Vec<Instruction>,
}
impl Default for BytecodeProgram {
fn default() -> Self {
Self::new()
}
}
impl BytecodeProgram {
pub fn new() -> Self {
Self {
instructions: Vec::new(),
}
}
pub fn push(&mut self, instr: Instruction) -> usize {
let idx = self.instructions.len();
self.instructions.push(instr);
idx
}
pub fn patch_jump(&mut self, idx: usize, target: usize) {
match &mut self.instructions[idx] {
Instruction::JumpIfFalse(t) | Instruction::JumpIfTrue(t) | Instruction::Jump(t) => {
*t = target;
}
other => {
debug_assert!(
false,
"patch_jump called on non-jump instruction: {:?}",
other
);
}
}
}
pub fn len(&self) -> usize {
self.instructions.len()
}
pub fn is_empty(&self) -> bool {
self.instructions.is_empty()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum VmValue {
Num(f64),
Bool(bool),
Sym(String),
}
impl VmValue {
pub fn as_num(&self) -> Result<f64, VmError> {
match self {
VmValue::Num(n) => Ok(*n),
VmValue::Bool(_) => Err(VmError::TypeMismatch {
expected: "Num",
got: "Bool",
}),
VmValue::Sym(_) => Err(VmError::TypeMismatch {
expected: "Num",
got: "Sym",
}),
}
}
pub fn as_bool(&self) -> Result<bool, VmError> {
match self {
VmValue::Bool(b) => Ok(*b),
VmValue::Num(_) => Err(VmError::TypeMismatch {
expected: "Bool",
got: "Num",
}),
VmValue::Sym(_) => Err(VmError::TypeMismatch {
expected: "Bool",
got: "Sym",
}),
}
}
pub fn is_truthy(&self) -> bool {
match self {
VmValue::Num(n) => *n != 0.0,
VmValue::Bool(b) => *b,
VmValue::Sym(s) => !s.is_empty(),
}
}
#[allow(dead_code)]
fn type_name(&self) -> &'static str {
match self {
VmValue::Num(_) => "Num",
VmValue::Bool(_) => "Bool",
VmValue::Sym(_) => "Sym",
}
}
}
#[derive(Debug)]
pub enum VmError {
StackUnderflow,
TypeMismatch {
expected: &'static str,
got: &'static str,
},
UnboundVariable(String),
DivisionByZero,
InvalidInstruction(usize),
ProgramEmpty,
}
impl std::fmt::Display for VmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VmError::StackUnderflow => write!(f, "VM stack underflow"),
VmError::TypeMismatch { expected, got } => {
write!(f, "type mismatch: expected {}, got {}", expected, got)
}
VmError::UnboundVariable(name) => {
write!(f, "unbound variable: '{}'", name)
}
VmError::DivisionByZero => write!(f, "division by zero"),
VmError::InvalidInstruction(ip) => {
write!(f, "invalid instruction pointer: {}", ip)
}
VmError::ProgramEmpty => write!(f, "program contains no instructions"),
}
}
}
impl std::error::Error for VmError {}
#[derive(Debug, Clone, Default)]
pub struct VmEnv {
vars: HashMap<String, VmValue>,
}
impl VmEnv {
pub fn new() -> Self {
Self {
vars: HashMap::new(),
}
}
pub fn set(&mut self, name: impl Into<String>, val: VmValue) {
self.vars.insert(name.into(), val);
}
pub fn set_num(&mut self, name: impl Into<String>, val: f64) {
self.set(name, VmValue::Num(val));
}
pub fn set_bool(&mut self, name: impl Into<String>, val: bool) {
self.set(name, VmValue::Bool(val));
}
pub fn get(&self, name: &str) -> Option<&VmValue> {
self.vars.get(name)
}
pub fn len(&self) -> usize {
self.vars.len()
}
pub fn is_empty(&self) -> bool {
self.vars.is_empty()
}
}
#[derive(Debug, Default, Clone)]
pub struct VmStats {
pub instructions_executed: usize,
pub max_stack_depth: usize,
pub jumps_taken: usize,
}
#[derive(Debug)]
pub enum CompileError {
UnsupportedExpr(String),
MaxDepthExceeded,
}
impl std::fmt::Display for CompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CompileError::UnsupportedExpr(desc) => {
write!(f, "unsupported expression in bytecode compiler: {}", desc)
}
CompileError::MaxDepthExceeded => {
write!(f, "expression depth exceeds configured maximum")
}
}
}
}
impl std::error::Error for CompileError {}
struct Compiler {
program: BytecodeProgram,
max_depth: usize,
}
impl Compiler {
fn new(max_depth: usize) -> Self {
Self {
program: BytecodeProgram::new(),
max_depth,
}
}
fn compile_expr(&mut self, expr: &TLExpr, depth: usize) -> Result<(), CompileError> {
if depth > self.max_depth {
return Err(CompileError::MaxDepthExceeded);
}
match expr {
TLExpr::Constant(c) => {
self.program.push(Instruction::PushNum(*c));
}
TLExpr::Pred { name, args } if args.is_empty() => {
self.program.push(Instruction::LoadVar(name.clone()));
}
TLExpr::Add(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Add);
}
TLExpr::Sub(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Sub);
}
TLExpr::Mul(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Mul);
}
TLExpr::Div(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Div);
}
TLExpr::Pow(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Pow);
}
TLExpr::Mod(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Mod);
}
TLExpr::Abs(a) => {
self.compile_expr(a, depth + 1)?;
self.program.push(Instruction::Abs);
}
TLExpr::Sqrt(a) => {
self.compile_expr(a, depth + 1)?;
self.program.push(Instruction::Sqrt);
}
TLExpr::Exp(a) => {
self.compile_expr(a, depth + 1)?;
self.program.push(Instruction::Exp);
}
TLExpr::Log(a) => {
self.compile_expr(a, depth + 1)?;
self.program.push(Instruction::Log);
}
TLExpr::Min(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Min);
}
TLExpr::Max(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Max);
}
TLExpr::Eq(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Eq);
}
TLExpr::Lt(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Lt);
}
TLExpr::Gt(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Gt);
}
TLExpr::Lte(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Le);
}
TLExpr::Gte(a, b) => {
self.compile_expr(a, depth + 1)?;
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Ge);
}
TLExpr::And(a, b) => {
self.compile_expr(a, depth + 1)?;
let jump_idx = self.program.push(Instruction::JumpIfFalse(0));
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Not);
self.program.push(Instruction::Not);
let end = self.program.len();
self.program.patch_jump(jump_idx, end);
}
TLExpr::Or(a, b) => {
self.compile_expr(a, depth + 1)?;
let jump_idx = self.program.push(Instruction::JumpIfTrue(0));
self.compile_expr(b, depth + 1)?;
self.program.push(Instruction::Not);
self.program.push(Instruction::Not);
let end = self.program.len();
self.program.patch_jump(jump_idx, end);
}
TLExpr::Not(a) => {
self.compile_expr(a, depth + 1)?;
self.program.push(Instruction::Not);
}
TLExpr::IfThenElse {
condition,
then_branch,
else_branch,
} => {
self.compile_expr(condition, depth + 1)?;
let jf_idx = self.program.push(Instruction::JumpIfFalse(0));
self.compile_expr(then_branch, depth + 1)?;
let jump_idx = self.program.push(Instruction::Jump(0));
let else_start = self.program.len();
self.program.patch_jump(jf_idx, else_start);
self.compile_expr(else_branch, depth + 1)?;
let end = self.program.len();
self.program.patch_jump(jump_idx, end);
}
TLExpr::Let { var, value, body } => {
self.compile_expr(value, depth + 1)?;
self.program.push(Instruction::StoreVar(var.clone()));
self.compile_expr(body, depth + 1)?;
}
TLExpr::TNorm { left, right, .. } => {
self.compile_expr(left, depth + 1)?;
self.compile_expr(right, depth + 1)?;
self.program.push(Instruction::TNorm);
}
TLExpr::TCoNorm { left, right, .. } => {
self.compile_expr(left, depth + 1)?;
self.compile_expr(right, depth + 1)?;
self.program.push(Instruction::TCoNorm);
}
TLExpr::FuzzyNot { expr: inner, .. } => {
self.compile_expr(inner, depth + 1)?;
self.program.push(Instruction::FuzzyNot);
}
TLExpr::SymbolLiteral(s) => {
self.program.push(Instruction::PushSym(s.clone()));
}
TLExpr::Match { scrutinee, arms } => {
if arms.is_empty() {
return Err(CompileError::UnsupportedExpr(
"Match with no arms".to_string(),
));
}
self.compile_expr(scrutinee, depth + 1)?;
let tmp = format!("__match_scrutinee_{depth}");
self.program.push(Instruction::StoreVar(tmp.clone()));
let (wildcard_body, non_wildcard) = arms
.split_last()
.ok_or_else(|| CompileError::UnsupportedExpr("Empty Match arms".into()))?;
self.emit_match_chain(&tmp, non_wildcard, &wildcard_body.1, depth)?;
}
other => {
return Err(CompileError::UnsupportedExpr(format!("{:?}", other)));
}
}
Ok(())
}
fn emit_match_chain(
&mut self,
scrutinee_var: &str,
arms: &[(tensorlogic_ir::MatchPattern, Box<TLExpr>)],
else_body: &TLExpr,
depth: usize,
) -> Result<(), CompileError> {
if arms.is_empty() {
return self.compile_expr(else_body, depth + 1);
}
let (pat, body) = &arms[0];
let remaining = &arms[1..];
self.program
.push(Instruction::LoadVar(scrutinee_var.to_string()));
match pat {
tensorlogic_ir::MatchPattern::ConstNumber(n) => {
self.program.push(Instruction::PushNum(*n));
}
tensorlogic_ir::MatchPattern::ConstSymbol(s) => {
self.program.push(Instruction::PushSym(s.clone()));
}
tensorlogic_ir::MatchPattern::Wildcard => {
return Err(CompileError::UnsupportedExpr(
"Wildcard in non-tail position".into(),
));
}
}
self.program.push(Instruction::Eq);
let jf_idx = self.program.push(Instruction::JumpIfFalse(0));
self.compile_expr(body, depth + 1)?;
let jump_idx = self.program.push(Instruction::Jump(0));
let else_start = self.program.len();
self.program.patch_jump(jf_idx, else_start);
self.emit_match_chain(scrutinee_var, remaining, else_body, depth)?;
let end = self.program.len();
self.program.patch_jump(jump_idx, end);
Ok(())
}
}
pub const DEFAULT_MAX_DEPTH: usize = 512;
pub fn compile(expr: &TLExpr) -> Result<BytecodeProgram, CompileError> {
compile_with_config(expr, DEFAULT_MAX_DEPTH)
}
pub fn compile_with_config(
expr: &TLExpr,
max_depth: usize,
) -> Result<BytecodeProgram, CompileError> {
let mut compiler = Compiler::new(max_depth);
compiler.compile_expr(expr, 0)?;
compiler.program.push(Instruction::Halt);
Ok(compiler.program)
}
pub fn execute(program: &BytecodeProgram, env: &VmEnv) -> Result<VmValue, VmError> {
let (val, _stats) = execute_with_stats(program, env)?;
Ok(val)
}
pub fn execute_with_stats(
program: &BytecodeProgram,
env: &VmEnv,
) -> Result<(VmValue, VmStats), VmError> {
if program.is_empty() {
return Err(VmError::ProgramEmpty);
}
let mut stack: Vec<VmValue> = Vec::with_capacity(16);
let mut local_env = env.clone();
let mut ip: usize = 0;
let mut stats = VmStats::default();
loop {
if ip >= program.instructions.len() {
return Err(VmError::InvalidInstruction(ip));
}
let instr = &program.instructions[ip];
stats.instructions_executed += 1;
match instr {
Instruction::PushNum(n) => {
stack.push(VmValue::Num(*n));
ip += 1;
}
Instruction::PushBool(b) => {
stack.push(VmValue::Bool(*b));
ip += 1;
}
Instruction::PushSym(s) => {
stack.push(VmValue::Sym(s.clone()));
ip += 1;
}
Instruction::Pop => {
stack.pop().ok_or(VmError::StackUnderflow)?;
ip += 1;
}
Instruction::Dup => {
let top = stack.last().ok_or(VmError::StackUnderflow)?.clone();
stack.push(top);
ip += 1;
}
Instruction::Add => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a + b));
ip += 1;
}
Instruction::Sub => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a - b));
ip += 1;
}
Instruction::Mul => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a * b));
ip += 1;
}
Instruction::Div => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
if b == 0.0 {
return Err(VmError::DivisionByZero);
}
stack.push(VmValue::Num(a / b));
ip += 1;
}
Instruction::Pow => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a.powf(b)));
ip += 1;
}
Instruction::Mod => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a % b));
ip += 1;
}
Instruction::Neg => {
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(-a));
ip += 1;
}
Instruction::Abs => {
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a.abs()));
ip += 1;
}
Instruction::Sqrt => {
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a.sqrt()));
ip += 1;
}
Instruction::Exp => {
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a.exp()));
ip += 1;
}
Instruction::Log => {
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a.ln()));
ip += 1;
}
Instruction::Min => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a.min(b)));
ip += 1;
}
Instruction::Max => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a.max(b)));
ip += 1;
}
Instruction::Eq => {
let b = pop_value(&mut stack)?;
let a = pop_value(&mut stack)?;
stack.push(VmValue::Bool(values_equal(&a, &b)));
ip += 1;
}
Instruction::Ne => {
let b = pop_value(&mut stack)?;
let a = pop_value(&mut stack)?;
stack.push(VmValue::Bool(!values_equal(&a, &b)));
ip += 1;
}
Instruction::Lt => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Bool(a < b));
ip += 1;
}
Instruction::Le => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Bool(a <= b));
ip += 1;
}
Instruction::Gt => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Bool(a > b));
ip += 1;
}
Instruction::Ge => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Bool(a >= b));
ip += 1;
}
Instruction::And => {
let b = pop_value(&mut stack)?;
let a = pop_value(&mut stack)?;
stack.push(VmValue::Bool(a.is_truthy() && b.is_truthy()));
ip += 1;
}
Instruction::Or => {
let b = pop_value(&mut stack)?;
let a = pop_value(&mut stack)?;
stack.push(VmValue::Bool(a.is_truthy() || b.is_truthy()));
ip += 1;
}
Instruction::Not => {
let a = pop_value(&mut stack)?;
stack.push(VmValue::Bool(!a.is_truthy()));
ip += 1;
}
Instruction::JumpIfFalse(target) => {
let target = *target;
let cond = pop_value(&mut stack)?;
if !cond.is_truthy() {
stack.push(VmValue::Bool(false));
ip = target;
stats.jumps_taken += 1;
} else {
ip += 1;
}
}
Instruction::JumpIfTrue(target) => {
let target = *target;
let cond = pop_value(&mut stack)?;
if cond.is_truthy() {
stack.push(VmValue::Bool(true));
ip = target;
stats.jumps_taken += 1;
} else {
ip += 1;
}
}
Instruction::Jump(target) => {
ip = *target;
stats.jumps_taken += 1;
}
Instruction::LoadVar(name) => {
let val = local_env
.get(name)
.ok_or_else(|| VmError::UnboundVariable(name.clone()))?
.clone();
stack.push(val);
ip += 1;
}
Instruction::StoreVar(name) => {
let val = pop_value(&mut stack)?;
local_env.set(name.clone(), val);
ip += 1;
}
Instruction::TNorm => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a * b));
ip += 1;
}
Instruction::TCoNorm => {
let b = pop_num(&mut stack)?;
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(a + b - a * b));
ip += 1;
}
Instruction::FuzzyNot => {
let a = pop_num(&mut stack)?;
stack.push(VmValue::Num(1.0 - a));
ip += 1;
}
Instruction::Halt => {
let result = stack.pop().ok_or(VmError::StackUnderflow)?;
if stats.max_stack_depth < stack.len() + 1 {
stats.max_stack_depth = stack.len() + 1;
}
return Ok((result, stats));
}
}
if stack.len() > stats.max_stack_depth {
stats.max_stack_depth = stack.len();
}
}
}
#[inline]
fn pop_value(stack: &mut Vec<VmValue>) -> Result<VmValue, VmError> {
stack.pop().ok_or(VmError::StackUnderflow)
}
#[inline]
fn pop_num(stack: &mut Vec<VmValue>) -> Result<f64, VmError> {
let val = stack.pop().ok_or(VmError::StackUnderflow)?;
match val {
VmValue::Num(n) => Ok(n),
VmValue::Bool(_) => Err(VmError::TypeMismatch {
expected: "Num",
got: "Bool",
}),
VmValue::Sym(_) => Err(VmError::TypeMismatch {
expected: "Num",
got: "Sym",
}),
}
}
#[inline]
fn values_equal(a: &VmValue, b: &VmValue) -> bool {
match (a, b) {
(VmValue::Num(x), VmValue::Num(y)) => x == y,
(VmValue::Bool(x), VmValue::Bool(y)) => x == y,
(VmValue::Sym(x), VmValue::Sym(y)) => x == y,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{FuzzyNegationKind, TCoNormKind, TLExpr, TNormKind};
fn eval(expr: TLExpr) -> VmValue {
let prog = compile(&expr).expect("compile failed");
let env = VmEnv::new();
execute(&prog, &env).expect("execute failed")
}
fn eval_env(expr: TLExpr, env: &VmEnv) -> VmValue {
let prog = compile(&expr).expect("compile failed");
execute(&prog, env).expect("execute failed")
}
#[test]
fn test_compile_constant_shape() {
let val = std::f64::consts::PI;
let prog = compile(&TLExpr::Constant(val)).expect("compile failed");
assert_eq!(prog.len(), 2, "should be [PushNum(PI), Halt]");
assert_eq!(prog.instructions[0], Instruction::PushNum(val));
assert_eq!(prog.instructions[1], Instruction::Halt);
}
#[test]
fn test_execute_push_num() {
let mut prog = BytecodeProgram::new();
prog.push(Instruction::PushNum(5.0));
prog.push(Instruction::Halt);
let env = VmEnv::new();
let result = execute(&prog, &env).expect("execute failed");
assert_eq!(result, VmValue::Num(5.0));
}
#[test]
fn test_add() {
let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
assert_eq!(eval(expr), VmValue::Num(5.0));
}
#[test]
fn test_sub() {
let expr = TLExpr::sub(TLExpr::Constant(10.0), TLExpr::Constant(4.0));
assert_eq!(eval(expr), VmValue::Num(6.0));
}
#[test]
fn test_mul() {
let expr = TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
assert_eq!(eval(expr), VmValue::Num(12.0));
}
#[test]
fn test_div() {
let expr = TLExpr::div(TLExpr::Constant(10.0), TLExpr::Constant(2.0));
assert_eq!(eval(expr), VmValue::Num(5.0));
}
#[test]
fn test_pow() {
let expr = TLExpr::pow(TLExpr::Constant(2.0), TLExpr::Constant(8.0));
assert_eq!(eval(expr), VmValue::Num(256.0));
}
#[test]
fn test_eq_true() {
let expr = TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(3.0));
assert_eq!(eval(expr), VmValue::Bool(true));
}
#[test]
fn test_lt_true() {
let expr = TLExpr::lt(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
assert_eq!(eval(expr), VmValue::Bool(true));
}
#[test]
fn test_and_false() {
let expr = TLExpr::and(
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
);
assert_eq!(eval(expr), VmValue::Bool(false));
}
#[test]
fn test_or_true() {
let expr = TLExpr::or(
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
);
assert_eq!(eval(expr), VmValue::Bool(true));
}
#[test]
fn test_not_false_to_true() {
let expr = TLExpr::negate(TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)));
assert_eq!(eval(expr), VmValue::Bool(true));
}
#[test]
fn test_short_circuit_and_jump() {
let expr = TLExpr::and(
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)), TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(3.0)), );
let prog = compile(&expr).expect("compile failed");
let env = VmEnv::new();
let (result, stats) = execute_with_stats(&prog, &env).expect("execute failed");
assert_eq!(result, VmValue::Bool(false));
assert!(stats.jumps_taken > 0, "JumpIfFalse should have been taken");
}
#[test]
fn test_short_circuit_or_jump() {
let expr = TLExpr::or(
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)), TLExpr::eq(TLExpr::Constant(3.0), TLExpr::Constant(4.0)), );
let prog = compile(&expr).expect("compile failed");
let env = VmEnv::new();
let (result, stats) = execute_with_stats(&prog, &env).expect("execute failed");
assert_eq!(result, VmValue::Bool(true));
assert!(stats.jumps_taken > 0, "JumpIfTrue should have been taken");
}
#[test]
fn test_ite_true_branch() {
let expr = TLExpr::if_then_else(
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
TLExpr::Constant(1.0),
TLExpr::Constant(2.0),
);
assert_eq!(eval(expr), VmValue::Num(1.0));
}
#[test]
fn test_ite_false_branch() {
let expr = TLExpr::if_then_else(
TLExpr::eq(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
TLExpr::Constant(1.0),
TLExpr::Constant(2.0),
);
assert_eq!(eval(expr), VmValue::Num(2.0));
}
#[test]
fn test_load_var() {
let expr = TLExpr::pred("x", vec![]);
let mut env = VmEnv::new();
env.set_num("x", 42.0);
assert_eq!(eval_env(expr, &env), VmValue::Num(42.0));
}
#[test]
fn test_let_binding() {
let expr = TLExpr::Let {
var: "y".to_string(),
value: Box::new(TLExpr::Constant(7.0)),
body: Box::new(TLExpr::mul(
TLExpr::pred("y", vec![]),
TLExpr::Constant(2.0),
)),
};
let env = VmEnv::new();
assert_eq!(eval_env(expr, &env), VmValue::Num(14.0));
}
#[test]
fn test_stack_underflow() {
let mut prog = BytecodeProgram::new();
prog.push(Instruction::Add); prog.push(Instruction::Halt);
let env = VmEnv::new();
let err = execute(&prog, &env).unwrap_err();
assert!(
matches!(err, VmError::StackUnderflow),
"expected StackUnderflow, got {:?}",
err
);
}
#[test]
fn test_unbound_variable() {
let mut prog = BytecodeProgram::new();
prog.push(Instruction::LoadVar("missing".to_string()));
prog.push(Instruction::Halt);
let env = VmEnv::new();
let err = execute(&prog, &env).unwrap_err();
assert!(
matches!(err, VmError::UnboundVariable(_)),
"expected UnboundVariable, got {:?}",
err
);
}
#[test]
fn test_stats_instructions_executed() {
let expr = TLExpr::Constant(1.0);
let prog = compile(&expr).expect("compile failed");
let env = VmEnv::new();
let (_val, stats) = execute_with_stats(&prog, &env).expect("execute failed");
assert!(stats.instructions_executed > 0);
}
#[test]
fn test_stats_max_stack_depth_single_push() {
let mut prog = BytecodeProgram::new();
prog.push(Instruction::PushNum(99.0));
prog.push(Instruction::Halt);
let env = VmEnv::new();
let (_val, stats) = execute_with_stats(&prog, &env).expect("execute failed");
assert_eq!(stats.max_stack_depth, 1, "single push should give depth 1");
}
#[test]
fn test_tnorm_product() {
let expr = TLExpr::TNorm {
kind: TNormKind::Product,
left: Box::new(TLExpr::Constant(0.5)),
right: Box::new(TLExpr::Constant(0.5)),
};
let result = eval(expr);
match result {
VmValue::Num(n) => {
assert!((n - 0.25).abs() < 1e-10, "expected 0.25, got {}", n);
}
_ => panic!("expected Num, got {:?}", result),
}
}
#[test]
fn test_fuzzy_not() {
let expr = TLExpr::FuzzyNot {
kind: FuzzyNegationKind::Standard,
expr: Box::new(TLExpr::Constant(0.3)),
};
let result = eval(expr);
match result {
VmValue::Num(n) => {
assert!((n - 0.7).abs() < 1e-10, "expected 0.7, got {}", n);
}
_ => panic!("expected Num, got {:?}", result),
}
}
#[test]
fn test_tconorm() {
let expr = TLExpr::TCoNorm {
kind: TCoNormKind::ProbabilisticSum,
left: Box::new(TLExpr::Constant(0.5)),
right: Box::new(TLExpr::Constant(0.5)),
};
let result = eval(expr);
match result {
VmValue::Num(n) => {
assert!((n - 0.75).abs() < 1e-10, "expected 0.75, got {}", n);
}
_ => panic!("expected Num, got {:?}", result),
}
}
#[test]
fn test_nested_arithmetic() {
let expr = TLExpr::mul(
TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
);
assert_eq!(eval(expr), VmValue::Num(21.0));
}
#[test]
fn test_division_by_zero() {
let mut prog = BytecodeProgram::new();
prog.push(Instruction::PushNum(1.0));
prog.push(Instruction::PushNum(0.0));
prog.push(Instruction::Div);
prog.push(Instruction::Halt);
let env = VmEnv::new();
let err = execute(&prog, &env).unwrap_err();
assert!(
matches!(err, VmError::DivisionByZero),
"expected DivisionByZero, got {:?}",
err
);
}
#[test]
fn test_abs() {
let expr = TLExpr::Abs(Box::new(TLExpr::Constant(-5.0)));
assert_eq!(eval(expr), VmValue::Num(5.0));
}
#[test]
fn test_compile_unsupported_forall() {
use tensorlogic_ir::Term;
let expr = TLExpr::forall("x", "D", TLExpr::pred("P", vec![Term::var("x")]));
let err = compile(&expr).unwrap_err();
assert!(
matches!(err, CompileError::UnsupportedExpr(_)),
"expected UnsupportedExpr, got {:?}",
err
);
}
#[test]
fn test_max_depth_exceeded() {
let inner = TLExpr::add(
TLExpr::add(
TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
TLExpr::Constant(1.0),
),
TLExpr::Constant(1.0),
);
let err = compile_with_config(&inner, 1).unwrap_err();
assert!(
matches!(err, CompileError::MaxDepthExceeded),
"expected MaxDepthExceeded, got {:?}",
err
);
}
}