use crate::{ArrangeKey, Column, DExpr, DBinOp, DataFrame, TidyAgg, TidyError, TidyFrame};
use std::collections::BTreeSet;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub enum ViewNode {
Scan { df: Rc<DataFrame> },
Filter {
input: Box<ViewNode>,
predicate: DExpr,
},
Select {
input: Box<ViewNode>,
columns: Vec<String>,
},
Mutate {
input: Box<ViewNode>,
assignments: Vec<(String, DExpr)>,
},
Arrange {
input: Box<ViewNode>,
keys: Vec<ArrangeKey>,
},
GroupSummarise {
input: Box<ViewNode>,
group_keys: Vec<String>,
aggregations: Vec<(String, TidyAgg)>,
},
StreamingGroupSummarise {
input: Box<ViewNode>,
group_keys: Vec<String>,
aggregations: Vec<(String, crate::StreamingAgg)>,
},
Distinct {
input: Box<ViewNode>,
columns: Vec<String>,
},
Join {
left: Box<ViewNode>,
right: Box<ViewNode>,
on: Vec<(String, String)>,
kind: JoinType,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
Semi,
Anti,
}
pub struct LazyView {
plan: ViewNode,
}
impl LazyView {
pub fn from_df(df: DataFrame) -> Self {
LazyView {
plan: ViewNode::Scan { df: Rc::new(df) },
}
}
pub fn from_rc(df: Rc<DataFrame>) -> Self {
LazyView {
plan: ViewNode::Scan { df },
}
}
pub fn filter(self, predicate: DExpr) -> Self {
LazyView {
plan: ViewNode::Filter {
input: Box::new(self.plan),
predicate,
},
}
}
pub fn select(self, columns: Vec<String>) -> Self {
LazyView {
plan: ViewNode::Select {
input: Box::new(self.plan),
columns,
},
}
}
pub fn mutate(self, assignments: Vec<(String, DExpr)>) -> Self {
LazyView {
plan: ViewNode::Mutate {
input: Box::new(self.plan),
assignments,
},
}
}
pub fn arrange(self, keys: Vec<ArrangeKey>) -> Self {
LazyView {
plan: ViewNode::Arrange {
input: Box::new(self.plan),
keys,
},
}
}
pub fn group_summarise(
self,
group_keys: Vec<String>,
aggregations: Vec<(String, TidyAgg)>,
) -> Self {
LazyView {
plan: ViewNode::GroupSummarise {
input: Box::new(self.plan),
group_keys,
aggregations,
},
}
}
pub fn distinct(self, columns: Vec<String>) -> Self {
LazyView {
plan: ViewNode::Distinct {
input: Box::new(self.plan),
columns,
},
}
}
pub fn join(self, right: LazyView, on: Vec<(String, String)>, kind: JoinType) -> Self {
LazyView {
plan: ViewNode::Join {
left: Box::new(self.plan),
right: Box::new(right.plan),
on,
kind,
},
}
}
pub fn collect(self) -> Result<TidyFrame, TidyError> {
let optimized = optimize(self.plan);
execute(optimized)
}
pub fn plan(&self) -> &ViewNode {
&self.plan
}
pub fn optimized_plan(self) -> ViewNode {
optimize(self.plan)
}
}
pub fn optimize(plan: ViewNode) -> ViewNode {
let plan = merge_filters(plan);
let plan = push_predicates_down(plan);
let plan = eliminate_redundant_selects(plan);
let plan = annotate_streamable_summarise(plan);
plan
}
fn try_streaming_agg(agg: &TidyAgg) -> Option<crate::StreamingAgg> {
use crate::StreamingAgg;
match agg {
TidyAgg::Count => Some(StreamingAgg::Count),
TidyAgg::Sum(c) => Some(StreamingAgg::Sum(c.clone())),
TidyAgg::Mean(c) => Some(StreamingAgg::Mean(c.clone())),
TidyAgg::Min(c) => Some(StreamingAgg::Min(c.clone())),
TidyAgg::Max(c) => Some(StreamingAgg::Max(c.clone())),
TidyAgg::Var(c) => Some(StreamingAgg::Var(c.clone())),
TidyAgg::Sd(c) => Some(StreamingAgg::Sd(c.clone())),
_ => None,
}
}
fn annotate_streamable_summarise(plan: ViewNode) -> ViewNode {
match plan {
ViewNode::GroupSummarise {
input,
group_keys,
aggregations,
} => {
let input = Box::new(annotate_streamable_summarise(*input));
let all_streaming: Option<Vec<(String, crate::StreamingAgg)>> = aggregations
.iter()
.map(|(name, agg)| try_streaming_agg(agg).map(|sa| (name.clone(), sa)))
.collect();
match all_streaming {
Some(streaming_aggs) => ViewNode::StreamingGroupSummarise {
input,
group_keys,
aggregations: streaming_aggs,
},
None => ViewNode::GroupSummarise {
input,
group_keys,
aggregations,
},
}
}
ViewNode::Filter { input, predicate } => ViewNode::Filter {
input: Box::new(annotate_streamable_summarise(*input)),
predicate,
},
ViewNode::Select { input, columns } => ViewNode::Select {
input: Box::new(annotate_streamable_summarise(*input)),
columns,
},
ViewNode::Mutate { input, assignments } => ViewNode::Mutate {
input: Box::new(annotate_streamable_summarise(*input)),
assignments,
},
ViewNode::Arrange { input, keys } => ViewNode::Arrange {
input: Box::new(annotate_streamable_summarise(*input)),
keys,
},
ViewNode::Distinct { input, columns } => ViewNode::Distinct {
input: Box::new(annotate_streamable_summarise(*input)),
columns,
},
ViewNode::Join {
left,
right,
on,
kind,
} => ViewNode::Join {
left: Box::new(annotate_streamable_summarise(*left)),
right: Box::new(annotate_streamable_summarise(*right)),
on,
kind,
},
ViewNode::StreamingGroupSummarise { .. } => plan,
ViewNode::Scan { .. } => plan,
}
}
fn merge_filters(plan: ViewNode) -> ViewNode {
match plan {
ViewNode::Filter { input, predicate } => {
let merged_input = merge_filters(*input);
match merged_input {
ViewNode::Filter {
input: inner,
predicate: inner_pred,
} => {
let combined = DExpr::BinOp {
op: DBinOp::And,
left: Box::new(inner_pred),
right: Box::new(predicate),
};
ViewNode::Filter {
input: inner,
predicate: combined,
}
}
other => ViewNode::Filter {
input: Box::new(other),
predicate,
},
}
}
ViewNode::Select { input, columns } => ViewNode::Select {
input: Box::new(merge_filters(*input)),
columns,
},
ViewNode::Mutate {
input,
assignments,
} => ViewNode::Mutate {
input: Box::new(merge_filters(*input)),
assignments,
},
ViewNode::Arrange { input, keys } => ViewNode::Arrange {
input: Box::new(merge_filters(*input)),
keys,
},
ViewNode::GroupSummarise {
input,
group_keys,
aggregations,
} => ViewNode::GroupSummarise {
input: Box::new(merge_filters(*input)),
group_keys,
aggregations,
},
ViewNode::Distinct { input, columns } => ViewNode::Distinct {
input: Box::new(merge_filters(*input)),
columns,
},
ViewNode::Join {
left,
right,
on,
kind,
} => ViewNode::Join {
left: Box::new(merge_filters(*left)),
right: Box::new(merge_filters(*right)),
on,
kind,
},
other => other, }
}
fn push_predicates_down(plan: ViewNode) -> ViewNode {
match plan {
ViewNode::Filter { input, predicate } => {
let optimized_input = push_predicates_down(*input);
push_filter_into(optimized_input, predicate)
}
ViewNode::Select { input, columns } => ViewNode::Select {
input: Box::new(push_predicates_down(*input)),
columns,
},
ViewNode::Mutate {
input,
assignments,
} => ViewNode::Mutate {
input: Box::new(push_predicates_down(*input)),
assignments,
},
ViewNode::Arrange { input, keys } => ViewNode::Arrange {
input: Box::new(push_predicates_down(*input)),
keys,
},
ViewNode::GroupSummarise {
input,
group_keys,
aggregations,
} => ViewNode::GroupSummarise {
input: Box::new(push_predicates_down(*input)),
group_keys,
aggregations,
},
ViewNode::Distinct { input, columns } => ViewNode::Distinct {
input: Box::new(push_predicates_down(*input)),
columns,
},
ViewNode::Join {
left,
right,
on,
kind,
} => ViewNode::Join {
left: Box::new(push_predicates_down(*left)),
right: Box::new(push_predicates_down(*right)),
on,
kind,
},
other => other,
}
}
fn push_filter_into(node: ViewNode, predicate: DExpr) -> ViewNode {
match node {
ViewNode::Select { input, columns } => ViewNode::Select {
input: Box::new(push_filter_into(*input, predicate)),
columns,
},
ViewNode::Arrange { input, keys } => ViewNode::Arrange {
input: Box::new(push_filter_into(*input, predicate)),
keys,
},
ViewNode::Mutate {
input,
assignments,
} => {
let pred_cols = expr_columns(&predicate);
let mutated_cols: BTreeSet<String> =
assignments.iter().map(|(name, _)| name.clone()).collect();
let references_mutated = pred_cols.iter().any(|c| mutated_cols.contains(c));
if references_mutated {
ViewNode::Filter {
input: Box::new(ViewNode::Mutate {
input,
assignments,
}),
predicate,
}
} else {
ViewNode::Mutate {
input: Box::new(push_filter_into(*input, predicate)),
assignments,
}
}
}
ViewNode::Join {
left,
right,
on,
kind,
} => {
let pred_cols = expr_columns(&predicate);
let left_cols = node_output_columns(&left);
let right_cols = node_output_columns(&right);
let all_in_left = pred_cols.iter().all(|c| left_cols.contains(c));
let all_in_right = pred_cols.iter().all(|c| right_cols.contains(c));
if all_in_left {
ViewNode::Join {
left: Box::new(push_filter_into(*left, predicate)),
right,
on,
kind,
}
} else if all_in_right {
ViewNode::Join {
left,
right: Box::new(push_filter_into(*right, predicate)),
on,
kind,
}
} else {
ViewNode::Filter {
input: Box::new(ViewNode::Join {
left,
right,
on,
kind,
}),
predicate,
}
}
}
other => ViewNode::Filter {
input: Box::new(other),
predicate,
},
}
}
fn eliminate_redundant_selects(plan: ViewNode) -> ViewNode {
match plan {
ViewNode::Select { input, columns } => {
let optimized_input = eliminate_redundant_selects(*input);
let input_cols = node_output_columns(&optimized_input);
let select_set: BTreeSet<&str> = columns.iter().map(|s| s.as_str()).collect();
let input_set: BTreeSet<&str> = input_cols.iter().map(|s| s.as_str()).collect();
if select_set == input_set {
optimized_input
} else {
ViewNode::Select {
input: Box::new(optimized_input),
columns,
}
}
}
ViewNode::Filter { input, predicate } => ViewNode::Filter {
input: Box::new(eliminate_redundant_selects(*input)),
predicate,
},
ViewNode::Mutate {
input,
assignments,
} => ViewNode::Mutate {
input: Box::new(eliminate_redundant_selects(*input)),
assignments,
},
ViewNode::Arrange { input, keys } => ViewNode::Arrange {
input: Box::new(eliminate_redundant_selects(*input)),
keys,
},
ViewNode::GroupSummarise {
input,
group_keys,
aggregations,
} => ViewNode::GroupSummarise {
input: Box::new(eliminate_redundant_selects(*input)),
group_keys,
aggregations,
},
ViewNode::Distinct { input, columns } => ViewNode::Distinct {
input: Box::new(eliminate_redundant_selects(*input)),
columns,
},
ViewNode::Join {
left,
right,
on,
kind,
} => ViewNode::Join {
left: Box::new(eliminate_redundant_selects(*left)),
right: Box::new(eliminate_redundant_selects(*right)),
on,
kind,
},
other => other,
}
}
fn execute(node: ViewNode) -> Result<TidyFrame, TidyError> {
match node {
ViewNode::Scan { df } => Ok(TidyFrame::from_df((*df).clone())),
ViewNode::Filter { input, predicate } => {
let frame = execute(*input)?;
let view = frame.view();
let filtered = view.filter(&predicate)?;
let df = filtered.materialize()?;
Ok(TidyFrame::from_df(df))
}
ViewNode::Select { input, columns } => {
let frame = execute(*input)?;
let view = frame.view();
let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
let selected = view.select(&col_refs)?;
let df = selected.materialize()?;
Ok(TidyFrame::from_df(df))
}
ViewNode::Mutate {
input,
assignments,
} => {
let frame = execute(*input)?;
let view = frame.view();
let assign_refs: Vec<(&str, DExpr)> = assignments
.into_iter()
.map(|(name, expr)| (leaked_str(&name), expr))
.collect();
let result = view.mutate(&assign_refs.iter().map(|(n, e)| (*n, e.clone())).collect::<Vec<_>>())?;
Ok(result)
}
ViewNode::Arrange { input, keys } => {
let frame = execute(*input)?;
let view = frame.view();
let arranged = view.arrange(&keys)?;
let df = arranged.materialize()?;
Ok(TidyFrame::from_df(df))
}
ViewNode::GroupSummarise {
input,
group_keys,
aggregations,
} => {
let frame = execute(*input)?;
let view = frame.view();
let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
let grouped = view.group_by(&key_refs)?;
let agg_refs: Vec<(&str, TidyAgg)> = aggregations
.into_iter()
.map(|(name, agg)| (leaked_str(&name), agg))
.collect();
let result = grouped.summarise(
&agg_refs.iter().map(|(n, a)| (*n, a.clone())).collect::<Vec<_>>(),
)?;
Ok(result)
}
ViewNode::StreamingGroupSummarise {
input,
group_keys,
aggregations,
} => {
let frame = execute(*input)?;
let view = frame.view();
let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
let agg_owned: Vec<(String, crate::StreamingAgg)> = aggregations;
let agg_refs: Vec<(&str, crate::StreamingAgg)> = agg_owned
.iter()
.map(|(name, sa)| (leaked_str(name), sa.clone()))
.collect();
view.summarise_streaming(&key_refs, &agg_refs)
}
ViewNode::Distinct { input, columns } => {
let frame = execute(*input)?;
let view = frame.view();
let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
let distinct = view.distinct(&col_refs)?;
let df = distinct.materialize()?;
Ok(TidyFrame::from_df(df))
}
ViewNode::Join {
left,
right,
on,
kind,
} => {
let left_frame = execute(*left)?;
let right_frame = execute(*right)?;
let left_view = left_frame.view();
let right_view = right_frame.view();
let on_refs: Vec<(&str, &str)> = on
.iter()
.map(|(l, r)| (l.as_str(), r.as_str()))
.collect();
match kind {
JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
JoinType::Left => left_view.left_join(&right_view, &on_refs),
JoinType::Semi => {
let result = left_view.semi_join(&right_view, &on_refs)?;
let df = result.materialize()?;
Ok(TidyFrame::from_df(df))
}
JoinType::Anti => {
let result = left_view.anti_join(&right_view, &on_refs)?;
let df = result.materialize()?;
Ok(TidyFrame::from_df(df))
}
}
}
}
}
fn expr_columns(expr: &DExpr) -> BTreeSet<String> {
let mut cols = BTreeSet::new();
collect_expr_cols(expr, &mut cols);
cols
}
fn collect_expr_cols(expr: &DExpr, cols: &mut BTreeSet<String>) {
match expr {
DExpr::Col(name) => {
cols.insert(name.clone());
}
DExpr::BinOp { left, right, .. } => {
collect_expr_cols(left, cols);
collect_expr_cols(right, cols);
}
DExpr::Agg(_, inner) => collect_expr_cols(inner, cols),
DExpr::FnCall(_, args) => {
for arg in args {
collect_expr_cols(arg, cols);
}
}
DExpr::CumSum(e)
| DExpr::CumProd(e)
| DExpr::CumMax(e)
| DExpr::CumMin(e)
| DExpr::Lag(e, _)
| DExpr::Lead(e, _)
| DExpr::Rank(e)
| DExpr::DenseRank(e) => {
collect_expr_cols(e, cols);
}
DExpr::RollingSum(col, _)
| DExpr::RollingMean(col, _)
| DExpr::RollingMin(col, _)
| DExpr::RollingMax(col, _)
| DExpr::RollingVar(col, _)
| DExpr::RollingSd(col, _) => {
cols.insert(col.clone());
}
DExpr::LitInt(_)
| DExpr::LitFloat(_)
| DExpr::LitBool(_)
| DExpr::LitStr(_)
| DExpr::Count
| DExpr::RowNumber => {}
}
}
fn node_output_columns(node: &ViewNode) -> BTreeSet<String> {
match node {
ViewNode::Scan { df } => df.column_names().into_iter().map(|s| s.to_string()).collect(),
ViewNode::Filter { input, .. } => node_output_columns(input),
ViewNode::Select { columns, .. } => columns.iter().cloned().collect(),
ViewNode::Mutate {
input,
assignments,
} => {
let mut cols = node_output_columns(input);
for (name, _) in assignments {
cols.insert(name.clone());
}
cols
}
ViewNode::Arrange { input, .. } => node_output_columns(input),
ViewNode::GroupSummarise {
group_keys,
aggregations,
..
} => {
let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
for (name, _) in aggregations {
cols.insert(name.clone());
}
cols
}
ViewNode::StreamingGroupSummarise {
group_keys,
aggregations,
..
} => {
let mut cols: BTreeSet<String> = group_keys.iter().cloned().collect();
for (name, _) in aggregations {
cols.insert(name.clone());
}
cols
}
ViewNode::Distinct { input, .. } => node_output_columns(input),
ViewNode::Join {
left, right, on, ..
} => {
let mut cols = node_output_columns(left);
let right_cols = node_output_columns(right);
let left_keys: BTreeSet<&str> = on.iter().map(|(l, _)| l.as_str()).collect();
let right_keys: BTreeSet<&str> = on.iter().map(|(_, r)| r.as_str()).collect();
for c in &right_cols {
if !right_keys.contains(c.as_str()) || !left_keys.contains(c.as_str()) {
cols.insert(c.clone());
}
}
cols
}
}
}
fn leaked_str(s: &str) -> &'static str {
Box::leak(s.to_string().into_boxed_str())
}
impl ViewNode {
pub fn count_filters(&self) -> usize {
match self {
ViewNode::Filter { input, .. } => 1 + input.count_filters(),
ViewNode::Select { input, .. } => input.count_filters(),
ViewNode::Mutate { input, .. } => input.count_filters(),
ViewNode::Arrange { input, .. } => input.count_filters(),
ViewNode::GroupSummarise { input, .. } => input.count_filters(),
ViewNode::StreamingGroupSummarise { input, .. } => input.count_filters(),
ViewNode::Distinct { input, .. } => input.count_filters(),
ViewNode::Join { left, right, .. } => {
left.count_filters() + right.count_filters()
}
ViewNode::Scan { .. } => 0,
}
}
pub fn is_filter_on_scan(&self) -> bool {
match self {
ViewNode::Filter { input, .. } => matches!(input.as_ref(), ViewNode::Scan { .. }),
_ => false,
}
}
pub fn innermost(&self) -> &ViewNode {
match self {
ViewNode::Filter { input, .. }
| ViewNode::Select { input, .. }
| ViewNode::Mutate { input, .. }
| ViewNode::Arrange { input, .. }
| ViewNode::GroupSummarise { input, .. }
| ViewNode::StreamingGroupSummarise { input, .. }
| ViewNode::Distinct { input, .. } => input.innermost(),
ViewNode::Join { left, .. } => left.innermost(),
ViewNode::Scan { .. } => self,
}
}
pub fn kind(&self) -> &'static str {
match self {
ViewNode::Scan { .. } => "Scan",
ViewNode::Filter { .. } => "Filter",
ViewNode::Select { .. } => "Select",
ViewNode::Mutate { .. } => "Mutate",
ViewNode::Arrange { .. } => "Arrange",
ViewNode::GroupSummarise { .. } => "GroupSummarise",
ViewNode::StreamingGroupSummarise { .. } => "StreamingGroupSummarise",
ViewNode::Distinct { .. } => "Distinct",
ViewNode::Join { .. } => "Join",
}
}
pub fn node_kinds(&self) -> Vec<&'static str> {
let mut out = vec![self.kind()];
match self {
ViewNode::Filter { input, .. }
| ViewNode::Select { input, .. }
| ViewNode::Mutate { input, .. }
| ViewNode::Arrange { input, .. }
| ViewNode::GroupSummarise { input, .. }
| ViewNode::StreamingGroupSummarise { input, .. }
| ViewNode::Distinct { input, .. } => {
out.extend(input.node_kinds());
}
ViewNode::Join { left, right, .. } => {
out.extend(left.node_kinds());
out.extend(right.node_kinds());
}
ViewNode::Scan { .. } => {}
}
out
}
}
const BATCH_SIZE: usize = 2048;
#[derive(Debug, Clone)]
pub struct Batch {
pub columns: Vec<(String, Column)>,
pub nrows: usize,
}
impl Batch {
fn into_dataframe(self) -> DataFrame {
DataFrame {
columns: self.columns,
}
}
fn get_column(&self, name: &str) -> Option<&Column> {
self.columns.iter().find(|(n, _)| n == name).map(|(_, c)| c)
}
fn column_names(&self) -> Vec<&str> {
self.columns.iter().map(|(n, _)| n.as_str()).collect()
}
}
fn slice_column(col: &Column, start: usize, end: usize) -> Column {
if matches!(col, Column::CategoricalAdaptive(_)) {
return slice_column(&col.to_legacy_categorical(), start, end);
}
match col {
Column::Float(v) => Column::Float(v[start..end].to_vec()),
Column::Int(v) => Column::Int(v[start..end].to_vec()),
Column::Str(v) => Column::Str(v[start..end].to_vec()),
Column::Bool(v) => Column::Bool(v[start..end].to_vec()),
Column::Categorical { levels, codes } => Column::Categorical {
levels: levels.clone(),
codes: codes[start..end].to_vec(),
},
Column::DateTime(v) => Column::DateTime(v[start..end].to_vec()),
Column::CategoricalAdaptive(_) => unreachable!("handled by early return"),
}
}
fn split_batches(df: &DataFrame) -> Vec<Batch> {
let nrows = df.nrows();
if nrows == 0 {
return vec![Batch {
columns: df.columns.iter().map(|(n, c)| {
(n.clone(), slice_column(c, 0, 0))
}).collect(),
nrows: 0,
}];
}
let mut batches = Vec::new();
let mut offset = 0;
while offset < nrows {
let end = (offset + BATCH_SIZE).min(nrows);
let batch_cols = df
.columns
.iter()
.map(|(name, col)| (name.clone(), slice_column(col, offset, end)))
.collect();
batches.push(Batch {
columns: batch_cols,
nrows: end - offset,
});
offset = end;
}
batches
}
fn merge_batches(batches: Vec<Batch>) -> Result<DataFrame, TidyError> {
if batches.is_empty() {
return Ok(DataFrame::new());
}
let schema: Vec<String> = batches[0].column_names().iter().map(|s| s.to_string()).collect();
if schema.is_empty() {
return Ok(DataFrame::new());
}
let total_rows: usize = batches.iter().map(|b| b.nrows).sum();
let mut merged_cols: Vec<(String, Column)> = schema
.iter()
.map(|name| {
let first_col = batches[0].get_column(name).unwrap();
let empty = match first_col {
Column::Float(_) => Column::Float(Vec::with_capacity(total_rows)),
Column::Int(_) => Column::Int(Vec::with_capacity(total_rows)),
Column::Str(_) => Column::Str(Vec::with_capacity(total_rows)),
Column::Bool(_) => Column::Bool(Vec::with_capacity(total_rows)),
Column::Categorical { levels, .. } => Column::Categorical {
levels: levels.clone(),
codes: Vec::with_capacity(total_rows),
},
Column::CategoricalAdaptive(_) => {
let legacy = first_col.to_legacy_categorical();
if let Column::Categorical { levels, .. } = legacy {
Column::Categorical {
levels,
codes: Vec::with_capacity(total_rows),
}
} else {
Column::Str(Vec::with_capacity(total_rows))
}
}
Column::DateTime(_) => Column::DateTime(Vec::with_capacity(total_rows)),
};
(name.clone(), empty)
})
.collect();
for batch in &batches {
if batch.nrows == 0 {
continue;
}
for (i, (name, merged_col)) in merged_cols.iter_mut().enumerate() {
let batch_col = batch.get_column(name).ok_or_else(|| {
TidyError::ColumnNotFound(format!(
"batch merge: column '{}' missing in batch (index {})",
name, i
))
})?;
append_column(merged_col, batch_col);
}
}
Ok(DataFrame { columns: merged_cols })
}
fn append_column(dst: &mut Column, src: &Column) {
match (dst, src) {
(Column::Float(d), Column::Float(s)) => d.extend_from_slice(s),
(Column::Int(d), Column::Int(s)) => d.extend_from_slice(s),
(Column::Str(d), Column::Str(s)) => d.extend(s.iter().cloned()),
(Column::Bool(d), Column::Bool(s)) => d.extend_from_slice(s),
(Column::Categorical { codes: d, .. }, Column::Categorical { codes: s, .. }) => {
d.extend_from_slice(s);
}
(Column::DateTime(d), Column::DateTime(s)) => d.extend_from_slice(s),
_ => {} }
}
#[derive(Debug, Clone)]
enum StreamableOp {
Filter { predicate: DExpr },
Select { columns: Vec<String> },
Mutate { assignments: Vec<(String, DExpr)> },
}
fn is_pipeline_breaker(node: &ViewNode) -> bool {
matches!(
node,
ViewNode::Arrange { .. }
| ViewNode::GroupSummarise { .. }
| ViewNode::StreamingGroupSummarise { .. }
| ViewNode::Distinct { .. }
| ViewNode::Join { .. }
)
}
fn collect_streamable_chain(node: ViewNode) -> (Vec<StreamableOp>, Box<ViewNode>) {
let mut ops = Vec::new();
let mut current = node;
loop {
match current {
ViewNode::Filter { input, predicate } => {
ops.push(StreamableOp::Filter { predicate });
current = *input;
}
ViewNode::Select { input, columns } => {
ops.push(StreamableOp::Select { columns });
current = *input;
}
ViewNode::Mutate { input, assignments } => {
ops.push(StreamableOp::Mutate { assignments });
current = *input;
}
other => {
ops.reverse();
return (ops, Box::new(other));
}
}
}
}
fn apply_op_to_batch(batch: Batch, op: &StreamableOp) -> Result<Batch, TidyError> {
match op {
StreamableOp::Filter { predicate } => {
let df = batch.into_dataframe();
if df.nrows() == 0 {
return Ok(Batch {
nrows: 0,
columns: df.columns,
});
}
let frame = TidyFrame::from_df(df);
let view = frame.view();
let filtered = view.filter(predicate)?;
let result_df = filtered.materialize()?;
let nrows = result_df.nrows();
Ok(Batch {
columns: result_df.columns,
nrows,
})
}
StreamableOp::Select { columns } => {
let selected: Vec<(String, Column)> = columns
.iter()
.filter_map(|name| {
batch
.columns
.iter()
.find(|(n, _)| n == name)
.cloned()
})
.collect();
Ok(Batch {
nrows: batch.nrows,
columns: selected,
})
}
StreamableOp::Mutate { assignments } => {
let df = batch.into_dataframe();
let frame = TidyFrame::from_df(df);
let view = frame.view();
let assign_refs: Vec<(&str, DExpr)> = assignments
.iter()
.map(|(name, expr)| (leaked_str(name), expr.clone()))
.collect();
let result = view.mutate(
&assign_refs
.iter()
.map(|(n, e)| (*n, e.clone()))
.collect::<Vec<_>>(),
)?;
let result_df = result.borrow().clone();
let nrows = result_df.nrows();
Ok(Batch {
columns: result_df.columns,
nrows,
})
}
}
}
fn apply_chain_batched(
frame: &TidyFrame,
chain: &[StreamableOp],
) -> Result<TidyFrame, TidyError> {
let df = frame.borrow().clone();
let batches = split_batches(&df);
let mut result_batches: Vec<Batch> = Vec::new();
for batch in batches {
let mut current = batch;
for op in chain {
current = apply_op_to_batch(current, op)?;
}
if current.nrows > 0 {
result_batches.push(current);
}
}
if result_batches.is_empty() {
let empty_df = DataFrame {
columns: df
.columns
.iter()
.map(|(name, col)| {
(name.clone(), slice_column(col, 0, 0))
})
.collect(),
};
let mut result_cols: Option<Vec<String>> = None;
for op in chain {
if let StreamableOp::Select { columns } = op {
result_cols = Some(columns.clone());
}
}
if let Some(cols) = result_cols {
let pruned: Vec<(String, Column)> = cols
.iter()
.filter_map(|name| {
empty_df
.columns
.iter()
.find(|(n, _)| n == name)
.cloned()
})
.collect();
return Ok(TidyFrame::from_df(DataFrame { columns: pruned }));
}
return Ok(TidyFrame::from_df(empty_df));
}
let merged = merge_batches(result_batches)?;
Ok(TidyFrame::from_df(merged))
}
pub fn execute_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
match &node {
ViewNode::Scan { .. } => execute(node),
_ if !is_pipeline_breaker(&node) => {
let (chain, base) = collect_streamable_chain(node);
if chain.is_empty() {
return execute_batched(*base);
}
let base_frame = execute_batched(*base)?;
apply_chain_batched(&base_frame, &chain)
}
_ => execute_breaker_batched(node),
}
}
fn execute_breaker_batched(node: ViewNode) -> Result<TidyFrame, TidyError> {
match node {
ViewNode::Arrange { input, keys } => {
let frame = execute_batched(*input)?;
let view = frame.view();
let arranged = view.arrange(&keys)?;
let df = arranged.materialize()?;
Ok(TidyFrame::from_df(df))
}
ViewNode::GroupSummarise {
input,
group_keys,
aggregations,
} => {
let frame = execute_batched(*input)?;
let view = frame.view();
let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
let grouped = view.group_by(&key_refs)?;
let agg_refs: Vec<(&str, TidyAgg)> = aggregations
.into_iter()
.map(|(name, agg)| (leaked_str(&name), agg))
.collect();
let result = grouped.summarise(
&agg_refs
.iter()
.map(|(n, a)| (*n, a.clone()))
.collect::<Vec<_>>(),
)?;
Ok(result)
}
ViewNode::StreamingGroupSummarise {
input,
group_keys,
aggregations,
} => {
let frame = execute_batched(*input)?;
let view = frame.view();
let key_refs: Vec<&str> = group_keys.iter().map(|s| s.as_str()).collect();
let agg_owned: Vec<(String, crate::StreamingAgg)> = aggregations;
let agg_refs: Vec<(&str, crate::StreamingAgg)> = agg_owned
.iter()
.map(|(name, sa)| (leaked_str(name), sa.clone()))
.collect();
view.summarise_streaming(&key_refs, &agg_refs)
}
ViewNode::Distinct { input, columns } => {
let frame = execute_batched(*input)?;
let view = frame.view();
let col_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
let distinct = view.distinct(&col_refs)?;
let df = distinct.materialize()?;
Ok(TidyFrame::from_df(df))
}
ViewNode::Join {
left,
right,
on,
kind,
} => {
let left_frame = execute_batched(*left)?;
let right_frame = execute_batched(*right)?;
let left_view = left_frame.view();
let right_view = right_frame.view();
let on_refs: Vec<(&str, &str)> =
on.iter().map(|(l, r)| (l.as_str(), r.as_str())).collect();
match kind {
JoinType::Inner => left_view.inner_join(&right_view, &on_refs),
JoinType::Left => left_view.left_join(&right_view, &on_refs),
JoinType::Semi => {
let result = left_view.semi_join(&right_view, &on_refs)?;
let df = result.materialize()?;
Ok(TidyFrame::from_df(df))
}
JoinType::Anti => {
let result = left_view.anti_join(&right_view, &on_refs)?;
let df = result.materialize()?;
Ok(TidyFrame::from_df(df))
}
}
}
other => execute(other),
}
}
impl LazyView {
pub fn collect_batched(self) -> Result<TidyFrame, TidyError> {
let optimized = optimize(self.plan);
execute_batched(optimized)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Column, DExpr, DBinOp, DataFrame, TidyAgg, ArrangeKey, TidyView};
fn test_df() -> DataFrame {
DataFrame {
columns: vec![
(
"name".to_string(),
Column::Str(vec![
"Alice".into(),
"Bob".into(),
"Carol".into(),
"Dave".into(),
]),
),
("age".to_string(), Column::Int(vec![30, 25, 35, 25])),
(
"score".to_string(),
Column::Float(vec![90.0, 85.0, 95.0, 80.0]),
),
],
}
}
fn dept_df() -> DataFrame {
DataFrame {
columns: vec![
(
"name".to_string(),
Column::Str(vec!["Alice".into(), "Bob".into(), "Eve".into()]),
),
(
"dept".to_string(),
Column::Str(vec!["Eng".into(), "Sales".into(), "Eng".into()]),
),
],
}
}
#[test]
fn lazy_filter_matches_eager() {
let df = test_df();
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(25)),
};
let eager_view = TidyView::from_df(df.clone());
let eager_filtered = eager_view.filter(&predicate).unwrap();
let eager_df = eager_filtered.materialize().unwrap();
let lazy_frame = LazyView::from_df(df)
.filter(predicate)
.collect()
.unwrap();
let lazy_df = lazy_frame.borrow();
assert_eq!(eager_df.nrows(), lazy_df.nrows());
assert_eq!(eager_df.nrows(), 2);
let eager_names: Vec<String> = match eager_df.get_column("name").unwrap() {
Column::Str(v) => v.clone(),
_ => panic!("expected Str"),
};
let lazy_names: Vec<String> = match lazy_df.get_column("name").unwrap() {
Column::Str(v) => v.clone(),
_ => panic!("expected Str"),
};
assert_eq!(eager_names, lazy_names);
}
#[test]
fn lazy_select_matches_eager() {
let df = test_df();
let eager_view = TidyView::from_df(df.clone());
let eager_selected = eager_view.select(&["name", "age"]).unwrap();
let eager_df = eager_selected.materialize().unwrap();
let lazy_frame = LazyView::from_df(df)
.select(vec!["name".into(), "age".into()])
.collect()
.unwrap();
let lazy_df = lazy_frame.borrow();
assert_eq!(eager_df.ncols(), 2);
assert_eq!(lazy_df.ncols(), 2);
assert_eq!(eager_df.column_names(), lazy_df.column_names());
}
#[test]
fn lazy_arrange_matches_eager() {
let df = test_df();
let keys = vec![ArrangeKey::asc("age")];
let eager_view = TidyView::from_df(df.clone());
let eager_arranged = eager_view.arrange(&keys).unwrap();
let eager_df = eager_arranged.materialize().unwrap();
let lazy_frame = LazyView::from_df(df)
.arrange(keys)
.collect()
.unwrap();
let lazy_df = lazy_frame.borrow();
let eager_ages = match eager_df.get_column("age").unwrap() {
Column::Int(v) => v.clone(),
_ => panic!("expected Int"),
};
let lazy_ages = match lazy_df.get_column("age").unwrap() {
Column::Int(v) => v.clone(),
_ => panic!("expected Int"),
};
assert_eq!(eager_ages, lazy_ages);
assert_eq!(eager_ages, vec![25, 25, 30, 35]);
}
#[test]
fn lazy_group_summarise_matches_eager() {
let df = test_df();
let eager_view = TidyView::from_df(df.clone());
let grouped = eager_view.group_by(&["age"]).unwrap();
let eager_frame = grouped
.summarise(&[("count", TidyAgg::Count)])
.unwrap();
let eager_df = eager_frame.borrow();
let lazy_frame = LazyView::from_df(df)
.group_summarise(
vec!["age".into()],
vec![("count".into(), TidyAgg::Count)],
)
.collect()
.unwrap();
let lazy_df = lazy_frame.borrow();
assert_eq!(eager_df.nrows(), lazy_df.nrows());
assert_eq!(eager_df.column_names(), lazy_df.column_names());
}
#[test]
fn predicate_pushdown_past_select() {
let df = test_df();
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(25)),
};
let lazy = LazyView::from_df(df)
.select(vec!["name".into(), "age".into()])
.filter(predicate);
let optimized = lazy.optimized_plan();
let kinds = optimized.node_kinds();
assert_eq!(kinds, vec!["Select", "Filter", "Scan"]);
}
#[test]
fn predicate_pushdown_past_arrange() {
let df = test_df();
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(25)),
};
let lazy = LazyView::from_df(df)
.arrange(vec![ArrangeKey::asc("age")])
.filter(predicate);
let optimized = lazy.optimized_plan();
let kinds = optimized.node_kinds();
assert_eq!(kinds, vec!["Arrange", "Filter", "Scan"]);
}
#[test]
fn predicate_not_pushed_past_mutate_when_dependent() {
let df = test_df();
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("doubled_age".into())),
right: Box::new(DExpr::LitInt(50)),
};
let lazy = LazyView::from_df(df)
.mutate(vec![(
"doubled_age".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(2)),
},
)])
.filter(predicate);
let optimized = lazy.optimized_plan();
let kinds = optimized.node_kinds();
assert_eq!(kinds, vec!["Filter", "Mutate", "Scan"]);
}
#[test]
fn predicate_pushed_past_mutate_when_independent() {
let df = test_df();
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("score".into())),
right: Box::new(DExpr::LitFloat(85.0)),
};
let lazy = LazyView::from_df(df)
.mutate(vec![(
"doubled_age".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(2)),
},
)])
.filter(predicate);
let optimized = lazy.optimized_plan();
let kinds = optimized.node_kinds();
assert_eq!(kinds, vec!["Mutate", "Filter", "Scan"]);
}
#[test]
fn predicate_not_pushed_past_group_summarise() {
let df = test_df();
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("count".into())),
right: Box::new(DExpr::LitInt(1)),
};
let lazy = LazyView::from_df(df)
.group_summarise(
vec!["age".into()],
vec![("count".into(), TidyAgg::Count)],
)
.filter(predicate);
let optimized = lazy.optimized_plan();
let kinds = optimized.node_kinds();
assert!(
kinds == vec!["Filter", "GroupSummarise", "Scan"]
|| kinds == vec!["Filter", "StreamingGroupSummarise", "Scan"],
"filter must stay above the group node, got {:?}",
kinds
);
}
#[test]
fn consecutive_filters_merged() {
let df = test_df();
let pred1 = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(20)),
};
let pred2 = DExpr::BinOp {
op: DBinOp::Lt,
left: Box::new(DExpr::Col("score".into())),
right: Box::new(DExpr::LitFloat(95.0)),
};
let lazy = LazyView::from_df(df).filter(pred1).filter(pred2);
let optimized = lazy.optimized_plan();
assert_eq!(optimized.count_filters(), 1);
let df2 = test_df();
let result = LazyView::from_df(df2)
.filter(DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(20)),
})
.filter(DExpr::BinOp {
op: DBinOp::Lt,
left: Box::new(DExpr::Col("score".into())),
right: Box::new(DExpr::LitFloat(95.0)),
})
.collect()
.unwrap();
let result_df = result.borrow();
assert_eq!(result_df.nrows(), 3);
}
#[test]
fn redundant_select_eliminated() {
let df = test_df();
let lazy = LazyView::from_df(df)
.select(vec!["name".into(), "age".into(), "score".into()]);
let optimized = lazy.optimized_plan();
assert_eq!(optimized.kind(), "Scan");
}
#[test]
fn non_redundant_select_kept() {
let df = test_df();
let lazy = LazyView::from_df(df).select(vec!["name".into(), "age".into()]);
let optimized = lazy.optimized_plan();
assert_eq!(optimized.kind(), "Select");
}
#[test]
fn determinism_3_runs_identical() {
for _ in 0..3 {
let df = test_df();
let result = LazyView::from_df(df)
.filter(DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(20)),
})
.select(vec!["name".into(), "age".into()])
.arrange(vec![ArrangeKey::desc("age")])
.collect()
.unwrap();
let result_df = result.borrow();
assert_eq!(result_df.nrows(), 4);
let ages = match result_df.get_column("age").unwrap() {
Column::Int(v) => v.clone(),
_ => panic!("expected Int"),
};
assert_eq!(ages, vec![35, 30, 25, 25]);
let names = match result_df.get_column("name").unwrap() {
Column::Str(v) => v.clone(),
_ => panic!("expected Str"),
};
assert_eq!(names, vec!["Carol", "Alice", "Bob", "Dave"]);
}
}
#[test]
fn lazy_inner_join() {
let left = test_df();
let right = dept_df();
let result = LazyView::from_df(left)
.join(
LazyView::from_df(right),
vec![("name".into(), "name".into())],
JoinType::Inner,
)
.collect()
.unwrap();
let result_df = result.borrow();
assert_eq!(result_df.nrows(), 2);
assert!(result_df.get_column("dept").is_some());
}
#[test]
fn lazy_semi_join() {
let left = test_df();
let right = dept_df();
let result = LazyView::from_df(left)
.join(
LazyView::from_df(right),
vec![("name".into(), "name".into())],
JoinType::Semi,
)
.collect()
.unwrap();
let result_df = result.borrow();
assert_eq!(result_df.nrows(), 2);
assert!(result_df.get_column("dept").is_none());
}
#[test]
fn lazy_anti_join() {
let left = test_df();
let right = dept_df();
let result = LazyView::from_df(left)
.join(
LazyView::from_df(right),
vec![("name".into(), "name".into())],
JoinType::Anti,
)
.collect()
.unwrap();
let result_df = result.borrow();
assert_eq!(result_df.nrows(), 2);
}
#[test]
fn lazy_distinct() {
let df = test_df();
let result = LazyView::from_df(df)
.distinct(vec!["age".into()])
.collect()
.unwrap();
let result_df = result.borrow();
assert_eq!(result_df.nrows(), 3);
}
#[test]
fn complex_lazy_chain() {
let df = test_df();
let result = LazyView::from_df(df)
.filter(DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(20)),
})
.mutate(vec![(
"bonus".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("score".into())),
right: Box::new(DExpr::LitFloat(1.1)),
},
)])
.select(vec!["name".into(), "bonus".into()])
.arrange(vec![ArrangeKey::desc("bonus")])
.collect()
.unwrap();
let result_df = result.borrow();
assert_eq!(result_df.nrows(), 4);
assert_eq!(result_df.ncols(), 2);
assert_eq!(result_df.column_names(), vec!["name", "bonus"]);
}
#[test]
fn predicate_pushdown_into_join_left_side() {
let left = test_df();
let right = dept_df();
let lazy = LazyView::from_df(left)
.join(
LazyView::from_df(right),
vec![("name".into(), "name".into())],
JoinType::Inner,
)
.filter(DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(25)),
});
let optimized = lazy.optimized_plan();
let kinds = optimized.node_kinds();
assert_eq!(kinds[0], "Join");
if let ViewNode::Join { left, right, .. } = &optimized {
assert_eq!(left.kind(), "Filter");
assert_eq!(right.kind(), "Scan");
} else {
panic!("expected Join at top");
}
}
fn assert_df_eq(a: &DataFrame, b: &DataFrame, context: &str) {
assert_eq!(
a.nrows(),
b.nrows(),
"{}: nrows differ ({} vs {})",
context,
a.nrows(),
b.nrows()
);
assert_eq!(
a.column_names(),
b.column_names(),
"{}: column names differ",
context
);
for (name_a, col_a) in &a.columns {
let col_b = b.get_column(name_a).unwrap_or_else(|| {
panic!("{}: column '{}' missing in b", context, name_a)
});
assert_col_eq(col_a, col_b, &format!("{} col '{}'", context, name_a));
}
}
fn assert_col_eq(a: &Column, b: &Column, context: &str) {
match (a, b) {
(Column::Int(va), Column::Int(vb)) => assert_eq!(va, vb, "{}", context),
(Column::Float(va), Column::Float(vb)) => {
assert_eq!(va.len(), vb.len(), "{}: float len", context);
for (i, (x, y)) in va.iter().zip(vb.iter()).enumerate() {
assert!(
(x - y).abs() < 1e-12,
"{}: float[{}] {} != {}",
context,
i,
x,
y
);
}
}
(Column::Str(va), Column::Str(vb)) => assert_eq!(va, vb, "{}", context),
(Column::Bool(va), Column::Bool(vb)) => assert_eq!(va, vb, "{}", context),
_ => panic!("{}: column type mismatch", context),
}
}
#[test]
fn batched_filter_parity() {
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(25)),
};
let eager = LazyView::from_df(test_df())
.filter(predicate.clone())
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.filter(predicate)
.collect_batched()
.unwrap();
assert_df_eq(&eager.borrow(), &batched.borrow(), "filter parity");
}
#[test]
fn batched_select_parity() {
let cols = vec!["name".into(), "score".into()];
let eager = LazyView::from_df(test_df())
.select(cols.clone())
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.select(cols)
.collect_batched()
.unwrap();
assert_df_eq(&eager.borrow(), &batched.borrow(), "select parity");
}
#[test]
fn batched_mutate_parity() {
let assignments = vec![(
"doubled".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(2)),
},
)];
let eager = LazyView::from_df(test_df())
.mutate(assignments.clone())
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.mutate(assignments)
.collect_batched()
.unwrap();
assert_df_eq(&eager.borrow(), &batched.borrow(), "mutate parity");
}
#[test]
fn batched_filter_select_mutate_chain_parity() {
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(20)),
};
let assignments = vec![(
"bonus".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("score".into())),
right: Box::new(DExpr::LitFloat(1.1)),
},
)];
let eager = LazyView::from_df(test_df())
.filter(predicate.clone())
.mutate(assignments.clone())
.select(vec!["name".into(), "bonus".into()])
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.filter(predicate)
.mutate(assignments)
.select(vec!["name".into(), "bonus".into()])
.collect_batched()
.unwrap();
assert_df_eq(
&eager.borrow(),
&batched.borrow(),
"filter+mutate+select chain parity",
);
}
#[test]
fn batched_group_summarise_parity() {
let eager = LazyView::from_df(test_df())
.group_summarise(
vec!["age".into()],
vec![("count".into(), TidyAgg::Count)],
)
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.group_summarise(
vec!["age".into()],
vec![("count".into(), TidyAgg::Count)],
)
.collect_batched()
.unwrap();
assert_df_eq(
&eager.borrow(),
&batched.borrow(),
"group_summarise parity",
);
}
#[test]
fn batched_arrange_parity() {
let keys = vec![ArrangeKey::asc("age")];
let eager = LazyView::from_df(test_df())
.arrange(keys.clone())
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.arrange(keys)
.collect_batched()
.unwrap();
assert_df_eq(&eager.borrow(), &batched.borrow(), "arrange parity");
}
#[test]
fn batched_distinct_parity() {
let eager = LazyView::from_df(test_df())
.distinct(vec!["age".into()])
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.distinct(vec!["age".into()])
.collect_batched()
.unwrap();
assert_df_eq(&eager.borrow(), &batched.borrow(), "distinct parity");
}
#[test]
fn batched_join_parity() {
let eager = LazyView::from_df(test_df())
.join(
LazyView::from_df(dept_df()),
vec![("name".into(), "name".into())],
JoinType::Inner,
)
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.join(
LazyView::from_df(dept_df()),
vec![("name".into(), "name".into())],
JoinType::Inner,
)
.collect_batched()
.unwrap();
assert_df_eq(&eager.borrow(), &batched.borrow(), "join parity");
}
#[test]
fn batched_complex_pipeline_parity() {
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(20)),
};
let assignments = vec![(
"bonus".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("score".into())),
right: Box::new(DExpr::LitFloat(1.1)),
},
)];
let eager = LazyView::from_df(test_df())
.filter(predicate.clone())
.mutate(assignments.clone())
.select(vec!["name".into(), "bonus".into()])
.arrange(vec![ArrangeKey::desc("bonus")])
.collect()
.unwrap();
let batched = LazyView::from_df(test_df())
.filter(predicate)
.mutate(assignments)
.select(vec!["name".into(), "bonus".into()])
.arrange(vec![ArrangeKey::desc("bonus")])
.collect_batched()
.unwrap();
assert_df_eq(
&eager.borrow(),
&batched.borrow(),
"complex pipeline parity",
);
}
#[test]
fn batched_determinism_3_runs() {
let mut results: Vec<Vec<i64>> = Vec::new();
let mut results_names: Vec<Vec<String>> = Vec::new();
for _ in 0..3 {
let result = LazyView::from_df(test_df())
.filter(DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(20)),
})
.select(vec!["name".into(), "age".into()])
.arrange(vec![ArrangeKey::desc("age")])
.collect_batched()
.unwrap();
let df = result.borrow();
let ages = match df.get_column("age").unwrap() {
Column::Int(v) => v.clone(),
_ => panic!("expected Int"),
};
let names = match df.get_column("name").unwrap() {
Column::Str(v) => v.clone(),
_ => panic!("expected Str"),
};
results.push(ages);
results_names.push(names);
}
assert_eq!(results[0], results[1]);
assert_eq!(results[1], results[2]);
assert_eq!(results_names[0], results_names[1]);
assert_eq!(results_names[1], results_names[2]);
assert_eq!(results[0], vec![35, 30, 25, 25]);
assert_eq!(results_names[0], vec!["Carol", "Alice", "Bob", "Dave"]);
}
fn large_df() -> DataFrame {
let n = 10_000usize;
let names: Vec<String> = (0..n).map(|i| format!("user_{}", i)).collect();
let ages: Vec<i64> = (0..n).map(|i| (i % 80) as i64 + 18).collect();
let scores: Vec<f64> = (0..n).map(|i| 50.0 + (i % 50) as f64).collect();
DataFrame {
columns: vec![
("name".to_string(), Column::Str(names)),
("age".to_string(), Column::Int(ages)),
("score".to_string(), Column::Float(scores)),
],
}
}
#[test]
fn batched_large_data_filter_parity() {
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(50)),
};
let eager = LazyView::from_df(large_df())
.filter(predicate.clone())
.collect()
.unwrap();
let batched = LazyView::from_df(large_df())
.filter(predicate)
.collect_batched()
.unwrap();
assert_df_eq(
&eager.borrow(),
&batched.borrow(),
"large data filter parity",
);
assert!(eager.borrow().nrows() > 0);
}
#[test]
fn batched_large_data_chain_parity() {
let predicate = DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(50)),
};
let assignments = vec![(
"bonus".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("score".into())),
right: Box::new(DExpr::LitFloat(1.5)),
},
)];
let eager = LazyView::from_df(large_df())
.filter(predicate.clone())
.mutate(assignments.clone())
.select(vec!["name".into(), "bonus".into()])
.collect()
.unwrap();
let batched = LazyView::from_df(large_df())
.filter(predicate)
.mutate(assignments)
.select(vec!["name".into(), "bonus".into()])
.collect_batched()
.unwrap();
assert_df_eq(
&eager.borrow(),
&batched.borrow(),
"large data chain parity",
);
}
#[test]
fn batched_large_data_determinism() {
let mut prev_ages: Option<Vec<i64>> = None;
for _ in 0..3 {
let result = LazyView::from_df(large_df())
.filter(DExpr::BinOp {
op: DBinOp::Gt,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(90)),
})
.mutate(vec![(
"double_age".into(),
DExpr::BinOp {
op: DBinOp::Mul,
left: Box::new(DExpr::Col("age".into())),
right: Box::new(DExpr::LitInt(2)),
},
)])
.collect_batched()
.unwrap();
let df = result.borrow();
let ages = match df.get_column("age").unwrap() {
Column::Int(v) => v.clone(),
_ => panic!("expected Int"),
};
if let Some(ref prev) = prev_ages {
assert_eq!(prev, &ages, "determinism: ages differ across runs");
}
prev_ages = Some(ages);
}
}
#[test]
fn split_batches_correct_count() {
let df = large_df();
let batches = split_batches(&df);
assert_eq!(batches.len(), 5);
assert_eq!(batches[0].nrows, 2048);
assert_eq!(batches[1].nrows, 2048);
assert_eq!(batches[2].nrows, 2048);
assert_eq!(batches[3].nrows, 2048);
assert_eq!(batches[4].nrows, 10000 - 4 * 2048); let total: usize = batches.iter().map(|b| b.nrows).sum();
assert_eq!(total, 10000);
}
#[test]
fn split_batches_small_df() {
let df = test_df(); let batches = split_batches(&df);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].nrows, 4);
}
#[test]
fn merge_batches_roundtrip() {
let df = large_df();
let batches = split_batches(&df);
let merged = merge_batches(batches).unwrap();
assert_df_eq(&df, &merged, "merge roundtrip");
}
}