use anyhow::{bail, Result};
use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr};
use crate::config::{ModalStrategy, TemporalStrategy};
use crate::context::{CompileState, CompilerContext};
use super::compile_expr;
const WORLD_AXIS: &str = "__world__";
const TIME_AXIS: &str = "__time__";
pub(crate) fn compile_box(
inner: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let world_axis = ensure_world_axis(ctx);
let inner_state = compile_expr(inner, ctx, graph)?;
let strategy = ctx.config.modal_strategy;
if !inner_state.axes.contains(world_axis) {
return Ok(inner_state);
}
match strategy {
ModalStrategy::AllWorldsMin | ModalStrategy::Threshold { .. } => {
apply_reduction(&inner_state, world_axis, "min", ctx, graph)
}
ModalStrategy::AllWorldsProduct => {
apply_reduction(&inner_state, world_axis, "prod", ctx, graph)
}
}
}
pub(crate) fn compile_diamond(
inner: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let world_axis = ensure_world_axis(ctx);
let inner_state = compile_expr(inner, ctx, graph)?;
if !inner_state.axes.contains(world_axis) {
return Ok(inner_state);
}
let strategy = ctx.config.modal_strategy;
match strategy {
ModalStrategy::AllWorldsMin | ModalStrategy::Threshold { .. } => {
apply_reduction(&inner_state, world_axis, "max", ctx, graph)
}
ModalStrategy::AllWorldsProduct => {
apply_reduction(&inner_state, world_axis, "sum", ctx, graph)
}
}
}
pub(crate) fn compile_next(
_inner: &TLExpr,
_ctx: &mut CompilerContext,
_graph: &mut EinsumGraph,
) -> Result<CompileState> {
bail!(
"Next (X) temporal operator requires shift operations which are not available in einsum. \
Consider using Eventually or Always operators, or implement backend-specific shift support."
)
}
pub(crate) fn compile_eventually(
inner: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let time_axis = ensure_time_axis(ctx);
let inner_state = compile_expr(inner, ctx, graph)?;
if !inner_state.axes.contains(time_axis) {
return Ok(inner_state);
}
let strategy = ctx.config.temporal_strategy;
match strategy {
TemporalStrategy::Max | TemporalStrategy::LogSumExp => {
apply_reduction(&inner_state, time_axis, "max", ctx, graph)
}
TemporalStrategy::Sum => {
apply_reduction(&inner_state, time_axis, "sum", ctx, graph)
}
}
}
pub(crate) fn compile_always(
inner: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let time_axis = ensure_time_axis(ctx);
let inner_state = compile_expr(inner, ctx, graph)?;
if !inner_state.axes.contains(time_axis) {
return Ok(inner_state);
}
let strategy = ctx.config.temporal_strategy;
match strategy {
TemporalStrategy::Max | TemporalStrategy::LogSumExp => {
apply_reduction(&inner_state, time_axis, "min", ctx, graph)
}
TemporalStrategy::Sum => {
apply_reduction(&inner_state, time_axis, "prod", ctx, graph)
}
}
}
pub(crate) fn compile_until(
_before: &TLExpr,
_after: &TLExpr,
_ctx: &mut CompilerContext,
_graph: &mut EinsumGraph,
) -> Result<CompileState> {
bail!(
"Until (U) temporal operator requires scan operations which are not available in einsum. \
Consider using Eventually or Always operators as approximations, or implement \
backend-specific scan support."
)
}
pub(crate) fn compile_release(
p: &TLExpr,
q: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let q_state = compile_expr(q, ctx, graph)?;
let always_q = TLExpr::Always(Box::new(q.clone()));
let p_or_always_q = TLExpr::or(p.clone(), always_q);
let p_or_always_q_state = compile_expr(&p_or_always_q, ctx, graph)?;
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let output_axes = merge_axes(&q_state.axes, &p_or_always_q_state.axes);
let spec = format!(
"{},{}->{}",
q_state.axes, p_or_always_q_state.axes, output_axes
);
let node = EinsumNode::new(
spec,
vec![q_state.tensor_idx, p_or_always_q_state.tensor_idx],
vec![result_idx],
);
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_weak_until(
p: &TLExpr,
q: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let always_p = TLExpr::Always(Box::new(p.clone()));
let eventually_q = TLExpr::Eventually(Box::new(q.clone()));
let weak_until_expr = TLExpr::or(always_p, eventually_q);
compile_expr(&weak_until_expr, ctx, graph)
}
pub(crate) fn compile_strong_release(
p: &TLExpr,
q: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let eventually_p = TLExpr::Eventually(Box::new(p.clone()));
let eventually_p_state = compile_expr(&eventually_p, ctx, graph)?;
let always_q = TLExpr::Always(Box::new(q.clone()));
let p_or_always_q = TLExpr::or(p.clone(), always_q);
let release_expr = TLExpr::and(q.clone(), p_or_always_q);
let release_state = compile_expr(&release_expr, ctx, graph)?;
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let output_axes = merge_axes(&eventually_p_state.axes, &release_state.axes);
let spec = format!(
"{},{}->{}",
eventually_p_state.axes, release_state.axes, output_axes
);
let node = EinsumNode::new(
spec,
vec![eventually_p_state.tensor_idx, release_state.tensor_idx],
vec![result_idx],
);
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
fn merge_axes(axes1: &str, axes2: &str) -> String {
let mut result = axes1.to_string();
for c in axes2.chars() {
if !result.contains(c) {
result.push(c);
}
}
let mut chars: Vec<char> = result.chars().collect();
chars.sort();
chars.into_iter().collect()
}
fn ensure_world_axis(ctx: &mut CompilerContext) -> char {
if let Some(&axis) = ctx.var_to_axis.get(WORLD_AXIS) {
return axis;
}
if !ctx.domains.contains_key(WORLD_AXIS) {
let world_size = ctx.config.modal_world_size.unwrap_or(10);
ctx.add_domain(WORLD_AXIS, world_size);
}
ctx.assign_axis(WORLD_AXIS)
}
fn ensure_time_axis(ctx: &mut CompilerContext) -> char {
if let Some(&axis) = ctx.var_to_axis.get(TIME_AXIS) {
return axis;
}
if !ctx.domains.contains_key(TIME_AXIS) {
let time_size = ctx.config.temporal_time_steps.unwrap_or(100);
ctx.add_domain(TIME_AXIS, time_size);
}
ctx.assign_axis(TIME_AXIS)
}
fn apply_reduction(
state: &CompileState,
axis_to_reduce: char,
reduction_op: &str,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let output_axes: String = state
.axes
.chars()
.filter(|&c| c != axis_to_reduce)
.collect();
let spec = format!("{}({}->{})", reduction_op, state.axes, output_axes);
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let node = EinsumNode::new(spec, vec![state.tensor_idx], vec![result_idx]);
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CompilationConfig, CompilerContext};
use tensorlogic_ir::{TLExpr, Term};
#[test]
fn test_ensure_world_axis() {
let mut ctx = CompilerContext::new();
let axis1 = ensure_world_axis(&mut ctx);
let axis2 = ensure_world_axis(&mut ctx);
assert_eq!(axis1, axis2);
assert!(ctx.domains.contains_key(WORLD_AXIS));
assert!(ctx.var_to_axis.contains_key(WORLD_AXIS));
}
#[test]
fn test_ensure_time_axis() {
let mut ctx = CompilerContext::new();
let axis1 = ensure_time_axis(&mut ctx);
let axis2 = ensure_time_axis(&mut ctx);
assert_eq!(axis1, axis2);
assert!(ctx.domains.contains_key(TIME_AXIS));
assert!(ctx.var_to_axis.contains_key(TIME_AXIS));
}
#[test]
fn test_compile_box_simple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 10);
let mut graph = EinsumGraph::new();
let pred = TLExpr::pred("happy", vec![Term::var("x")]);
let result = compile_box(&pred, &mut ctx, &mut graph);
assert!(ctx.domains.contains_key(WORLD_AXIS));
let _ = result;
}
#[test]
fn test_compile_diamond_simple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 10);
let mut graph = EinsumGraph::new();
let pred = TLExpr::pred("possible", vec![Term::var("x")]);
let result = compile_diamond(&pred, &mut ctx, &mut graph);
assert!(ctx.domains.contains_key(WORLD_AXIS));
let _ = result;
}
#[test]
fn test_compile_eventually_simple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Event", 5);
let mut graph = EinsumGraph::new();
let pred = TLExpr::pred("occurs", vec![Term::var("e")]);
let result = compile_eventually(&pred, &mut ctx, &mut graph);
assert!(ctx.domains.contains_key(TIME_AXIS));
let _ = result;
}
#[test]
fn test_next_not_implemented() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let pred = TLExpr::pred("p", vec![Term::var("x")]);
let result = compile_next(&pred, &mut ctx, &mut graph);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("shift"));
}
#[test]
fn test_until_not_implemented() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let pred1 = TLExpr::pred("p", vec![Term::var("x")]);
let pred2 = TLExpr::pred("q", vec![Term::var("x")]);
let result = compile_until(&pred1, &pred2, &mut ctx, &mut graph);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("scan"));
}
#[test]
fn test_modal_strategy_configuration() {
let ctx = CompilerContext::with_config(CompilationConfig::hard_boolean());
assert_eq!(ctx.config.modal_strategy, ModalStrategy::AllWorldsMin);
let ctx = CompilerContext::with_config(CompilationConfig::soft_differentiable());
assert_eq!(ctx.config.modal_strategy, ModalStrategy::AllWorldsProduct);
}
#[test]
fn test_temporal_strategy_configuration() {
let ctx = CompilerContext::with_config(CompilationConfig::hard_boolean());
assert_eq!(ctx.config.temporal_strategy, TemporalStrategy::Max);
let ctx = CompilerContext::with_config(CompilationConfig::soft_differentiable());
assert_eq!(ctx.config.temporal_strategy, TemporalStrategy::Sum);
}
}