use tensorlogic_compiler::{compile_to_einsum, JitCompiler};
use tensorlogic_ir::{TLExpr, Term};
fn knows_expr() -> TLExpr {
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")])
}
fn complex_expr() -> TLExpr {
TLExpr::and(
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
TLExpr::pred("likes", vec![Term::var("y"), Term::var("z")]),
)
}
#[test]
fn test_jit_output_matches_direct_compile() {
let expr = knows_expr();
let direct = compile_to_einsum(&expr).expect("direct compile");
let jit = JitCompiler::new(3);
for _ in 0..4 {
jit.compile(&expr).expect("jit compile");
}
let hot = jit.compile(&expr).expect("hot path");
let _ = (direct, hot);
}
#[test]
fn test_jit_stats_after_many_calls() {
let jit = JitCompiler::new(2);
let expr = TLExpr::pred("p", vec![Term::var("a"), Term::constant("1")]);
for _ in 0..10 {
jit.compile(&expr).expect("compile");
}
let stats = jit.stats();
assert!(
stats.jit_hits > 0,
"should have jit hits after 10 calls with threshold=2; stats={stats:?}"
);
}
#[test]
fn test_jit_multiple_distinct_expressions() {
let jit = JitCompiler::new(2);
let e1 = knows_expr();
let e2 = complex_expr();
let e3 = TLExpr::pred("foo", vec![Term::var("a")]);
for _ in 0..3 {
jit.compile(&e1).expect("e1");
jit.compile(&e2).expect("e2");
}
jit.compile(&e3).expect("e3 cold");
assert_eq!(jit.hot_path_count(), 2, "e1 and e2 should be in hot cache");
assert_eq!(jit.call_count(&e1), 3);
assert_eq!(jit.call_count(&e2), 3);
assert_eq!(jit.call_count(&e3), 1);
}
#[test]
fn test_jit_hot_graph_is_reused_across_calls() {
let jit = JitCompiler::new(2);
let expr = knows_expr();
for _ in 0..2 {
jit.compile(&expr).expect("cold call");
}
assert_eq!(jit.hot_path_count(), 1);
let calls_before = jit.stats().jit_hits;
let extra = 8usize;
for _ in 0..extra {
jit.compile(&expr).expect("hot call");
}
let calls_after = jit.stats().jit_hits;
assert_eq!(
calls_after - calls_before,
extra,
"all {extra} extra calls should have been hot-cache hits"
);
}
#[test]
fn test_jit_clear_cache_allows_repromotion() {
let mut jit = JitCompiler::new(2);
let expr = knows_expr();
for _ in 0..2 {
jit.compile(&expr).expect("compile before clear");
}
assert_eq!(jit.hot_path_count(), 1);
jit.clear_cache();
assert_eq!(jit.hot_path_count(), 0);
assert_eq!(jit.call_count(&expr), 0);
for _ in 0..2 {
jit.compile(&expr).expect("compile after clear");
}
assert_eq!(jit.hot_path_count(), 1, "should be re-promoted after clear");
}