use std::collections::HashSet;
use std::fmt::{self, Display, Formatter, Write};
use std::hash::{Hash, Hasher};
use std::mem;
use std::sync::Arc;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::Volatility;
use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion,
};
use datafusion_common::{
Column, DFSchema, HashMap, Result, ScalarValue, Spans, TableReference,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use sqlparser::ast::{
display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem,
NullTreatment, RenameSelectItem, ReplaceSelectElement,
};
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub enum Expr {
Alias(Alias),
Column(Column),
ScalarVariable(DataType, Vec<String>),
Literal(ScalarValue),
BinaryExpr(BinaryExpr),
Like(Like),
SimilarTo(Like),
Not(Box<Expr>),
IsNotNull(Box<Expr>),
IsNull(Box<Expr>),
IsTrue(Box<Expr>),
IsFalse(Box<Expr>),
IsUnknown(Box<Expr>),
IsNotTrue(Box<Expr>),
IsNotFalse(Box<Expr>),
IsNotUnknown(Box<Expr>),
Negative(Box<Expr>),
Between(Between),
Case(Case),
Cast(Cast),
TryCast(TryCast),
ScalarFunction(ScalarFunction),
AggregateFunction(AggregateFunction),
WindowFunction(WindowFunction),
InList(InList),
Exists(Exists),
InSubquery(InSubquery),
ScalarSubquery(Subquery),
#[deprecated(
since = "46.0.0",
note = "A wildcard needs to be resolved to concrete expressions when constructing the logical plan. See https://github.com/apache/datafusion/issues/7765"
)]
Wildcard {
qualifier: Option<TableReference>,
options: Box<WildcardOptions>,
},
GroupingSet(GroupingSet),
Placeholder(Placeholder),
OuterReferenceColumn(DataType, Column),
Unnest(Unnest),
}
impl Default for Expr {
fn default() -> Self {
Expr::Literal(ScalarValue::Null)
}
}
impl From<Column> for Expr {
fn from(value: Column) -> Self {
Expr::Column(value)
}
}
impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr {
fn from(value: (Option<&'a TableReference>, &'a FieldRef)) -> Self {
Expr::from(Column::from(value))
}
}
impl<'a> TreeNodeContainer<'a, Self> for Expr {
fn apply_elements<F: FnMut(&'a Self) -> Result<TreeNodeRecursion>>(
&'a self,
mut f: F,
) -> Result<TreeNodeRecursion> {
f(self)
}
fn map_elements<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
f(self)
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Unnest {
pub expr: Box<Expr>,
}
impl Unnest {
pub fn new(expr: Expr) -> Self {
Self {
expr: Box::new(expr),
}
}
pub fn new_boxed(boxed: Box<Expr>) -> Self {
Self { expr: boxed }
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Alias {
pub expr: Box<Expr>,
pub relation: Option<TableReference>,
pub name: String,
pub metadata: Option<std::collections::HashMap<String, String>>,
}
impl Hash for Alias {
fn hash<H: Hasher>(&self, state: &mut H) {
self.expr.hash(state);
self.relation.hash(state);
self.name.hash(state);
}
}
impl PartialOrd for Alias {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let cmp = self.expr.partial_cmp(&other.expr);
let Some(std::cmp::Ordering::Equal) = cmp else {
return cmp;
};
let cmp = self.relation.partial_cmp(&other.relation);
let Some(std::cmp::Ordering::Equal) = cmp else {
return cmp;
};
self.name.partial_cmp(&other.name)
}
}
impl Alias {
pub fn new(
expr: Expr,
relation: Option<impl Into<TableReference>>,
name: impl Into<String>,
) -> Self {
Self {
expr: Box::new(expr),
relation: relation.map(|r| r.into()),
name: name.into(),
metadata: None,
}
}
pub fn with_metadata(
mut self,
metadata: Option<std::collections::HashMap<String, String>>,
) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct BinaryExpr {
pub left: Box<Expr>,
pub op: Operator,
pub right: Box<Expr>,
}
impl BinaryExpr {
pub fn new(left: Box<Expr>, op: Operator, right: Box<Expr>) -> Self {
Self { left, op, right }
}
}
impl Display for BinaryExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
fn write_child(
f: &mut Formatter<'_>,
expr: &Expr,
precedence: u8,
) -> fmt::Result {
match expr {
Expr::BinaryExpr(child) => {
let p = child.op.precedence();
if p == 0 || p < precedence {
write!(f, "({child})")?;
} else {
write!(f, "{child}")?;
}
}
_ => write!(f, "{expr}")?,
}
Ok(())
}
let precedence = self.op.precedence();
write_child(f, self.left.as_ref(), precedence)?;
write!(f, " {} ", self.op)?;
write_child(f, self.right.as_ref(), precedence)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
pub struct Case {
pub expr: Option<Box<Expr>>,
pub when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
pub else_expr: Option<Box<Expr>>,
}
impl Case {
pub fn new(
expr: Option<Box<Expr>>,
when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
else_expr: Option<Box<Expr>>,
) -> Self {
Self {
expr,
when_then_expr,
else_expr,
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Like {
pub negated: bool,
pub expr: Box<Expr>,
pub pattern: Box<Expr>,
pub escape_char: Option<char>,
pub case_insensitive: bool,
}
impl Like {
pub fn new(
negated: bool,
expr: Box<Expr>,
pattern: Box<Expr>,
escape_char: Option<char>,
case_insensitive: bool,
) -> Self {
Self {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Between {
pub expr: Box<Expr>,
pub negated: bool,
pub low: Box<Expr>,
pub high: Box<Expr>,
}
impl Between {
pub fn new(expr: Box<Expr>, negated: bool, low: Box<Expr>, high: Box<Expr>) -> Self {
Self {
expr,
negated,
low,
high,
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct ScalarFunction {
pub func: Arc<crate::ScalarUDF>,
pub args: Vec<Expr>,
}
impl ScalarFunction {
pub fn name(&self) -> &str {
self.func.name()
}
}
impl ScalarFunction {
pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
Self { func: udf, args }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum GetFieldAccess {
NamedStructField { name: ScalarValue },
ListIndex { key: Box<Expr> },
ListRange {
start: Box<Expr>,
stop: Box<Expr>,
stride: Box<Expr>,
},
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Cast {
pub expr: Box<Expr>,
pub data_type: DataType,
}
impl Cast {
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
Self { expr, data_type }
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct TryCast {
pub expr: Box<Expr>,
pub data_type: DataType,
}
impl TryCast {
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
Self { expr, data_type }
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Sort {
pub expr: Expr,
pub asc: bool,
pub nulls_first: bool,
}
impl Sort {
pub fn new(expr: Expr, asc: bool, nulls_first: bool) -> Self {
Self {
expr,
asc,
nulls_first,
}
}
pub fn reverse(&self) -> Self {
Self {
expr: self.expr.clone(),
asc: !self.asc,
nulls_first: !self.nulls_first,
}
}
pub fn with_expr(&self, expr: Expr) -> Self {
Self {
expr,
asc: self.asc,
nulls_first: self.nulls_first,
}
}
}
impl Display for Sort {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.expr)?;
if self.asc {
write!(f, " ASC")?;
} else {
write!(f, " DESC")?;
}
if self.nulls_first {
write!(f, " NULLS FIRST")?;
} else {
write!(f, " NULLS LAST")?;
}
Ok(())
}
}
impl<'a> TreeNodeContainer<'a, Expr> for Sort {
fn apply_elements<F: FnMut(&'a Expr) -> Result<TreeNodeRecursion>>(
&'a self,
f: F,
) -> Result<TreeNodeRecursion> {
self.expr.apply_elements(f)
}
fn map_elements<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.expr
.map_elements(f)?
.map_data(|expr| Ok(Self { expr, ..self }))
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct AggregateFunction {
pub func: Arc<crate::AggregateUDF>,
pub params: AggregateFunctionParams,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct AggregateFunctionParams {
pub args: Vec<Expr>,
pub distinct: bool,
pub filter: Option<Box<Expr>>,
pub order_by: Option<Vec<Sort>>,
pub null_treatment: Option<NullTreatment>,
}
impl AggregateFunction {
pub fn new_udf(
func: Arc<crate::AggregateUDF>,
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Sort>>,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func,
params: AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum WindowFunctionDefinition {
AggregateUDF(Arc<crate::AggregateUDF>),
WindowUDF(Arc<WindowUDF>),
}
impl WindowFunctionDefinition {
pub fn return_type(
&self,
input_expr_types: &[DataType],
_input_expr_nullable: &[bool],
display_name: &str,
) -> Result<DataType> {
match self {
WindowFunctionDefinition::AggregateUDF(fun) => {
fun.return_type(input_expr_types)
}
WindowFunctionDefinition::WindowUDF(fun) => fun
.field(WindowUDFFieldArgs::new(input_expr_types, display_name))
.map(|field| field.data_type().clone()),
}
}
pub fn signature(&self) -> Signature {
match self {
WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(),
WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(),
}
}
pub fn name(&self) -> &str {
match self {
WindowFunctionDefinition::WindowUDF(fun) => fun.name(),
WindowFunctionDefinition::AggregateUDF(fun) => fun.name(),
}
}
}
impl Display for WindowFunctionDefinition {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
WindowFunctionDefinition::AggregateUDF(fun) => Display::fmt(fun, f),
WindowFunctionDefinition::WindowUDF(fun) => Display::fmt(fun, f),
}
}
}
impl From<Arc<crate::AggregateUDF>> for WindowFunctionDefinition {
fn from(value: Arc<crate::AggregateUDF>) -> Self {
Self::AggregateUDF(value)
}
}
impl From<Arc<WindowUDF>> for WindowFunctionDefinition {
fn from(value: Arc<WindowUDF>) -> Self {
Self::WindowUDF(value)
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct WindowFunction {
pub fun: WindowFunctionDefinition,
pub params: WindowFunctionParams,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct WindowFunctionParams {
pub args: Vec<Expr>,
pub partition_by: Vec<Expr>,
pub order_by: Vec<Sort>,
pub window_frame: WindowFrame,
pub null_treatment: Option<NullTreatment>,
}
impl WindowFunction {
pub fn new(fun: impl Into<WindowFunctionDefinition>, args: Vec<Expr>) -> Self {
Self {
fun: fun.into(),
params: WindowFunctionParams {
args,
partition_by: Vec::default(),
order_by: Vec::default(),
window_frame: WindowFrame::new(None),
null_treatment: None,
},
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Exists {
pub subquery: Subquery,
pub negated: bool,
}
impl Exists {
pub fn new(subquery: Subquery, negated: bool) -> Self {
Self { subquery, negated }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateUDF {
pub fun: Arc<udaf::AggregateUDF>,
pub args: Vec<Expr>,
pub filter: Option<Box<Expr>>,
pub order_by: Option<Vec<Expr>>,
}
impl AggregateUDF {
pub fn new(
fun: Arc<udaf::AggregateUDF>,
args: Vec<Expr>,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
) -> Self {
Self {
fun,
args,
filter,
order_by,
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct InList {
pub expr: Box<Expr>,
pub list: Vec<Expr>,
pub negated: bool,
}
impl InList {
pub fn new(expr: Box<Expr>, list: Vec<Expr>, negated: bool) -> Self {
Self {
expr,
list,
negated,
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct InSubquery {
pub expr: Box<Expr>,
pub subquery: Subquery,
pub negated: bool,
}
impl InSubquery {
pub fn new(expr: Box<Expr>, subquery: Subquery, negated: bool) -> Self {
Self {
expr,
subquery,
negated,
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Placeholder {
pub id: String,
pub data_type: Option<DataType>,
}
impl Placeholder {
pub fn new(id: String, data_type: Option<DataType>) -> Self {
Self { id, data_type }
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub enum GroupingSet {
Rollup(Vec<Expr>),
Cube(Vec<Expr>),
GroupingSets(Vec<Vec<Expr>>),
}
impl GroupingSet {
pub fn distinct_expr(&self) -> Vec<&Expr> {
match self {
GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => {
exprs.iter().collect()
}
GroupingSet::GroupingSets(groups) => {
let mut exprs: Vec<&Expr> = vec![];
for exp in groups.iter().flatten() {
if !exprs.contains(&exp) {
exprs.push(exp);
}
}
exprs
}
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug, Default)]
pub struct WildcardOptions {
pub ilike: Option<IlikeSelectItem>,
pub exclude: Option<ExcludeSelectItem>,
pub except: Option<ExceptSelectItem>,
pub replace: Option<PlannedReplaceSelectItem>,
pub rename: Option<RenameSelectItem>,
}
impl WildcardOptions {
pub fn with_replace(self, replace: PlannedReplaceSelectItem) -> Self {
WildcardOptions {
ilike: self.ilike,
exclude: self.exclude,
except: self.except,
replace: Some(replace),
rename: self.rename,
}
}
}
impl Display for WildcardOptions {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
if let Some(ilike) = &self.ilike {
write!(f, " {ilike}")?;
}
if let Some(exclude) = &self.exclude {
write!(f, " {exclude}")?;
}
if let Some(except) = &self.except {
write!(f, " {except}")?;
}
if let Some(replace) = &self.replace {
write!(f, " {replace}")?;
}
if let Some(rename) = &self.rename {
write!(f, " {rename}")?;
}
Ok(())
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug, Default)]
pub struct PlannedReplaceSelectItem {
pub items: Vec<ReplaceSelectElement>,
pub planned_expressions: Vec<Expr>,
}
impl Display for PlannedReplaceSelectItem {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "REPLACE")?;
write!(f, " ({})", display_comma_separated(&self.items))?;
Ok(())
}
}
impl PlannedReplaceSelectItem {
pub fn items(&self) -> &[ReplaceSelectElement] {
&self.items
}
pub fn expressions(&self) -> &[Expr] {
&self.planned_expressions
}
}
impl Expr {
pub fn schema_name(&self) -> impl Display + '_ {
SchemaDisplay(self)
}
pub fn human_display(&self) -> impl Display + '_ {
SqlDisplay(self)
}
pub fn qualified_name(&self) -> (Option<TableReference>, String) {
match self {
Expr::Column(Column {
relation,
name,
spans: _,
}) => (relation.clone(), name.clone()),
Expr::Alias(Alias { relation, name, .. }) => (relation.clone(), name.clone()),
_ => (None, self.schema_name().to_string()),
}
}
#[deprecated(since = "42.0.0", note = "use format! instead")]
pub fn canonical_name(&self) -> String {
format!("{self}")
}
pub fn variant_name(&self) -> &str {
match self {
Expr::AggregateFunction { .. } => "AggregateFunction",
Expr::Alias(..) => "Alias",
Expr::Between { .. } => "Between",
Expr::BinaryExpr { .. } => "BinaryExpr",
Expr::Case { .. } => "Case",
Expr::Cast { .. } => "Cast",
Expr::Column(..) => "Column",
Expr::OuterReferenceColumn(_, _) => "Outer",
Expr::Exists { .. } => "Exists",
Expr::GroupingSet(..) => "GroupingSet",
Expr::InList { .. } => "InList",
Expr::InSubquery(..) => "InSubquery",
Expr::IsNotNull(..) => "IsNotNull",
Expr::IsNull(..) => "IsNull",
Expr::Like { .. } => "Like",
Expr::SimilarTo { .. } => "RLike",
Expr::IsTrue(..) => "IsTrue",
Expr::IsFalse(..) => "IsFalse",
Expr::IsUnknown(..) => "IsUnknown",
Expr::IsNotTrue(..) => "IsNotTrue",
Expr::IsNotFalse(..) => "IsNotFalse",
Expr::IsNotUnknown(..) => "IsNotUnknown",
Expr::Literal(..) => "Literal",
Expr::Negative(..) => "Negative",
Expr::Not(..) => "Not",
Expr::Placeholder(_) => "Placeholder",
Expr::ScalarFunction(..) => "ScalarFunction",
Expr::ScalarSubquery { .. } => "ScalarSubquery",
Expr::ScalarVariable(..) => "ScalarVariable",
Expr::TryCast { .. } => "TryCast",
Expr::WindowFunction { .. } => "WindowFunction",
#[expect(deprecated)]
Expr::Wildcard { .. } => "Wildcard",
Expr::Unnest { .. } => "Unnest",
}
}
pub fn eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::Eq, other)
}
pub fn not_eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::NotEq, other)
}
pub fn gt(self, other: Expr) -> Expr {
binary_expr(self, Operator::Gt, other)
}
pub fn gt_eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::GtEq, other)
}
pub fn lt(self, other: Expr) -> Expr {
binary_expr(self, Operator::Lt, other)
}
pub fn lt_eq(self, other: Expr) -> Expr {
binary_expr(self, Operator::LtEq, other)
}
pub fn and(self, other: Expr) -> Expr {
binary_expr(self, Operator::And, other)
}
pub fn or(self, other: Expr) -> Expr {
binary_expr(self, Operator::Or, other)
}
pub fn like(self, other: Expr) -> Expr {
Expr::Like(Like::new(
false,
Box::new(self),
Box::new(other),
None,
false,
))
}
pub fn not_like(self, other: Expr) -> Expr {
Expr::Like(Like::new(
true,
Box::new(self),
Box::new(other),
None,
false,
))
}
pub fn ilike(self, other: Expr) -> Expr {
Expr::Like(Like::new(
false,
Box::new(self),
Box::new(other),
None,
true,
))
}
pub fn not_ilike(self, other: Expr) -> Expr {
Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true))
}
pub fn name_for_alias(&self) -> Result<String> {
Ok(self.schema_name().to_string())
}
pub fn alias_if_changed(self, original_name: String) -> Result<Expr> {
let new_name = self.name_for_alias()?;
if new_name == original_name {
return Ok(self);
}
Ok(self.alias(original_name))
}
pub fn alias(self, name: impl Into<String>) -> Expr {
Expr::Alias(Alias::new(self, None::<&str>, name.into()))
}
pub fn alias_with_metadata(
self,
name: impl Into<String>,
metadata: Option<std::collections::HashMap<String, String>>,
) -> Expr {
Expr::Alias(Alias::new(self, None::<&str>, name.into()).with_metadata(metadata))
}
pub fn alias_qualified(
self,
relation: Option<impl Into<TableReference>>,
name: impl Into<String>,
) -> Expr {
Expr::Alias(Alias::new(self, relation, name.into()))
}
pub fn alias_qualified_with_metadata(
self,
relation: Option<impl Into<TableReference>>,
name: impl Into<String>,
metadata: Option<std::collections::HashMap<String, String>>,
) -> Expr {
Expr::Alias(Alias::new(self, relation, name.into()).with_metadata(metadata))
}
pub fn unalias(self) -> Expr {
match self {
Expr::Alias(alias) => *alias.expr,
_ => self,
}
}
pub fn unalias_nested(self) -> Transformed<Expr> {
self.transform_down_up(
|expr| {
let recursion = if matches!(
expr,
Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_)
) {
TreeNodeRecursion::Jump
} else {
TreeNodeRecursion::Continue
};
Ok(Transformed::new(expr, false, recursion))
},
|expr| {
if let Expr::Alias(Alias { expr, .. }) = expr {
Ok(Transformed::yes(*expr))
} else {
Ok(Transformed::no(expr))
}
},
)
.unwrap()
}
pub fn in_list(self, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList(InList::new(Box::new(self), list, negated))
}
pub fn is_null(self) -> Expr {
Expr::IsNull(Box::new(self))
}
pub fn is_not_null(self) -> Expr {
Expr::IsNotNull(Box::new(self))
}
pub fn sort(self, asc: bool, nulls_first: bool) -> Sort {
Sort::new(self, asc, nulls_first)
}
pub fn is_true(self) -> Expr {
Expr::IsTrue(Box::new(self))
}
pub fn is_not_true(self) -> Expr {
Expr::IsNotTrue(Box::new(self))
}
pub fn is_false(self) -> Expr {
Expr::IsFalse(Box::new(self))
}
pub fn is_not_false(self) -> Expr {
Expr::IsNotFalse(Box::new(self))
}
pub fn is_unknown(self) -> Expr {
Expr::IsUnknown(Box::new(self))
}
pub fn is_not_unknown(self) -> Expr {
Expr::IsNotUnknown(Box::new(self))
}
pub fn between(self, low: Expr, high: Expr) -> Expr {
Expr::Between(Between::new(
Box::new(self),
false,
Box::new(low),
Box::new(high),
))
}
pub fn not_between(self, low: Expr, high: Expr) -> Expr {
Expr::Between(Between::new(
Box::new(self),
true,
Box::new(low),
Box::new(high),
))
}
pub fn try_as_col(&self) -> Option<&Column> {
if let Expr::Column(it) = self {
Some(it)
} else {
None
}
}
pub fn get_as_join_column(&self) -> Option<&Column> {
match self {
Expr::Column(c) => Some(c),
Expr::Cast(Cast { expr, .. }) => match &**expr {
Expr::Column(c) => Some(c),
_ => None,
},
_ => None,
}
}
pub fn column_refs(&self) -> HashSet<&Column> {
let mut using_columns = HashSet::new();
self.add_column_refs(&mut using_columns);
using_columns
}
pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) {
self.apply(|expr| {
if let Expr::Column(col) = expr {
set.insert(col);
}
Ok(TreeNodeRecursion::Continue)
})
.expect("traversal is infallible");
}
pub fn column_refs_counts(&self) -> HashMap<&Column, usize> {
let mut map = HashMap::new();
self.add_column_ref_counts(&mut map);
map
}
pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) {
self.apply(|expr| {
if let Expr::Column(col) = expr {
*map.entry(col).or_default() += 1;
}
Ok(TreeNodeRecursion::Continue)
})
.expect("traversal is infallible");
}
pub fn any_column_refs(&self) -> bool {
self.exists(|expr| Ok(matches!(expr, Expr::Column(_))))
.expect("exists closure is infallible")
}
pub fn contains_outer(&self) -> bool {
self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. })))
.expect("exists closure is infallible")
}
pub fn is_volatile_node(&self) -> bool {
matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile)
}
pub fn is_volatile(&self) -> bool {
self.exists(|expr| Ok(expr.is_volatile_node()))
.expect("exists closure is infallible")
}
pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> {
let mut has_placeholder = false;
self.transform(|mut expr| {
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr {
rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?;
rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?;
};
if let Expr::Between(Between {
expr,
negated: _,
low,
high,
}) = &mut expr
{
rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?;
rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?;
}
if let Expr::Placeholder(_) = &expr {
has_placeholder = true;
}
Ok(Transformed::yes(expr))
})
.data()
.map(|data| (data, has_placeholder))
}
pub fn short_circuits(&self) -> bool {
match self {
Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(),
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
matches!(op, Operator::And | Operator::Or)
}
Expr::Case { .. } => true,
#[expect(deprecated)]
Expr::AggregateFunction(..)
| Expr::Alias(..)
| Expr::Between(..)
| Expr::Cast(..)
| Expr::Column(..)
| Expr::Exists(..)
| Expr::GroupingSet(..)
| Expr::InList(..)
| Expr::InSubquery(..)
| Expr::IsFalse(..)
| Expr::IsNotFalse(..)
| Expr::IsNotNull(..)
| Expr::IsNotTrue(..)
| Expr::IsNotUnknown(..)
| Expr::IsNull(..)
| Expr::IsTrue(..)
| Expr::IsUnknown(..)
| Expr::Like(..)
| Expr::ScalarSubquery(..)
| Expr::ScalarVariable(_, _)
| Expr::SimilarTo(..)
| Expr::Not(..)
| Expr::Negative(..)
| Expr::OuterReferenceColumn(_, _)
| Expr::TryCast(..)
| Expr::Unnest(..)
| Expr::Wildcard { .. }
| Expr::WindowFunction(..)
| Expr::Literal(..)
| Expr::Placeholder(..) => false,
}
}
pub fn spans(&self) -> Option<&Spans> {
match self {
Expr::Column(col) => Some(&col.spans),
_ => None,
}
}
}
impl Normalizeable for Expr {
fn can_normalize(&self) -> bool {
#[allow(clippy::match_like_matches_macro)]
match self {
Expr::BinaryExpr(BinaryExpr {
op:
_op @ (Operator::Plus
| Operator::Multiply
| Operator::BitwiseAnd
| Operator::BitwiseOr
| Operator::BitwiseXor
| Operator::Eq
| Operator::NotEq),
..
}) => true,
_ => false,
}
}
}
impl NormalizeEq for Expr {
fn normalize_eq(&self, other: &Self) -> bool {
match (self, other) {
(
Expr::BinaryExpr(BinaryExpr {
left: self_left,
op: self_op,
right: self_right,
}),
Expr::BinaryExpr(BinaryExpr {
left: other_left,
op: other_op,
right: other_right,
}),
) => {
if self_op != other_op {
return false;
}
if matches!(
self_op,
Operator::Plus
| Operator::Multiply
| Operator::BitwiseAnd
| Operator::BitwiseOr
| Operator::BitwiseXor
| Operator::Eq
| Operator::NotEq
) {
(self_left.normalize_eq(other_left)
&& self_right.normalize_eq(other_right))
|| (self_left.normalize_eq(other_right)
&& self_right.normalize_eq(other_left))
} else {
self_left.normalize_eq(other_left)
&& self_right.normalize_eq(other_right)
}
}
(
Expr::Alias(Alias {
expr: self_expr,
relation: self_relation,
name: self_name,
..
}),
Expr::Alias(Alias {
expr: other_expr,
relation: other_relation,
name: other_name,
..
}),
) => {
self_name == other_name
&& self_relation == other_relation
&& self_expr.normalize_eq(other_expr)
}
(
Expr::Like(Like {
negated: self_negated,
expr: self_expr,
pattern: self_pattern,
escape_char: self_escape_char,
case_insensitive: self_case_insensitive,
}),
Expr::Like(Like {
negated: other_negated,
expr: other_expr,
pattern: other_pattern,
escape_char: other_escape_char,
case_insensitive: other_case_insensitive,
}),
)
| (
Expr::SimilarTo(Like {
negated: self_negated,
expr: self_expr,
pattern: self_pattern,
escape_char: self_escape_char,
case_insensitive: self_case_insensitive,
}),
Expr::SimilarTo(Like {
negated: other_negated,
expr: other_expr,
pattern: other_pattern,
escape_char: other_escape_char,
case_insensitive: other_case_insensitive,
}),
) => {
self_negated == other_negated
&& self_escape_char == other_escape_char
&& self_case_insensitive == other_case_insensitive
&& self_expr.normalize_eq(other_expr)
&& self_pattern.normalize_eq(other_pattern)
}
(Expr::Not(self_expr), Expr::Not(other_expr))
| (Expr::IsNull(self_expr), Expr::IsNull(other_expr))
| (Expr::IsTrue(self_expr), Expr::IsTrue(other_expr))
| (Expr::IsFalse(self_expr), Expr::IsFalse(other_expr))
| (Expr::IsUnknown(self_expr), Expr::IsUnknown(other_expr))
| (Expr::IsNotNull(self_expr), Expr::IsNotNull(other_expr))
| (Expr::IsNotTrue(self_expr), Expr::IsNotTrue(other_expr))
| (Expr::IsNotFalse(self_expr), Expr::IsNotFalse(other_expr))
| (Expr::IsNotUnknown(self_expr), Expr::IsNotUnknown(other_expr))
| (Expr::Negative(self_expr), Expr::Negative(other_expr))
| (
Expr::Unnest(Unnest { expr: self_expr }),
Expr::Unnest(Unnest { expr: other_expr }),
) => self_expr.normalize_eq(other_expr),
(
Expr::Between(Between {
expr: self_expr,
negated: self_negated,
low: self_low,
high: self_high,
}),
Expr::Between(Between {
expr: other_expr,
negated: other_negated,
low: other_low,
high: other_high,
}),
) => {
self_negated == other_negated
&& self_expr.normalize_eq(other_expr)
&& self_low.normalize_eq(other_low)
&& self_high.normalize_eq(other_high)
}
(
Expr::Cast(Cast {
expr: self_expr,
data_type: self_data_type,
}),
Expr::Cast(Cast {
expr: other_expr,
data_type: other_data_type,
}),
)
| (
Expr::TryCast(TryCast {
expr: self_expr,
data_type: self_data_type,
}),
Expr::TryCast(TryCast {
expr: other_expr,
data_type: other_data_type,
}),
) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr),
(
Expr::ScalarFunction(ScalarFunction {
func: self_func,
args: self_args,
}),
Expr::ScalarFunction(ScalarFunction {
func: other_func,
args: other_args,
}),
) => {
self_func.name() == other_func.name()
&& self_args.len() == other_args.len()
&& self_args
.iter()
.zip(other_args.iter())
.all(|(a, b)| a.normalize_eq(b))
}
(
Expr::AggregateFunction(AggregateFunction {
func: self_func,
params:
AggregateFunctionParams {
args: self_args,
distinct: self_distinct,
filter: self_filter,
order_by: self_order_by,
null_treatment: self_null_treatment,
},
}),
Expr::AggregateFunction(AggregateFunction {
func: other_func,
params:
AggregateFunctionParams {
args: other_args,
distinct: other_distinct,
filter: other_filter,
order_by: other_order_by,
null_treatment: other_null_treatment,
},
}),
) => {
self_func.name() == other_func.name()
&& self_distinct == other_distinct
&& self_null_treatment == other_null_treatment
&& self_args.len() == other_args.len()
&& self_args
.iter()
.zip(other_args.iter())
.all(|(a, b)| a.normalize_eq(b))
&& match (self_filter, other_filter) {
(Some(self_filter), Some(other_filter)) => {
self_filter.normalize_eq(other_filter)
}
(None, None) => true,
_ => false,
}
&& match (self_order_by, other_order_by) {
(Some(self_order_by), Some(other_order_by)) => self_order_by
.iter()
.zip(other_order_by.iter())
.all(|(a, b)| {
a.asc == b.asc
&& a.nulls_first == b.nulls_first
&& a.expr.normalize_eq(&b.expr)
}),
(None, None) => true,
_ => false,
}
}
(
Expr::WindowFunction(WindowFunction {
fun: self_fun,
params: self_params,
}),
Expr::WindowFunction(WindowFunction {
fun: other_fun,
params: other_params,
}),
) => {
let (
WindowFunctionParams {
args: self_args,
window_frame: self_window_frame,
partition_by: self_partition_by,
order_by: self_order_by,
null_treatment: self_null_treatment,
},
WindowFunctionParams {
args: other_args,
window_frame: other_window_frame,
partition_by: other_partition_by,
order_by: other_order_by,
null_treatment: other_null_treatment,
},
) = (self_params, other_params);
self_fun.name() == other_fun.name()
&& self_window_frame == other_window_frame
&& self_null_treatment == other_null_treatment
&& self_args.len() == other_args.len()
&& self_args
.iter()
.zip(other_args.iter())
.all(|(a, b)| a.normalize_eq(b))
&& self_partition_by
.iter()
.zip(other_partition_by.iter())
.all(|(a, b)| a.normalize_eq(b))
&& self_order_by
.iter()
.zip(other_order_by.iter())
.all(|(a, b)| {
a.asc == b.asc
&& a.nulls_first == b.nulls_first
&& a.expr.normalize_eq(&b.expr)
})
}
(
Expr::Exists(Exists {
subquery: self_subquery,
negated: self_negated,
}),
Expr::Exists(Exists {
subquery: other_subquery,
negated: other_negated,
}),
) => {
self_negated == other_negated
&& self_subquery.normalize_eq(other_subquery)
}
(
Expr::InSubquery(InSubquery {
expr: self_expr,
subquery: self_subquery,
negated: self_negated,
}),
Expr::InSubquery(InSubquery {
expr: other_expr,
subquery: other_subquery,
negated: other_negated,
}),
) => {
self_negated == other_negated
&& self_expr.normalize_eq(other_expr)
&& self_subquery.normalize_eq(other_subquery)
}
(
Expr::ScalarSubquery(self_subquery),
Expr::ScalarSubquery(other_subquery),
) => self_subquery.normalize_eq(other_subquery),
(
Expr::GroupingSet(GroupingSet::Rollup(self_exprs)),
Expr::GroupingSet(GroupingSet::Rollup(other_exprs)),
)
| (
Expr::GroupingSet(GroupingSet::Cube(self_exprs)),
Expr::GroupingSet(GroupingSet::Cube(other_exprs)),
) => {
self_exprs.len() == other_exprs.len()
&& self_exprs
.iter()
.zip(other_exprs.iter())
.all(|(a, b)| a.normalize_eq(b))
}
(
Expr::GroupingSet(GroupingSet::GroupingSets(self_exprs)),
Expr::GroupingSet(GroupingSet::GroupingSets(other_exprs)),
) => {
self_exprs.len() == other_exprs.len()
&& self_exprs.iter().zip(other_exprs.iter()).all(|(a, b)| {
a.len() == b.len()
&& a.iter().zip(b.iter()).all(|(x, y)| x.normalize_eq(y))
})
}
(
Expr::InList(InList {
expr: self_expr,
list: self_list,
negated: self_negated,
}),
Expr::InList(InList {
expr: other_expr,
list: other_list,
negated: other_negated,
}),
) => {
self_negated == other_negated
&& self_expr.normalize_eq(other_expr)
&& self_list.len() == other_list.len()
&& self_list
.iter()
.zip(other_list.iter())
.all(|(a, b)| a.normalize_eq(b))
}
(
Expr::Case(Case {
expr: self_expr,
when_then_expr: self_when_then_expr,
else_expr: self_else_expr,
}),
Expr::Case(Case {
expr: other_expr,
when_then_expr: other_when_then_expr,
else_expr: other_else_expr,
}),
) => {
self_when_then_expr.len() == other_when_then_expr.len()
&& self_when_then_expr
.iter()
.zip(other_when_then_expr.iter())
.all(|((self_when, self_then), (other_when, other_then))| {
self_when.normalize_eq(other_when)
&& self_then.normalize_eq(other_then)
})
&& match (self_expr, other_expr) {
(Some(self_expr), Some(other_expr)) => {
self_expr.normalize_eq(other_expr)
}
(None, None) => true,
(_, _) => false,
}
&& match (self_else_expr, other_else_expr) {
(Some(self_else_expr), Some(other_else_expr)) => {
self_else_expr.normalize_eq(other_else_expr)
}
(None, None) => true,
(_, _) => false,
}
}
(_, _) => self == other,
}
}
}
impl HashNode for Expr {
fn hash_node<H: Hasher>(&self, state: &mut H) {
mem::discriminant(self).hash(state);
match self {
Expr::Alias(Alias {
expr: _expr,
relation,
name,
..
}) => {
relation.hash(state);
name.hash(state);
}
Expr::Column(column) => {
column.hash(state);
}
Expr::ScalarVariable(data_type, name) => {
data_type.hash(state);
name.hash(state);
}
Expr::Literal(scalar_value) => {
scalar_value.hash(state);
}
Expr::BinaryExpr(BinaryExpr {
left: _left,
op,
right: _right,
}) => {
op.hash(state);
}
Expr::Like(Like {
negated,
expr: _expr,
pattern: _pattern,
escape_char,
case_insensitive,
})
| Expr::SimilarTo(Like {
negated,
expr: _expr,
pattern: _pattern,
escape_char,
case_insensitive,
}) => {
negated.hash(state);
escape_char.hash(state);
case_insensitive.hash(state);
}
Expr::Not(_expr)
| Expr::IsNotNull(_expr)
| Expr::IsNull(_expr)
| Expr::IsTrue(_expr)
| Expr::IsFalse(_expr)
| Expr::IsUnknown(_expr)
| Expr::IsNotTrue(_expr)
| Expr::IsNotFalse(_expr)
| Expr::IsNotUnknown(_expr)
| Expr::Negative(_expr) => {}
Expr::Between(Between {
expr: _expr,
negated,
low: _low,
high: _high,
}) => {
negated.hash(state);
}
Expr::Case(Case {
expr: _expr,
when_then_expr: _when_then_expr,
else_expr: _else_expr,
}) => {}
Expr::Cast(Cast {
expr: _expr,
data_type,
})
| Expr::TryCast(TryCast {
expr: _expr,
data_type,
}) => {
data_type.hash(state);
}
Expr::ScalarFunction(ScalarFunction { func, args: _args }) => {
func.hash(state);
}
Expr::AggregateFunction(AggregateFunction {
func,
params:
AggregateFunctionParams {
args: _args,
distinct,
filter: _,
order_by: _,
null_treatment,
},
}) => {
func.hash(state);
distinct.hash(state);
null_treatment.hash(state);
}
Expr::WindowFunction(WindowFunction { fun, params }) => {
let WindowFunctionParams {
args: _args,
partition_by: _,
order_by: _,
window_frame,
null_treatment,
} = params;
fun.hash(state);
window_frame.hash(state);
null_treatment.hash(state);
}
Expr::InList(InList {
expr: _expr,
list: _list,
negated,
}) => {
negated.hash(state);
}
Expr::Exists(Exists { subquery, negated }) => {
subquery.hash(state);
negated.hash(state);
}
Expr::InSubquery(InSubquery {
expr: _expr,
subquery,
negated,
}) => {
subquery.hash(state);
negated.hash(state);
}
Expr::ScalarSubquery(subquery) => {
subquery.hash(state);
}
#[expect(deprecated)]
Expr::Wildcard { qualifier, options } => {
qualifier.hash(state);
options.hash(state);
}
Expr::GroupingSet(grouping_set) => {
mem::discriminant(grouping_set).hash(state);
match grouping_set {
GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {}
GroupingSet::GroupingSets(_exprs) => {}
}
}
Expr::Placeholder(place_holder) => {
place_holder.hash(state);
}
Expr::OuterReferenceColumn(data_type, column) => {
data_type.hash(state);
column.hash(state);
}
Expr::Unnest(Unnest { expr: _expr }) => {}
};
}
}
fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> {
if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr {
if data_type.is_none() {
let other_dt = other.get_type(schema);
match other_dt {
Err(e) => {
Err(e.context(format!(
"Can not find type of {other} needed to infer type of {expr}"
)))?;
}
Ok(dt) => {
*data_type = Some(dt);
}
}
};
}
Ok(())
}
#[macro_export]
macro_rules! expr_vec_fmt {
( $ARRAY:expr ) => {{
$ARRAY
.iter()
.map(|e| format!("{e}"))
.collect::<Vec<String>>()
.join(", ")
}};
}
struct SchemaDisplay<'a>(&'a Expr);
impl Display for SchemaDisplay<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self.0 {
#[expect(deprecated)]
Expr::Column(_)
| Expr::Literal(_)
| Expr::ScalarVariable(..)
| Expr::OuterReferenceColumn(..)
| Expr::Placeholder(_)
| Expr::Wildcard { .. } => write!(f, "{}", self.0),
Expr::AggregateFunction(AggregateFunction { func, params }) => {
match func.schema_name(params) {
Ok(name) => {
write!(f, "{name}")
}
Err(e) => {
write!(f, "got error from schema_name {}", e)
}
}
}
Expr::Alias(Alias {
name,
relation: Some(relation),
..
}) => write!(f, "{relation}.{name}"),
Expr::Alias(Alias { name, .. }) => write!(f, "{name}"),
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
if *negated {
write!(
f,
"{} NOT BETWEEN {} AND {}",
SchemaDisplay(expr),
SchemaDisplay(low),
SchemaDisplay(high),
)
} else {
write!(
f,
"{} BETWEEN {} AND {}",
SchemaDisplay(expr),
SchemaDisplay(low),
SchemaDisplay(high),
)
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
write!(f, "{} {op} {}", SchemaDisplay(left), SchemaDisplay(right),)
}
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => {
write!(f, "CASE ")?;
if let Some(e) = expr {
write!(f, "{} ", SchemaDisplay(e))?;
}
for (when, then) in when_then_expr {
write!(
f,
"WHEN {} THEN {} ",
SchemaDisplay(when),
SchemaDisplay(then),
)?;
}
if let Some(e) = else_expr {
write!(f, "ELSE {} ", SchemaDisplay(e))?;
}
write!(f, "END")
}
Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => {
write!(f, "{}", SchemaDisplay(expr))
}
Expr::InList(InList {
expr,
list,
negated,
}) => {
let inlist_name = schema_name_from_exprs(list)?;
if *negated {
write!(f, "{} NOT IN {}", SchemaDisplay(expr), inlist_name)
} else {
write!(f, "{} IN {}", SchemaDisplay(expr), inlist_name)
}
}
Expr::Exists(Exists { negated: true, .. }) => write!(f, "NOT EXISTS"),
Expr::Exists(Exists { negated: false, .. }) => write!(f, "EXISTS"),
Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?)
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
write!(f, "GROUPING SETS (")?;
for exprs in lists_of_exprs.iter() {
write!(f, "({})", schema_name_from_exprs(exprs)?)?;
}
write!(f, ")")
}
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?)
}
Expr::IsNull(expr) => write!(f, "{} IS NULL", SchemaDisplay(expr)),
Expr::IsNotNull(expr) => {
write!(f, "{} IS NOT NULL", SchemaDisplay(expr))
}
Expr::IsUnknown(expr) => {
write!(f, "{} IS UNKNOWN", SchemaDisplay(expr))
}
Expr::IsNotUnknown(expr) => {
write!(f, "{} IS NOT UNKNOWN", SchemaDisplay(expr))
}
Expr::InSubquery(InSubquery { negated: true, .. }) => {
write!(f, "NOT IN")
}
Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"),
Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)),
Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)),
Expr::IsNotTrue(expr) => {
write!(f, "{} IS NOT TRUE", SchemaDisplay(expr))
}
Expr::IsNotFalse(expr) => {
write!(f, "{} IS NOT FALSE", SchemaDisplay(expr))
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => {
write!(
f,
"{} {}{} {}",
SchemaDisplay(expr),
if *negated { "NOT " } else { "" },
if *case_insensitive { "ILIKE" } else { "LIKE" },
SchemaDisplay(pattern),
)?;
if let Some(char) = escape_char {
write!(f, " CHAR '{char}'")?;
}
Ok(())
}
Expr::Negative(expr) => write!(f, "(- {})", SchemaDisplay(expr)),
Expr::Not(expr) => write!(f, "NOT {}", SchemaDisplay(expr)),
Expr::Unnest(Unnest { expr }) => {
write!(f, "UNNEST({})", SchemaDisplay(expr))
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
match func.schema_name(args) {
Ok(name) => {
write!(f, "{name}")
}
Err(e) => {
write!(f, "got error from schema_name {}", e)
}
}
}
Expr::ScalarSubquery(Subquery { subquery, .. }) => {
write!(f, "{}", subquery.schema().field(0).name())
}
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
..
}) => {
write!(
f,
"{} {} {}",
SchemaDisplay(expr),
if *negated {
"NOT SIMILAR TO"
} else {
"SIMILAR TO"
},
SchemaDisplay(pattern),
)?;
if let Some(char) = escape_char {
write!(f, " CHAR '{char}'")?;
}
Ok(())
}
Expr::WindowFunction(WindowFunction { fun, params }) => match fun {
WindowFunctionDefinition::AggregateUDF(fun) => {
match fun.window_function_schema_name(params) {
Ok(name) => {
write!(f, "{name}")
}
Err(e) => {
write!(f, "got error from window_function_schema_name {}", e)
}
}
}
_ => {
let WindowFunctionParams {
args,
partition_by,
order_by,
window_frame,
null_treatment,
} = params;
write!(
f,
"{}({})",
fun,
schema_name_from_exprs_comma_separated_without_space(args)?
)?;
if let Some(null_treatment) = null_treatment {
write!(f, " {}", null_treatment)?;
}
if !partition_by.is_empty() {
write!(
f,
" PARTITION BY [{}]",
schema_name_from_exprs(partition_by)?
)?;
}
if !order_by.is_empty() {
write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?;
};
write!(f, " {window_frame}")
}
},
}
}
}
struct SqlDisplay<'a>(&'a Expr);
impl Display for SqlDisplay<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self.0 {
Expr::Literal(scalar) => scalar.fmt(f),
Expr::Alias(Alias { name, .. }) => write!(f, "{name}"),
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
if *negated {
write!(
f,
"{} NOT BETWEEN {} AND {}",
SqlDisplay(expr),
SqlDisplay(low),
SqlDisplay(high),
)
} else {
write!(
f,
"{} BETWEEN {} AND {}",
SqlDisplay(expr),
SqlDisplay(low),
SqlDisplay(high),
)
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
write!(f, "{} {op} {}", SqlDisplay(left), SqlDisplay(right),)
}
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => {
write!(f, "CASE ")?;
if let Some(e) = expr {
write!(f, "{} ", SqlDisplay(e))?;
}
for (when, then) in when_then_expr {
write!(f, "WHEN {} THEN {} ", SqlDisplay(when), SqlDisplay(then),)?;
}
if let Some(e) = else_expr {
write!(f, "ELSE {} ", SqlDisplay(e))?;
}
write!(f, "END")
}
Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => {
write!(f, "{}", SqlDisplay(expr))
}
Expr::InList(InList {
expr,
list,
negated,
}) => {
write!(
f,
"{}{} IN {}",
SqlDisplay(expr),
if *negated { " NOT" } else { "" },
ExprListDisplay::comma_separated(list.as_slice())
)
}
Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
write!(
f,
"ROLLUP ({})",
ExprListDisplay::comma_separated(exprs.as_slice())
)
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
write!(f, "GROUPING SETS (")?;
for exprs in lists_of_exprs.iter() {
write!(
f,
"({})",
ExprListDisplay::comma_separated(exprs.as_slice())
)?;
}
write!(f, ")")
}
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
write!(
f,
"ROLLUP ({})",
ExprListDisplay::comma_separated(exprs.as_slice())
)
}
Expr::IsNull(expr) => write!(f, "{} IS NULL", SqlDisplay(expr)),
Expr::IsNotNull(expr) => {
write!(f, "{} IS NOT NULL", SqlDisplay(expr))
}
Expr::IsUnknown(expr) => {
write!(f, "{} IS UNKNOWN", SqlDisplay(expr))
}
Expr::IsNotUnknown(expr) => {
write!(f, "{} IS NOT UNKNOWN", SqlDisplay(expr))
}
Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SqlDisplay(expr)),
Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SqlDisplay(expr)),
Expr::IsNotTrue(expr) => {
write!(f, "{} IS NOT TRUE", SqlDisplay(expr))
}
Expr::IsNotFalse(expr) => {
write!(f, "{} IS NOT FALSE", SqlDisplay(expr))
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => {
write!(
f,
"{} {}{} {}",
SqlDisplay(expr),
if *negated { "NOT " } else { "" },
if *case_insensitive { "ILIKE" } else { "LIKE" },
SqlDisplay(pattern),
)?;
if let Some(char) = escape_char {
write!(f, " CHAR '{char}'")?;
}
Ok(())
}
Expr::Negative(expr) => write!(f, "(- {})", SqlDisplay(expr)),
Expr::Not(expr) => write!(f, "NOT {}", SqlDisplay(expr)),
Expr::Unnest(Unnest { expr }) => {
write!(f, "UNNEST({})", SqlDisplay(expr))
}
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
..
}) => {
write!(
f,
"{} {} {}",
SqlDisplay(expr),
if *negated {
"NOT SIMILAR TO"
} else {
"SIMILAR TO"
},
SqlDisplay(pattern),
)?;
if let Some(char) = escape_char {
write!(f, " CHAR '{char}'")?;
}
Ok(())
}
Expr::AggregateFunction(AggregateFunction { func, params }) => {
match func.human_display(params) {
Ok(name) => {
write!(f, "{name}")
}
Err(e) => {
write!(f, "got error from schema_name {}", e)
}
}
}
_ => write!(f, "{}", self.0),
}
}
}
pub(crate) fn schema_name_from_exprs_comma_separated_without_space(
exprs: &[Expr],
) -> Result<String, fmt::Error> {
schema_name_from_exprs_inner(exprs, ",")
}
pub struct ExprListDisplay<'a> {
exprs: &'a [Expr],
sep: &'a str,
}
impl<'a> ExprListDisplay<'a> {
pub fn new(exprs: &'a [Expr], sep: &'a str) -> Self {
Self { exprs, sep }
}
pub fn comma_separated(exprs: &'a [Expr]) -> Self {
Self::new(exprs, ", ")
}
}
impl Display for ExprListDisplay<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let mut first = true;
for expr in self.exprs {
if !first {
write!(f, "{}", self.sep)?;
}
write!(f, "{}", SqlDisplay(expr))?;
first = false;
}
Ok(())
}
}
pub fn schema_name_from_exprs(exprs: &[Expr]) -> Result<String, fmt::Error> {
schema_name_from_exprs_inner(exprs, ", ")
}
fn schema_name_from_exprs_inner(exprs: &[Expr], sep: &str) -> Result<String, fmt::Error> {
let mut s = String::new();
for (i, e) in exprs.iter().enumerate() {
if i > 0 {
write!(&mut s, "{sep}")?;
}
write!(&mut s, "{}", SchemaDisplay(e))?;
}
Ok(s)
}
pub fn schema_name_from_sorts(sorts: &[Sort]) -> Result<String, fmt::Error> {
let mut s = String::new();
for (i, e) in sorts.iter().enumerate() {
if i > 0 {
write!(&mut s, ", ")?;
}
let ordering = if e.asc { "ASC" } else { "DESC" };
let nulls_ordering = if e.nulls_first {
"NULLS FIRST"
} else {
"NULLS LAST"
};
write!(&mut s, "{} {} {}", e.expr, ordering, nulls_ordering)?;
}
Ok(s)
}
pub const OUTER_REFERENCE_COLUMN_PREFIX: &str = "outer_ref";
pub const UNNEST_COLUMN_PREFIX: &str = "UNNEST";
impl Display for Expr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"),
Expr::Column(c) => write!(f, "{c}"),
Expr::OuterReferenceColumn(_, c) => {
write!(f, "{OUTER_REFERENCE_COLUMN_PREFIX}({c})")
}
Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")),
Expr::Literal(v) => write!(f, "{v:?}"),
Expr::Case(case) => {
write!(f, "CASE ")?;
if let Some(e) = &case.expr {
write!(f, "{e} ")?;
}
for (w, t) in &case.when_then_expr {
write!(f, "WHEN {w} THEN {t} ")?;
}
if let Some(e) = &case.else_expr {
write!(f, "ELSE {e} ")?;
}
write!(f, "END")
}
Expr::Cast(Cast { expr, data_type }) => {
write!(f, "CAST({expr} AS {data_type:?})")
}
Expr::TryCast(TryCast { expr, data_type }) => {
write!(f, "TRY_CAST({expr} AS {data_type:?})")
}
Expr::Not(expr) => write!(f, "NOT {expr}"),
Expr::Negative(expr) => write!(f, "(- {expr})"),
Expr::IsNull(expr) => write!(f, "{expr} IS NULL"),
Expr::IsNotNull(expr) => write!(f, "{expr} IS NOT NULL"),
Expr::IsTrue(expr) => write!(f, "{expr} IS TRUE"),
Expr::IsFalse(expr) => write!(f, "{expr} IS FALSE"),
Expr::IsUnknown(expr) => write!(f, "{expr} IS UNKNOWN"),
Expr::IsNotTrue(expr) => write!(f, "{expr} IS NOT TRUE"),
Expr::IsNotFalse(expr) => write!(f, "{expr} IS NOT FALSE"),
Expr::IsNotUnknown(expr) => write!(f, "{expr} IS NOT UNKNOWN"),
Expr::Exists(Exists {
subquery,
negated: true,
}) => write!(f, "NOT EXISTS ({subquery:?})"),
Expr::Exists(Exists {
subquery,
negated: false,
}) => write!(f, "EXISTS ({subquery:?})"),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated: true,
}) => write!(f, "{expr} NOT IN ({subquery:?})"),
Expr::InSubquery(InSubquery {
expr,
subquery,
negated: false,
}) => write!(f, "{expr} IN ({subquery:?})"),
Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"),
Expr::BinaryExpr(expr) => write!(f, "{expr}"),
Expr::ScalarFunction(fun) => {
fmt_function(f, fun.name(), false, &fun.args, true)
}
Expr::WindowFunction(WindowFunction { fun, params }) => match fun {
WindowFunctionDefinition::AggregateUDF(fun) => {
match fun.window_function_display_name(params) {
Ok(name) => {
write!(f, "{}", name)
}
Err(e) => {
write!(f, "got error from window_function_display_name {}", e)
}
}
}
WindowFunctionDefinition::WindowUDF(fun) => {
let WindowFunctionParams {
args,
partition_by,
order_by,
window_frame,
null_treatment,
} = params;
fmt_function(f, &fun.to_string(), false, args, true)?;
if let Some(nt) = null_treatment {
write!(f, "{}", nt)?;
}
if !partition_by.is_empty() {
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
}
if !order_by.is_empty() {
write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?;
}
write!(
f,
" {} BETWEEN {} AND {}",
window_frame.units,
window_frame.start_bound,
window_frame.end_bound
)
}
},
Expr::AggregateFunction(AggregateFunction { func, params }) => {
match func.display_name(params) {
Ok(name) => {
write!(f, "{}", name)
}
Err(e) => {
write!(f, "got error from display_name {}", e)
}
}
}
Expr::Between(Between {
expr,
negated,
low,
high,
}) => {
if *negated {
write!(f, "{expr} NOT BETWEEN {low} AND {high}")
} else {
write!(f, "{expr} BETWEEN {low} AND {high}")
}
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => {
write!(f, "{expr}")?;
let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" };
if *negated {
write!(f, " NOT")?;
}
if let Some(char) = escape_char {
write!(f, " {op_name} {pattern} ESCAPE '{char}'")
} else {
write!(f, " {op_name} {pattern}")
}
}
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive: _,
}) => {
write!(f, "{expr}")?;
if *negated {
write!(f, " NOT")?;
}
if let Some(char) = escape_char {
write!(f, " SIMILAR TO {pattern} ESCAPE '{char}'")
} else {
write!(f, " SIMILAR TO {pattern}")
}
}
Expr::InList(InList {
expr,
list,
negated,
}) => {
if *negated {
write!(f, "{expr} NOT IN ([{}])", expr_vec_fmt!(list))
} else {
write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list))
}
}
#[expect(deprecated)]
Expr::Wildcard { qualifier, options } => match qualifier {
Some(qualifier) => write!(f, "{qualifier}.*{options}"),
None => write!(f, "*{options}"),
},
Expr::GroupingSet(grouping_sets) => match grouping_sets {
GroupingSet::Rollup(exprs) => {
write!(f, "ROLLUP ({})", expr_vec_fmt!(exprs))
}
GroupingSet::Cube(exprs) => {
write!(f, "CUBE ({})", expr_vec_fmt!(exprs))
}
GroupingSet::GroupingSets(lists_of_exprs) => {
write!(
f,
"GROUPING SETS ({})",
lists_of_exprs
.iter()
.map(|exprs| format!("({})", expr_vec_fmt!(exprs)))
.collect::<Vec<String>>()
.join(", ")
)
}
},
Expr::Placeholder(Placeholder { id, .. }) => write!(f, "{id}"),
Expr::Unnest(Unnest { expr }) => {
write!(f, "{UNNEST_COLUMN_PREFIX}({expr})")
}
}
}
}
fn fmt_function(
f: &mut Formatter,
fun: &str,
distinct: bool,
args: &[Expr],
display: bool,
) -> fmt::Result {
let args: Vec<String> = match display {
true => args.iter().map(|arg| format!("{arg}")).collect(),
false => args.iter().map(|arg| format!("{arg:?}")).collect(),
};
let distinct_str = match distinct {
true => "DISTINCT ",
false => "",
};
write!(f, "{}({}{})", fun, distinct_str, args.join(", "))
}
pub fn physical_name(expr: &Expr) -> Result<String> {
match expr {
Expr::Column(col) => Ok(col.name.clone()),
Expr::Alias(alias) => Ok(alias.name.clone()),
_ => Ok(expr.schema_name().to_string()),
}
}
#[cfg(test)]
mod test {
use crate::expr_fn::col;
use crate::{
case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility,
};
use sqlparser::ast;
use sqlparser::ast::{Ident, IdentWithAlias};
use std::any::Any;
#[test]
#[allow(deprecated)]
fn format_case_when() -> Result<()> {
let expr = case(col("a"))
.when(lit(1), lit(true))
.when(lit(0), lit(false))
.otherwise(lit(ScalarValue::Null))?;
let expected = "CASE a WHEN Int32(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END";
assert_eq!(expected, expr.canonical_name());
assert_eq!(expected, format!("{expr}"));
Ok(())
}
#[test]
#[allow(deprecated)]
fn format_cast() -> Result<()> {
let expr = Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))),
data_type: DataType::Utf8,
});
let expected_canonical = "CAST(Float32(1.23) AS Utf8)";
assert_eq!(expected_canonical, expr.canonical_name());
assert_eq!(expected_canonical, format!("{expr}"));
assert_eq!("Float32(1.23)", expr.schema_name().to_string());
Ok(())
}
#[test]
fn test_partial_ord() {
let exp1 = col("a") + lit(1);
let exp2 = col("a") + lit(2);
let exp3 = !(col("a") + lit(2));
assert!(exp1 < exp2);
assert!(exp3 > exp2);
assert!(exp1 < exp3)
}
#[test]
fn test_collect_expr() -> Result<()> {
{
let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
let columns = expr.column_refs();
assert_eq!(1, columns.len());
assert!(columns.contains(&Column::from_name("a")));
}
{
let expr = col("a") + col("b") + lit(1);
let columns = expr.column_refs();
assert_eq!(2, columns.len());
assert!(columns.contains(&Column::from_name("a")));
assert!(columns.contains(&Column::from_name("b")));
}
Ok(())
}
#[test]
fn test_logical_ops() {
assert_eq!(
format!("{}", lit(1u32).eq(lit(2u32))),
"UInt32(1) = UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).not_eq(lit(2u32))),
"UInt32(1) != UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).gt(lit(2u32))),
"UInt32(1) > UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).gt_eq(lit(2u32))),
"UInt32(1) >= UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).lt(lit(2u32))),
"UInt32(1) < UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).lt_eq(lit(2u32))),
"UInt32(1) <= UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).and(lit(2u32))),
"UInt32(1) AND UInt32(2)"
);
assert_eq!(
format!("{}", lit(1u32).or(lit(2u32))),
"UInt32(1) OR UInt32(2)"
);
}
#[test]
fn test_is_volatile_scalar_func() {
#[derive(Debug)]
struct TestScalarUDF {
signature: Signature,
}
impl ScalarUDFImpl for TestScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"TestScalarUDF"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
}
fn invoke_with_args(
&self,
_args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
}
}
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
}));
assert_ne!(udf.signature().volatility, Volatility::Volatile);
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(
1,
vec![DataType::Float32],
Volatility::Volatile,
),
}));
assert_eq!(udf.signature().volatility, Volatility::Volatile);
}
use super::*;
#[test]
fn test_display_wildcard() {
assert_eq!(format!("{}", wildcard()), "*");
assert_eq!(format!("{}", qualified_wildcard("t1")), "t1.*");
assert_eq!(
format!(
"{}",
wildcard_with_options(wildcard_options(
Some(IlikeSelectItem {
pattern: "c1".to_string()
}),
None,
None,
None,
None
))
),
"* ILIKE 'c1'"
);
assert_eq!(
format!(
"{}",
wildcard_with_options(wildcard_options(
None,
Some(ExcludeSelectItem::Multiple(vec![
Ident::from("c1"),
Ident::from("c2")
])),
None,
None,
None
))
),
"* EXCLUDE (c1, c2)"
);
assert_eq!(
format!(
"{}",
wildcard_with_options(wildcard_options(
None,
None,
Some(ExceptSelectItem {
first_element: Ident::from("c1"),
additional_elements: vec![Ident::from("c2")]
}),
None,
None
))
),
"* EXCEPT (c1, c2)"
);
assert_eq!(
format!(
"{}",
wildcard_with_options(wildcard_options(
None,
None,
None,
Some(PlannedReplaceSelectItem {
items: vec![ReplaceSelectElement {
expr: ast::Expr::Identifier(Ident::from("c1")),
column_name: Ident::from("a1"),
as_keyword: false
}],
planned_expressions: vec![]
}),
None
))
),
"* REPLACE (c1 a1)"
);
assert_eq!(
format!(
"{}",
wildcard_with_options(wildcard_options(
None,
None,
None,
None,
Some(RenameSelectItem::Multiple(vec![IdentWithAlias {
ident: Ident::from("c1"),
alias: Ident::from("a1")
}]))
))
),
"* RENAME (c1 AS a1)"
)
}
#[test]
fn test_schema_display_alias_with_relation() {
assert_eq!(
format!(
"{}",
SchemaDisplay(
&lit(1).alias_qualified("table_name".into(), "column_name")
)
),
"table_name.column_name"
);
}
#[test]
fn test_schema_display_alias_without_relation() {
assert_eq!(
format!(
"{}",
SchemaDisplay(&lit(1).alias_qualified(None::<&str>, "column_name"))
),
"column_name"
);
}
fn wildcard_options(
opt_ilike: Option<IlikeSelectItem>,
opt_exclude: Option<ExcludeSelectItem>,
opt_except: Option<ExceptSelectItem>,
opt_replace: Option<PlannedReplaceSelectItem>,
opt_rename: Option<RenameSelectItem>,
) -> WildcardOptions {
WildcardOptions {
ilike: opt_ilike,
exclude: opt_exclude,
except: opt_except,
replace: opt_replace,
rename: opt_rename,
}
}
}