use lex_ast::canonicalize_program;
use lex_bytecode::{compile_program, vm::Vm, Value};
use lex_runtime::{DefaultHandler, Policy};
use lex_syntax::parse_source;
use std::sync::Arc;
fn run(src: &str, fn_name: &str, args: Vec<Value>) -> Value {
let prog = parse_source(src).expect("parse");
let stages = canonicalize_program(&prog);
if let Err(errs) = lex_types::check_program(&stages) {
panic!("type errors:\n{errs:#?}");
}
let bc = Arc::new(compile_program(&stages));
let handler = DefaultHandler::new(Policy::pure()).with_program(Arc::clone(&bc));
let mut vm = Vm::with_handler(&bc, Box::new(handler));
vm.call(fn_name, args).unwrap_or_else(|e| panic!("call {fn_name}: {e}"))
}
fn unwrap_ok(v: Value) -> Value {
match v {
Value::Variant { name, args } if name == "Ok" && args.len() == 1
=> args.into_iter().next().unwrap(),
other => panic!("expected Ok(_), got {other:?}"),
}
}
const SRC: &str = r#"
import "std.list" as list
import "std.arrow" as arrow
import "std.df" as df
# 6 rows; g cycles ["a","b","a","b","a","b"]; x = [1..6], y = [10..60].
fn build() -> Result[Table, Str] {
let xs := list.cons(1, list.cons(2, list.cons(3, list.cons(4, list.cons(5, list.cons(6, []))))))
let ys := list.cons(10, list.cons(20, list.cons(30, list.cons(40, list.cons(50, list.cons(60, []))))))
let cols := list.cons(("x", xs), list.cons(("y", ys), []))
arrow.from_int_columns(cols)
}
fn build_g() -> Result[Table, Str] {
let xs := list.cons(1, list.cons(2, list.cons(3, list.cons(4, list.cons(5, list.cons(6, []))))))
let ys := list.cons(10, list.cons(20, list.cons(30, list.cons(40, list.cons(50, list.cons(60, []))))))
let gs := list.cons("a", list.cons("b", list.cons("a", list.cons("b", list.cons("a", list.cons("b", []))))))
match arrow.from_int_columns(list.cons(("x", xs), list.cons(("y", ys), []))) {
Err(e) => Err(e),
Ok(t) => match arrow.from_str_columns(list.cons(("g", gs), [])) {
Err(e) => Err(e),
Ok(_) => Err("placeholder"),
},
}
}
fn filter_eq_3_nrows() -> Int {
match build() {
Err(_) => -1,
Ok(t) => match df.filter_eq_int(t, "x", 3) {
Err(_) => -2,
Ok(t2) => arrow.nrows(t2),
},
}
}
fn filter_gt_3_nrows() -> Int {
match build() {
Err(_) => -1,
Ok(t) => match df.filter_gt_int(t, "x", 3) {
Err(_) => -2,
Ok(t2) => arrow.nrows(t2),
},
}
}
fn sort_first_x_desc() -> Int {
match build() {
Err(_) => -1,
Ok(t) => match df.sort_by(t, "x", false) {
Err(_) => -2,
Ok(t2) => match arrow.col_sum_int(arrow.head(t2, 1), "x") {
Ok(s) => s,
Err(_) => -3,
},
},
}
}
# group_by single-key (the g column) with sum(x) + mean(y).
# Need a Table with three columns — use from_str_columns isn't ideal because
# we'd need a mixed builder. Instead, construct the g column separately by
# read_csv in a real test; for the unit test we sort-by-x then verify
# rather than build mixed-type. Keep this simple.
fn group_by_x_self() -> Int {
# group by "x" (each row distinct), sum(y). 6 distinct x values => 6 rows.
match build() {
Err(_) => -1,
Ok(t) => match df.group_by_agg(
t,
list.cons("x", []),
list.cons(("sum_y", "y", "sum"), [])
) {
Err(_) => -2,
Ok(t2) => arrow.nrows(t2),
},
}
}
"#;
#[test]
fn df_filter_eq_int() {
assert_eq!(run(SRC, "filter_eq_3_nrows", vec![]), Value::Int(1));
}
#[test]
fn df_filter_gt_int() {
assert_eq!(run(SRC, "filter_gt_3_nrows", vec![]), Value::Int(3));
}
#[test]
fn df_sort_by_desc() {
assert_eq!(run(SRC, "sort_first_x_desc", vec![]), Value::Int(6));
}
#[test]
fn df_group_by_agg() {
assert_eq!(run(SRC, "group_by_x_self", vec![]), Value::Int(6));
}
#[test]
fn df_kernels_via_direct_dispatch() {
use arrow_array::{Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
let schema = Schema::new(vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Int64, false),
]);
let xs = Int64Array::from(vec![1, 2, 3, 4, 5, 6]);
let ys = Int64Array::from(vec![10, 20, 30, 40, 50, 60]);
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(xs), Arc::new(ys)],
).unwrap();
let table = Value::ArrowTable(Arc::new(batch));
let r = lex_runtime::df::dispatch(
"filter_gt_int",
&[table.clone(), Value::Str("x".into()), Value::Int(4)],
).unwrap().unwrap();
let out = unwrap_ok(r);
if let Value::ArrowTable(t) = out {
assert_eq!(t.num_rows(), 2);
} else {
panic!("expected ArrowTable, got {out:?}");
}
}
fn make_mixed_batch() -> Value {
use arrow_array::{Float64Array, Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
let schema = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Float64, false),
Field::new("g", DataType::Utf8, false),
Field::new("z", DataType::Int64, true),
]));
let xs = Int64Array::from(vec![1_i64, 2, 3, 4, 5, 6]);
let ys = Float64Array::from(vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]);
let gs = StringArray::from(vec!["a", "b", "a", "b", "a", "b"]);
let zs = Int64Array::from(vec![None, Some(20), None, Some(40), Some(50), Some(60)]);
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(xs), Arc::new(ys), Arc::new(gs), Arc::new(zs)],
).unwrap();
Value::ArrowTable(Arc::new(batch))
}
fn nrows_of(v: Value) -> usize {
match v {
Value::ArrowTable(t) => t.num_rows(),
other => panic!("expected ArrowTable, got {other:?}"),
}
}
fn unwrap_err(v: Value) -> String {
match v {
Value::Variant { name, args } if name == "Err" && args.len() == 1 => {
match args.into_iter().next().unwrap() {
Value::Str(s) => s.to_string(),
other => panic!("Err payload not Str: {other:?}"),
}
}
other => panic!("expected Err(_), got {other:?}"),
}
}
#[test]
fn df_filter_eq_str() {
let t = make_mixed_batch();
let r = lex_runtime::df::dispatch(
"filter_eq_str",
&[t, Value::Str("g".into()), Value::Str("a".into())],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 3);
}
#[test]
fn df_filter_in_str() {
let t = make_mixed_batch();
let needles = Value::List([Value::Str("a".into()), Value::Str("z".into())].into_iter().collect());
let r = lex_runtime::df::dispatch(
"filter_in_str",
&[t, Value::Str("g".into()), needles],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 3);
}
#[test]
fn df_filter_in_str_empty_list_is_empty_result() {
let t = make_mixed_batch();
let needles = Value::List(std::collections::VecDeque::new());
let r = lex_runtime::df::dispatch(
"filter_in_str",
&[t, Value::Str("g".into()), needles],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 0);
}
#[test]
fn df_filter_eq_str_on_int_column_is_err() {
let t = make_mixed_batch();
let r = lex_runtime::df::dispatch(
"filter_eq_str",
&[t, Value::Str("x".into()), Value::Str("1".into())],
).unwrap().unwrap();
let msg = unwrap_err(r);
assert!(msg.contains("expected") && msg.contains("Utf8"),
"expected type-mismatch with Utf8, got: {msg}");
}
#[test]
fn df_filter_lt_float() {
let t = make_mixed_batch();
let r = lex_runtime::df::dispatch(
"filter_lt_float",
&[t, Value::Str("y".into()), Value::Float(3.5)],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 3);
}
#[test]
fn df_filter_gt_float() {
let t = make_mixed_batch();
let r = lex_runtime::df::dispatch(
"filter_gt_float",
&[t, Value::Str("y".into()), Value::Float(3.5)],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 3);
}
#[test]
fn df_filter_eq_float() {
let t = make_mixed_batch();
let r = lex_runtime::df::dispatch(
"filter_eq_float",
&[t, Value::Str("y".into()), Value::Float(4.0)],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 1);
}
#[test]
fn df_filter_isnull_and_notnull() {
let t = make_mixed_batch();
let r = lex_runtime::df::dispatch(
"filter_isnull", &[t.clone(), Value::Str("z".into())],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 2, "z has 2 nulls");
let r = lex_runtime::df::dispatch(
"filter_notnull", &[t, Value::Str("z".into())],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 4, "z has 4 non-null values");
}
#[test]
fn df_filter_isnull_unknown_column_is_err() {
let t = make_mixed_batch();
let r = lex_runtime::df::dispatch(
"filter_isnull", &[t, Value::Str("nope".into())],
).unwrap().unwrap();
let msg = unwrap_err(r);
assert!(msg.contains("nope"), "error should name missing column: {msg}");
}
#[test]
fn df_drop_nulls() {
let t = make_mixed_batch();
let cols = Value::List([Value::Str("z".into())].into_iter().collect());
let r = lex_runtime::df::dispatch(
"drop_nulls", &[t, cols],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 4, "drop_nulls on z removes 2 rows");
}
#[test]
fn df_drop_nulls_empty_col_list_is_noop() {
let t = make_mixed_batch();
let empty_cols = Value::List(std::collections::VecDeque::new());
let r = lex_runtime::df::dispatch(
"drop_nulls", &[t, empty_cols],
).unwrap().unwrap();
assert_eq!(nrows_of(unwrap_ok(r)), 6, "empty col list is a no-op");
}