use crate::ast::assignment_clause::HamelinAssignmentClause;
use crate::ast::expression::HamelinExpression;
use crate::ast::pipeline::HamelinPipeline;
use crate::ast::sort_expression::HamelinSortExpression;
use crate::env::Environment;
use crate::translation::projection_builder::{ProjectionBuilder, ProjectionBuilderExt};
use crate::translation::sql_query_helpers::prepend_projections;
use crate::translation::{PendingQuery, PendingQueryResult};
use hamelin_lib::antlr::hamelinparser::{
ExpressionContextAll, GroupClauseContextAttrs, WindowCommandContext, WindowCommandContextAttrs,
};
use hamelin_lib::err::{TranslationError, TranslationErrors};
use hamelin_lib::func::def::{FunctionTranslationContext, SpecialPosition};
use hamelin_lib::sql::expression::identifier::SimpleIdentifier;
use hamelin_lib::sql::expression::literal::ColumnReference;
use hamelin_lib::sql::expression::operator::Operator;
use hamelin_lib::sql::expression::{Direction, Leaf, OrderByExpression, SQLExpression};
use hamelin_lib::sql::query::window::{
FrameBoundary, FrameType, NamedWindowReference, WindowFrame, WindowReference,
WindowSpecification,
};
use hamelin_lib::types::matcher::{Matcher, NumericMatcher};
use hamelin_lib::types::{Type, CALENDAR_INTERVAL, INTERVAL, ROWS, TIMESTAMP};
use std::rc::Rc;
pub fn translate(
ctx: &WindowCommandContext<'static>,
pipeline: &HamelinPipeline,
pending_query_result: &mut PendingQueryResult,
) {
let window_spec =
match get_window_specification(ctx, pipeline, &pending_query_result.translation) {
Ok(window_spec) => window_spec,
Err(e) => {
pending_query_result.errors.extend(e);
WindowSpecification::default()
}
};
let window_ref: WindowReference = if ctx.assignmentClause_all().len() == 1 {
window_spec.clone().into()
} else {
NamedWindowReference::new(SimpleIdentifier::new(&format!(
"w{}",
pending_query_result.translation.query.windows.windows.len()
)))
.into()
};
let fctx = FunctionTranslationContext::default()
.with_special_allowed(SpecialPosition::Agg)
.with_special_allowed(SpecialPosition::Window)
.with_window(window_ref.clone());
let mut new_projections = ProjectionBuilder::default();
for clause in ctx.assignmentClause_all() {
let expr_ctx = pipeline
.context
.expression_translation_context(&pending_query_result.translation.env, fctx.clone());
if let Some((identifier, translation)) = pending_query_result
.errors
.consume_errors(HamelinAssignmentClause::new(clause.clone(), expr_ctx).to_sql())
{
new_projections.intialize_struct_reference_then_bind(
identifier,
translation.sql,
translation.typ,
&pending_query_result.translation.env,
);
}
}
for clause in ctx.groupClause_all() {
let expr_ctx = pipeline
.context
.expression_translation_context(&pending_query_result.translation.env, fctx.clone());
if let Some((identifier, translation)) = pending_query_result.errors.consume_errors(
TranslationErrors::expect(clause.as_ref(), clause.assignmentClause()).and_then(
|assignment_clause| {
HamelinAssignmentClause::new(assignment_clause.clone(), expr_ctx).to_sql()
},
),
) {
new_projections.bind(identifier, translation.sql, translation.typ);
}
}
let env = pending_query_result
.translation
.env
.clone()
.prepend_overwrite(&Environment::new(
new_projections.clone().build_hamelin_type(),
));
let mut query = pending_query_result.translation.query.clone();
let window_column_refs: Vec<ColumnReference> = window_spec
.partition_by
.iter()
.flat_map(|e| e.get_column_references())
.chain(
window_spec
.order_by
.iter()
.flat_map(|obe| obe.expression.get_column_references()),
)
.collect();
if query.references_columns_in_column_refs(&window_column_refs[..]) {
query = query.push_down();
}
if let Some(projections) = pending_query_result.errors.consume_errors(
new_projections
.build_projections()
.map_err(|e| TranslationError::wrap_box(ctx, e.into()).single()),
) {
query = prepend_projections(&query, projections, &env);
}
if let WindowReference::NamedWindowReference(nwr) = window_ref {
query = query.with_window(nwr.name, window_spec);
}
pending_query_result.translation.query = query;
pending_query_result.translation.env = env;
}
fn get_window_specification(
ctx: &WindowCommandContext<'static>,
pipeline: &HamelinPipeline,
previous: &PendingQuery,
) -> Result<WindowSpecification, TranslationErrors> {
let group = TranslationErrors::from_vec(
ctx.groupClause_all()
.iter()
.map(|ctx| {
HamelinAssignmentClause::new(
TranslationErrors::expect(ctx.as_ref(), ctx.assignmentClause())?.clone(),
pipeline
.context
.default_expression_translation_context(&previous.env),
)
.to_sql()
.map(|res| res.1.sql)
})
.collect(),
)?;
let sort = if !ctx.sortExpression_all().is_empty() {
let expr_ctx = pipeline
.context
.expression_translation_context(&previous.env, FunctionTranslationContext::default());
HamelinSortExpression::new(ctx.sortExpression_all(), expr_ctx).translate()?
} else {
get_timestamp_ref(previous)
.map(|(ts, typ)| {
(
vec![OrderByExpression::new(ts.into(), Direction::ASC)],
vec![typ],
)
})
.unwrap_or_default()
};
let frame = if let Some(within) = &ctx.within {
Some(get_window_frame(within.clone(), &sort, pipeline, previous)?)
} else {
None
};
Ok(WindowSpecification {
partition_by: group,
order_by: sort.0,
frame,
})
}
pub fn get_timestamp_ref(previous: &PendingQuery) -> Option<(ColumnReference, Type)> {
if let Some(cr) = previous
.env
.lookup(&SimpleIdentifier::new("timestamp").into())
.into_iter()
.filter(|t| *t == TIMESTAMP)
.next()
{
Some((
ColumnReference::new(SimpleIdentifier::new("timestamp").into()),
cr,
))
} else {
None
}
}
fn get_window_frame(
ctx: Rc<ExpressionContextAll<'static>>,
sort: &(Vec<OrderByExpression>, Vec<Type>),
pipeline: &HamelinPipeline,
previous: &PendingQuery,
) -> Result<WindowFrame, TranslationErrors> {
let expression = HamelinExpression::new(
ctx.clone(),
pipeline
.context
.default_expression_translation_context(&previous.env),
);
let expression_translation = expression.translate()?;
let res = match expression_translation.typ {
t if t == INTERVAL || t == CALENDAR_INTERVAL => match expression_translation.sql {
SQLExpression::UnaryOperatorApply(uao) if uao.operator == Operator::Minus => {
check_column_types_match_sort_type(sort, TIMESTAMP, ctx.clone())?;
WindowFrame {
frame_type: FrameType::RANGE,
preceding: FrameBoundary::BoundaryExpression(uao.operand),
following: FrameBoundary::CurrentRowBoundary,
}
}
e => {
check_column_types_match_sort_type(&sort, TIMESTAMP, ctx.clone())?;
WindowFrame {
frame_type: FrameType::RANGE,
preceding: FrameBoundary::CurrentRowBoundary,
following: FrameBoundary::BoundaryExpression(Box::new(e)),
}
}
},
t if t == ROWS => match expression_translation.sql {
SQLExpression::UnaryOperatorApply(uao) if uao.operator == Operator::Minus => {
WindowFrame {
frame_type: FrameType::ROWS,
preceding: FrameBoundary::BoundaryExpression(uao.operand),
following: FrameBoundary::CurrentRowBoundary,
}
}
e => WindowFrame {
frame_type: FrameType::ROWS,
preceding: FrameBoundary::CurrentRowBoundary,
following: FrameBoundary::BoundaryExpression(Box::new(e)),
},
},
Type::Range(r) if *r.of == ROWS => {
let range = expression_translation
.sql
.unwrap_cast()
.expression
.unwrap_row_literal()
.values;
let pair = (range[0].clone(), range[1].clone());
match pair {
(
SQLExpression::Leaf(Leaf::NullLiteral(_)),
SQLExpression::Leaf(Leaf::NullLiteral(_)),
) => WindowFrame {
frame_type: FrameType::ROWS,
preceding: FrameBoundary::Unbounded,
following: FrameBoundary::Unbounded,
},
(_, SQLExpression::UnaryOperatorApply(uao)) if uao.operator == Operator::Minus => {
return TranslationError::msg(
ctx.as_ref(),
"End of the range must be positive.",
)
.single_result();
}
(SQLExpression::Leaf(Leaf::NullLiteral(_)), end) => WindowFrame {
frame_type: FrameType::ROWS,
preceding: FrameBoundary::Unbounded,
following: FrameBoundary::BoundaryExpression(Box::new(end)),
},
(
SQLExpression::UnaryOperatorApply(uao),
SQLExpression::Leaf(Leaf::NullLiteral(_)),
) if uao.operator == Operator::Minus => WindowFrame {
frame_type: FrameType::ROWS,
preceding: FrameBoundary::BoundaryExpression(uao.operand),
following: FrameBoundary::Unbounded,
},
(SQLExpression::UnaryOperatorApply(uao), end)
if uao.operator == Operator::Minus =>
{
WindowFrame {
frame_type: FrameType::ROWS,
preceding: FrameBoundary::BoundaryExpression(uao.operand),
following: FrameBoundary::BoundaryExpression(Box::new(end)),
}
}
_ => {
return TranslationError::msg(
ctx.as_ref(),
"Beginning of the range must be negative.",
)
.single_result();
}
}
}
Type::Range(r)
if *r.of == INTERVAL
|| *r.of == CALENDAR_INTERVAL
|| NumericMatcher::default().matches(&*r.of) =>
{
let range = expression_translation
.sql
.unwrap_cast()
.expression
.unwrap_row_literal()
.values;
let pair = (range[0].clone(), range[1].clone());
let range_type = match *r.of {
INTERVAL | CALENDAR_INTERVAL => TIMESTAMP,
t => t,
};
check_column_types_match_sort_type(sort, range_type, ctx.clone())?;
match pair {
(
SQLExpression::Leaf(Leaf::NullLiteral(_)),
SQLExpression::Leaf(Leaf::NullLiteral(_)),
) => WindowFrame {
frame_type: FrameType::RANGE,
preceding: FrameBoundary::Unbounded,
following: FrameBoundary::Unbounded,
},
(_, SQLExpression::UnaryOperatorApply(uao)) if uao.operator == Operator::Minus => {
return TranslationError::msg(
ctx.as_ref(),
"End of the range must be positive.",
)
.single_result();
}
(SQLExpression::Leaf(Leaf::NullLiteral(_)), end) => WindowFrame {
frame_type: FrameType::RANGE,
preceding: FrameBoundary::Unbounded,
following: FrameBoundary::BoundaryExpression(Box::new(end)),
},
(
SQLExpression::UnaryOperatorApply(uao),
SQLExpression::Leaf(Leaf::NullLiteral(_)),
) if uao.operator == Operator::Minus => WindowFrame {
frame_type: FrameType::RANGE,
preceding: FrameBoundary::BoundaryExpression(uao.operand),
following: FrameBoundary::Unbounded,
},
(SQLExpression::UnaryOperatorApply(uao), end)
if uao.operator == Operator::Minus =>
{
WindowFrame {
frame_type: FrameType::RANGE,
preceding: FrameBoundary::BoundaryExpression(uao.operand),
following: FrameBoundary::BoundaryExpression(Box::new(end)),
}
}
_ => {
return TranslationError::msg(
ctx.as_ref(),
"Beginning of the range must be negative.",
)
.single_result();
}
}
}
_ => {
return TranslationError::msg(
ctx.as_ref(),
"Bad window frame. It did not match any supported pattern.",
)
.single_result();
}
};
Ok(res)
}
fn check_column_types_match_sort_type(
expr: &(Vec<OrderByExpression>, Vec<Type>),
typ: Type,
ctx: Rc<ExpressionContextAll<'static>>,
) -> Result<bool, TranslationErrors> {
let (sorts, types) = expr;
if sorts.len() == 1 {
sorts.first().map(|_| {
let expr_type = &types[0];
if expr_type == &typ || NumericMatcher::default().matches(expr_type) && NumericMatcher::default().matches(&typ) {
Ok(true)
} else {
TranslationError::msg(
ctx.as_ref(),
&format!("Window frame requires the SORT expression type to match the RANGE expression types: {} does not match {}", expr_type, typ),
).single_result()
}
}).unwrap_or(
TranslationError::msg(
ctx.as_ref(),
"Window frame requires a SORT expression.",
)
.single_result()
)
} else {
TranslationError::msg(
ctx.as_ref(),
"Window frame allows only a single column for SORT.",
)
.single_result()
}
}