use std::collections::HashMap;
use tensorlogic_ir::{EinsumGraph, EinsumNode, Metadata, TLExpr};
use crate::CompilerContext;
pub struct MetadataBuilder {
source_file: Option<String>,
rule_id: Option<String>,
rule_counter: usize,
}
impl MetadataBuilder {
pub fn new() -> Self {
Self {
source_file: None,
rule_id: None,
rule_counter: 0,
}
}
pub fn with_source_file(mut self, file: impl Into<String>) -> Self {
self.source_file = Some(file.into());
self
}
pub fn with_rule_id(mut self, rule_id: impl Into<String>) -> Self {
self.rule_id = Some(rule_id.into());
self
}
pub fn fresh_rule_id(&mut self) -> String {
let id = format!("rule_{}", self.rule_counter);
self.rule_counter += 1;
id
}
pub fn predicate_metadata(&mut self, name: &str, args: &[String]) -> Metadata {
let mut meta = Metadata::new().with_name(format!("predicate:{}", name));
if let Some(ref file) = self.source_file {
meta = meta.with_attribute("source_file", file.clone());
}
if let Some(ref rule) = self.rule_id {
meta = meta.with_attribute("rule_id", rule.clone());
}
meta = meta.with_attribute("predicate_name", name.to_string());
meta = meta.with_attribute("arity", args.len().to_string());
for (i, arg) in args.iter().enumerate() {
meta = meta.with_attribute(format!("arg_{}", i), arg.clone());
}
meta
}
pub fn logic_op_metadata(&mut self, op_type: &str, operand_count: usize) -> Metadata {
let mut meta = Metadata::new().with_name(format!("logic_op:{}", op_type));
if let Some(ref file) = self.source_file {
meta = meta.with_attribute("source_file", file.clone());
}
if let Some(ref rule) = self.rule_id {
meta = meta.with_attribute("rule_id", rule.clone());
}
meta = meta.with_attribute("operation", op_type.to_string());
meta = meta.with_attribute("operand_count", operand_count.to_string());
meta
}
pub fn quantifier_metadata(
&mut self,
quantifier_type: &str,
var: &str,
domain: &str,
) -> Metadata {
let mut meta = Metadata::new().with_name(format!("quantifier:{}", quantifier_type));
if let Some(ref file) = self.source_file {
meta = meta.with_attribute("source_file", file.clone());
}
if let Some(ref rule) = self.rule_id {
meta = meta.with_attribute("rule_id", rule.clone());
}
meta = meta.with_attribute("quantifier", quantifier_type.to_string());
meta = meta.with_attribute("variable", var.to_string());
meta = meta.with_attribute("domain", domain.to_string());
meta
}
pub fn from_expr(&mut self, expr: &TLExpr) -> Metadata {
match expr {
TLExpr::Pred { name, args } => {
let arg_names: Vec<String> = args.iter().map(|t| format!("{:?}", t)).collect();
self.predicate_metadata(name, &arg_names)
}
TLExpr::And(_, _) => self.logic_op_metadata("AND", 2),
TLExpr::Or(_, _) => self.logic_op_metadata("OR", 2),
TLExpr::Not(_) => self.logic_op_metadata("NOT", 1),
TLExpr::Imply(_, _) => self.logic_op_metadata("IMPLY", 2),
TLExpr::Exists { var, domain, .. } => self.quantifier_metadata("EXISTS", var, domain),
TLExpr::ForAll { var, domain, .. } => self.quantifier_metadata("FORALL", var, domain),
TLExpr::Constant(_) => Metadata::new()
.with_name("constant")
.with_attribute("type", "constant"),
_ => Metadata::new().with_name("expression"),
}
}
}
impl Default for MetadataBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn propagate_metadata(
graph: &mut EinsumGraph,
ctx: &CompilerContext,
_builder: &mut MetadataBuilder,
) {
let mut metadata_to_add: Vec<(usize, Metadata)> = Vec::new();
for (tensor_idx, tensor_name) in graph.tensors.iter().enumerate() {
if graph.inputs.contains(&tensor_idx) {
if let Some(domain_name) = ctx.var_to_domain.values().find(|d| {
tensor_name.starts_with(&format!("{}_", d))
|| tensor_name.contains(&format!("_{}_", d))
}) {
let mut meta = Metadata::new()
.with_name(format!("input_tensor:{}", tensor_name))
.with_attribute("domain", domain_name.clone())
.with_attribute("tensor_type", "input");
if let Some(domain_info) = ctx.domains.get(domain_name) {
meta = meta.with_attribute("cardinality", domain_info.cardinality.to_string());
}
metadata_to_add.push((tensor_idx, meta));
}
}
}
for (domain_name, domain_info) in &ctx.domains {
for &output_idx in &graph.outputs {
if let Some(tensor_name) = graph.tensors.get(output_idx) {
if tensor_name.contains(domain_name) {
let meta = Metadata::new()
.with_name(format!("output_tensor:{}", tensor_name))
.with_attribute("domain", domain_name.clone())
.with_attribute("cardinality", domain_info.cardinality.to_string())
.with_attribute("tensor_type", "output");
metadata_to_add.push((output_idx, meta));
}
}
}
}
for (idx, meta) in metadata_to_add {
graph.add_tensor_metadata(idx, meta);
}
}
pub struct MetadataCompilationResult {
pub graph: EinsumGraph,
pub builder: MetadataBuilder,
pub expr_to_nodes: HashMap<String, Vec<usize>>,
}
impl MetadataCompilationResult {
pub fn new(graph: EinsumGraph, builder: MetadataBuilder) -> Self {
Self {
graph,
builder,
expr_to_nodes: HashMap::new(),
}
}
pub fn record_expression(&mut self, expr_id: impl Into<String>, node_indices: Vec<usize>) {
self.expr_to_nodes.insert(expr_id.into(), node_indices);
}
pub fn get_nodes_for_expr(&self, expr_id: &str) -> Option<&[usize]> {
self.expr_to_nodes.get(expr_id).map(|v| v.as_slice())
}
}
pub fn attach_expr_metadata(node: &mut EinsumNode, expr: &TLExpr, builder: &mut MetadataBuilder) {
let metadata = builder.from_expr(expr);
node.set_metadata(metadata);
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
#[test]
fn test_metadata_builder_new() {
let builder = MetadataBuilder::new();
assert!(builder.source_file.is_none());
assert!(builder.rule_id.is_none());
assert_eq!(builder.rule_counter, 0);
}
#[test]
fn test_metadata_builder_with_source_file() {
let builder = MetadataBuilder::new().with_source_file("test.tl");
assert_eq!(builder.source_file, Some("test.tl".to_string()));
}
#[test]
fn test_metadata_builder_fresh_rule_id() {
let mut builder = MetadataBuilder::new();
let id1 = builder.fresh_rule_id();
let id2 = builder.fresh_rule_id();
assert_eq!(id1, "rule_0");
assert_eq!(id2, "rule_1");
}
#[test]
fn test_predicate_metadata() {
let mut builder = MetadataBuilder::new()
.with_source_file("test.tl")
.with_rule_id("rule_1");
let meta = builder.predicate_metadata("knows", &["x".to_string(), "y".to_string()]);
assert_eq!(meta.name, Some("predicate:knows".to_string()));
assert_eq!(meta.get_attribute("predicate_name"), Some("knows"));
assert_eq!(meta.get_attribute("arity"), Some("2"));
assert_eq!(meta.get_attribute("source_file"), Some("test.tl"));
assert_eq!(meta.get_attribute("rule_id"), Some("rule_1"));
}
#[test]
fn test_logic_op_metadata() {
let mut builder = MetadataBuilder::new();
let meta = builder.logic_op_metadata("AND", 2);
assert_eq!(meta.name, Some("logic_op:AND".to_string()));
assert_eq!(meta.get_attribute("operation"), Some("AND"));
assert_eq!(meta.get_attribute("operand_count"), Some("2"));
}
#[test]
fn test_quantifier_metadata() {
let mut builder = MetadataBuilder::new();
let meta = builder.quantifier_metadata("EXISTS", "x", "Person");
assert_eq!(meta.name, Some("quantifier:EXISTS".to_string()));
assert_eq!(meta.get_attribute("quantifier"), Some("EXISTS"));
assert_eq!(meta.get_attribute("variable"), Some("x"));
assert_eq!(meta.get_attribute("domain"), Some("Person"));
}
#[test]
fn test_from_expr_predicate() {
let mut builder = MetadataBuilder::new();
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let meta = builder.from_expr(&expr);
assert_eq!(meta.name, Some("predicate:knows".to_string()));
assert_eq!(meta.get_attribute("predicate_name"), Some("knows"));
}
#[test]
fn test_from_expr_and() {
let mut builder = MetadataBuilder::new();
let expr = TLExpr::And(
Box::new(TLExpr::pred("p", vec![Term::var("x")])),
Box::new(TLExpr::pred("q", vec![Term::var("y")])),
);
let meta = builder.from_expr(&expr);
assert_eq!(meta.name, Some("logic_op:AND".to_string()));
assert_eq!(meta.get_attribute("operation"), Some("AND"));
}
#[test]
fn test_from_expr_exists() {
let mut builder = MetadataBuilder::new();
let expr = TLExpr::exists(
"x",
"Person",
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
);
let meta = builder.from_expr(&expr);
assert_eq!(meta.name, Some("quantifier:EXISTS".to_string()));
assert_eq!(meta.get_attribute("quantifier"), Some("EXISTS"));
assert_eq!(meta.get_attribute("variable"), Some("x"));
assert_eq!(meta.get_attribute("domain"), Some("Person"));
}
#[test]
fn test_propagate_metadata_with_domains() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
ctx.bind_var("x", "Person").unwrap();
let mut graph = EinsumGraph::new();
let tensor_idx = graph.add_tensor("Person_x");
graph.inputs.push(tensor_idx);
let mut builder = MetadataBuilder::new();
propagate_metadata(&mut graph, &ctx, &mut builder);
let meta = graph.get_tensor_metadata(tensor_idx);
assert!(meta.is_some());
}
#[test]
fn test_metadata_compilation_result() {
let graph = EinsumGraph::new();
let builder = MetadataBuilder::new();
let mut result = MetadataCompilationResult::new(graph, builder);
result.record_expression("expr_1", vec![0, 1, 2]);
assert_eq!(result.get_nodes_for_expr("expr_1"), Some(&[0, 1, 2][..]));
assert_eq!(result.get_nodes_for_expr("expr_2"), None);
}
#[test]
fn test_attach_expr_metadata() {
let mut builder = MetadataBuilder::new();
let mut node = EinsumNode::new("ab->a", vec![0], vec![1]);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
attach_expr_metadata(&mut node, &expr, &mut builder);
let meta = node.get_metadata();
assert!(meta.is_some());
assert_eq!(meta.unwrap().get_attribute("predicate_name"), Some("knows"));
}
}