use anyhow::{bail, Result};
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::compile::compile_expr;
use crate::context::{CompileState, CompilerContext};
pub(crate) fn compile_lambda(
var: &str,
var_type: &Option<String>,
body: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let type_name = match var_type {
Some(t) => t.as_str(),
None => {
bail!(
"Lambda variable '{}' requires a type annotation. \
Please specify the domain type (e.g., λx:Node.φ(x)).",
var
);
}
};
if !ctx.domains.contains_key(type_name) {
bail!(
"Lambda variable '{}' has unknown type '{}'. \
Please register the domain before using in lambda.",
var,
type_name
);
}
let prev_binding = ctx.var_to_domain.get(var).cloned();
ctx.bind_var(var, type_name)?;
let body_result = compile_expr(body, ctx, graph)?;
if let Some(domain) = prev_binding {
ctx.var_to_domain.insert(var.to_string(), domain);
} else {
ctx.var_to_domain.remove(var);
}
Ok(body_result)
}
pub(crate) fn compile_apply(
function: &TLExpr,
argument: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
match function {
TLExpr::Lambda {
var,
var_type,
body,
} => {
let type_name = match var_type {
Some(t) => t.as_str(),
None => {
bail!(
"Lambda variable '{}' requires a type annotation for beta reduction.",
var
);
}
};
if !ctx.domains.contains_key(type_name) {
bail!(
"Lambda variable '{}' has unknown type '{}'. \
Domain must be registered before beta reduction.",
var,
type_name
);
}
let arg_result = compile_expr(argument, ctx, graph)?;
let prev_domain_binding = ctx.var_to_domain.get(var).cloned();
let prev_axis_binding = ctx.var_to_axis.get(var).copied();
let prev_let_binding = ctx.let_bindings.get(var).copied();
ctx.let_bindings
.insert(var.to_string(), arg_result.tensor_idx);
ctx.bind_var(var, type_name)?;
if !arg_result.axes.is_empty() {
if let Some(first_axis) = arg_result.axes.chars().next() {
ctx.var_to_axis.insert(var.to_string(), first_axis);
}
}
let body_result = compile_expr(body, ctx, graph)?;
ctx.let_bindings.remove(var);
if let Some(domain) = prev_domain_binding {
ctx.var_to_domain.insert(var.to_string(), domain);
} else {
ctx.var_to_domain.remove(var);
}
if let Some(axis) = prev_axis_binding {
ctx.var_to_axis.insert(var.to_string(), axis);
} else {
ctx.var_to_axis.remove(var);
}
if let Some(tensor_idx) = prev_let_binding {
ctx.let_bindings.insert(var.to_string(), tensor_idx);
}
Ok(body_result)
}
_ => {
let func_result = compile_expr(function, ctx, graph)?;
let arg_result = compile_expr(argument, ctx, graph)?;
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let output_axes = merge_axes(&func_result.axes, &arg_result.axes);
let spec = if func_result.axes.is_empty() && arg_result.axes.is_empty() {
",->".to_string()
} else if func_result.axes.is_empty() {
format!(",{}->{}", arg_result.axes, output_axes)
} else if arg_result.axes.is_empty() {
format!("{},->{}", func_result.axes, output_axes)
} else {
format!("{},{}->{}", func_result.axes, arg_result.axes, output_axes)
};
let node = tensorlogic_ir::EinsumNode::new(
spec,
vec![func_result.tensor_idx, arg_result.tensor_idx],
vec![result_idx],
);
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
}
}
fn merge_axes(axes1: &str, axes2: &str) -> String {
let mut result = axes1.to_string();
for c in axes2.chars() {
if !result.contains(c) {
result.push(c);
}
}
let mut chars: Vec<char> = result.chars().collect();
chars.sort();
chars.into_iter().collect()
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
#[test]
fn test_lambda_simple_body() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Node", 10);
let mut graph = EinsumGraph::new();
let body = TLExpr::pred("P", vec![Term::var("x")]);
let result =
compile_lambda("x", &Some("Node".to_string()), &body, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
assert!(!result.axes.is_empty());
}
#[test]
fn test_beta_reduction_simple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Node", 10);
let mut graph = EinsumGraph::new();
let lambda_body = TLExpr::pred("P", vec![Term::var("x")]);
let lambda = TLExpr::lambda("x", Some("Node".to_string()), lambda_body);
let argument = TLExpr::pred("a", vec![]);
let _result = compile_apply(&lambda, &argument, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_beta_reduction_with_free_variable() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Node", 10);
let mut graph = EinsumGraph::new();
let lambda_body = TLExpr::pred("Connected", vec![Term::var("x"), Term::var("y")]);
let lambda = TLExpr::lambda("x", Some("Node".to_string()), lambda_body);
let argument = TLExpr::pred("source", vec![]);
ctx.bind_var("y", "Node").unwrap();
let _result = compile_apply(&lambda, &argument, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_lambda_with_unbound_type_fails() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let body = TLExpr::pred("P", vec![Term::var("x")]);
let result = compile_lambda("x", &Some("Node".to_string()), &body, &mut ctx, &mut graph);
assert!(result.is_err());
}
#[test]
fn test_apply_non_lambda() {
let mut ctx = CompilerContext::new();
ctx.add_domain("D", 5);
let mut graph = EinsumGraph::new();
let function = TLExpr::pred("P", vec![]);
let argument = TLExpr::pred("x", vec![]);
let _result = compile_apply(&function, &argument, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_nested_lambda_application() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Node", 10);
let mut graph = EinsumGraph::new();
let inner_body = TLExpr::pred("Connected", vec![Term::var("x"), Term::var("y")]);
let inner_lambda = TLExpr::lambda("y", Some("Node".to_string()), inner_body);
let outer_lambda = TLExpr::lambda("x", Some("Node".to_string()), inner_lambda);
let arg_a = TLExpr::pred("a", vec![]);
let arg_b = TLExpr::pred("b", vec![]);
let first_app = TLExpr::apply(outer_lambda, arg_a);
let second_app = TLExpr::apply(first_app, arg_b);
let _result = compile_expr(&second_app, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
}