use std::collections::HashMap;
use std::sync::Arc;
use datafusion::common::{Column, ScalarValue, TableReference, UnnestOptions};
use datafusion::datasource::provider_as_source;
use datafusion::logical_expr::expr::{
AggregateFunction, Sort as DFSort, WindowFunction, WindowFunctionDefinition,
WindowFunctionParams,
};
use datafusion::logical_expr::{
ident, Expr, JoinType as DFJoinType, LogicalPlan, LogicalPlanBuilder, SortExpr,
WindowFrame as DFWindowFrame, WindowFrameBound as DFWindowFrameBound,
WindowFrameUnits as DFWindowFrameUnits,
};
use datafusion::prelude::SessionContext;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::clause::SortOrder;
use hamelin_lib::tree::ast::identifier::Identifier;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_translation::{
IRAggCommand, IRAssignment, IRCommand, IRExplodeCommand, IRExpression, IRFromCommand, IRInput,
IRJoinCommand, IRLimitCommand, IRSelectCommand, IRSortCommand, IRSortExpression,
IRWhereCommand, IRWindowCommand, JoinType, RangeBound, RowBound,
WindowFrame as HamelinWindowFrame,
};
use crate::expr::{translate_expr_with_ctx, ExprTranslationContext};
fn translate_ir_expr(
expr: &IRExpression,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
translate_expr_with_ctx(expr.inner(), ctx)
}
pub async fn translate_from_command(
cmd: &IRFromCommand,
ctx: &SessionContext,
ctes: &HashMap<Identifier, Arc<LogicalPlan>>,
_output_schema: &TypeEnvironment,
command: &IRCommand,
_expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
if cmd.inputs.is_empty() {
return Err(Arc::new(TranslationError::msg(
command,
"FROM command has no inputs",
)));
}
let mut plans: Vec<LogicalPlan> = Vec::new();
for input in &cmd.inputs {
let plan = translate_from_input(input, ctx, ctes, command).await?;
plans.push(plan);
}
if plans.len() == 1 {
Ok(plans.remove(0))
} else {
let mut result = plans.remove(0);
for plan in plans {
result = LogicalPlanBuilder::from(result)
.union(plan)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?;
}
Ok(result)
}
}
async fn translate_from_input(
input: &IRInput,
ctx: &SessionContext,
ctes: &HashMap<Identifier, Arc<LogicalPlan>>,
command: &IRCommand,
) -> Result<LogicalPlan, Arc<TranslationError>> {
match input {
IRInput::Table(identifier) => scan_table_or_cte(identifier, ctx, ctes, command).await,
IRInput::With(name, _pipeline) => {
let cte_ident: Identifier = name.clone().into();
if let Some(cte_plan) = ctes.get(&cte_ident) {
Ok(cte_plan.as_ref().clone())
} else {
Err(Arc::new(TranslationError::msg(
command,
&format!("CTE '{}' not found", name.as_str()),
)))
}
}
}
}
async fn scan_table_or_cte(
identifier: &Identifier,
ctx: &SessionContext,
ctes: &HashMap<Identifier, Arc<LogicalPlan>>,
command: &IRCommand,
) -> Result<LogicalPlan, Arc<TranslationError>> {
if let Some(cte_plan) = ctes.get(identifier) {
let table_name = identifier
.segments()
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(".");
let schema = cte_plan.schema();
let projections: Vec<Expr> = schema
.fields()
.iter()
.map(|f| {
Expr::Column(Column::new(
Some(TableReference::bare(table_name.as_str())),
f.name(),
))
})
.collect();
return LogicalPlanBuilder::from(cte_plan.as_ref().clone())
.project(projections)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)));
}
let segments = identifier.segments();
let table_ref = match segments {
[name] => TableReference::bare(name.as_str()),
[schema, name] => TableReference::partial(schema.as_str(), name.as_str()),
[catalog, schema, name] => {
TableReference::full(catalog.as_str(), schema.as_str(), name.as_str())
}
_ => {
return Err(Arc::new(TranslationError::msg(
command,
&format!(
"Invalid table identifier '{}': expected 1-3 parts, got {}",
identifier,
segments.len()
),
)));
}
};
let provider = ctx
.table_provider(table_ref)
.await
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?;
let table_source = provider_as_source(provider);
let scan_name = segments
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(".");
LogicalPlanBuilder::scan(scan_name, table_source, None)
.and_then(|b| b.build())
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
pub async fn translate_join_command(
cmd: &IRJoinCommand,
input: LogicalPlan,
ctx: &SessionContext,
ctes: &HashMap<Identifier, Arc<LogicalPlan>>,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
let right_ident: Identifier = cmd.right.clone().into();
let right_plan = scan_table_or_cte(&right_ident, ctx, ctes, command).await?;
let join_type = match cmd.join_type {
JoinType::Inner => DFJoinType::Inner,
JoinType::Left => DFJoinType::Left,
};
let condition = translate_ir_expr(&cmd.condition, expr_ctx)?;
LogicalPlanBuilder::from(input)
.join_on(right_plan, join_type, vec![condition])
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
pub fn translate_where_command(
cmd: &IRWhereCommand,
input: LogicalPlan,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
let predicate = translate_ir_expr(&cmd.predicate, expr_ctx)?;
LogicalPlanBuilder::from(input)
.filter(predicate)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
pub fn translate_select_command(
cmd: &IRSelectCommand,
input: LogicalPlan,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
let projections: Result<Vec<Expr>, Arc<TranslationError>> = cmd
.assignments
.iter()
.map(|assignment| {
let expr = translate_ir_expr(&assignment.expression, expr_ctx)?;
Ok(expr.alias(assignment.identifier.as_str()))
})
.collect();
LogicalPlanBuilder::from(input)
.project(projections?)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
pub fn translate_limit_command(
cmd: &IRLimitCommand,
input: LogicalPlan,
command: &IRCommand,
_expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
LogicalPlanBuilder::from(input)
.limit(0, Some(cmd.count))
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
pub fn translate_sort_command(
cmd: &IRSortCommand,
input: LogicalPlan,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
let sort_exprs: Result<Vec<SortExpr>, Arc<TranslationError>> = cmd
.sort_by
.iter()
.map(|sort_expr| translate_sort_expression(sort_expr, expr_ctx))
.collect();
LogicalPlanBuilder::from(input)
.sort(sort_exprs?)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
fn translate_sort_expression(
sort_expr: &IRSortExpression,
expr_ctx: &ExprTranslationContext,
) -> Result<SortExpr, Arc<TranslationError>> {
let expr = translate_ir_expr(&sort_expr.expression, expr_ctx)?;
let ascending = matches!(sort_expr.order, SortOrder::Asc);
Ok(expr.sort(ascending, false)) }
pub fn translate_explode_command(
cmd: &IRExplodeCommand,
input: LogicalPlan,
output_schema: &TypeEnvironment,
command: &IRCommand,
_expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
let columns: Vec<Column> = cmd
.columns
.iter()
.map(|c| Column::from_name(c.as_str()))
.collect();
let unnest_plan = LogicalPlanBuilder::from(input)
.unnest_columns_with_options(columns, UnnestOptions::default().with_preserve_nulls(false))
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?;
let projections: Vec<Expr> = output_schema
.as_struct()
.iter()
.map(|(name, _)| ident(name.name()))
.collect();
LogicalPlanBuilder::from(unnest_plan)
.project(projections)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
pub fn translate_agg_command(
cmd: &IRAggCommand,
input: LogicalPlan,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
let group_exprs: Result<Vec<Expr>, Arc<TranslationError>> = cmd
.group_by
.iter()
.map(|assignment| translate_assignment(assignment, expr_ctx))
.collect();
let group_exprs = group_exprs?;
let sort_exprs: Result<Vec<SortExpr>, Arc<TranslationError>> = cmd
.sort_by
.iter()
.map(|sort_expr| translate_sort_expression(sort_expr, expr_ctx))
.collect();
let sort_exprs = sort_exprs?;
let aggr_exprs: Result<Vec<Expr>, Arc<TranslationError>> = cmd
.aggregates
.iter()
.map(|assignment| translate_assignment_with_order_by(assignment, &sort_exprs, expr_ctx))
.collect();
let aggr_exprs = aggr_exprs?;
LogicalPlanBuilder::from(input)
.aggregate(group_exprs, aggr_exprs)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
fn translate_assignment_with_order_by(
assignment: &IRAssignment,
order_by: &[SortExpr],
expr_ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let mut expr = translate_ir_expr(&assignment.expression, expr_ctx)?;
if !order_by.is_empty() {
if let Expr::AggregateFunction(ref mut agg) = expr {
agg.params.order_by = order_by.to_vec();
}
}
Ok(expr.alias(assignment.identifier.as_str()))
}
fn translate_assignment(
assignment: &IRAssignment,
expr_ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let expr = translate_ir_expr(&assignment.expression, expr_ctx)?;
Ok(expr.alias(assignment.identifier.as_str()))
}
pub fn translate_window_command(
cmd: &IRWindowCommand,
input: LogicalPlan,
output_schema: &TypeEnvironment,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<LogicalPlan, Arc<TranslationError>> {
let partition_exprs: Result<Vec<Expr>, Arc<TranslationError>> = cmd
.partition_by
.iter()
.map(|e| translate_ir_expr(e, expr_ctx))
.collect();
let partition_exprs = partition_exprs?;
let order_by_exprs: Result<Vec<DFSort>, Arc<TranslationError>> = cmd
.sort_by
.iter()
.map(|s| {
let expr = translate_ir_expr(&s.expression, expr_ctx)?;
let asc = matches!(s.order, SortOrder::Asc);
Ok(DFSort::new(expr, asc, false)) })
.collect();
let order_by_exprs = order_by_exprs?;
let df_frame = match &cmd.frame {
Some(frame) => Some(convert_window_frame(frame, command, expr_ctx)?),
None => None,
};
let window_exprs: Result<Vec<Expr>, Arc<TranslationError>> = cmd
.projections
.iter()
.map(|assignment| {
let expr = translate_ir_expr(&assignment.expression, expr_ctx)?;
let windowed = convert_to_window_expr(
expr,
partition_exprs.clone(),
order_by_exprs.clone(),
df_frame.clone(),
!cmd.sort_by.is_empty(),
command,
)?;
Ok(windowed.alias(assignment.identifier.as_str()))
})
.collect();
let window_exprs = window_exprs?;
let window_plan = LogicalPlanBuilder::from(input)
.window(window_exprs)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?;
let reorder_exprs: Vec<Expr> = output_schema
.as_struct()
.iter()
.map(|(name, _)| ident(name.name()))
.collect();
LogicalPlanBuilder::from(window_plan)
.project(reorder_exprs)
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))?
.build()
.map_err(|e| Arc::new(TranslationError::wrap(command, e)))
}
fn convert_window_frame(
frame: &HamelinWindowFrame,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<DFWindowFrame, Arc<TranslationError>> {
match frame {
HamelinWindowFrame::Rows { start, end } => {
let start_bound = convert_row_bound(start, true);
let end_bound = convert_row_bound(end, false);
Ok(DFWindowFrame::new_bounds(
DFWindowFrameUnits::Rows,
start_bound,
end_bound,
))
}
HamelinWindowFrame::Range { start, end } => {
let start_bound = convert_range_bound(start, true, command, expr_ctx)?;
let end_bound = convert_range_bound(end, false, command, expr_ctx)?;
Ok(DFWindowFrame::new_bounds(
DFWindowFrameUnits::Range,
start_bound,
end_bound,
))
}
}
}
fn convert_row_bound(bound: &RowBound, is_start: bool) -> DFWindowFrameBound {
match bound {
RowBound::Unbounded => {
if is_start {
DFWindowFrameBound::Preceding(ScalarValue::Null)
} else {
DFWindowFrameBound::Following(ScalarValue::Null)
}
}
RowBound::CurrentRow => DFWindowFrameBound::CurrentRow,
RowBound::Preceding(n) => DFWindowFrameBound::Preceding(ScalarValue::UInt64(Some(*n))),
RowBound::Following(n) => DFWindowFrameBound::Following(ScalarValue::UInt64(Some(*n))),
}
}
fn convert_range_bound(
bound: &RangeBound,
is_start: bool,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<DFWindowFrameBound, Arc<TranslationError>> {
match bound {
RangeBound::Unbounded => {
if is_start {
Ok(DFWindowFrameBound::Preceding(ScalarValue::Null))
} else {
Ok(DFWindowFrameBound::Following(ScalarValue::Null))
}
}
RangeBound::CurrentRow => Ok(DFWindowFrameBound::CurrentRow),
RangeBound::Preceding(expr) => {
let scalar = ir_expr_to_scalar(expr, command, expr_ctx)?;
Ok(DFWindowFrameBound::Preceding(scalar))
}
RangeBound::Following(expr) => {
let scalar = ir_expr_to_scalar(expr, command, expr_ctx)?;
Ok(DFWindowFrameBound::Following(scalar))
}
}
}
fn ir_expr_to_scalar(
expr: &IRExpression,
command: &IRCommand,
expr_ctx: &ExprTranslationContext,
) -> Result<ScalarValue, Arc<TranslationError>> {
let df_expr = translate_ir_expr(expr, expr_ctx)?;
match df_expr {
Expr::Literal(scalar, _) => Ok(scalar),
other => Err(Arc::new(TranslationError::msg(
command,
&format!("Expected literal for window frame bound, got: {other}"),
))),
}
}
fn convert_to_window_expr(
expr: Expr,
partition_by: Vec<Expr>,
order_by: Vec<DFSort>,
window_frame: Option<DFWindowFrame>,
has_order_by: bool,
command: &IRCommand,
) -> Result<Expr, Arc<TranslationError>> {
let default_frame = if has_order_by {
DFWindowFrame::new(Some(true)) } else {
DFWindowFrame::new(None) };
match expr {
Expr::AggregateFunction(agg) => {
let AggregateFunction {
func,
params: agg_params,
} = agg;
let window_func = WindowFunction {
fun: WindowFunctionDefinition::AggregateUDF(func),
params: WindowFunctionParams {
args: agg_params.args,
partition_by,
order_by,
window_frame: window_frame.unwrap_or(default_frame),
filter: agg_params.filter,
null_treatment: agg_params.null_treatment,
distinct: agg_params.distinct,
},
};
Ok(Expr::WindowFunction(Box::new(window_func)))
}
Expr::WindowFunction(wf) => {
let WindowFunction {
fun,
params: wf_params,
} = *wf;
let updated = WindowFunction {
fun,
params: WindowFunctionParams {
args: wf_params.args,
partition_by,
order_by,
window_frame: window_frame.unwrap_or(default_frame),
filter: wf_params.filter,
null_treatment: wf_params.null_treatment,
distinct: wf_params.distinct,
},
};
Ok(Expr::WindowFunction(Box::new(updated)))
}
other => Err(Arc::new(TranslationError::msg(
command,
&format!(
"WINDOW expression must be an aggregate or window function, got: {:?}",
other
),
))),
}
}