use std::collections::HashMap;
use std::sync::Arc;
use relon_codegen_cranelift::AotEvaluator;
use relon_codegen_llvm::{EmittedEntryShape, EmittedFieldType, LlvmAotEvaluator};
use relon_eval_api::{Evaluator, Value};
use relon_evaluator::{Context, Scope, TreeWalkEvaluator};
use relon_parser::parse_document;
const CONST_LIST_SRC: &str = "#main() -> List<Int>\n[1, 2, 3]\n";
const PARAM_LIST_SRC: &str = "#main(Int n) -> List<Int>\n[n, n + 1, 7]\n";
const LIST_PARAM_SRC: &str = "#main(List<Int> xs) -> List<Int>\nxs\n";
fn as_int_list(v: &Value) -> Vec<i64> {
match v {
Value::List(items) => items
.iter()
.map(|e| match e {
Value::Int(n) => *n,
other => panic!("expected Int list element, got {other:?}"),
})
.collect(),
other => panic!("expected List result, got {other:?}"),
}
}
fn oracle(src: &str, args: HashMap<String, Value>) -> Vec<i64> {
let node = parse_document(src).expect("parse src");
let analyzed = Arc::new(relon_analyzer::analyze(&node));
let mut ctx = Context::new()
.with_root(node)
.with_analyzed(Arc::clone(&analyzed));
TreeWalkEvaluator::prepare_in_place(&mut ctx);
let walker = TreeWalkEvaluator::new(Arc::new(ctx));
let scope = Arc::new(Scope::default());
as_int_list(&walker.run_main(&scope, args).expect("tree-walker run_main"))
}
fn emit_to_tmp(name: &str, src: &str) -> Result<relon_codegen_llvm::EmitObjectInfo, String> {
let tmp_dir =
std::env::temp_dir().join(format!("relon_aot_list_{name}_{}", std::process::id()));
std::fs::create_dir_all(&tmp_dir).map_err(|e| format!("create tmp dir: {e}"))?;
let out = tmp_dir.join(format!("{name}.o"));
let symbol = format!("__test_aot_list_{name}");
let info = LlvmAotEvaluator::emit_object(src, &symbol, &out).map_err(|e| format!("{e:?}"))?;
let bytes = std::fs::metadata(&out)
.map_err(|e| format!("stat .o: {e}"))?
.len();
if bytes == 0 {
return Err("emit_object produced an empty .o".to_string());
}
Ok(info)
}
#[test]
fn const_list_emit_object_native_descriptors() {
let info = emit_to_tmp("const_list", CONST_LIST_SRC).expect("emit_object accepts List<Int>");
assert_eq!(info.shape, EmittedEntryShape::Buffer);
assert!(info.main_fields.is_empty(), "no #main params");
assert_eq!(info.return_fields.len(), 1);
assert_eq!(info.return_fields[0].ty, EmittedFieldType::ListInt);
assert!(info.return_has_tail, "List<Int> return is pointer-indirect");
}
#[test]
fn param_list_emit_object_native_descriptors() {
let info = emit_to_tmp("param_list", PARAM_LIST_SRC).expect("emit_object accepts Int->List");
assert_eq!(info.shape, EmittedEntryShape::Buffer);
assert_eq!(info.main_fields.len(), 1);
assert_eq!(info.main_fields[0].ty, EmittedFieldType::Int);
assert_eq!(info.return_fields.len(), 1);
assert_eq!(info.return_fields[0].ty, EmittedFieldType::ListInt);
}
#[test]
fn list_param_emit_object_descriptor_only() {
let info = emit_to_tmp("list_param", LIST_PARAM_SRC)
.expect("emit_object accepts List<Int> param descriptor");
assert_eq!(info.shape, EmittedEntryShape::Buffer);
assert_eq!(info.main_fields.len(), 1);
assert_eq!(info.main_fields[0].ty, EmittedFieldType::ListInt);
assert_eq!(info.param_names, vec!["xs".to_string()]);
assert_eq!(info.return_fields.len(), 1);
assert_eq!(info.return_fields[0].ty, EmittedFieldType::ListInt);
}
#[test]
fn const_list_value_e2e_three_way() {
let llvm = LlvmAotEvaluator::from_source(CONST_LIST_SRC)
.unwrap_or_else(|e| panic!("LLVM from_source: {e:?}"));
let cl = AotEvaluator::from_source(CONST_LIST_SRC)
.unwrap_or_else(|e| panic!("cranelift from_source: {e:?}"));
let want = oracle(CONST_LIST_SRC, HashMap::new());
assert_eq!(want, vec![1, 2, 3], "oracle sanity");
let got_llvm = as_int_list(&llvm.run_main(HashMap::new()).expect("llvm run_main"));
let got_cl = as_int_list(&cl.run_main(HashMap::new()).expect("cranelift run_main"));
assert_eq!(got_llvm, want, "LLVM List<Int> return decode diverged");
assert_eq!(got_cl, want, "cranelift List<Int> return decode diverged");
}
#[test]
fn param_list_value_e2e_three_way() {
let llvm = LlvmAotEvaluator::from_source(PARAM_LIST_SRC)
.unwrap_or_else(|e| panic!("LLVM from_source: {e:?}"));
let cl = AotEvaluator::from_source(PARAM_LIST_SRC)
.unwrap_or_else(|e| panic!("cranelift from_source: {e:?}"));
for n in [0_i64, 1, 5, -3, 100] {
let mut a = HashMap::new();
a.insert("n".to_string(), Value::Int(n));
let want = oracle(PARAM_LIST_SRC, a.clone());
assert_eq!(want, vec![n, n + 1, 7], "oracle sanity at n={n}");
let got_llvm = as_int_list(&llvm.run_main(a.clone()).expect("llvm run_main"));
let got_cl = as_int_list(&cl.run_main(a.clone()).expect("cranelift run_main"));
assert_eq!(got_llvm, want, "LLVM Int->List<Int> diverged at n={n}");
assert_eq!(got_cl, want, "cranelift Int->List<Int> diverged at n={n}");
}
}