use hamelin_lib::err::{TranslationError, TranslationErrors};
use hamelin_lib::func::def::{FunctionTranslationContext, SpecialPosition};
use hamelin_lib::sql::expression::identifier::{
Identifier as SQLIdentifier, SimpleIdentifier as SQLSimpleIdentifier,
};
use hamelin_lib::sql::expression::literal::{ColumnReference, IntegerLiteral};
use hamelin_lib::sql::expression::{Direction, OrderByExpression};
use hamelin_lib::sql::query::projection::{Binding, ColumnProjection, Projection};
use hamelin_lib::sql::query::set::SetOperation;
use hamelin_lib::sql::query::window::{
FrameBoundary, FrameType, WindowFrame as SQLWindowFrame, WindowReference, WindowSpecification,
};
use hamelin_lib::sql::query::{
Join, JoinClause, JoinType as SQLJoinType, SQLQuery, SQLQueryExpression, SubQuery,
TableExpression, TableReference,
};
use hamelin_lib::tree::ast::clause::SortOrder;
use hamelin_lib::tree::ast::identifier::SimpleIdentifier;
use hamelin_translation::{
IRAggCommand, IRCommand, IRCommandKind, IRExplodeCommand, IRFromCommand, IRInput,
IRJoinCommand, IRLimitCommand, IRSelectCommand, IRSortCommand, IRSortExpression,
IRWhereCommand, IRWindowCommand, JoinType as IRJoinType,
};
use hamelin_translation::{RangeBound, RowBound, WindowFrame};
use crate::context::TranslationContext;
use crate::expr::translate_expression;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
pub type CommandTranslationResult = Result<SQLQuery, TranslationErrors>;
pub type QueryInput = Option<SQLQuery>;
fn build_projections_from_schema(schema: &TypeEnvironment) -> Vec<Projection> {
schema
.as_struct()
.iter()
.map(|(name, _typ)| {
let sql_name: SQLSimpleIdentifier = name.clone().into();
let col_ref = ColumnReference::new(sql_name.clone().into());
Binding::new(sql_name, col_ref.into()).into()
})
.collect()
}
pub trait CommandTranslator {
fn translate_command(
&self,
ctx: &mut TranslationContext,
cmd: &IRCommand,
query: QueryInput,
) -> CommandTranslationResult {
match &cmd.kind {
IRCommandKind::From(from_cmd) => self.translate_from(ctx, cmd, from_cmd, query),
IRCommandKind::Where(where_cmd) => self.translate_where(ctx, cmd, where_cmd, query),
IRCommandKind::Select(select_cmd) => self.translate_select(ctx, cmd, select_cmd, query),
IRCommandKind::Sort(sort_cmd) => self.translate_sort(ctx, cmd, sort_cmd, query),
IRCommandKind::Limit(limit_cmd) => self.translate_limit(ctx, cmd, limit_cmd, query),
IRCommandKind::Agg(agg_cmd) => self.translate_agg(ctx, cmd, agg_cmd, query),
IRCommandKind::Window(window_cmd) => self.translate_window(ctx, cmd, window_cmd, query),
IRCommandKind::Join(join_cmd) => self.translate_join(ctx, cmd, join_cmd, query),
IRCommandKind::Explode(explode_cmd) => {
self.translate_explode(ctx, cmd, explode_cmd, query)
}
}
}
fn translate_from(
&self,
_ctx: &mut TranslationContext,
cmd: &IRCommand,
from_cmd: &IRFromCommand,
_query: QueryInput,
) -> CommandTranslationResult {
if from_cmd.inputs.is_empty() {
return Err(TranslationError::msg(cmd, "FROM command has no inputs").single());
}
let projections = build_projections_from_schema(&cmd.output_schema);
if from_cmd.inputs.len() == 1 {
let table_expr = translate_input(&from_cmd.inputs[0]);
Ok(SQLQuery::default().from(table_expr).select(projections))
} else {
let mut queries: Vec<SQLQueryExpression> = from_cmd
.inputs
.iter()
.map(|input| {
let table_expr = translate_input(input);
let projs = build_projections_from_schema(&cmd.output_schema);
let query = SQLQuery::default().from(table_expr).select(projs);
query.into()
})
.collect();
let first = queries.remove(0);
let result = queries
.into_iter()
.fold(first, |acc, next| SetOperation::new(acc, next).into());
match result {
SQLQueryExpression::SetOperation(set_op) => {
let subquery = SubQuery::new(set_op.into());
Ok(SQLQuery::default()
.from(subquery.into())
.select(projections))
}
SQLQueryExpression::SQLQuery(q) => Ok(q),
}
}
}
fn translate_where(
&self,
_ctx: &mut TranslationContext,
cmd: &IRCommand,
where_cmd: &IRWhereCommand,
query: QueryInput,
) -> CommandTranslationResult {
let predicate_sql = translate_expression(_ctx, &where_cmd.predicate)?;
match query {
Some(q) => {
let subquery = SubQuery::new(q.into());
let projections: Vec<Projection> = cmd
.output_schema
.as_struct()
.keys()
.map(|name| {
let sql_name: SQLSimpleIdentifier = name.clone().into();
ColumnProjection::new(sql_name.into()).into()
})
.collect();
Ok(SQLQuery::default()
.from(subquery.into())
.select(projections)
.where_(predicate_sql))
}
None => Ok(SQLQuery::default().where_(predicate_sql)),
}
}
fn translate_select(
&self,
ctx: &mut TranslationContext,
_cmd: &IRCommand,
select_cmd: &IRSelectCommand,
query: QueryInput,
) -> CommandTranslationResult {
let mut projections = Vec::with_capacity(select_cmd.assignments.len());
for assignment in &select_cmd.assignments {
let expr_sql = translate_expression(ctx, &assignment.expression)?;
let name: SQLSimpleIdentifier = assignment.identifier.clone().into();
projections.push(Binding::new(name, expr_sql).into());
}
match query {
None => Ok(SQLQuery::default().select(projections)),
Some(q) => {
let order_by = q.order_by.clone();
let mut inner = q;
inner.order_by = None;
let subquery = SubQuery::new(inner.into());
let mut result = SQLQuery::default()
.from(subquery.into())
.select(projections);
if let Some(order_by) = order_by {
result = result.order_by(order_by);
}
Ok(result)
}
}
}
fn translate_sort(
&self,
ctx: &mut TranslationContext,
_cmd: &IRCommand,
sort_cmd: &IRSortCommand,
query: QueryInput,
) -> CommandTranslationResult {
let order_bys = translate_sort_expressions(ctx, &sort_cmd.sort_by)?;
match query {
Some(q) => Ok(q.order_by(order_bys)),
None => Ok(SQLQuery::default().order_by(order_bys)),
}
}
fn translate_limit(
&self,
_ctx: &mut TranslationContext,
cmd: &IRCommand,
limit_cmd: &IRLimitCommand,
query: QueryInput,
) -> CommandTranslationResult {
let count_sql = IntegerLiteral::from_int(limit_cmd.count as i64).into();
match query {
Some(q) if q.limit.is_some() => {
let subquery = SubQuery::new(q.into());
let projections = build_projections_from_schema(&cmd.output_schema);
Ok(SQLQuery::default()
.from(subquery.into())
.select(projections)
.limit(count_sql))
}
Some(q) => Ok(q.limit(count_sql)),
None => Ok(SQLQuery::default().limit(count_sql)),
}
}
fn translate_agg(
&self,
ctx: &mut TranslationContext,
_cmd: &IRCommand,
agg_cmd: &IRAggCommand,
query: QueryInput,
) -> CommandTranslationResult {
let order_by = translate_sort_expressions(ctx, &agg_cmd.sort_by)?;
let saved_fctx = std::mem::replace(
&mut ctx.fctx,
FunctionTranslationContext::default()
.with_special_allowed(SpecialPosition::Agg)
.with_order_by(order_by),
);
let mut group_by_exprs = Vec::with_capacity(agg_cmd.group_by.len());
let mut projections = Vec::with_capacity(agg_cmd.group_by.len() + agg_cmd.aggregates.len());
for assignment in &agg_cmd.group_by {
let expr_sql = translate_expression(ctx, &assignment.expression)?;
group_by_exprs.push(expr_sql.clone());
let name: SQLSimpleIdentifier = assignment.identifier.clone().into();
projections.push(Binding::new(name, expr_sql).into());
}
for assignment in &agg_cmd.aggregates {
let expr_sql = translate_expression(ctx, &assignment.expression)?;
let name: SQLSimpleIdentifier = assignment.identifier.clone().into();
projections.push(Binding::new(name, expr_sql).into());
}
ctx.fctx = saved_fctx;
let mut result = SQLQuery::default().select(projections);
if !group_by_exprs.is_empty() {
result = result.group_by(group_by_exprs);
}
match query {
Some(q) => Ok(result.from(SubQuery::new(q.into()).into())),
None => Ok(result),
}
}
fn translate_window(
&self,
ctx: &mut TranslationContext,
cmd: &IRCommand,
window_cmd: &IRWindowCommand,
query: QueryInput,
) -> CommandTranslationResult {
let mut partition_by = Vec::with_capacity(window_cmd.partition_by.len());
for expr in &window_cmd.partition_by {
partition_by.push(translate_expression(ctx, expr)?);
}
let order_by = translate_sort_expressions(ctx, &window_cmd.sort_by)?;
let frame = window_cmd
.frame
.as_ref()
.map(|f| translate_window_frame(ctx, f))
.transpose()?;
let window_spec = WindowSpecification {
partition_by,
order_by: order_by.clone(),
frame,
};
let saved_fctx = std::mem::replace(
&mut ctx.fctx,
FunctionTranslationContext::default()
.with_special_allowed(SpecialPosition::Window)
.with_special_allowed(SpecialPosition::Agg)
.with_window(WindowReference::WindowSpecification(window_spec)),
);
let mut projections = Vec::with_capacity(window_cmd.projections.len());
for assignment in &window_cmd.projections {
let expr_sql = translate_expression(ctx, &assignment.expression)?;
let name: SQLSimpleIdentifier = assignment.identifier.clone().into();
projections.push(Binding::new(name, expr_sql).into());
}
ctx.fctx = saved_fctx;
let window_projection_names: std::collections::HashSet<_> = window_cmd
.projections
.iter()
.map(|a| a.identifier.clone())
.collect();
for (col_name, _col_type) in cmd.output_schema.as_struct().iter() {
let simple_id: SimpleIdentifier = col_name.clone().into();
if !window_projection_names.contains(&simple_id) {
let sql_name: SQLSimpleIdentifier = simple_id.into();
let col_ref = ColumnReference::new(sql_name.clone().into());
projections.push(Binding::new(sql_name, col_ref.into()).into());
}
}
let result = SQLQuery::default().select(projections);
match query {
Some(q) => Ok(result.from(SubQuery::new(q.into()).into())),
None => Ok(result),
}
}
fn translate_join(
&self,
ctx: &mut TranslationContext,
cmd: &IRCommand,
join_cmd: &IRJoinCommand,
query: QueryInput,
) -> CommandTranslationResult {
let query = query
.ok_or_else(|| TranslationError::msg(cmd, "JOIN requires a left side").single())?;
let left_alias = SQLSimpleIdentifier::new("_left");
let left_subquery = SubQuery::new(query.into()).alias(left_alias.clone().into());
let right_alias: SQLSimpleIdentifier = join_cmd.right.clone().into();
let right_table =
TableReference::new(right_alias.clone().into()).alias(right_alias.clone().into());
let condition_sql = translate_expression(ctx, &join_cmd.condition)?;
let sql_join_type = match join_cmd.join_type {
IRJoinType::Inner => SQLJoinType::INNER,
IRJoinType::Left => SQLJoinType::LEFT,
};
let join = Join::new(left_subquery.into()).with_clause(JoinClause {
table: right_table.into(),
join_type: sql_join_type,
condition: Some(condition_sql),
});
let projections = build_projections_from_schema(&cmd.output_schema);
Ok(SQLQuery::default().from(join.into()).select(projections))
}
fn translate_explode(
&self,
_ctx: &mut TranslationContext,
cmd: &IRCommand,
_explode_cmd: &IRExplodeCommand,
_query: QueryInput,
) -> CommandTranslationResult;
}
fn translate_input(input: &IRInput) -> TableExpression {
match input {
IRInput::Table(id) => {
let sql_id: SQLIdentifier = id.clone().into();
TableReference::new(sql_id).into()
}
IRInput::With(name, _pipeline) => {
let sql_id: SQLIdentifier = name.clone().into();
TableReference::new(sql_id).into()
}
}
}
fn translate_sort_expressions(
ctx: &TranslationContext,
sort_exprs: &[IRSortExpression],
) -> Result<Vec<OrderByExpression>, TranslationErrors> {
let mut result = Vec::with_capacity(sort_exprs.len());
for sort_expr in sort_exprs {
let expr_sql = translate_expression(ctx, &sort_expr.expression)?;
let direction = match sort_expr.order {
SortOrder::Asc => Direction::ASC,
SortOrder::Desc => Direction::DESC,
};
result.push(OrderByExpression::new(expr_sql, direction));
}
Ok(result)
}
fn translate_window_frame(
ctx: &TranslationContext,
frame: &WindowFrame,
) -> Result<SQLWindowFrame, TranslationErrors> {
match frame {
WindowFrame::Rows { start, end } => Ok(SQLWindowFrame {
frame_type: FrameType::ROWS,
preceding: translate_row_bound(start),
following: translate_row_bound(end),
}),
WindowFrame::Range { start, end } => Ok(SQLWindowFrame {
frame_type: FrameType::RANGE,
preceding: translate_range_bound(ctx, start)?,
following: translate_range_bound(ctx, end)?,
}),
}
}
fn translate_row_bound(bound: &RowBound) -> FrameBoundary {
match bound {
RowBound::Unbounded => FrameBoundary::Unbounded,
RowBound::CurrentRow => FrameBoundary::CurrentRowBoundary,
RowBound::Preceding(n) | RowBound::Following(n) => {
FrameBoundary::BoundaryExpression(Box::new(IntegerLiteral::from_int(*n as i64).into()))
}
}
}
fn translate_range_bound(
ctx: &TranslationContext,
bound: &RangeBound,
) -> Result<FrameBoundary, TranslationErrors> {
match bound {
RangeBound::Unbounded => Ok(FrameBoundary::Unbounded),
RangeBound::CurrentRow => Ok(FrameBoundary::CurrentRowBoundary),
RangeBound::Preceding(expr) | RangeBound::Following(expr) => {
let sql = translate_expression(ctx, expr)?;
Ok(FrameBoundary::BoundaryExpression(Box::new(sql)))
}
}
}