use lex_ast::canonicalize_program;
use lex_bytecode::{compile_program, vm::Vm, Value};
use lex_runtime::{DefaultHandler, Policy};
use lex_syntax::parse_source;
use std::sync::Arc;
fn run(src: &str, fn_name: &str, args: Vec<Value>) -> Value {
let prog = parse_source(src).expect("parse");
let stages = canonicalize_program(&prog);
if let Err(errs) = lex_types::check_program(&stages) {
panic!("type errors:\n{errs:#?}");
}
let bc = Arc::new(compile_program(&stages));
let handler = DefaultHandler::new(Policy::pure()).with_program(Arc::clone(&bc));
let mut vm = Vm::with_handler(&bc, Box::new(handler));
vm.call(fn_name, args).unwrap_or_else(|e| panic!("call {fn_name}: {e}"))
}
fn unwrap_ok(v: Value) -> Value {
match v {
Value::Variant { name, args } if name == "Ok" && args.len() == 1
=> args.into_iter().next().unwrap(),
other => panic!("expected Ok(_), got {other:?}"),
}
}
const SRC: &str = r#"
import "std.list" as list
import "std.arrow" as arrow
import "std.df" as df
# 6 rows; g cycles ["a","b","a","b","a","b"]; x = [1..6], y = [10..60].
fn build() -> Result[Table, Str] {
let xs := list.cons(1, list.cons(2, list.cons(3, list.cons(4, list.cons(5, list.cons(6, []))))))
let ys := list.cons(10, list.cons(20, list.cons(30, list.cons(40, list.cons(50, list.cons(60, []))))))
let cols := list.cons(("x", xs), list.cons(("y", ys), []))
arrow.from_int_columns(cols)
}
fn build_g() -> Result[Table, Str] {
let xs := list.cons(1, list.cons(2, list.cons(3, list.cons(4, list.cons(5, list.cons(6, []))))))
let ys := list.cons(10, list.cons(20, list.cons(30, list.cons(40, list.cons(50, list.cons(60, []))))))
let gs := list.cons("a", list.cons("b", list.cons("a", list.cons("b", list.cons("a", list.cons("b", []))))))
match arrow.from_int_columns(list.cons(("x", xs), list.cons(("y", ys), []))) {
Err(e) => Err(e),
Ok(t) => match arrow.from_str_columns(list.cons(("g", gs), [])) {
Err(e) => Err(e),
Ok(_) => Err("placeholder"),
},
}
}
fn filter_eq_3_nrows() -> Int {
match build() {
Err(_) => -1,
Ok(t) => match df.filter_eq_int(t, "x", 3) {
Err(_) => -2,
Ok(t2) => arrow.nrows(t2),
},
}
}
fn filter_gt_3_nrows() -> Int {
match build() {
Err(_) => -1,
Ok(t) => match df.filter_gt_int(t, "x", 3) {
Err(_) => -2,
Ok(t2) => arrow.nrows(t2),
},
}
}
fn sort_first_x_desc() -> Int {
match build() {
Err(_) => -1,
Ok(t) => match df.sort_by(t, "x", false) {
Err(_) => -2,
Ok(t2) => match arrow.col_sum_int(arrow.head(t2, 1), "x") {
Ok(s) => s,
Err(_) => -3,
},
},
}
}
# group_by single-key (the g column) with sum(x) + mean(y).
# Need a Table with three columns — use from_str_columns isn't ideal because
# we'd need a mixed builder. Instead, construct the g column separately by
# read_csv in a real test; for the unit test we sort-by-x then verify
# rather than build mixed-type. Keep this simple.
fn group_by_x_self() -> Int {
# group by "x" (each row distinct), sum(y). 6 distinct x values => 6 rows.
match build() {
Err(_) => -1,
Ok(t) => match df.group_by_agg(
t,
list.cons("x", []),
list.cons(("sum_y", "y", "sum"), [])
) {
Err(_) => -2,
Ok(t2) => arrow.nrows(t2),
},
}
}
"#;
#[test]
fn df_filter_eq_int() {
assert_eq!(run(SRC, "filter_eq_3_nrows", vec![]), Value::Int(1));
}
#[test]
fn df_filter_gt_int() {
assert_eq!(run(SRC, "filter_gt_3_nrows", vec![]), Value::Int(3));
}
#[test]
fn df_sort_by_desc() {
assert_eq!(run(SRC, "sort_first_x_desc", vec![]), Value::Int(6));
}
#[test]
fn df_group_by_agg() {
assert_eq!(run(SRC, "group_by_x_self", vec![]), Value::Int(6));
}
#[test]
fn df_kernels_via_direct_dispatch() {
use arrow_array::{Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
let schema = Schema::new(vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Int64, false),
]);
let xs = Int64Array::from(vec![1, 2, 3, 4, 5, 6]);
let ys = Int64Array::from(vec![10, 20, 30, 40, 50, 60]);
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(xs), Arc::new(ys)],
).unwrap();
let table = Value::ArrowTable(Arc::new(batch));
let r = lex_runtime::df::dispatch(
"filter_gt_int",
&[table.clone(), Value::Str("x".into()), Value::Int(4)],
).unwrap().unwrap();
let out = unwrap_ok(r);
if let Value::ArrowTable(t) = out {
assert_eq!(t.num_rows(), 2);
} else {
panic!("expected ArrowTable, got {out:?}");
}
}