use crate::{compile_to_einsum_with_context, CompilerContext};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use tensorlogic_ir::{EinsumGraph, IrError, TLExpr, Term};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExpressionDependencies {
pub predicates: HashSet<String>,
pub variables: HashSet<String>,
pub domains: HashSet<String>,
pub config_hash: u64,
}
impl ExpressionDependencies {
pub fn new() -> Self {
Self {
predicates: HashSet::new(),
variables: HashSet::new(),
domains: HashSet::new(),
config_hash: 0,
}
}
pub fn analyze(expr: &TLExpr, ctx: &CompilerContext) -> Self {
let mut deps = Self::new();
deps.analyze_recursive(expr);
deps.config_hash = Self::hash_config(ctx);
deps
}
fn analyze_recursive(&mut self, expr: &TLExpr) {
match expr {
TLExpr::Pred { name, args } => {
self.predicates.insert(name.clone());
for arg in args {
self.analyze_term(arg);
}
}
TLExpr::And(left, right) | TLExpr::Or(left, right) | TLExpr::Imply(left, right) => {
self.analyze_recursive(left);
self.analyze_recursive(right);
}
TLExpr::Not(inner) => {
self.analyze_recursive(inner);
}
TLExpr::Exists { var, domain, body } | TLExpr::ForAll { var, domain, body } => {
self.variables.insert(var.clone());
self.domains.insert(domain.clone());
self.analyze_recursive(body);
}
TLExpr::Score(inner) => {
self.analyze_recursive(inner);
}
TLExpr::Add(left, right)
| TLExpr::Sub(left, right)
| TLExpr::Mul(left, right)
| TLExpr::Div(left, right) => {
self.analyze_recursive(left);
self.analyze_recursive(right);
}
TLExpr::Eq(left, right)
| TLExpr::Lt(left, right)
| TLExpr::Gt(left, right)
| TLExpr::Lte(left, right)
| TLExpr::Gte(left, right) => {
self.analyze_recursive(left);
self.analyze_recursive(right);
}
TLExpr::IfThenElse {
condition,
then_branch,
else_branch,
} => {
self.analyze_recursive(condition);
self.analyze_recursive(then_branch);
self.analyze_recursive(else_branch);
}
TLExpr::Aggregate {
op: _,
var,
domain,
body,
group_by,
} => {
self.variables.insert(var.clone());
self.domains.insert(domain.clone());
self.analyze_recursive(body);
if let Some(gb_vars) = group_by {
for var_name in gb_vars {
self.variables.insert(var_name.clone());
}
}
}
TLExpr::TNorm {
kind: _,
left,
right,
}
| TLExpr::TCoNorm {
kind: _,
left,
right,
} => {
self.analyze_recursive(left);
self.analyze_recursive(right);
}
TLExpr::FuzzyNot {
kind: _,
expr: inner,
} => {
self.analyze_recursive(inner);
}
TLExpr::FuzzyImplication {
kind: _,
premise,
conclusion,
} => {
self.analyze_recursive(premise);
self.analyze_recursive(conclusion);
}
TLExpr::SoftExists {
var,
domain,
body,
temperature: _,
}
| TLExpr::SoftForAll {
var,
domain,
body,
temperature: _,
} => {
self.variables.insert(var.clone());
self.domains.insert(domain.clone());
self.analyze_recursive(body);
}
TLExpr::WeightedRule { weight: _, rule } => {
self.analyze_recursive(rule);
}
TLExpr::ProbabilisticChoice { alternatives } => {
for (_, alt) in alternatives {
self.analyze_recursive(alt);
}
}
TLExpr::Let { var, value, body } => {
self.variables.insert(var.clone());
self.analyze_recursive(value);
self.analyze_recursive(body);
}
TLExpr::Box(inner)
| TLExpr::Diamond(inner)
| TLExpr::Next(inner)
| TLExpr::Eventually(inner)
| TLExpr::Always(inner) => {
self.analyze_recursive(inner);
}
TLExpr::Until { before, after } | TLExpr::WeakUntil { before, after } => {
self.analyze_recursive(before);
self.analyze_recursive(after);
}
TLExpr::Release { released, releaser }
| TLExpr::StrongRelease { released, releaser } => {
self.analyze_recursive(released);
self.analyze_recursive(releaser);
}
TLExpr::Abs(inner)
| TLExpr::Sqrt(inner)
| TLExpr::Exp(inner)
| TLExpr::Log(inner)
| TLExpr::Sin(inner)
| TLExpr::Cos(inner)
| TLExpr::Tan(inner)
| TLExpr::Floor(inner)
| TLExpr::Ceil(inner)
| TLExpr::Round(inner) => {
self.analyze_recursive(inner);
}
TLExpr::Pow(left, right)
| TLExpr::Min(left, right)
| TLExpr::Max(left, right)
| TLExpr::Mod(left, right) => {
self.analyze_recursive(left);
self.analyze_recursive(right);
}
TLExpr::Constant(_) => {
}
_ => {
}
}
}
fn analyze_term(&mut self, term: &Term) {
if let Term::Var(name) = term {
self.variables.insert(name.clone());
}
}
fn hash_config(ctx: &CompilerContext) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
format!("{:?}", ctx.config).hash(&mut hasher);
hasher.finish()
}
}
impl Default for ExpressionDependencies {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ChangeDetector {
previous_predicates: HashMap<String, (usize, Vec<String>)>,
previous_domains: HashMap<String, usize>,
previous_config_hash: u64,
}
impl ChangeDetector {
pub fn new() -> Self {
Self {
previous_predicates: HashMap::new(),
previous_domains: HashMap::new(),
previous_config_hash: 0,
}
}
pub fn update(&mut self, ctx: &CompilerContext) {
self.previous_predicates.clear();
self.previous_domains.clear();
for (name, info) in &ctx.domains {
self.previous_domains.insert(name.clone(), info.cardinality);
}
self.previous_config_hash = ExpressionDependencies::hash_config(ctx);
}
pub fn detect_changes(&self, ctx: &CompilerContext) -> ChangeSet {
let mut changes = ChangeSet::new();
for (name, info) in &ctx.domains {
if let Some(&prev_size) = self.previous_domains.get(name.as_str()) {
if prev_size != info.cardinality {
changes.changed_domains.insert(name.clone());
}
} else {
changes.new_domains.insert(name.clone());
}
}
for name in self.previous_domains.keys() {
if !ctx.domains.contains_key(name) {
changes.removed_domains.insert(name.clone());
}
}
let current_hash = ExpressionDependencies::hash_config(ctx);
if current_hash != self.previous_config_hash {
changes.config_changed = true;
}
changes
}
}
impl Default for ChangeDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct ChangeSet {
pub new_predicates: HashSet<String>,
pub changed_predicates: HashSet<String>,
pub removed_predicates: HashSet<String>,
pub new_domains: HashSet<String>,
pub changed_domains: HashSet<String>,
pub removed_domains: HashSet<String>,
pub config_changed: bool,
}
impl ChangeSet {
fn new() -> Self {
Self::default()
}
pub fn has_changes(&self) -> bool {
!self.new_predicates.is_empty()
|| !self.changed_predicates.is_empty()
|| !self.removed_predicates.is_empty()
|| !self.new_domains.is_empty()
|| !self.changed_domains.is_empty()
|| !self.removed_domains.is_empty()
|| self.config_changed
}
pub fn affects(&self, deps: &ExpressionDependencies) -> bool {
if self.config_changed {
return true;
}
for pred in &deps.predicates {
if self.changed_predicates.contains(pred) || self.removed_predicates.contains(pred) {
return true;
}
}
for domain in &deps.domains {
if self.changed_domains.contains(domain) || self.removed_domains.contains(domain) {
return true;
}
}
false
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
graph: EinsumGraph,
dependencies: ExpressionDependencies,
#[allow(dead_code)]
timestamp: u64,
}
pub struct IncrementalCompiler {
context: CompilerContext,
cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
change_detector: ChangeDetector,
stats: Arc<Mutex<IncrementalStats>>,
next_timestamp: Arc<Mutex<u64>>,
}
impl IncrementalCompiler {
pub fn new(context: CompilerContext) -> Self {
let mut change_detector = ChangeDetector::new();
change_detector.update(&context);
Self {
context,
cache: Arc::new(Mutex::new(HashMap::new())),
change_detector,
stats: Arc::new(Mutex::new(IncrementalStats::default())),
next_timestamp: Arc::new(Mutex::new(0)),
}
}
pub fn context(&self) -> &CompilerContext {
&self.context
}
pub fn context_mut(&mut self) -> &mut CompilerContext {
&mut self.context
}
pub fn compile(&mut self, expr: &TLExpr) -> Result<EinsumGraph, IrError> {
let changes = self.change_detector.detect_changes(&self.context);
if changes.has_changes() {
self.invalidate_affected(&changes);
self.change_detector.update(&self.context);
}
let expr_key = format!("{:?}", expr);
let cache = self.cache.lock().expect("lock should not be poisoned");
if let Some(entry) = cache.get(&expr_key) {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.cache_hits += 1;
stats.nodes_reused += entry.graph.nodes.len();
drop(stats);
return Ok(entry.graph.clone());
}
drop(cache);
let deps = ExpressionDependencies::analyze(expr, &self.context);
let graph = compile_to_einsum_with_context(expr, &mut self.context).map_err(|e| {
IrError::InvalidEinsumSpec {
spec: format!("{:?}", expr),
reason: format!("Compilation failed: {}", e),
}
})?;
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.cache_misses += 1;
stats.nodes_compiled += graph.nodes.len();
drop(stats);
let mut timestamp_guard = self
.next_timestamp
.lock()
.expect("lock should not be poisoned");
let timestamp = *timestamp_guard;
*timestamp_guard += 1;
drop(timestamp_guard);
let mut cache = self.cache.lock().expect("lock should not be poisoned");
cache.insert(
expr_key,
CacheEntry {
graph: graph.clone(),
dependencies: deps,
timestamp,
},
);
Ok(graph)
}
fn invalidate_affected(&mut self, changes: &ChangeSet) {
let mut cache = self.cache.lock().expect("lock should not be poisoned");
cache.retain(|_, entry| !changes.affects(&entry.dependencies));
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.invalidations += 1;
}
pub fn clear_cache(&mut self) {
let mut cache = self.cache.lock().expect("lock should not be poisoned");
cache.clear();
}
pub fn stats(&self) -> IncrementalStats {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn reset_stats(&mut self) {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
*stats = IncrementalStats::default();
}
}
#[derive(Debug, Clone, Default)]
pub struct IncrementalStats {
pub cache_hits: usize,
pub cache_misses: usize,
pub invalidations: usize,
pub nodes_reused: usize,
pub nodes_compiled: usize,
}
impl IncrementalStats {
pub fn hit_rate(&self) -> f64 {
let total = self.cache_hits + self.cache_misses;
if total == 0 {
0.0
} else {
self.cache_hits as f64 / total as f64
}
}
pub fn reuse_rate(&self) -> f64 {
let total = self.nodes_reused + self.nodes_compiled;
if total == 0 {
0.0
} else {
self.nodes_reused as f64 / total as f64
}
}
pub fn total_compilations(&self) -> usize {
self.cache_hits + self.cache_misses
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dependency_tracking() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let deps = ExpressionDependencies::analyze(&expr, &ctx);
assert!(deps.predicates.contains("knows"));
assert!(deps.variables.contains("x"));
assert!(deps.variables.contains("y"));
}
#[test]
fn test_incremental_compilation_reuse() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut compiler = IncrementalCompiler::new(ctx);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let _graph1 = compiler.compile(&expr).expect("unwrap");
assert_eq!(compiler.stats().cache_misses, 1);
assert_eq!(compiler.stats().cache_hits, 0);
let _graph2 = compiler.compile(&expr).expect("unwrap");
assert_eq!(compiler.stats().cache_misses, 1);
assert_eq!(compiler.stats().cache_hits, 1);
assert_eq!(compiler.stats().hit_rate(), 0.5);
}
#[test]
fn test_change_detection_domain() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut detector = ChangeDetector::new();
detector.update(&ctx);
let changes = detector.detect_changes(&ctx);
assert!(!changes.has_changes());
ctx.add_domain("Person", 200);
let changes = detector.detect_changes(&ctx);
assert!(changes.has_changes());
assert!(changes.changed_domains.contains("Person"));
}
#[test]
fn test_invalidation_on_domain_change() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut compiler = IncrementalCompiler::new(ctx);
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let _graph1 = compiler.compile(&expr).expect("unwrap");
assert_eq!(compiler.stats().cache_misses, 1);
compiler.context_mut().add_domain("Person", 200);
let _graph2 = compiler.compile(&expr).expect("unwrap");
assert!(compiler.stats().cache_misses >= 1);
assert!(compiler.stats().invalidations >= 1);
}
#[test]
fn test_incremental_stats() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let mut compiler = IncrementalCompiler::new(ctx);
let expr1 = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let expr2 = TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]);
compiler.compile(&expr1).expect("unwrap");
compiler.compile(&expr1).expect("unwrap"); compiler.compile(&expr2).expect("unwrap");
let stats = compiler.stats();
assert_eq!(stats.total_compilations(), 3);
assert!(
stats.cache_hits >= 1,
"Expected at least 1 cache hit, got {}",
stats.cache_hits
);
assert!(
stats.hit_rate() > 0.0,
"Expected positive hit rate, got {}",
stats.hit_rate()
);
}
#[test]
fn test_complex_expression_dependencies() {
let mut ctx = CompilerContext::new();
ctx.add_domain("Person", 100);
let expr = TLExpr::exists(
"x",
"Person",
TLExpr::and(
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
TLExpr::pred("likes", vec![Term::var("x"), Term::var("z")]),
),
);
let deps = ExpressionDependencies::analyze(&expr, &ctx);
assert!(deps.predicates.contains("knows"));
assert!(deps.predicates.contains("likes"));
assert!(deps.variables.contains("x"));
assert!(deps.variables.contains("y"));
assert!(deps.variables.contains("z"));
assert!(deps.domains.contains("Person"));
}
}