use std::{cmp::Ordering, sync::Arc, vec};
use super::{
dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle,
rewrite::TableAliasRewriter, Unparser,
};
use datafusion_common::{
internal_err,
tree_node::{Transformed, TransformedResult, TreeNode},
Column, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan,
LogicalPlanBuilder, Projection, SortExpr, Unnest, Window,
};
use indexmap::IndexSet;
use sqlparser::ast;
use sqlparser::tokenizer::Span;
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
already_projected: bool,
) -> Option<&Aggregate> {
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
} else {
find_agg_node_within_select(input, true)
}
} else {
find_agg_node_within_select(input, already_projected)
}
}
pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> {
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
if let LogicalPlan::Unnest(unnest) = input {
Some(unnest)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
None
} else {
find_unnest_node_within_select(input)
}
}
pub(crate) fn find_unnest_node_until_relation(plan: &LogicalPlan) -> Option<&Unnest> {
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
if let LogicalPlan::Unnest(unnest) = input {
Some(unnest)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Subquery(_) = input {
None
} else if let LogicalPlan::SubqueryAlias(_) = input {
None
} else {
find_unnest_node_within_select(input)
}
}
pub(crate) fn find_window_nodes_within_select<'a>(
plan: &'a LogicalPlan,
mut prev_windows: Option<Vec<&'a Window>>,
already_projected: bool,
) -> Option<Vec<&'a Window>> {
let input = plan.inputs();
let input = if input.len() > 1 {
return prev_windows;
} else {
input.first()?
};
match input {
LogicalPlan::Window(window) => {
prev_windows = match &mut prev_windows {
Some(windows) => {
windows.push(window);
prev_windows
}
_ => Some(vec![window]),
};
find_window_nodes_within_select(input, prev_windows, already_projected)
}
LogicalPlan::Projection(_) => {
if already_projected {
prev_windows
} else {
find_window_nodes_within_select(input, prev_windows, true)
}
}
LogicalPlan::TableScan(_) => prev_windows,
_ => find_window_nodes_within_select(input, prev_windows, already_projected),
}
}
pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result<Expr> {
expr.transform(|sub_expr| {
if let Expr::Column(col_ref) = &sub_expr {
if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) {
if let Ok(idx) = unnest.schema.index_of_column(col_ref) {
if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() {
if let Some(unprojected_expr) = expr.get(idx) {
let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone()));
return Ok(Transformed::yes(unnest_expr));
}
}
}
return internal_err!(
"Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name
);
}
}
Ok(Transformed::no(sub_expr))
}).map(|e| e.data)
}
pub(crate) fn unproject_agg_exprs(
expr: Expr,
agg: &Aggregate,
windows: Option<&[&Window]>,
) -> Result<Expr> {
expr.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
Ok(Transformed::yes(unprojected_expr.clone()))
} else if let Some(unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?));
} else {
internal_err!(
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
)
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}
pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result<Expr> {
expr.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = find_window_expr(windows, &c.name) {
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}
fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
if let Ok(index) = agg.schema.index_of_column(column) {
if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?;
match index.cmp(&grouping_expr.len()) {
Ordering::Less => Ok(grouping_expr.into_iter().nth(index)),
Ordering::Equal => {
internal_err!(
"Tried to unproject column referring to internal grouping id"
)
}
Ordering::Greater => {
Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1))
}
}
} else {
Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index))
}
} else {
Ok(None)
}
}
fn find_window_expr<'a>(
windows: &'a [&'a Window],
column_name: &'a str,
) -> Option<&'a Expr> {
windows
.iter()
.flat_map(|w| w.window_expr.iter())
.find(|expr| expr.schema_name().to_string() == column_name)
}
pub(crate) fn unproject_sort_expr(
mut sort_expr: SortExpr,
agg: Option<&Aggregate>,
input: &LogicalPlan,
) -> Result<SortExpr> {
sort_expr.expr = sort_expr
.expr
.transform(|sub_expr| {
match sub_expr {
Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)),
Expr::Column(col) => {
if col.relation.is_some() {
return Ok(Transformed::no(Expr::Column(col)));
}
if let Some(agg) = agg {
if agg.schema.is_column_from_schema(&col) {
return Ok(Transformed::yes(unproject_agg_exprs(
Expr::Column(col),
agg,
None,
)?));
}
}
if let LogicalPlan::Projection(Projection { expr, schema, .. }) =
input
{
if let Ok(idx) = schema.index_of_column(&col) {
if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
return Ok(Transformed::yes(Expr::ScalarFunction(
scalar_fn.clone(),
)));
}
}
}
Ok(Transformed::no(Expr::Column(col)))
}
_ => Ok(Transformed::no(sub_expr)),
}
})
.map(|e| e.data)?;
Ok(sort_expr)
}
pub(crate) fn try_transform_to_simple_table_scan_with_filters(
plan: &LogicalPlan,
) -> Result<Option<(LogicalPlan, Vec<Expr>)>> {
let mut filters: IndexSet<Expr> = IndexSet::new();
let mut plan_stack = vec![plan];
let mut table_alias = None;
while let Some(current_plan) = plan_stack.pop() {
match current_plan {
LogicalPlan::SubqueryAlias(alias) => {
table_alias = Some(alias.alias.clone());
plan_stack.push(alias.input.as_ref());
}
LogicalPlan::Filter(filter) => {
if !filters.contains(&filter.predicate) {
filters.insert(filter.predicate.clone());
}
plan_stack.push(filter.input.as_ref());
}
LogicalPlan::TableScan(table_scan) => {
let table_schema = table_scan.source.schema();
let mut filter_alias_rewriter =
table_alias.as_ref().map(|alias_name| TableAliasRewriter {
table_schema: &table_schema,
alias_name: alias_name.clone(),
});
let table_scan_filters = table_scan
.filters
.iter()
.cloned()
.map(|expr| {
if let Some(ref mut rewriter) = filter_alias_rewriter {
expr.rewrite(rewriter).data()
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>, DataFusionError>>()?;
for table_scan_filter in table_scan_filters {
if !filters.contains(&table_scan_filter) {
filters.insert(table_scan_filter);
}
}
let mut builder = LogicalPlanBuilder::scan(
table_scan.table_name.clone(),
Arc::clone(&table_scan.source),
table_scan.projection.clone(),
)?;
if let Some(alias) = table_alias.take() {
builder = builder.alias(alias)?;
}
let plan = builder.build()?;
let filters = filters.into_iter().collect();
return Ok(Some((plan, filters)));
}
_ => {
return Ok(None);
}
}
}
Ok(None)
}
pub(crate) fn date_part_to_sql(
unparser: &Unparser,
style: DateFieldExtractStyle,
date_part_args: &[Expr],
) -> Result<Option<ast::Expr>> {
match (style, date_part_args.len()) {
(DateFieldExtractStyle::Extract, 2) => {
let date_expr = unparser.expr_to_sql(&date_part_args[1])?;
if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
let field = match field.to_lowercase().as_str() {
"year" => ast::DateTimeField::Year,
"month" => ast::DateTimeField::Month,
"day" => ast::DateTimeField::Day,
"hour" => ast::DateTimeField::Hour,
"minute" => ast::DateTimeField::Minute,
"second" => ast::DateTimeField::Second,
_ => return Ok(None),
};
return Ok(Some(ast::Expr::Extract {
field,
expr: Box::new(date_expr),
syntax: ast::ExtractSyntax::From,
}));
}
}
(DateFieldExtractStyle::Strftime, 2) => {
let column = unparser.expr_to_sql(&date_part_args[1])?;
if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
let field = match field.to_lowercase().as_str() {
"year" => "%Y",
"month" => "%m",
"day" => "%d",
"hour" => "%H",
"minute" => "%M",
"second" => "%S",
_ => return Ok(None),
};
return Ok(Some(ast::Expr::Function(ast::Function {
name: ast::ObjectName::from(vec![ast::Ident {
value: "strftime".to_string(),
quote_style: None,
span: Span::empty(),
}]),
args: ast::FunctionArguments::List(ast::FunctionArgumentList {
duplicate_treatment: None,
args: vec![
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
ast::Expr::value(ast::Value::SingleQuotedString(
field.to_string(),
)),
)),
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)),
],
clauses: vec![],
}),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
parameters: ast::FunctionArguments::None,
uses_odbc_syntax: false,
})));
}
}
(DateFieldExtractStyle::DatePart, _) => {
return Ok(Some(
unparser.scalar_function_to_sql("date_part", date_part_args)?,
));
}
_ => {}
};
Ok(None)
}
pub(crate) fn character_length_to_sql(
unparser: &Unparser,
style: CharacterLengthStyle,
character_length_args: &[Expr],
) -> Result<Option<ast::Expr>> {
let func_name = match style {
CharacterLengthStyle::CharacterLength => "character_length",
CharacterLengthStyle::Length => "length",
};
Ok(Some(unparser.scalar_function_to_sql(
func_name,
character_length_args,
)?))
}
pub(crate) fn sqlite_from_unixtime_to_sql(
unparser: &Unparser,
from_unixtime_args: &[Expr],
) -> Result<Option<ast::Expr>> {
if from_unixtime_args.len() != 1 {
return internal_err!(
"from_unixtime for SQLite expects 1 argument, found {}",
from_unixtime_args.len()
);
}
Ok(Some(unparser.scalar_function_to_sql(
"datetime",
&[
from_unixtime_args[0].clone(),
Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string())), None),
],
)?))
}
pub(crate) fn sqlite_date_trunc_to_sql(
unparser: &Unparser,
date_trunc_args: &[Expr],
) -> Result<Option<ast::Expr>> {
if date_trunc_args.len() != 2 {
return internal_err!(
"date_trunc for SQLite expects 2 arguments, found {}",
date_trunc_args.len()
);
}
if let Expr::Literal(ScalarValue::Utf8(Some(unit)), _) = &date_trunc_args[0] {
let format = match unit.to_lowercase().as_str() {
"year" => "%Y",
"month" => "%Y-%m",
"day" => "%Y-%m-%d",
"hour" => "%Y-%m-%d %H",
"minute" => "%Y-%m-%d %H:%M",
"second" => "%Y-%m-%d %H:%M:%S",
_ => return Ok(None),
};
return Ok(Some(unparser.scalar_function_to_sql(
"strftime",
&[
Expr::Literal(ScalarValue::Utf8(Some(format.to_string())), None),
date_trunc_args[1].clone(),
],
)?));
}
Ok(None)
}