use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
use crate::compile::compile_expr;
use crate::context::{CompileState, CompilerContext};
const DEFAULT_LAMBDA: f64 = 1.0;
pub(crate) fn compile_abducible(
name: &str,
cost: f64,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let abducible_name = format!("abducible_{}", name);
let abducible_idx = graph.add_tensor(abducible_name.clone());
use tensorlogic_ir::Metadata;
let metadata = Metadata::new()
.with_name(format!("Abducible: {}", name))
.with_attribute("abducible_cost", cost.to_string());
graph.tensor_metadata.insert(abducible_idx, metadata);
register_abducible(ctx, name, cost, abducible_idx)?;
Ok(CompileState {
tensor_idx: abducible_idx,
axes: String::new(),
})
}
pub(crate) fn compile_explain(
formula: &tensorlogic_ir::TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let formula_result = compile_expr(formula, ctx, graph)?;
let abducibles = get_registered_abducibles(ctx, graph)?;
if abducibles.is_empty() {
return Ok(formula_result);
}
let cost_term_idx = compute_total_cost(ctx, graph, &abducibles)?;
let lambda_tensor = create_constant_tensor(DEFAULT_LAMBDA, ctx, graph)?;
let weighted_cost_name = ctx.fresh_temp();
let weighted_cost_idx = graph.add_tensor(weighted_cost_name);
let mul_node = EinsumNode {
op: OpType::ElemBinary {
op: "mul".to_string(),
},
inputs: vec![cost_term_idx, lambda_tensor],
outputs: vec![weighted_cost_idx],
metadata: None,
};
graph.add_node(mul_node)?;
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let sub_node = EinsumNode {
op: OpType::ElemBinary {
op: "sub".to_string(),
},
inputs: vec![formula_result.tensor_idx, weighted_cost_idx],
outputs: vec![result_idx],
metadata: None, };
graph.add_node(sub_node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: formula_result.axes,
})
}
fn register_abducible(
ctx: &mut CompilerContext,
name: &str,
_cost: f64,
tensor_idx: usize,
) -> Result<()> {
let key = format!("abd_{}", name);
ctx.let_bindings.insert(key, tensor_idx);
Ok(())
}
fn get_registered_abducibles(
ctx: &CompilerContext,
graph: &EinsumGraph,
) -> Result<Vec<(String, f64, usize)>> {
let mut abducibles = Vec::new();
for (key, &tensor_idx) in &ctx.let_bindings {
if let Some(name) = key.strip_prefix("abd_") {
let cost = if let Some(metadata) = graph.tensor_metadata.get(&tensor_idx) {
if let Some(cost_str) = metadata.get_attribute("abducible_cost") {
cost_str.parse::<f64>().unwrap_or(1.0) } else {
1.0 }
} else {
1.0 };
abducibles.push((name.to_string(), cost, tensor_idx));
}
}
Ok(abducibles)
}
fn compute_total_cost(
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
abducibles: &[(String, f64, usize)],
) -> Result<usize> {
if abducibles.is_empty() {
return create_constant_tensor(0.0, ctx, graph);
}
let (_, cost_0, tensor_idx_0) = abducibles[0];
let cost_0_tensor = create_constant_tensor(cost_0, ctx, graph)?;
let accumulator_name = ctx.fresh_temp();
let mut accumulator_idx = graph.add_tensor(accumulator_name);
let mul_node_0 = EinsumNode {
op: OpType::ElemBinary {
op: "mul".to_string(),
},
inputs: vec![tensor_idx_0, cost_0_tensor],
outputs: vec![accumulator_idx],
metadata: None,
};
graph.add_node(mul_node_0)?;
for (_, cost_i, tensor_idx_i) in abducibles.iter().skip(1) {
let cost_i_tensor = create_constant_tensor(*cost_i, ctx, graph)?;
let weighted_name = ctx.fresh_temp();
let weighted_idx = graph.add_tensor(weighted_name);
let mul_node = EinsumNode {
op: OpType::ElemBinary {
op: "mul".to_string(),
},
inputs: vec![*tensor_idx_i, cost_i_tensor],
outputs: vec![weighted_idx],
metadata: None,
};
graph.add_node(mul_node)?;
let new_accumulator_name = ctx.fresh_temp();
let new_accumulator_idx = graph.add_tensor(new_accumulator_name);
let add_node = EinsumNode {
op: OpType::ElemBinary {
op: "add".to_string(),
},
inputs: vec![accumulator_idx, weighted_idx],
outputs: vec![new_accumulator_idx],
metadata: None,
};
graph.add_node(add_node)?;
accumulator_idx = new_accumulator_idx;
}
Ok(accumulator_idx)
}
fn create_constant_tensor(
value: f64,
_ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<usize> {
let const_name = format!("const_{}", value);
let const_idx = graph.add_tensor(const_name.clone());
let metadata = format!("constant:{}", value);
graph
.tensors
.get_mut(const_idx)
.unwrap()
.push_str(&format!("#{}", metadata));
Ok(const_idx)
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TLExpr, Term};
#[test]
fn test_abducible_compilation() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let result = compile_abducible("has_flu", 0.3, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
assert!(result.axes.is_empty());
assert!(graph.tensors[result.tensor_idx].contains("abducible"));
}
#[test]
fn test_explain_without_abducibles() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let safe = TLExpr::pred("Safe", vec![]);
let _result = compile_explain(&safe, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_explain_with_single_abducible() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
compile_abducible("has_flu", 0.3, &mut ctx, &mut graph).unwrap();
let fever = TLExpr::pred("Fever", vec![]);
let _result = compile_explain(&fever, &mut ctx, &mut graph).unwrap();
assert!(!graph.nodes.is_empty());
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_explain_with_multiple_abducibles() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
compile_abducible("has_flu", 0.3, &mut ctx, &mut graph).unwrap();
compile_abducible("has_cold", 0.2, &mut ctx, &mut graph).unwrap();
compile_abducible("has_covid", 0.5, &mut ctx, &mut graph).unwrap();
let fever = TLExpr::pred("Fever", vec![]);
let cough = TLExpr::pred("Cough", vec![]);
let symptoms = TLExpr::and(fever, cough);
let _result = compile_explain(&symptoms, &mut ctx, &mut graph).unwrap();
assert!(graph.nodes.len() >= 3);
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_abducible_with_zero_cost() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let result = compile_abducible("free_assumption", 0.0, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
assert!(graph.tensors[result.tensor_idx].contains("abducible"));
}
#[test]
fn test_abducible_with_high_cost() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let result =
compile_abducible("expensive_hypothesis", 100.0, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
assert!(graph.tensors[result.tensor_idx].contains("abducible"));
}
#[test]
fn test_multiple_explain_calls() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
compile_abducible("H1", 1.0, &mut ctx, &mut graph).unwrap();
compile_abducible("H2", 2.0, &mut ctx, &mut graph).unwrap();
let formula1 = TLExpr::pred("P", vec![]);
let formula2 = TLExpr::pred("Q", vec![]);
let _result1 = compile_explain(&formula1, &mut ctx, &mut graph).unwrap();
let _result2 = compile_explain(&formula2, &mut ctx, &mut graph).unwrap();
assert!(graph.nodes.len() >= 2);
}
#[test]
fn test_explain_with_free_variables() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 10);
let mut graph = EinsumGraph::new();
compile_abducible("knows_someone", 1.0, &mut ctx, &mut graph).unwrap();
let knows = TLExpr::pred("Knows", vec![Term::var("x"), Term::var("y")]);
ctx.bind_var("x", "Person").unwrap();
ctx.bind_var("y", "Person").unwrap();
let result = compile_explain(&knows, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_nested_explain_not_recommended() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
compile_abducible("H", 1.0, &mut ctx, &mut graph).unwrap();
let p = TLExpr::pred("P", vec![]);
let inner_explain = TLExpr::Explain {
formula: Box::new(p),
};
let _result = compile_explain(&inner_explain, &mut ctx, &mut graph);
}
#[test]
fn test_abducible_name_uniqueness() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let result1 = compile_abducible("H", 1.0, &mut ctx, &mut graph).unwrap();
let result2 = compile_abducible("H", 1.0, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
assert_ne!(result1.tensor_idx, result2.tensor_idx);
}
#[test]
fn test_abducible_cost_metadata_storage() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let result1 = compile_abducible("cheap", 0.5, &mut ctx, &mut graph).unwrap();
let result2 = compile_abducible("expensive", 10.0, &mut ctx, &mut graph).unwrap();
let result3 = compile_abducible("moderate", 2.5, &mut ctx, &mut graph).unwrap();
let meta1 = graph.tensor_metadata.get(&result1.tensor_idx).unwrap();
assert_eq!(meta1.get_attribute("abducible_cost"), Some("0.5"));
let meta2 = graph.tensor_metadata.get(&result2.tensor_idx).unwrap();
assert_eq!(meta2.get_attribute("abducible_cost"), Some("10"));
let meta3 = graph.tensor_metadata.get(&result3.tensor_idx).unwrap();
assert_eq!(meta3.get_attribute("abducible_cost"), Some("2.5"));
}
#[test]
fn test_get_registered_abducibles_extracts_costs() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
compile_abducible("H1", 1.0, &mut ctx, &mut graph).unwrap();
compile_abducible("H2", 2.5, &mut ctx, &mut graph).unwrap();
compile_abducible("H3", 0.3, &mut ctx, &mut graph).unwrap();
let abducibles = get_registered_abducibles(&ctx, &graph).unwrap();
assert_eq!(abducibles.len(), 3);
for (name, cost, _idx) in &abducibles {
match name.as_str() {
"H1" => assert_eq!(*cost, 1.0),
"H2" => assert_eq!(*cost, 2.5),
"H3" => assert_eq!(*cost, 0.3),
_ => panic!("Unexpected abducible name: {}", name),
}
}
}
#[test]
fn test_explain_uses_correct_costs() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
compile_abducible("H1", 1.0, &mut ctx, &mut graph).unwrap();
compile_abducible("H2", 5.0, &mut ctx, &mut graph).unwrap();
let formula = TLExpr::pred("Safe", vec![]);
compile_explain(&formula, &mut ctx, &mut graph).unwrap();
assert!(!graph.nodes.is_empty());
let abducibles = get_registered_abducibles(&ctx, &graph).unwrap();
assert_eq!(abducibles.len(), 2);
for (name, cost, _) in &abducibles {
if name == "H1" {
assert_eq!(*cost, 1.0);
} else if name == "H2" {
assert_eq!(*cost, 5.0);
}
}
}
}