use std::collections::HashMap;
use std::sync::Arc;
use polars::prelude::*;
use polars_plan::dsl::PlanSerializationContext;
use taxa_core::source::{FrameSource, Source};
use taxa_core::TransformSource;
fn placeholder(name: &str) -> LazyFrame {
LazyFrame::scan_parquet(
PlPath::new(&format!("taxa://{name}")),
ScanArgsParquet::default(),
)
.expect("scan_parquet should build a lazy placeholder without reading")
}
fn serialize(lf: LazyFrame) -> Vec<u8> {
let mut buf = Vec::new();
lf.logical_plan
.serialize_versioned(&mut buf, PlanSerializationContext::default())
.expect("serialize_versioned");
buf
}
#[test]
fn join_rebinds_both_sources() {
let plan = placeholder("left")
.join(
placeholder("right"),
[col("key")],
[col("key")],
JoinArgs::new(JoinType::Inner),
)
.select([col("key"), col("lval"), col("rval")])
.sort(["key"], Default::default());
let bytes = serialize(plan);
let left = df![
"key" => ["a", "b", "c"],
"lval" => [1_i64, 2, 3],
]
.unwrap();
let right = df![
"key" => ["a", "b", "c"],
"rval" => [10_i64, 20, 30],
]
.unwrap();
let mut sources: HashMap<String, Arc<dyn Source>> = HashMap::new();
sources.insert("left".into(), Arc::new(FrameSource::new(left)));
sources.insert("right".into(), Arc::new(FrameSource::new(right)));
let src = TransformSource::new(sources, bytes);
let out = src.frame().unwrap().collect().unwrap();
let expected = df![
"key" => ["a", "b", "c"],
"lval" => [1_i64, 2, 3],
"rval" => [10_i64, 20, 30],
]
.unwrap();
assert_eq!(out, expected, "both Join leaves must rebind + execute");
}
#[test]
fn single_source_filter_groupby() {
let plan = placeholder("x")
.filter(col("px").gt(lit(1.0_f64)))
.group_by([col("sym")])
.agg([col("px").sum().alias("total")])
.sort(["sym"], Default::default());
let bytes = serialize(plan);
let x = df![
"sym" => ["AAPL", "MSFT", "AAPL", "GOOG", "MSFT"],
"px" => [1.0_f64, 2.0, 3.0, 4.0, 5.0],
]
.unwrap();
let mut sources: HashMap<String, Arc<dyn Source>> = HashMap::new();
sources.insert("x".into(), Arc::new(FrameSource::new(x)));
let src = TransformSource::new(sources, bytes);
let out = src.frame().unwrap().collect().unwrap();
let expected = df![
"sym" => ["AAPL", "GOOG", "MSFT"],
"total" => [3.0_f64, 4.0, 7.0],
]
.unwrap();
assert_eq!(out, expected);
}
#[test]
fn real_path_scan_is_rejected() {
let plan = LazyFrame::scan_parquet(PlPath::new("/etc/passwd"), ScanArgsParquet::default())
.expect("lazy scan builds without reading")
.select([col("*")]);
let bytes = serialize(plan);
let mut sources: HashMap<String, Arc<dyn Source>> = HashMap::new();
sources.insert(
"present".into(),
Arc::new(FrameSource::new(df!["k" => [1_i64]].unwrap())),
);
let src = TransformSource::new(sources, bytes);
let msg = match src.frame() {
Ok(_) => panic!("a real-path file scan must be rejected, not executed"),
Err(e) => format!("{e}"),
};
assert!(
msg.contains("/etc/passwd"),
"error should name the disallowed path, got: {msg}"
);
assert!(
msg.contains("taxa://"),
"error should say sources must be taxa:// named, got: {msg}"
);
}
#[test]
fn unknown_source_errors() {
let plan = placeholder("missing").select([col("*")]);
let bytes = serialize(plan);
let mut sources: HashMap<String, Arc<dyn Source>> = HashMap::new();
sources.insert(
"present".into(),
Arc::new(FrameSource::new(df!["k" => [1_i64]].unwrap())),
);
let src = TransformSource::new(sources, bytes);
let msg = match src.frame() {
Ok(_) => panic!("undeclared placeholder must error"),
Err(e) => format!("{e}"),
};
assert!(
msg.contains("missing"),
"error should name the undeclared source, got: {msg}"
);
}