use std::mem::swap;
use hamelin_lib::operator::Operator;
use hamelin_lib::sql::expression::apply::BinaryOperatorApply;
use hamelin_lib::sql::expression::identifier::Identifier;
use hamelin_lib::sql::expression::{OrderByExpression, SQLExpression};
use hamelin_lib::sql::query::projection::{Binding, Projection};
use hamelin_lib::sql::query::SQLQuery;
use crate::env::Environment;
fn query_has_window_functions(query: &SQLQuery) -> bool {
query.projections.iter().any(|p| match p {
Projection::Binding(Binding { expression, .. }) => {
expression_has_window_function(expression)
}
Projection::ColumnProjection(_) => false,
}) || !query.windows.windows.is_empty()
}
fn expression_has_window_function(expr: &SQLExpression) -> bool {
match expr {
SQLExpression::WindowExpression(_) => true,
SQLExpression::Leaf(_) => false,
SQLExpression::BinaryOperatorApply(b) => {
expression_has_window_function(&b.left) || expression_has_window_function(&b.right)
}
SQLExpression::UnaryOperatorApply(u) => expression_has_window_function(&u.operand),
SQLExpression::Cast(c) => expression_has_window_function(&c.expression),
SQLExpression::TryCast(t) => expression_has_window_function(&t.expression),
SQLExpression::FunctionCallApply(f) => {
f.arguments.iter().any(expression_has_window_function)
}
SQLExpression::Case(c) => {
c.else_expression
.as_ref()
.is_some_and(|e| expression_has_window_function(e))
|| c.when_expressions.iter().any(|(cond, result)| {
expression_has_window_function(cond) || expression_has_window_function(result)
})
}
SQLExpression::RowLiteral(r) => r.values.iter().any(expression_has_window_function),
SQLExpression::ArrayLiteral(a) => a.elements.iter().any(expression_has_window_function),
SQLExpression::Dot(d) => expression_has_window_function(&d.expression),
SQLExpression::SQLIndexLookup(i) => {
expression_has_window_function(&i.expression)
|| expression_has_window_function(&i.index)
}
SQLExpression::Lambda(l) => expression_has_window_function(&l.body),
SQLExpression::OrderByExpression(o) => expression_has_window_function(&o.expression),
SQLExpression::RegexpExtractFunction(r) => expression_has_window_function(&r.value),
SQLExpression::RegexpCountFunction(r) => expression_has_window_function(&r.value),
SQLExpression::ExtractFunction(e) => expression_has_window_function(&e.value),
SQLExpression::TupleLiteral(t) => t.elements.iter().any(expression_has_window_function),
}
}
pub fn prepend_projections(
query: &SQLQuery,
mut projections: Vec<Projection>,
env: &Environment,
) -> SQLQuery {
let mut ret = query.clone();
if ret.references_column_in_projections(&projections) {
ret = ret.push_down();
}
let new_names: Vec<Identifier> = projections
.iter()
.map(|p| match p {
Projection::Binding(Binding { name, .. }) => name.clone().into(),
Projection::ColumnProjection(cp) => cp.identifier.clone(),
})
.collect();
ret = ret.remove_projections(&new_names[..]);
swap(&mut ret.projections, &mut projections);
ret.projections.extend(projections.iter().cloned());
for projection in env.get_column_projections() {
if ret.projections.iter().all(|p| match p {
Projection::ColumnProjection(cp) => {
cp.identifier.last() != projection.identifier.last()
}
Projection::Binding(Binding { name, .. }) => {
let id: Identifier = name.clone().into();
id.last() != projection.identifier.last()
}
}) {
ret.projections.push(projection.into());
}
}
ret
}
pub fn add_filter_condition(
query: &SQLQuery,
condition: SQLExpression,
env: &Environment,
) -> SQLQuery {
let mut ret = query.clone();
let needs_pushdown = ret
.references_columns_in_column_refs(&condition.get_column_references()[..])
|| query_has_window_functions(&ret);
if needs_pushdown {
ret = ret.push_down().select(
env.get_column_projections()
.into_iter()
.map(|cp| cp.into())
.collect(),
);
}
let new_where = ret
.where_
.map(|current| {
BinaryOperatorApply::new(
Operator::And.try_into().unwrap(),
condition.clone(),
current.clone(),
)
.into()
})
.unwrap_or(condition);
ret.where_ = Some(new_where);
ret
}
pub fn add_order_expression(
query: SQLQuery,
order_by: Vec<OrderByExpression>,
env: &Environment,
) -> SQLQuery {
let push_down_needed = order_by
.iter()
.map(|e| SQLExpression::from(e.clone()).get_column_references())
.any(|c| query.references_columns_in_column_refs(&c[..]));
let ret_query = if push_down_needed {
query.push_down().select(
env.get_column_projections()
.into_iter()
.map(|cp| Projection::from(cp))
.collect(),
)
} else {
query.clone()
};
ret_query.order_by(order_by)
}
pub fn apply_limit(query: SQLQuery, limit: SQLExpression, env: &Environment) -> SQLQuery {
if query.limit.is_some() {
query
.push_down()
.select(
env.get_column_projections()
.into_iter()
.map(|cp| Projection::from(cp))
.collect(),
)
.limit(limit)
} else {
query.limit(limit)
}
}