use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, EinsumNode, OpType};
use crate::compile::compile_expr;
use crate::context::{CompileState, CompilerContext};
const DEFAULT_STATE_SPACE_SIZE: usize = 10;
const STATE_AXIS: &str = "__state__";
pub(crate) fn compile_nominal(
name: &str,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
ensure_state_domain(ctx)?;
let state_index = get_nominal_index(name, ctx)?;
let nominal_name = format!("nominal_{}", name);
let nominal_idx = graph.add_tensor(nominal_name.clone());
let state_axis = ctx.assign_axis(STATE_AXIS);
let metadata = format!("nominal:{}:index:{}", name, state_index);
graph
.tensors
.get_mut(nominal_idx)
.unwrap()
.push_str(&format!("#{}", metadata));
Ok(CompileState {
tensor_idx: nominal_idx,
axes: state_axis.to_string(),
})
}
pub(crate) fn compile_at(
nominal: &str,
formula: &tensorlogic_ir::TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
ensure_state_domain(ctx)?;
let _state_index = get_nominal_index(nominal, ctx)?;
let formula_result = compile_expr(formula, ctx, graph)?;
let state_axis = ctx.assign_axis(STATE_AXIS);
if !formula_result.axes.contains(state_axis) {
return Ok(formula_result);
}
let nominal_result = compile_nominal(nominal, ctx, graph)?;
let selected_name = ctx.fresh_temp();
let selected_idx = graph.add_tensor(selected_name);
let output_axes: String = formula_result
.axes
.chars()
.filter(|&c| c != state_axis)
.collect();
let spec = if formula_result.axes == state_axis.to_string() {
format!("{0},{0}->", state_axis)
} else {
format!(
"{},{}->{}",
formula_result.axes, nominal_result.axes, output_axes
)
};
let node = EinsumNode::new(
spec,
vec![formula_result.tensor_idx, nominal_result.tensor_idx],
vec![selected_idx],
);
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: selected_idx,
axes: output_axes,
})
}
pub(crate) fn compile_somewhere(
formula: &tensorlogic_ir::TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
ensure_state_domain(ctx)?;
let formula_result = compile_expr(formula, ctx, graph)?;
let state_axis = ctx.assign_axis(STATE_AXIS);
if !formula_result.axes.contains(state_axis) {
return Ok(formula_result);
}
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let output_axes: String = formula_result
.axes
.chars()
.filter(|&c| c != state_axis)
.collect();
let node = EinsumNode {
op: OpType::Reduce {
op: "max".to_string(),
axes: vec![], },
inputs: vec![formula_result.tensor_idx],
outputs: vec![result_idx],
metadata: None,
};
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
pub(crate) fn compile_everywhere(
formula: &tensorlogic_ir::TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
ensure_state_domain(ctx)?;
let formula_result = compile_expr(formula, ctx, graph)?;
let state_axis = ctx.assign_axis(STATE_AXIS);
if !formula_result.axes.contains(state_axis) {
return Ok(formula_result);
}
let result_name = ctx.fresh_temp();
let result_idx = graph.add_tensor(result_name);
let output_axes: String = formula_result
.axes
.chars()
.filter(|&c| c != state_axis)
.collect();
let node = EinsumNode {
op: OpType::Reduce {
op: "min".to_string(),
axes: vec![], },
inputs: vec![formula_result.tensor_idx],
outputs: vec![result_idx],
metadata: None,
};
graph.add_node(node)?;
Ok(CompileState {
tensor_idx: result_idx,
axes: output_axes,
})
}
fn ensure_state_domain(ctx: &mut CompilerContext) -> Result<()> {
if !ctx.domains.contains_key(STATE_AXIS) {
ctx.add_domain(STATE_AXIS, DEFAULT_STATE_SPACE_SIZE);
}
Ok(())
}
fn get_nominal_index(name: &str, ctx: &mut CompilerContext) -> Result<usize> {
let state_size = ctx
.domains
.get(STATE_AXIS)
.map(|d| d.cardinality)
.unwrap_or(DEFAULT_STATE_SPACE_SIZE);
let mut hash: usize = 0;
for byte in name.bytes() {
hash = hash.wrapping_mul(31).wrapping_add(byte as usize);
}
let index = hash % state_size;
Ok(index)
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TLExpr, Term};
#[test]
fn test_nominal_compilation() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let result = compile_nominal("home", &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
let state_axis = ctx.assign_axis(STATE_AXIS);
assert!(result.axes.contains(state_axis));
}
#[test]
fn test_at_operator_simple() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 10);
let mut graph = EinsumGraph::new();
let safe = TLExpr::pred("Safe", vec![Term::var("x")]);
ctx.bind_var("x", "Person").unwrap();
let _result = compile_at("home", &safe, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_somewhere_operator() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let safe = TLExpr::pred("Safe", vec![]);
let _result = compile_somewhere(&safe, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_everywhere_operator() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let safe = TLExpr::pred("Safe", vec![]);
let _result = compile_everywhere(&safe, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_somewhere_with_free_variable() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 10);
let mut graph = EinsumGraph::new();
let knows = TLExpr::pred("Knows", vec![Term::var("x"), Term::var("y")]);
ctx.bind_var("x", "Person").unwrap();
ctx.bind_var("y", "Person").unwrap();
let _result = compile_somewhere(&knows, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_everywhere_with_free_variable() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 10);
let mut graph = EinsumGraph::new();
let safe = TLExpr::pred("Safe", vec![Term::var("x")]);
ctx.bind_var("x", "Person").unwrap();
let _result = compile_everywhere(&safe, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_nested_somewhere_everywhere() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let safe = TLExpr::pred("Safe", vec![]);
let everywhere_safe = TLExpr::Everywhere {
formula: Box::new(safe),
};
let _result = compile_somewhere(&everywhere_safe, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_multiple_nominals_distinct_indices() {
let mut ctx = CompilerContext::new();
ensure_state_domain(&mut ctx).unwrap();
let idx1 = get_nominal_index("home", &mut ctx).unwrap();
let idx2 = get_nominal_index("office", &mut ctx).unwrap();
let idx3 = get_nominal_index("home", &mut ctx).unwrap();
assert_eq!(idx1, idx3);
let _ = idx2;
}
#[test]
fn test_at_with_constant_formula() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let constant = TLExpr::Constant(1.0);
let result = compile_at("home", &constant, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
let state_axis = ctx.assign_axis(STATE_AXIS);
assert!(!result.axes.contains(state_axis));
}
}