use std::fmt;
use std::sync::Arc;
use crate::optimized::dataframe::OptimizedDataFrame;
use crate::optimized::operations::JoinType;
#[derive(Debug, Clone, PartialEq)]
pub enum AggExpr {
Sum(Box<Expr>),
Mean(Box<Expr>),
Min(Box<Expr>),
Max(Box<Expr>),
Count(Box<Expr>),
StdDev(Box<Expr>),
}
impl fmt::Display for AggExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AggExpr::Sum(e) => write!(f, "SUM({})", e),
AggExpr::Mean(e) => write!(f, "MEAN({})", e),
AggExpr::Min(e) => write!(f, "MIN({})", e),
AggExpr::Max(e) => write!(f, "MAX({})", e),
AggExpr::Count(e) => write!(f, "COUNT({})", e),
AggExpr::StdDev(e) => write!(f, "STDDEV({})", e),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BinaryOp {
Add,
Sub,
Mul,
Div,
Eq,
NotEq,
Lt,
LtEq,
Gt,
GtEq,
And,
Or,
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BinaryOp::Add => write!(f, "+"),
BinaryOp::Sub => write!(f, "-"),
BinaryOp::Mul => write!(f, "*"),
BinaryOp::Div => write!(f, "/"),
BinaryOp::Eq => write!(f, "=="),
BinaryOp::NotEq => write!(f, "!="),
BinaryOp::Lt => write!(f, "<"),
BinaryOp::LtEq => write!(f, "<="),
BinaryOp::Gt => write!(f, ">"),
BinaryOp::GtEq => write!(f, ">="),
BinaryOp::And => write!(f, "AND"),
BinaryOp::Or => write!(f, "OR"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum UnaryOp {
Not,
Neg,
}
impl fmt::Display for UnaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnaryOp::Not => write!(f, "NOT"),
UnaryOp::Neg => write!(f, "-"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum LiteralValue {
Int64(i64),
Float64(f64),
Utf8(String),
Boolean(bool),
Null,
}
impl fmt::Display for LiteralValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LiteralValue::Int64(v) => write!(f, "{}", v),
LiteralValue::Float64(v) => write!(f, "{}", v),
LiteralValue::Utf8(v) => write!(f, "\"{}\"", v),
LiteralValue::Boolean(v) => write!(f, "{}", v),
LiteralValue::Null => write!(f, "NULL"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Column(String),
Literal(LiteralValue),
BinaryOp {
left: Box<Expr>,
op: BinaryOp,
right: Box<Expr>,
},
UnaryOp { op: UnaryOp, expr: Box<Expr> },
Agg(Box<AggExpr>),
Cast {
expr: Box<Expr>,
data_type: CastType,
},
IsNull(Box<Expr>),
IsNotNull(Box<Expr>),
If {
condition: Box<Expr>,
then_expr: Box<Expr>,
else_expr: Box<Expr>,
},
Alias { expr: Box<Expr>, name: String },
Wildcard,
}
#[derive(Debug, Clone, PartialEq)]
pub enum CastType {
Int64,
Float64,
Utf8,
Boolean,
}
impl fmt::Display for CastType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CastType::Int64 => write!(f, "Int64"),
CastType::Float64 => write!(f, "Float64"),
CastType::Utf8 => write!(f, "Utf8"),
CastType::Boolean => write!(f, "Boolean"),
}
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expr::Column(name) => write!(f, "col(\"{}\")", name),
Expr::Literal(lit) => write!(f, "lit({})", lit),
Expr::BinaryOp { left, op, right } => write!(f, "({} {} {})", left, op, right),
Expr::UnaryOp { op, expr } => write!(f, "({} {})", op, expr),
Expr::Agg(agg) => write!(f, "{}", agg),
Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type),
Expr::IsNull(expr) => write!(f, "IS_NULL({})", expr),
Expr::IsNotNull(expr) => write!(f, "IS_NOT_NULL({})", expr),
Expr::If {
condition,
then_expr,
else_expr,
} => write!(f, "IF({}, {}, {})", condition, then_expr, else_expr),
Expr::Alias { expr, name } => write!(f, "{}.alias(\"{}\")", expr, name),
Expr::Wildcard => write!(f, "*"),
}
}
}
impl Expr {
pub fn col(name: impl Into<String>) -> Self {
Expr::Column(name.into())
}
pub fn lit_int(v: i64) -> Self {
Expr::Literal(LiteralValue::Int64(v))
}
pub fn lit_float(v: f64) -> Self {
Expr::Literal(LiteralValue::Float64(v))
}
pub fn lit_str(v: impl Into<String>) -> Self {
Expr::Literal(LiteralValue::Utf8(v.into()))
}
pub fn lit_bool(v: bool) -> Self {
Expr::Literal(LiteralValue::Boolean(v))
}
pub fn alias(self, name: impl Into<String>) -> Self {
Expr::Alias {
expr: Box::new(self),
name: name.into(),
}
}
pub fn binary_op(self, op: BinaryOp, other: Expr) -> Self {
Expr::BinaryOp {
left: Box::new(self),
op,
right: Box::new(other),
}
}
pub fn eq(self, other: Expr) -> Self {
self.binary_op(BinaryOp::Eq, other)
}
pub fn neq(self, other: Expr) -> Self {
self.binary_op(BinaryOp::NotEq, other)
}
pub fn gt(self, other: Expr) -> Self {
self.binary_op(BinaryOp::Gt, other)
}
pub fn gt_eq(self, other: Expr) -> Self {
self.binary_op(BinaryOp::GtEq, other)
}
pub fn lt(self, other: Expr) -> Self {
self.binary_op(BinaryOp::Lt, other)
}
pub fn lt_eq(self, other: Expr) -> Self {
self.binary_op(BinaryOp::LtEq, other)
}
pub fn and(self, other: Expr) -> Self {
self.binary_op(BinaryOp::And, other)
}
pub fn or(self, other: Expr) -> Self {
self.binary_op(BinaryOp::Or, other)
}
pub fn is_null(self) -> Self {
Expr::IsNull(Box::new(self))
}
pub fn is_not_null(self) -> Self {
Expr::IsNotNull(Box::new(self))
}
pub fn sum(self) -> Self {
Expr::Agg(Box::new(AggExpr::Sum(Box::new(self))))
}
pub fn mean(self) -> Self {
Expr::Agg(Box::new(AggExpr::Mean(Box::new(self))))
}
pub fn min(self) -> Self {
Expr::Agg(Box::new(AggExpr::Min(Box::new(self))))
}
pub fn max(self) -> Self {
Expr::Agg(Box::new(AggExpr::Max(Box::new(self))))
}
pub fn count(self) -> Self {
Expr::Agg(Box::new(AggExpr::Count(Box::new(self))))
}
pub fn std_dev(self) -> Self {
Expr::Agg(Box::new(AggExpr::StdDev(Box::new(self))))
}
pub fn referenced_columns(&self) -> Vec<String> {
let mut cols = Vec::new();
self.collect_columns(&mut cols);
cols
}
fn collect_columns(&self, cols: &mut Vec<String>) {
match self {
Expr::Column(name) => cols.push(name.clone()),
Expr::BinaryOp { left, right, .. } => {
left.collect_columns(cols);
right.collect_columns(cols);
}
Expr::UnaryOp { expr, .. } => expr.collect_columns(cols),
Expr::Agg(agg) => match agg.as_ref() {
AggExpr::Sum(e)
| AggExpr::Mean(e)
| AggExpr::Min(e)
| AggExpr::Max(e)
| AggExpr::Count(e)
| AggExpr::StdDev(e) => e.collect_columns(cols),
},
Expr::Cast { expr, .. } => expr.collect_columns(cols),
Expr::IsNull(expr) | Expr::IsNotNull(expr) => expr.collect_columns(cols),
Expr::If {
condition,
then_expr,
else_expr,
} => {
condition.collect_columns(cols);
then_expr.collect_columns(cols);
else_expr.collect_columns(cols);
}
Expr::Alias { expr, .. } => expr.collect_columns(cols),
Expr::Literal(_) | Expr::Wildcard => {}
}
}
pub fn output_name(&self) -> Option<String> {
match self {
Expr::Column(name) => Some(name.clone()),
Expr::Alias { name, .. } => Some(name.clone()),
_ => None,
}
}
pub fn is_constant(&self) -> bool {
match self {
Expr::Literal(_) => true,
Expr::BinaryOp { left, right, .. } => left.is_constant() && right.is_constant(),
Expr::UnaryOp { expr, .. } => expr.is_constant(),
Expr::Cast { expr, .. } => expr.is_constant(),
Expr::If {
condition,
then_expr,
else_expr,
} => condition.is_constant() && then_expr.is_constant() && else_expr.is_constant(),
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub enum LogicalPlan {
Scan {
source: Arc<OptimizedDataFrame>,
projection: Option<Vec<String>>,
},
Filter {
predicate: Expr,
input: Box<LogicalPlan>,
},
Project {
exprs: Vec<Expr>,
input: Box<LogicalPlan>,
},
Aggregate {
keys: Vec<Expr>,
aggs: Vec<Expr>,
input: Box<LogicalPlan>,
},
Join {
left: Box<LogicalPlan>,
right: Box<LogicalPlan>,
left_on: Expr,
right_on: Expr,
join_type: JoinType,
},
Sort {
by: Vec<Expr>,
ascending: Vec<bool>,
input: Box<LogicalPlan>,
},
Limit {
n: usize,
input: Box<LogicalPlan>,
},
Union {
left: Box<LogicalPlan>,
right: Box<LogicalPlan>,
},
}
impl LogicalPlan {
pub fn display(&self) -> String {
self.display_indent(0)
}
fn display_indent(&self, indent: usize) -> String {
let pad = " ".repeat(indent);
match self {
LogicalPlan::Scan { source, projection } => {
let proj = match projection {
Some(cols) => format!("[{}]", cols.join(", ")),
None => "*".to_string(),
};
format!(
"{}Scan: {} rows, projection={}\n",
pad,
source.row_count(),
proj
)
}
LogicalPlan::Filter { predicate, input } => {
let mut s = format!("{}Filter: {}\n", pad, predicate);
s.push_str(&input.display_indent(indent + 1));
s
}
LogicalPlan::Project { exprs, input } => {
let expr_str = exprs
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ");
let mut s = format!("{}Project: {}\n", pad, expr_str);
s.push_str(&input.display_indent(indent + 1));
s
}
LogicalPlan::Aggregate { keys, aggs, input } => {
let key_str = keys
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ");
let agg_str = aggs
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ");
let mut s = format!("{}Aggregate: keys=[{}], aggs=[{}]\n", pad, key_str, agg_str);
s.push_str(&input.display_indent(indent + 1));
s
}
LogicalPlan::Join {
left,
right,
left_on,
right_on,
join_type,
} => {
let mut s = format!(
"{}Join ({:?}): {} = {}\n",
pad, join_type, left_on, right_on
);
s.push_str(&left.display_indent(indent + 1));
s.push_str(&right.display_indent(indent + 1));
s
}
LogicalPlan::Sort {
by,
ascending,
input,
} => {
let sort_str = by
.iter()
.zip(ascending.iter())
.map(|(e, asc)| format!("{} {}", e, if *asc { "ASC" } else { "DESC" }))
.collect::<Vec<_>>()
.join(", ");
let mut s = format!("{}Sort: {}\n", pad, sort_str);
s.push_str(&input.display_indent(indent + 1));
s
}
LogicalPlan::Limit { n, input } => {
let mut s = format!("{}Limit: {}\n", pad, n);
s.push_str(&input.display_indent(indent + 1));
s
}
LogicalPlan::Union { left, right } => {
let mut s = format!("{}Union:\n", pad);
s.push_str(&left.display_indent(indent + 1));
s.push_str(&right.display_indent(indent + 1));
s
}
}
}
pub fn referenced_columns_shallow(&self) -> Vec<String> {
match self {
LogicalPlan::Filter { predicate, .. } => predicate.referenced_columns(),
LogicalPlan::Project { exprs, .. } => {
exprs.iter().flat_map(|e| e.referenced_columns()).collect()
}
LogicalPlan::Aggregate { keys, aggs, .. } => keys
.iter()
.chain(aggs.iter())
.flat_map(|e| e.referenced_columns())
.collect(),
LogicalPlan::Sort { by, .. } => {
by.iter().flat_map(|e| e.referenced_columns()).collect()
}
LogicalPlan::Join {
left_on, right_on, ..
} => {
let mut cols = left_on.referenced_columns();
cols.extend(right_on.referenced_columns());
cols
}
LogicalPlan::Scan { projection, .. } => projection.clone().unwrap_or_default(),
LogicalPlan::Limit { .. } | LogicalPlan::Union { .. } => vec![],
}
}
pub fn input(&self) -> Option<&LogicalPlan> {
match self {
LogicalPlan::Filter { input, .. }
| LogicalPlan::Project { input, .. }
| LogicalPlan::Aggregate { input, .. }
| LogicalPlan::Sort { input, .. }
| LogicalPlan::Limit { input, .. } => Some(input),
_ => None,
}
}
}