use anyhow::{bail, Result};
use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr};
use crate::compile::compile_expr;
use crate::context::{CompileState, CompilerContext};
pub(crate) fn compile_set_membership(
element: &TLExpr,
set: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let set_state = compile_expr(set, ctx, graph)?;
let elem_state = compile_expr(element, ctx, graph)?;
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for c in elem_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
for c in set_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let spec = if output_axes.is_empty() {
",->"
} else if elem_state.axes == set_state.axes {
graph.add_node(EinsumNode::elem_binary(
"multiply",
elem_state.tensor_idx,
set_state.tensor_idx,
result_idx,
))?;
return Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
});
} else {
&format!("{},{}->{}", elem_state.axes, set_state.axes, output_axes)
};
graph.add_node(EinsumNode::einsum(
spec,
vec![elem_state.tensor_idx, set_state.tensor_idx],
vec![result_idx],
))?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_set_union(
left: &TLExpr,
right: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let mut left_state = compile_expr(left, ctx, graph)?;
let mut right_state = compile_expr(right, ctx, graph)?;
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for c in left_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
for c in right_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
if !left_state.axes.is_empty() && left_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", left_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
graph.add_node(EinsumNode::einsum(
broadcast_spec,
vec![left_state.tensor_idx],
vec![broadcast_idx],
))?;
left_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
if !right_state.axes.is_empty() && right_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", right_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
graph.add_node(EinsumNode::einsum(
broadcast_spec,
vec![right_state.tensor_idx],
vec![broadcast_idx],
))?;
right_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
graph.add_node(EinsumNode::elem_binary(
"max",
left_state.tensor_idx,
right_state.tensor_idx,
result_idx,
))?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_set_intersection(
left: &TLExpr,
right: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let mut left_state = compile_expr(left, ctx, graph)?;
let mut right_state = compile_expr(right, ctx, graph)?;
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for c in left_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
for c in right_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
if !left_state.axes.is_empty() && left_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", left_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
graph.add_node(EinsumNode::einsum(
broadcast_spec,
vec![left_state.tensor_idx],
vec![broadcast_idx],
))?;
left_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
if !right_state.axes.is_empty() && right_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", right_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
graph.add_node(EinsumNode::einsum(
broadcast_spec,
vec![right_state.tensor_idx],
vec![broadcast_idx],
))?;
right_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
graph.add_node(EinsumNode::elem_binary(
"min",
left_state.tensor_idx,
right_state.tensor_idx,
result_idx,
))?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_set_difference(
left: &TLExpr,
right: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let mut left_state = compile_expr(left, ctx, graph)?;
let mut right_state = compile_expr(right, ctx, graph)?;
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for c in left_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
for c in right_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
if !left_state.axes.is_empty() && left_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", left_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
graph.add_node(EinsumNode::einsum(
broadcast_spec,
vec![left_state.tensor_idx],
vec![broadcast_idx],
))?;
left_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
if !right_state.axes.is_empty() && right_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", right_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
graph.add_node(EinsumNode::einsum(
broadcast_spec,
vec![right_state.tensor_idx],
vec![broadcast_idx],
))?;
right_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
let not_right_name = ctx.fresh_temp();
let not_right_idx = graph.add_tensor(not_right_name);
graph.add_node(EinsumNode::elem_unary(
"one_minus",
right_state.tensor_idx,
not_right_idx,
))?;
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
graph.add_node(EinsumNode::elem_binary(
"multiply",
left_state.tensor_idx,
not_right_idx,
result_idx,
))?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_set_cardinality(
set: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let set_state = compile_expr(set, ctx, graph)?;
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let axes_to_reduce: Vec<usize> = set_state
.axes
.chars()
.map(|c| (c as u8 - b'a') as usize)
.collect();
graph.add_node(EinsumNode::reduce(
"sum",
axes_to_reduce,
set_state.tensor_idx,
result_idx,
))?;
Ok(CompileState {
tensor_idx: result_idx,
axes: String::new(),
})
}
pub(crate) fn compile_empty_set(
_ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let tensor_name = "const_0.0";
let tensor_idx = graph.add_tensor(tensor_name);
Ok(CompileState {
tensor_idx,
axes: String::new(),
})
}
pub(crate) fn compile_set_comprehension(
var: &str,
domain: &str,
condition: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
if !ctx.domains.contains_key(domain) {
bail!(
"Domain '{}' not found for set comprehension variable '{}'",
domain,
var
);
}
ctx.bind_var(var, domain)?;
let cond_state = compile_expr(condition, ctx, graph)?;
Ok(cond_state)
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TLExpr, Term};
#[test]
fn test_empty_set_compilation() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let result = compile_empty_set(&mut ctx, &mut graph).unwrap();
assert!(result.axes.is_empty());
}
#[test]
fn test_set_comprehension_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut graph = EinsumGraph::new();
let condition = TLExpr::pred("P", vec![Term::var("x")]);
let result =
compile_set_comprehension("x", "Person", &condition, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
}
#[test]
fn test_set_union_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut graph = EinsumGraph::new();
let set_a = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("A", vec![Term::var("x")])),
};
let set_b = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("B", vec![Term::var("x")])),
};
let result = compile_set_union(&set_a, &set_b, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
}
#[test]
fn test_set_intersection_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut graph = EinsumGraph::new();
let set_a = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("A", vec![Term::var("x")])),
};
let set_b = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("B", vec![Term::var("x")])),
};
let result = compile_set_intersection(&set_a, &set_b, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
}
#[test]
fn test_set_difference_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut graph = EinsumGraph::new();
let set_a = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("A", vec![Term::var("x")])),
};
let set_b = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("B", vec![Term::var("x")])),
};
let result = compile_set_difference(&set_a, &set_b, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
}
#[test]
fn test_set_cardinality_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut graph = EinsumGraph::new();
let set_expr = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("Adult", vec![Term::var("x")])),
};
let result = compile_set_cardinality(&set_expr, &mut ctx, &mut graph).unwrap();
assert!(result.axes.is_empty());
}
#[test]
fn test_set_membership_compilation() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut graph = EinsumGraph::new();
let set_expr = TLExpr::SetComprehension {
var: "x".to_string(),
domain: "Person".to_string(),
condition: Box::new(TLExpr::pred("Adult", vec![Term::var("x")])),
};
let elem = TLExpr::pred("IsAlice", vec![Term::var("y")]);
let result = compile_set_membership(&elem, &set_expr, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
}
}