use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr};
use crate::context::{CompileState, CompilerContext};
use super::{compile_expr, strategy_mapping};
pub(crate) fn compile_and(
left: &TLExpr,
right: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
use crate::config::AndStrategy;
let left_state = compile_expr(left, ctx, graph)?;
let right_state = compile_expr(right, ctx, graph)?;
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for c in left_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
for c in right_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
if matches!(
ctx.config.and_strategy,
AndStrategy::Product | AndStrategy::ProductTNorm
) && !left_state.axes.is_empty()
&& !right_state.axes.is_empty()
{
let spec = format!("{},{}->{}", left_state.axes, right_state.axes, output_axes);
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let node = EinsumNode::new(
spec,
vec![left_state.tensor_idx, right_state.tensor_idx],
vec![result_idx],
);
graph.add_node(node)?;
return Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
});
}
let mut left_state = left_state;
let mut right_state = right_state;
if !left_state.axes.is_empty()
&& !right_state.axes.is_empty()
&& left_state.axes != right_state.axes
{
if left_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", left_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
let broadcast_node = EinsumNode::new(
broadcast_spec,
vec![left_state.tensor_idx],
vec![broadcast_idx],
);
graph.add_node(broadcast_node)?;
left_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
if right_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", right_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
let broadcast_node = EinsumNode::new(
broadcast_spec,
vec![right_state.tensor_idx],
vec![broadcast_idx],
);
graph.add_node(broadcast_node)?;
right_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
}
let result_idx = strategy_mapping::compile_and_with_strategy(
left_state.tensor_idx,
right_state.tensor_idx,
ctx,
graph,
)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_or(
left: &TLExpr,
right: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let mut left_state = compile_expr(left, ctx, graph)?;
let mut right_state = compile_expr(right, ctx, graph)?;
if !left_state.axes.is_empty()
&& !right_state.axes.is_empty()
&& left_state.axes != right_state.axes
{
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for c in left_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
for c in right_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
if left_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", left_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
let broadcast_node = EinsumNode::new(
broadcast_spec,
vec![left_state.tensor_idx],
vec![broadcast_idx],
);
graph.add_node(broadcast_node)?;
left_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
if right_state.axes != output_axes {
let broadcast_spec = format!("{}->{}", right_state.axes, output_axes);
let broadcast_name = ctx.fresh_temp();
let broadcast_idx = graph.add_tensor(broadcast_name);
let broadcast_node = EinsumNode::new(
broadcast_spec,
vec![right_state.tensor_idx],
vec![broadcast_idx],
);
graph.add_node(broadcast_node)?;
right_state = CompileState {
tensor_idx: broadcast_idx,
axes: output_axes.clone(),
};
}
}
let mut output_axes = String::new();
let mut seen = std::collections::HashSet::new();
for c in left_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
for c in right_state.axes.chars() {
if seen.insert(c) {
output_axes.push(c);
}
}
let result_idx = strategy_mapping::compile_or_with_strategy(
left_state.tensor_idx,
right_state.tensor_idx,
ctx,
graph,
)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_not(
inner: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let inner_state = compile_expr(inner, ctx, graph)?;
let result_idx =
strategy_mapping::compile_not_with_strategy(inner_state.tensor_idx, ctx, graph)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: inner_state.axes,
})
}