use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr};
use crate::compile::compile_expr;
use crate::context::{CompileState, CompilerContext};
pub(crate) fn compile_counting_exists(
var: &str,
domain: &str,
body: &TLExpr,
min_count: usize,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
ctx.bind_var(var, domain)?;
let axis = ctx.assign_axis(var);
let body_state = compile_expr(body, ctx, graph)?;
let output_axes: String = body_state.axes.chars().filter(|&c| c != axis).collect();
let sum_spec = format!("sum({}->{})", body_state.axes, output_axes);
let sum_name = ctx.fresh_temp();
let sum_idx = graph.add_tensor(sum_name);
let sum_node = EinsumNode::new(sum_spec, vec![body_state.tensor_idx], vec![sum_idx]);
graph.add_node(sum_node)?;
let threshold = (min_count as f64) - 0.5;
let threshold_name = format!("const_{}", threshold);
let threshold_idx = graph.add_tensor(&threshold_name);
let diff_spec = format!("subtract({},{}->{})", output_axes, "", output_axes);
let diff_name = ctx.fresh_temp();
let diff_idx = graph.add_tensor(diff_name);
let diff_node = EinsumNode::new(diff_spec, vec![sum_idx, threshold_idx], vec![diff_idx]);
graph.add_node(diff_node)?;
let sigmoid_spec = format!("sigmoid({}->{})", output_axes, output_axes);
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let sigmoid_node = EinsumNode::new(sigmoid_spec, vec![diff_idx], vec![result_idx]);
graph.add_node(sigmoid_node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_counting_forall(
var: &str,
domain: &str,
body: &TLExpr,
min_count: usize,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
compile_counting_exists(var, domain, body, min_count, ctx, graph)
}
pub(crate) fn compile_exact_count(
var: &str,
domain: &str,
body: &TLExpr,
count: usize,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
ctx.bind_var(var, domain)?;
let axis = ctx.assign_axis(var);
let body_state = compile_expr(body, ctx, graph)?;
let output_axes: String = body_state.axes.chars().filter(|&c| c != axis).collect();
let sum_spec = format!("sum({}->{})", body_state.axes, output_axes);
let sum_name = ctx.fresh_temp();
let sum_idx = graph.add_tensor(sum_name);
let sum_node = EinsumNode::new(sum_spec, vec![body_state.tensor_idx], vec![sum_idx]);
graph.add_node(sum_node)?;
let target = count as f64;
let target_name = format!("const_{}", target);
let target_idx = graph.add_tensor(&target_name);
let diff_spec = format!("subtract({},{}->{})", output_axes, "", output_axes);
let diff_name = ctx.fresh_temp();
let diff_idx = graph.add_tensor(diff_name);
let diff_node = EinsumNode::new(diff_spec, vec![sum_idx, target_idx], vec![diff_idx]);
graph.add_node(diff_node)?;
let sq_spec = format!("multiply({},{}->{})", output_axes, output_axes, output_axes);
let sq_name = ctx.fresh_temp();
let sq_idx = graph.add_tensor(sq_name);
let sq_node = EinsumNode::new(sq_spec, vec![diff_idx, diff_idx], vec![sq_idx]);
graph.add_node(sq_node)?;
let neg_spec = format!("negate({}->{})", output_axes, output_axes);
let neg_name = ctx.fresh_temp();
let neg_idx = graph.add_tensor(neg_name);
let neg_node = EinsumNode::new(neg_spec, vec![sq_idx], vec![neg_idx]);
graph.add_node(neg_node)?;
let exp_spec = format!("exp({}->{})", output_axes, output_axes);
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let exp_node = EinsumNode::new(exp_spec, vec![neg_idx], vec![result_idx]);
graph.add_node(exp_node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_majority(
var: &str,
domain: &str,
body: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let domain_info = ctx
.domains
.get(domain)
.ok_or_else(|| anyhow::anyhow!("Domain '{}' not found", domain))?;
let domain_size = domain_info.cardinality;
ctx.bind_var(var, domain)?;
let axis = ctx.assign_axis(var);
let body_state = compile_expr(body, ctx, graph)?;
let output_axes: String = body_state.axes.chars().filter(|&c| c != axis).collect();
let sum_spec = format!("sum({}->{})", body_state.axes, output_axes);
let sum_name = ctx.fresh_temp();
let sum_idx = graph.add_tensor(sum_name);
let sum_node = EinsumNode::new(sum_spec, vec![body_state.tensor_idx], vec![sum_idx]);
graph.add_node(sum_node)?;
let half_size = (domain_size as f64) / 2.0;
let half_name = format!("const_{}", half_size);
let half_idx = graph.add_tensor(&half_name);
let diff_spec = format!("subtract({},{}->{})", output_axes, "", output_axes);
let diff_name = ctx.fresh_temp();
let diff_idx = graph.add_tensor(diff_name);
let diff_node = EinsumNode::new(diff_spec, vec![sum_idx, half_idx], vec![diff_idx]);
graph.add_node(diff_node)?;
let sigmoid_spec = format!("sigmoid({}->{})", output_axes, output_axes);
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let sigmoid_node = EinsumNode::new(sigmoid_spec, vec![diff_idx], vec![result_idx]);
graph.add_node(sigmoid_node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
#[test]
fn test_counting_exists_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 10);
let body = TLExpr::pred("happy", vec![Term::var("x")]);
let mut graph = EinsumGraph::default();
let result = compile_counting_exists("x", "Person", &body, 3, &mut ctx, &mut graph);
assert!(result.is_ok());
let state = result.unwrap();
assert!(state.axes.is_empty());
assert!(graph.nodes.len() >= 3);
}
#[test]
fn test_exact_count_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Item", 5);
let body = TLExpr::pred("selected", vec![Term::var("i")]);
let mut graph = EinsumGraph::default();
let result = compile_exact_count("i", "Item", &body, 2, &mut ctx, &mut graph);
assert!(result.is_ok());
let state = result.unwrap();
assert!(state.axes.is_empty());
assert!(graph.nodes.len() >= 5);
}
#[test]
fn test_majority_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Voter", 100);
let body = TLExpr::pred("votes_yes", vec![Term::var("v")]);
let mut graph = EinsumGraph::default();
let result = compile_majority("v", "Voter", &body, &mut ctx, &mut graph);
assert!(result.is_ok());
let state = result.unwrap();
assert!(state.axes.is_empty());
assert!(graph.nodes.len() >= 3);
}
#[test]
fn test_counting_forall_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Student", 20);
let body = TLExpr::pred("passed", vec![Term::var("s")]);
let mut graph = EinsumGraph::default();
let result = compile_counting_forall("s", "Student", &body, 15, &mut ctx, &mut graph);
assert!(result.is_ok());
let state = result.unwrap();
assert!(state.axes.is_empty());
}
}