use anyhow::{bail, Result};
use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr};
use crate::context::{CompileState, CompilerContext};
use super::compile_expr;
pub(crate) fn compile_weighted_rule(
weight: f64,
rule: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let rule_state = compile_expr(rule, ctx, graph)?;
let weight_name = format!("const_{}", weight);
let weight_idx = if !graph.tensors.contains(&weight_name) {
graph.add_tensor(weight_name.clone())
} else {
graph
.tensors
.iter()
.position(|t| t == &weight_name)
.unwrap()
};
let result_idx = graph.add_tensor(format!("weighted_rule_{}", graph.tensors.len()));
let node = EinsumNode::elem_binary("multiply", rule_state.tensor_idx, weight_idx, result_idx);
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: rule_state.axes,
})
}
pub(crate) fn compile_probabilistic_choice(
alternatives: &[(f64, TLExpr)],
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
if alternatives.is_empty() {
bail!("ProbabilisticChoice requires at least one alternative");
}
for (prob, _) in alternatives {
if *prob < 0.0 {
bail!(
"Probabilities in ProbabilisticChoice must be non-negative, got {}",
prob
);
}
}
let mut compiled_alternatives = Vec::new();
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for (prob, expr) in alternatives {
let state = compile_expr(expr, ctx, graph)?;
for c in state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
compiled_alternatives.push((*prob, state));
}
if compiled_alternatives.len() == 1 {
let (prob, state) = &compiled_alternatives[0];
let prob_name = format!("const_{}", prob);
let prob_idx = if !graph.tensors.contains(&prob_name) {
graph.add_tensor(prob_name.clone())
} else {
graph.tensors.iter().position(|t| t == &prob_name).unwrap()
};
let result_idx = graph.add_tensor(format!("prob_choice_{}", graph.tensors.len()));
let node = EinsumNode::elem_binary("multiply", state.tensor_idx, prob_idx, result_idx);
graph.add_node(node)?;
return Ok(CompileState {
tensor_idx: result_idx,
axes: state.axes.clone(),
});
}
let (prob_0, state_0) = &compiled_alternatives[0];
let prob_0_name = format!("const_{}", prob_0);
let prob_0_idx = if !graph.tensors.contains(&prob_0_name) {
graph.add_tensor(prob_0_name.clone())
} else {
graph
.tensors
.iter()
.position(|t| t == &prob_0_name)
.unwrap()
};
let weighted_0_idx = graph.add_tensor(format!("weighted_0_{}", graph.tensors.len()));
let node = EinsumNode::elem_binary("multiply", state_0.tensor_idx, prob_0_idx, weighted_0_idx);
graph.add_node(node)?;
let mut accum_idx = weighted_0_idx;
let mut accum_axes = state_0.axes.clone();
for (i, (prob_i, state_i)) in compiled_alternatives.iter().skip(1).enumerate() {
let prob_i_name = format!("const_{}", prob_i);
let prob_i_idx = if !graph.tensors.contains(&prob_i_name) {
graph.add_tensor(prob_i_name.clone())
} else {
graph
.tensors
.iter()
.position(|t| t == &prob_i_name)
.unwrap()
};
let weighted_i_idx =
graph.add_tensor(format!("weighted_{}_{}", i + 1, graph.tensors.len()));
let weighted_node =
EinsumNode::elem_binary("multiply", state_i.tensor_idx, prob_i_idx, weighted_i_idx);
graph.add_node(weighted_node)?;
let new_accum_idx = graph.add_tensor(format!("accum_{}_{}", i + 1, graph.tensors.len()));
if accum_axes == state_i.axes || accum_axes.is_empty() || state_i.axes.is_empty() {
let add_node = EinsumNode::elem_binary("add", accum_idx, weighted_i_idx, new_accum_idx);
graph.add_node(add_node)?;
if accum_axes.is_empty() {
accum_axes = state_i.axes.clone();
} else if !state_i.axes.is_empty() {
let mut seen_in_accum = std::collections::HashSet::new();
for c in accum_axes.chars() {
seen_in_accum.insert(c);
}
for c in state_i.axes.chars() {
if !seen_in_accum.contains(&c) {
accum_axes.push(c);
}
}
}
} else {
let add_spec = format!("{},{}->{}", accum_axes, state_i.axes, output_axes);
let add_node = EinsumNode::new(
add_spec,
vec![accum_idx, weighted_i_idx],
vec![new_accum_idx],
);
graph.add_node(add_node)?;
accum_axes = output_axes.clone();
}
accum_idx = new_accum_idx;
}
Ok(CompileState {
tensor_idx: accum_idx,
axes: output_axes,
})
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TLExpr, Term};
#[test]
fn test_compile_weighted_rule() {
let mut ctx = CompilerContext::new();
ctx.add_domain("D", 10);
let mut graph = EinsumGraph::new();
let rule = TLExpr::pred("P", vec![Term::var("x")]);
let result = compile_weighted_rule(0.8, &rule, &mut ctx, &mut graph);
assert!(result.is_ok());
}
#[test]
fn test_compile_probabilistic_choice_single() {
let mut ctx = CompilerContext::new();
ctx.add_domain("D", 10);
let mut graph = EinsumGraph::new();
let expr = TLExpr::pred("P", vec![Term::var("x")]);
let alternatives = vec![(1.0, expr)];
let result = compile_probabilistic_choice(&alternatives, &mut ctx, &mut graph);
assert!(result.is_ok());
}
#[test]
fn test_compile_probabilistic_choice_multiple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("D", 10);
let mut graph = EinsumGraph::new();
let expr1 = TLExpr::pred("P", vec![Term::var("x")]);
let expr2 = TLExpr::pred("Q", vec![Term::var("x")]);
let alternatives = vec![(0.6, expr1), (0.4, expr2)];
let result = compile_probabilistic_choice(&alternatives, &mut ctx, &mut graph);
assert!(result.is_ok());
}
#[test]
fn test_probabilistic_choice_empty_fails() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let alternatives = vec![];
let result = compile_probabilistic_choice(&alternatives, &mut ctx, &mut graph);
assert!(result.is_err());
}
#[test]
fn test_probabilistic_choice_negative_prob_fails() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let expr = TLExpr::pred("P", vec![Term::var("x")]);
let alternatives = vec![(-0.5, expr)];
let result = compile_probabilistic_choice(&alternatives, &mut ctx, &mut graph);
assert!(result.is_err());
}
}