use tabulon::{Parser, PreparedExpr, Tabula, IdentityResolver, VarAccessStrategy};
#[repr(C)]
struct EvalCtx {
values: Vec<f64>,
cached: Vec<f64>,
hit: Vec<u8>, call_counts: Vec<u32>,
miss_counts: Vec<u32>,
}
impl EvalCtx {
fn with_values(values: Vec<f64>) -> Self {
let n = values.len();
Self {
values,
cached: vec![0.0; n],
hit: vec![0; n],
call_counts: vec![0; n],
miss_counts: vec![0; n],
}
}
}
#[tabulon::resolver]
fn get_var_test(idx: u32, ctx: &mut EvalCtx) -> f64 {
let i = idx as usize;
ctx.call_counts[i] = ctx.call_counts[i].saturating_add(1);
if ctx.hit[i] != 0 {
return ctx.cached[i];
}
ctx.miss_counts[i] = ctx.miss_counts[i].saturating_add(1);
let v = ctx.values[i];
ctx.cached[i] = v;
ctx.hit[i] = 1;
v
}
#[test]
fn resolver_basic_eval() {
let parser = Parser::new("(A + B) * C").unwrap();
let prepared: PreparedExpr<String> = parser.parse_with_var_resolver(&IdentityResolver).unwrap();
assert_eq!(prepared.ordered_vars, vec!["A", "B", "C"].into_iter().map(String::from).collect::<Vec<_>>());
let mut eng = Tabula::<EvalCtx>::new_ctx();
tabulon::register_resolver_typed!(eng, __tabulon_resolver_marker_get_var_test).unwrap();
let compiled = eng
.compile_prepared_with(&prepared, VarAccessStrategy::ResolverCall { symbol: "get_var_test" })
.unwrap();
let vals = vec![2.0, 3.0, 4.0];
let mut ctx = EvalCtx::with_values(vals);
let out = compiled.eval_resolver_ctx(&mut ctx).unwrap();
assert!((out - 20.0).abs() < 1e-9);
assert_eq!(ctx.miss_counts, vec![1, 1, 1]);
assert_eq!(ctx.call_counts, vec![1, 1, 1]);
}
#[test]
fn resolver_short_circuit_and_if_skips_unneeded_vars() {
let parser = Parser::new("A && B && if((D + E) < 3, F, G)").unwrap();
let prepared: PreparedExpr<String> = parser.parse_with_var_resolver(&IdentityResolver).unwrap();
let mut idx_of = std::collections::HashMap::new();
for (i, k) in prepared.ordered_vars.iter().enumerate() {
idx_of.insert(k.as_str(), i);
}
let idx = |name: &str| *idx_of.get(name).expect("var present");
let mut eng = Tabula::<EvalCtx>::new_ctx();
tabulon::register_resolver_typed!(eng, __tabulon_resolver_marker_get_var_test).unwrap();
let compiled = eng
.compile_prepared_with(&prepared, VarAccessStrategy::ResolverCall { symbol: "get_var_test" })
.unwrap();
let mut vals = vec![0.0; prepared.ordered_vars.len()];
for v in vals.iter_mut() { *v = 42.0; }
vals[idx("A")] = 0.0;
let mut ctx = EvalCtx::with_values(vals);
let out = compiled.eval_resolver_ctx(&mut ctx).unwrap();
assert!(out.abs() < 1e-12);
let mut expected_calls = vec![0u32; prepared.ordered_vars.len()];
let mut expected_misses = vec![0u32; prepared.ordered_vars.len()];
expected_calls[idx("A")] = 1;
expected_misses[idx("A")] = 1;
assert_eq!(ctx.call_counts, expected_calls, "unexpected resolver call pattern: {:?}", ctx.call_counts);
assert_eq!(ctx.miss_counts, expected_misses, "unexpected resolver miss pattern: {:?}", ctx.miss_counts);
}
#[test]
fn resolver_nested_if_memoizes_across_multiple_uses() {
let parser = Parser::new("if(A + B > 0, A, if(B + C, C, B + D))").unwrap();
let prepared: PreparedExpr<String> = parser.parse_with_var_resolver(&IdentityResolver).unwrap();
let mut name_to_idx = std::collections::HashMap::new();
for (i, k) in prepared.ordered_vars.iter().enumerate() {
name_to_idx.insert(k.as_str(), i);
}
let gi = |n: &str| *name_to_idx.get(n).expect("present");
let mut eng = Tabula::<EvalCtx>::new_ctx();
tabulon::register_resolver_typed!(eng, __tabulon_resolver_marker_get_var_test).unwrap();
let compiled = eng
.compile_prepared_with(&prepared, VarAccessStrategy::ResolverCall { symbol: "get_var_test" })
.unwrap();
let mut vals = vec![0.0; prepared.ordered_vars.len()];
let ia = gi("A");
let ib = gi("B");
let ic = gi("C");
let id = gi("D");
vals[ia] = -2.0;
vals[ib] = 1.0;
vals[ic] = 7.0;
vals[id] = 100.0;
let mut ctx = EvalCtx::with_values(vals);
let out = compiled.eval_resolver_ctx(&mut ctx).unwrap();
assert!((out - 7.0).abs() < 1e-12);
let mut expected_calls = vec![0u32; prepared.ordered_vars.len()];
expected_calls[ia] = 1;
expected_calls[ib] = 1;
expected_calls[ic] = 1;
let mut expected_misses = vec![0u32; prepared.ordered_vars.len()];
expected_misses[ia] = 1;
expected_misses[ib] = 1;
expected_misses[ic] = 1;
assert_eq!(ctx.call_counts, expected_calls, "call counts: {:?}", ctx.call_counts);
assert_eq!(ctx.miss_counts, expected_misses, "miss counts: {:?}", ctx.miss_counts);
}