use std::collections::HashMap;
use relon_codegen_llvm::LlvmAotEvaluator;
use relon_eval_api::{Evaluator, Value};
fn bs_oracle(lo: i64, hi: i64, t: i64) -> i64 {
if hi - lo <= 1 {
return lo;
}
let mid = (lo + hi) / 2;
if mid <= t {
bs_oracle(mid, hi, t)
} else {
bs_oracle(lo, mid, t)
}
}
fn w17_oracle(n: i64) -> i64 {
let mut acc = 0i64;
for i in 0..n {
acc += bs_oracle(0, n, (i.wrapping_mul(31)) % n);
}
acc
}
fn extract_int(v: Value) -> i64 {
match v {
Value::Int(i) => i,
other => panic!("W17 return expected Int, got {other:?}"),
}
}
const W17_SRC: &str = "#unstrict\n\
#main(Int n) -> Int\n\
range(n).reduce(0, (acc, i) => acc + bs(0, n, (i * 31) % n))\n\
where {\n\
bs(lo, hi, t): hi - lo <= 1 ? lo : (\n\
(lo + hi) / 2 <= t\n\
? bs((lo + hi) / 2, hi, t)\n\
: bs(lo, (lo + hi) / 2, t)\n\
)\n\
}";
#[test]
fn w17_where_bound_recursive_helper_lowers_and_evaluates() {
let ev = LlvmAotEvaluator::from_source(W17_SRC)
.expect("W17 where-bound recursive helper compiles via LLVM AOT");
for n in [1i64, 2, 3, 5, 8, 16, 50, 100] {
let mut args = HashMap::new();
args.insert("n".to_string(), Value::Int(n));
let got = extract_int(ev.run_main(args).expect("run_main"));
let want = w17_oracle(n);
assert_eq!(
got, want,
"W17 binary-search LLVM AOT result mismatches tree-walker oracle for n={n}"
);
}
}