use std::collections::HashMap;
use std::sync::Arc;
use relon_codegen_cranelift::{AotEvaluator, SandboxConfig};
use relon_codegen_llvm::LlvmAotEvaluator;
use relon_eval_api::{Evaluator, Value};
use relon_evaluator::{Context, Scope, TreeWalkEvaluator};
use relon_ir::ir::{Func, IrType, Module as IrModule, Op, TaggedOp};
use relon_parser::{parse_document, TokenRange};
fn t(op: Op) -> TaggedOp {
TaggedOp {
op,
range: TokenRange::default(),
}
}
fn legacy_module(body: Vec<TaggedOp>, params: Vec<IrType>) -> IrModule {
let func = Func {
name: "run_main".to_string(),
params,
ret: IrType::I64,
body,
range: TokenRange::default(),
};
IrModule {
imports: vec![],
funcs: vec![func],
entry_func_index: Some(0),
closure_table: vec![],
}
}
fn run_both(ir: &IrModule, params: &[&str], args: &[i64]) -> i64 {
let cl = AotEvaluator::from_ir_direct(
ir.clone(),
SandboxConfig::default(),
params.iter().map(|s| s.to_string()).collect(),
)
.expect("cranelift compile");
let llvm = LlvmAotEvaluator::from_ir_direct(
ir.clone(),
params.iter().map(|s| s.to_string()).collect(),
)
.expect("llvm compile");
let cl_v = cl.run_main_legacy_i64(args).expect("cranelift run");
let llvm_v = llvm.run_main_legacy_i64(args).expect("llvm run");
assert_eq!(
cl_v, llvm_v,
"cranelift / llvm divergence for args {args:?}: cranelift={cl_v}, llvm={llvm_v}"
);
cl_v
}
fn min_select_body() -> Vec<TaggedOp> {
vec![
t(Op::LocalGet(0)),
t(Op::LocalGet(1)),
t(Op::LocalGet(0)),
t(Op::LocalGet(1)),
t(Op::Lt(IrType::I64)),
t(Op::Select { ty: IrType::I64 }),
t(Op::Return),
]
}
#[test]
fn select_min_three_way_consistent() {
let ir = legacy_module(min_select_body(), vec![IrType::I64, IrType::I64]);
let cases = [
(3, 10, 3),
(10, 3, 3),
(-5, -10, -10),
(7, 7, 7),
(i64::MIN, 0, i64::MIN),
(0, i64::MAX, 0),
];
for (x, y, expected) in cases {
let got = run_both(&ir, &["x", "y"], &[x, y]);
assert_eq!(got, expected, "min({x}, {y})");
}
}
#[test]
fn select_i32_discriminant_picks_arm() {
let body = vec![
t(Op::LocalGet(0)), t(Op::LocalGet(1)), t(Op::ConstI32(1)), t(Op::ConstI32(1)), t(Op::Eq(IrType::I32)), t(Op::Select { ty: IrType::I64 }),
t(Op::Return),
];
let ir = legacy_module(body, vec![IrType::I64, IrType::I64]);
let got = run_both(&ir, &["a", "b"], &[111, 222]);
assert_eq!(got, 111);
}
fn build_tree_walker(src: &str) -> (TreeWalkEvaluator, Arc<Scope>) {
let node = parse_document(src)
.unwrap_or_else(|e| panic!("parse failed for source:\n{src}\nerror: {e:?}"));
let analyzed = Arc::new(relon_analyzer::analyze(&node));
let mut ctx = Context::new()
.with_root(node)
.with_analyzed(Arc::clone(&analyzed));
TreeWalkEvaluator::prepare_in_place(&mut ctx);
(
TreeWalkEvaluator::new(Arc::new(ctx)),
Arc::new(Scope::default()),
)
}
fn llvm_source_two_i64(src: &str, a: i64, b: i64) -> i64 {
let ev = LlvmAotEvaluator::from_source(src).expect("LLVM from_source");
let mut args = HashMap::new();
args.insert("a".to_string(), Value::Int(a));
args.insert("b".to_string(), Value::Int(b));
match ev.run_main(args).expect("LLVM run_main") {
Value::Int(n) => n,
other => panic!("unexpected LLVM return {other:?}"),
}
}
fn oracle_source_two_i64(src: &str, a: i64, b: i64) -> i64 {
let (walker, scope) = build_tree_walker(src);
let mut args = HashMap::new();
args.insert("a".to_string(), Value::Int(a));
args.insert("b".to_string(), Value::Int(b));
match walker.run_main(&scope, args).expect("tree-walk run_main") {
Value::Int(n) => n,
other => panic!("unexpected tree-walk return {other:?}"),
}
}
#[test]
fn select_min_max_from_source_matches_tree_walker() {
let min_src = "#main(Int a, Int b) -> Int\nmin(a, b)";
let max_src = "#main(Int a, Int b) -> Int\nmax(a, b)";
for (a, b) in [(3, 10), (10, 3), (-5, -10), (7, 7)] {
let llvm_min = llvm_source_two_i64(min_src, a, b);
let oracle_min = oracle_source_two_i64(min_src, a, b);
assert_eq!(llvm_min, oracle_min, "min({a}, {b}) llvm vs tree-walk");
let llvm_max = llvm_source_two_i64(max_src, a, b);
let oracle_max = oracle_source_two_i64(max_src, a, b);
assert_eq!(llvm_max, oracle_max, "max({a}, {b}) llvm vs tree-walk");
}
}
fn br_table_dispatch_body(disc: i32) -> Vec<TaggedOp> {
vec![
t(Op::ConstI32(disc)),
t(Op::Block {
result_ty: None,
body: vec![
t(Op::Block {
result_ty: None,
body: vec![
t(Op::Block {
result_ty: None,
body: vec![t(Op::BrTable {
default: 2,
targets: vec![1, 0],
})],
}),
t(Op::ConstI64(200)),
t(Op::Return),
],
}),
t(Op::ConstI64(100)),
t(Op::Return),
],
}),
t(Op::ConstI64(999)),
t(Op::Return),
]
}
#[test]
fn br_table_three_way_consistent() {
let cases = [
(0_i32, 100_i64), (1, 200), (2, 999), (5, 999), (-1, 999), ];
for (disc, expected) in cases {
let ir = legacy_module(br_table_dispatch_body(disc), vec![IrType::I64]);
let got = run_both(&ir, &["x"], &[0]);
assert_eq!(got, expected, "BrTable discriminant {disc}");
}
}
#[test]
fn br_table_loop_back_edge_three_way_consistent() {
let body = vec![
t(Op::ConstI64(0)),
t(Op::LetSet {
idx: 0,
ty: IrType::I64,
}),
t(Op::Block {
result_ty: None,
body: vec![t(Op::Loop {
result_ty: None,
body: vec![
t(Op::LetGet {
idx: 0,
ty: IrType::I64,
}),
t(Op::ConstI64(1)),
t(Op::Add(IrType::I64)),
t(Op::LetSet {
idx: 0,
ty: IrType::I64,
}),
t(Op::LetGet {
idx: 0,
ty: IrType::I64,
}),
t(Op::ConstI64(3)),
t(Op::Ge(IrType::I64)), t(Op::BrTable {
default: 1,
targets: vec![0],
}),
],
})],
}),
t(Op::LetGet {
idx: 0,
ty: IrType::I64,
}),
t(Op::Return),
];
let ir = legacy_module(body, vec![IrType::I64]);
let got = run_both(&ir, &["x"], &[0]);
assert_eq!(got, 3, "BrTable loop back-edge final counter");
}