use polars::prelude::{Expr, LazyFrame};
use std::collections::HashSet;
#[derive(Clone)]
pub enum LogicalPlan {
Base(LazyFrame),
Project {
exprs: Vec<Expr>,
input: Box<LogicalPlan>,
},
Filter {
predicate: Expr,
input: Box<LogicalPlan>,
},
}
impl LogicalPlan {
pub fn from_lazy(lf: LazyFrame) -> Self {
LogicalPlan::Base(lf)
}
pub fn to_lazy(&self) -> LazyFrame {
match self {
LogicalPlan::Base(lf) => lf.clone(),
LogicalPlan::Project { exprs, input } => {
let child = input.to_lazy();
child.select(exprs)
}
LogicalPlan::Filter { predicate, input } => {
let child = input.to_lazy();
child.filter(predicate.clone())
}
}
}
pub fn optimize(&self) -> LogicalPlan {
match self {
LogicalPlan::Base(lf) => LogicalPlan::Base(lf.clone()),
LogicalPlan::Project { exprs, input } => LogicalPlan::Project {
exprs: exprs.clone(),
input: Box::new(input.optimize()),
},
LogicalPlan::Filter { predicate, input } => {
let optimized_input = input.optimize();
if let LogicalPlan::Project {
exprs,
input: project_child,
} = optimized_input
{
if let Some(rewritten) =
try_rewrite_filter_over_project(predicate, &exprs, &project_child)
{
return rewritten;
}
LogicalPlan::Filter {
predicate: predicate.clone(),
input: Box::new(LogicalPlan::Project {
exprs,
input: project_child,
}),
}
} else {
LogicalPlan::Filter {
predicate: predicate.clone(),
input: Box::new(optimized_input),
}
}
}
}
}
}
fn try_rewrite_filter_over_project(
predicate: &Expr,
project_exprs: &[Expr],
project_child: &LogicalPlan,
) -> Option<LogicalPlan> {
let mut projected_cols: Vec<String> = Vec::with_capacity(project_exprs.len());
for e in project_exprs {
match e {
Expr::Column(name) => projected_cols.push(name.as_str().to_string()),
_ => return None,
}
}
let projected_set: HashSet<String> = projected_cols.iter().cloned().collect();
let referenced_cols = expr_referenced_columns(predicate);
if referenced_cols.is_empty() {
return None;
}
let has_dropped_ref = referenced_cols
.iter()
.any(|c| !projected_set.contains(c.as_str()));
if !has_dropped_ref {
return None;
}
let new_filter = LogicalPlan::Filter {
predicate: predicate.clone(),
input: Box::new(project_child.clone()),
};
Some(LogicalPlan::Project {
exprs: project_exprs.to_vec(),
input: Box::new(new_filter),
})
}
fn expr_referenced_columns(expr: &Expr) -> HashSet<String> {
let mut refs = HashSet::new();
let _ = expr.clone().try_map_expr(|e| {
if let Expr::Column(name) = &e {
refs.insert(name.as_str().to_string());
}
Ok(e)
});
refs
}