use anyhow::{bail, Result};
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::compile::compile_expr;
use crate::context::{CompileState, CompilerContext};
const DEFAULT_UNROLL_DEPTH: usize = 5;
pub(crate) fn compile_least_fixpoint(
var: &str,
body: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
compile_fixpoint_internal(var, body, ctx, graph, InitValue::Zero)
}
pub(crate) fn compile_greatest_fixpoint(
var: &str,
body: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
compile_fixpoint_internal(var, body, ctx, graph, InitValue::One)
}
#[derive(Debug, Clone, Copy)]
enum InitValue {
Zero,
One,
}
fn compile_fixpoint_internal(
var: &str,
body: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
init_value: InitValue,
) -> Result<CompileState> {
let free_vars = body.free_vars();
if !free_vars.contains(var) {
return compile_expr(body, ctx, graph);
}
if !ctx.var_to_domain.contains_key(var) {
if let Some(domain) = infer_fixpoint_domain(body, var) {
ctx.bind_var(var, &domain)?;
} else {
bail!(
"Cannot infer domain for fixpoint variable '{}'. \
Please bind the variable to a domain before using in fixpoint.",
var
);
}
}
let _axis = ctx.assign_axis(var);
let body_free_vars = body.free_vars();
let mut axes_vec: Vec<char> = body_free_vars.iter().map(|v| ctx.assign_axis(v)).collect();
axes_vec.sort();
let axes: String = axes_vec.into_iter().collect();
let init_float = match init_value {
InitValue::Zero => 0.0,
InitValue::One => 1.0,
};
let init_name = format!("fixpoint_init_{}", init_float);
let mut current_tensor_idx = graph.add_tensor(init_name);
let saved_binding = ctx.let_bindings.get(var).copied();
let unroll_depth = get_unroll_depth();
for _iteration in 0..unroll_depth {
ctx.let_bindings.insert(var.to_string(), current_tensor_idx);
let iteration_result = compile_expr(body, ctx, graph)?;
current_tensor_idx = iteration_result.tensor_idx;
}
if let Some(prev_binding) = saved_binding {
ctx.let_bindings.insert(var.to_string(), prev_binding);
} else {
ctx.let_bindings.remove(var);
}
Ok(CompileState {
tensor_idx: current_tensor_idx,
axes,
})
}
fn get_unroll_depth() -> usize {
DEFAULT_UNROLL_DEPTH
}
fn infer_fixpoint_domain(body: &TLExpr, _var: &str) -> Option<String> {
match body {
TLExpr::Exists { domain, .. }
| TLExpr::ForAll { domain, .. }
| TLExpr::Aggregate { domain, .. }
| TLExpr::SoftExists { domain, .. }
| TLExpr::SoftForAll { domain, .. }
| TLExpr::SetComprehension { domain, .. }
| TLExpr::CountingExists { domain, .. }
| TLExpr::CountingForAll { domain, .. }
| TLExpr::ExactCount { domain, .. }
| TLExpr::Majority { domain, .. } => Some(domain.clone()),
TLExpr::And(left, right) | TLExpr::Or(left, right) => {
infer_fixpoint_domain(left, _var).or_else(|| infer_fixpoint_domain(right, _var))
}
TLExpr::Not(inner)
| TLExpr::Box(inner)
| TLExpr::Diamond(inner)
| TLExpr::Next(inner)
| TLExpr::Eventually(inner)
| TLExpr::Always(inner)
| TLExpr::WeightedRule {
rule: inner,
weight: _,
} => infer_fixpoint_domain(inner, _var),
TLExpr::Until { before, after } => {
infer_fixpoint_domain(before, _var).or_else(|| infer_fixpoint_domain(after, _var))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
#[test]
fn test_least_fixpoint_simple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Node", 10);
let mut graph = EinsumGraph::new();
let body = TLExpr::pred("P", vec![Term::var("x")]);
ctx.bind_var("x", "Node").unwrap();
let result = compile_least_fixpoint("X", &body, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_greatest_fixpoint_simple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("State", 5);
let mut graph = EinsumGraph::new();
let body = TLExpr::pred("Safe", vec![Term::var("s")]);
ctx.bind_var("s", "State").unwrap();
let result = compile_greatest_fixpoint("X", &body, &mut ctx, &mut graph).unwrap();
assert!(!result.axes.is_empty());
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_fixpoint_with_recursion() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Node", 10);
let mut graph = EinsumGraph::new();
let r = TLExpr::pred("R", vec![Term::var("x"), Term::var("y")]);
let x = TLExpr::pred("X", vec![Term::var("x"), Term::var("y")]);
let body = TLExpr::or(r, x);
ctx.bind_var("x", "Node").unwrap();
ctx.bind_var("y", "Node").unwrap();
let _result = compile_least_fixpoint("X", &body, &mut ctx, &mut graph).unwrap();
assert!(!graph.nodes.is_empty());
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_fixpoint_unbound_variable_fails() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let body = TLExpr::pred("P", vec![Term::var("x")]);
let result = compile_least_fixpoint("X", &body, &mut ctx, &mut graph);
assert!(result.is_ok());
}
#[test]
fn test_fixpoint_with_quantifier_inference() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Node", 10);
let mut graph = EinsumGraph::new();
let body = TLExpr::exists(
"y",
"Node",
TLExpr::and(
TLExpr::pred("R", vec![Term::var("x"), Term::var("y")]),
TLExpr::pred("X", vec![Term::var("y"), Term::var("z")]),
),
);
ctx.bind_var("x", "Node").unwrap();
ctx.bind_var("z", "Node").unwrap();
let _result = compile_least_fixpoint("X", &body, &mut ctx, &mut graph).unwrap();
assert!(!graph.nodes.is_empty());
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_least_vs_greatest_both_compile() {
let mut ctx1 = CompilerContext::new();
let mut ctx2 = CompilerContext::new();
ctx1.add_domain("D", 5);
ctx2.add_domain("D", 5);
let mut graph1 = EinsumGraph::new();
let mut graph2 = EinsumGraph::new();
let body = TLExpr::pred("P", vec![Term::var("x")]);
ctx1.bind_var("x", "D").unwrap();
ctx2.bind_var("x", "D").unwrap();
let _least_result = compile_least_fixpoint("X", &body, &mut ctx1, &mut graph1).unwrap();
let _greatest_result =
compile_greatest_fixpoint("X", &body, &mut ctx2, &mut graph2).unwrap();
assert!(!graph1.tensors.is_empty());
assert!(!graph2.tensors.is_empty());
assert!(!graph1.tensors.is_empty());
assert!(!graph2.tensors.is_empty());
}
}