use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
func::defs::Transform,
tree::{
ast::{
clause::{SortExpression, SortOrder},
command::Command,
expression::Expression,
identifier::SimpleIdentifier,
node::Span,
},
builder::{
agg_command, array, call, cast, drop_command, eq, explode_command, field_ref, is_null,
let_command, null, pipeline, subtract, ExpressionBuilder, IntoExpressionBuilder,
},
typed_ast::{
command::TypedCommand,
context::StatementTranslationContext,
environment::TypeEnvironment,
expression::{TypedApply, TypedExpression, TypedExpressionKind},
pipeline::TypedPipeline,
},
},
types::Type,
};
use crate::unique::UniqueNameGenerator;
struct TransformNameGenerators {
row_id: UniqueNameGenerator,
indices: UniqueNameGenerator,
elem: UniqueNameGenerator,
idx: UniqueNameGenerator,
body: UniqueNameGenerator,
result: UniqueNameGenerator,
is_empty: UniqueNameGenerator,
}
impl TransformNameGenerators {
fn new() -> Self {
Self {
row_id: UniqueNameGenerator::new("__row_id"),
indices: UniqueNameGenerator::new("__indices"),
elem: UniqueNameGenerator::new("__elem"),
idx: UniqueNameGenerator::new("__idx"),
body: UniqueNameGenerator::new("__body"),
result: UniqueNameGenerator::new("__result"),
is_empty: UniqueNameGenerator::new("__is_empty"),
}
}
}
pub fn lower_transform(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
let mut current = pipeline;
let mut deferred_drops: Vec<Arc<Command>> = Vec::new();
while pipeline_has_transform(¤t)? {
let new_ast = transform_pipeline(¤t, ctx, &mut deferred_drops)?;
current = Arc::new(TypedPipeline::from_ast_with_context(Arc::new(new_ast), ctx));
}
if !deferred_drops.is_empty() {
use hamelin_lib::tree::ast::command::CommandClass;
let mut new_ast = (*current.ast).clone();
let insert_pos = new_ast
.commands
.iter()
.position(|c| c.kind.command_class() == CommandClass::Dml)
.unwrap_or(new_ast.commands.len());
for (i, drop) in deferred_drops.into_iter().enumerate() {
new_ast.commands.insert(insert_pos + i, drop);
}
current = Arc::new(TypedPipeline::from_ast_with_context(Arc::new(new_ast), ctx));
}
Ok(current)
}
fn pipeline_has_transform(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
Ok(valid.commands.iter().any(command_has_transform))
}
fn command_has_transform(cmd: &Arc<TypedCommand>) -> bool {
cmd.find_expression(&mut |expr| is_transform_call(expr))
.is_some()
}
fn is_transform_call(expr: &TypedExpression) -> bool {
if let TypedExpressionKind::Apply(apply) = &expr.kind {
return apply.function_def.type_id() == std::any::TypeId::of::<Transform>();
}
false
}
fn transform_pipeline(
in_pipeline: &TypedPipeline,
ctx: &mut StatementTranslationContext,
deferred_drops: &mut Vec<Arc<Command>>,
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Arc<TranslationError>> {
let valid = in_pipeline.valid_ref()?;
let mut builder = pipeline().at(in_pipeline.ast.span.clone());
let mut name_gens = TransformNameGenerators::new();
for cmd in valid.commands.iter() {
let transformed_cmds = transform_command(cmd, &mut name_gens, ctx, deferred_drops)?;
for c in transformed_cmds {
builder = builder.command(c);
}
}
Ok(builder.build())
}
fn transform_command(
cmd: &Arc<TypedCommand>,
name_gens: &mut TransformNameGenerators,
ctx: &mut StatementTranslationContext,
deferred_drops: &mut Vec<Arc<Command>>,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let mut transforms = collect_transforms(cmd);
if transforms.is_empty() {
return Ok(vec![cmd.ast.clone()]);
}
transforms.reverse();
let mut result: Vec<Arc<Command>> = Vec::new();
let mut replacements: Vec<(
Arc<Expression>,
SimpleIdentifier,
SimpleIdentifier,
Arc<Expression>,
Arc<Type>,
)> = Vec::new();
let mut prior_result_columns: Vec<SimpleIdentifier> = Vec::new();
for transform_expr in transforms {
let TypedExpressionKind::Apply(apply) = &transform_expr.kind else {
continue;
};
let row_id_name = name_gens.row_id.next(&cmd.input_schema);
let indices_name = name_gens.indices.next(&cmd.input_schema);
let elem_name = name_gens.elem.next(&cmd.input_schema);
let idx_name = name_gens.idx.next(&cmd.input_schema);
let body_name = name_gens.body.next(&cmd.input_schema);
let result_name = name_gens.result.next(&cmd.input_schema);
let is_empty_name = name_gens.is_empty.next(&cmd.input_schema);
let array_expr = apply.parameter_binding.get_by_name("array").map_err(|e| {
ctx.error("transform() missing 'array' parameter")
.at(&*transform_expr.ast)
.with_source_boxed(e.into())
.emit()
})?;
let lambda_expr = apply.parameter_binding.get_by_name("lambda").map_err(|e| {
ctx.error("transform() missing 'lambda' parameter")
.at(&*transform_expr.ast)
.with_source_boxed(e.into())
.emit()
})?;
let TypedExpressionKind::Lambda(lambda) = &lambda_expr.kind else {
return Err(ctx
.error("transform() second argument must be a lambda")
.at(&*transform_expr.ast)
.emit());
};
let lambda_param = lambda.parameters[0].name.clone();
let array_ast = array_expr.ast.as_ref().clone();
let let_row_id = let_command()
.named_field(row_id_name.clone(), call("uuid"))
.at(cmd.ast.span)
.build();
result.push(Arc::new(let_row_id));
let len_expr = call("len").arg(AstExpressionWrapper(array_expr.ast.clone()));
let indices_expr = call("sequence").arg(0).arg(subtract(len_expr, 1)).arg(1);
let let_indices = let_command()
.named_field(indices_name.clone(), indices_expr)
.at(cmd.ast.span)
.build();
result.push(Arc::new(let_indices));
let let_is_empty = let_command()
.named_field(
is_empty_name.clone(),
call("coalesce")
.arg(eq(call("len").arg(field_ref(indices_name.clone())), 0))
.arg(true),
)
.at(cmd.ast.span)
.build();
result.push(Arc::new(let_is_empty));
let sentinel = cast(
array().element(null()),
array_expr.resolved_type.as_ref().clone(),
);
let padded_array = call("if")
.arg(field_ref(is_empty_name.clone()))
.arg(sentinel)
.arg(AstExpressionWrapper(Arc::new(array_ast)));
let padded_indices = call("if")
.arg(field_ref(is_empty_name.clone()))
.arg(array().element(0))
.arg(field_ref(indices_name.clone()));
let explode = explode_command()
.named_field(elem_name.clone(), padded_array)
.named_field(idx_name.clone(), padded_indices)
.at(cmd.ast.span)
.build();
result.push(Arc::new(explode));
let body_with_substitution =
substitute_lambda_param(&lambda.body, &lambda_param, &elem_name);
let let_body = let_command()
.named_field(
body_name.clone(),
AstExpressionWrapper(body_with_substitution),
)
.at(cmd.ast.span)
.build();
result.push(Arc::new(let_body));
let agg = build_agg_command(
&result_name,
&body_name,
&row_id_name,
&idx_name,
&is_empty_name,
&cmd.input_schema,
&[
row_id_name.clone(),
indices_name.clone(),
elem_name.clone(),
idx_name.clone(),
body_name.clone(),
is_empty_name.clone(),
],
&prior_result_columns,
cmd.ast.span,
);
result.push(Arc::new(agg));
let drop = drop_command()
.field(row_id_name)
.field(indices_name)
.field(elem_name)
.field(idx_name)
.field(body_name)
.at(cmd.ast.span)
.build();
result.push(Arc::new(drop));
prior_result_columns.push(result_name.clone());
prior_result_columns.push(is_empty_name.clone());
replacements.push((
transform_expr.ast.clone(),
result_name,
is_empty_name,
array_expr.ast.clone(),
transform_expr.resolved_type.clone(),
));
}
let transformed_cmd = replace_transforms_in_command(cmd, &replacements);
result.push(transformed_cmd);
if !replacements.is_empty() {
let mut drop_results = drop_command().at(cmd.ast.span);
for (_, result_name, is_empty_name, _, _) in &replacements {
drop_results = drop_results.field(result_name.clone());
drop_results = drop_results.field(is_empty_name.clone());
}
deferred_drops.push(Arc::new(drop_results.build()));
}
Ok(result)
}
fn collect_transforms(cmd: &Arc<TypedCommand>) -> Vec<Arc<TypedExpression>> {
let mut transforms = Vec::new();
let mut visited: Vec<Arc<Expression>> = Vec::new();
loop {
let found = cmd.find_expression(&mut |expr| {
is_transform_call(expr)
&& !visited.iter().any(|v| Arc::ptr_eq(v, &expr.ast))
&& !has_nested_transform(expr)
});
match found {
Some(expr) => {
visited.push(expr.ast.clone());
transforms.push(Arc::new(expr.clone()));
}
None => break,
}
}
transforms
}
fn has_nested_transform(expr: &TypedExpression) -> bool {
let TypedExpressionKind::Apply(apply) = &expr.kind else {
return false;
};
apply
.parameter_binding
.iter()
.any(|arg| arg.find(&mut |e| is_transform_call(e)).is_some())
}
fn substitute_lambda_param(
body: &TypedExpression,
param_name: &SimpleIdentifier,
replacement_name: &SimpleIdentifier,
) -> Arc<Expression> {
struct SubstituteAlgebra<'a> {
param_name: &'a SimpleIdentifier,
replacement_name: &'a SimpleIdentifier,
}
impl hamelin_lib::tree::typed_ast::expression::MapExpressionAlgebra for SubstituteAlgebra<'_> {
fn field_reference(
&mut self,
node: &hamelin_lib::tree::typed_ast::expression::TypedFieldReference,
expr: &TypedExpression,
) -> Arc<Expression> {
let matches = node
.field_name
.valid_ref()
.map(|name| name == self.param_name)
.unwrap_or(false);
if matches {
Arc::new(field_ref(self.replacement_name.clone()).build())
} else {
expr.ast.clone()
}
}
}
let mut alg = SubstituteAlgebra {
param_name,
replacement_name,
};
body.cata(&mut alg)
}
fn build_agg_command(
result_name: &SimpleIdentifier,
body_name: &SimpleIdentifier,
row_id_name: &SimpleIdentifier,
idx_name: &SimpleIdentifier,
is_empty_name: &SimpleIdentifier,
input_schema: &TypeEnvironment,
temp_columns: &[SimpleIdentifier],
prior_result_columns: &[SimpleIdentifier],
span: Span,
) -> Command {
let mut agg = agg_command()
.named_aggregate(
result_name.clone(),
call("array_agg").arg(field_ref(body_name.clone())),
)
.at(span);
agg = agg.named_aggregate(
is_empty_name.clone(),
call("any_value").arg(field_ref(is_empty_name.clone())),
);
agg = agg.group_by(row_id_name.clone());
for (field_name, _) in input_schema.as_struct().iter() {
let field_name_ast = field_name.clone();
if temp_columns.iter().any(|t| t.as_str() == field_name.name()) {
continue;
}
agg = agg.named_aggregate(
field_name_ast.clone(),
call("any_value").arg(field_ref(field_name_ast)),
);
}
for col in prior_result_columns {
agg = agg.named_aggregate(col.clone(), call("any_value").arg(field_ref(col.clone())));
}
agg = agg.sort_expr(SortExpression {
span: Span::NONE,
expression: Arc::new(field_ref(idx_name.clone()).build()),
order: Some(SortOrder::Asc),
});
agg.build()
}
fn replace_transforms_in_command(
cmd: &Arc<TypedCommand>,
replacements: &[(
Arc<Expression>,
SimpleIdentifier,
SimpleIdentifier,
Arc<Expression>,
Arc<Type>,
)],
) -> Arc<Command> {
struct ReplaceTransformsAlgebra<'a> {
replacements: &'a [(
Arc<Expression>,
SimpleIdentifier,
SimpleIdentifier,
Arc<Expression>,
Arc<Type>,
)],
}
impl hamelin_lib::tree::typed_ast::expression::MapExpressionAlgebra
for ReplaceTransformsAlgebra<'_>
{
fn apply(
&mut self,
node: &TypedApply,
expr: &TypedExpression,
children: hamelin_lib::func::def::ParameterBinding<Arc<Expression>>,
) -> Arc<Expression> {
if node.function_def.type_id() == std::any::TypeId::of::<Transform>() {
for (transform_ast, result_name, is_empty_name, array_ast, result_type) in
self.replacements
{
if Arc::ptr_eq(&expr.ast, transform_ast) {
let empty_result = cast(array(), result_type.as_ref().clone());
let inner_if = call("if")
.arg(field_ref(is_empty_name.clone()))
.arg(empty_result)
.arg(field_ref(result_name.clone()));
let if_expr = call("if")
.arg(is_null(AstExpressionWrapper(array_ast.clone())))
.arg(null())
.arg(inner_if)
.build();
return Arc::new(if_expr);
}
}
}
node.replace_children_ast(expr, children)
}
}
let mut alg = ReplaceTransformsAlgebra { replacements };
cmd.cata_expressions(&mut alg)
}
#[derive(Debug)]
struct AstExpressionWrapper(Arc<Expression>);
impl ExpressionBuilder for AstExpressionWrapper {
fn build(&self) -> Expression {
self.0.as_ref().clone()
}
}
impl IntoExpressionBuilder for AstExpressionWrapper {
fn into_expression_builder(self) -> Box<dyn ExpressionBuilder> {
Box::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::type_check;
use hamelin_lib::{
tree::{
ast::pipeline::Pipeline,
builder::{
agg_command, array, call, cast, drop_command, eq, explode_command, field_ref,
is_null, lambda1, let_command, multiply, null, pipeline, select_command,
sort_command,
},
},
types::{array::Array, struct_type::Struct, INT},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::no_transform_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::simple_transform_lowered(
pipeline()
.command(select_command().named_field("x", 1).build())
.command(let_command()
.named_field("doubled", call("transform")
.arg(call("sequence").arg(1).arg(3).arg(1))
.arg(lambda1("elem").body(multiply(field_ref("elem"), 2))))
.build())
.build(),
pipeline()
.command(select_command().named_field("x", 1).build())
// LET __row_id_0 = uuid()
.command(let_command()
.named_field("__row_id_0", call("uuid"))
.build())
// LET __indices_0 = sequence(0, len(arr) - 1, 1)
.command(let_command()
.named_field("__indices_0", call("sequence")
.arg(0)
.arg(subtract(call("len").arg(call("sequence").arg(1).arg(3).arg(1)), 1))
.arg(1))
.build())
// LET __is_empty_0 = coalesce(len(__indices_0) = 0, true)
.command(let_command()
.named_field(
"__is_empty_0",
call("coalesce")
.arg(eq(call("len").arg(field_ref("__indices_0")), 0))
.arg(true),
)
.build())
// EXPLODE __elem_0 = if(__is_empty_0, [null] AS array(int), arr), __idx_0 = if(__is_empty_0, [0], __indices_0)
.command(explode_command()
.named_field("__elem_0", call("if")
.arg(field_ref("__is_empty_0"))
.arg(cast(array().element(null()), Array::new(INT).into()))
.arg(call("sequence").arg(1).arg(3).arg(1)))
.named_field("__idx_0", call("if")
.arg(field_ref("__is_empty_0"))
.arg(array().element(0))
.arg(field_ref("__indices_0")))
.build())
// LET __body_0 = __elem_0 * 2
.command(let_command()
.named_field("__body_0", multiply(field_ref("__elem_0"), 2))
.build())
// AGG __result_0 = array_agg(__body_0), __is_empty_0 = any_value(__is_empty_0), x = any_value(x)
// BY __row_id_0 SORT __idx_0
.command(agg_command()
.named_aggregate("__result_0", call("array_agg").arg(field_ref("__body_0")))
.named_aggregate("__is_empty_0", call("any_value").arg(field_ref("__is_empty_0")))
.named_aggregate("x", call("any_value").arg(field_ref("x")))
.group_by("__row_id_0")
.sort(sort_command().asc(field_ref("__idx_0")))
.build())
// DROP __row_id_0, __indices_0, __elem_0, __idx_0, __body_0
.command(drop_command()
.field("__row_id_0")
.field("__indices_0")
.field("__elem_0")
.field("__idx_0")
.field("__body_0")
.build())
// LET doubled = if(arr IS NULL, NULL, if(__is_empty_0, [] AS array(int), __result_0))
// (The arr here is sequence(1,3,1) which is the original array expression)
.command(let_command()
.named_field("doubled", call("if")
.arg(is_null(call("sequence").arg(1).arg(3).arg(1)))
.arg(null())
.arg(call("if")
.arg(field_ref("__is_empty_0"))
.arg(cast(array(), Array::new(INT).into()))
.arg(field_ref("__result_0"))))
.build())
// DROP __result_0, __is_empty_0
.command(drop_command()
.field("__result_0")
.field("__is_empty_0")
.build())
.build(),
Struct::default()
.with_str("doubled", Array::new(INT).into())
.with_str("x", INT)
)]
#[case::transform_with_field_ref_array(
pipeline()
.command(select_command()
.named_field("arr", call("sequence").arg(1).arg(3).arg(1))
.build())
.command(let_command()
.named_field("doubled", call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(multiply(field_ref("x"), 2))))
.build())
.build(),
pipeline()
.command(select_command()
.named_field("arr", call("sequence").arg(1).arg(3).arg(1))
.build())
// LET __row_id_0 = uuid()
.command(let_command()
.named_field("__row_id_0", call("uuid"))
.build())
// LET __indices_0 = sequence(0, len(arr) - 1, 1)
.command(let_command()
.named_field("__indices_0", call("sequence")
.arg(0)
.arg(subtract(call("len").arg(field_ref("arr")), 1))
.arg(1))
.build())
// LET __is_empty_0 = coalesce(len(__indices_0) = 0, true)
.command(let_command()
.named_field(
"__is_empty_0",
call("coalesce")
.arg(eq(call("len").arg(field_ref("__indices_0")), 0))
.arg(true),
)
.build())
// EXPLODE __elem_0 = if(__is_empty_0, [null] AS array(int), arr), __idx_0 = if(__is_empty_0, [0], __indices_0)
.command(explode_command()
.named_field("__elem_0", call("if")
.arg(field_ref("__is_empty_0"))
.arg(cast(array().element(null()), Array::new(INT).into()))
.arg(field_ref("arr")))
.named_field("__idx_0", call("if")
.arg(field_ref("__is_empty_0"))
.arg(array().element(0))
.arg(field_ref("__indices_0")))
.build())
// LET __body_0 = __elem_0 * 2
.command(let_command()
.named_field("__body_0", multiply(field_ref("__elem_0"), 2))
.build())
// AGG __result_0 = array_agg(__body_0), __is_empty_0 = any_value(__is_empty_0),
// arr = any_value(arr) BY __row_id_0 SORT __idx_0
.command(agg_command()
.named_aggregate("__result_0", call("array_agg").arg(field_ref("__body_0")))
.named_aggregate("__is_empty_0", call("any_value").arg(field_ref("__is_empty_0")))
.named_aggregate("arr", call("any_value").arg(field_ref("arr")))
.group_by("__row_id_0")
.sort(sort_command().asc(field_ref("__idx_0")))
.build())
// DROP __row_id_0, __indices_0, __elem_0, __idx_0, __body_0
.command(drop_command()
.field("__row_id_0")
.field("__indices_0")
.field("__elem_0")
.field("__idx_0")
.field("__body_0")
.build())
// LET doubled = if(arr IS NULL, NULL, if(__is_empty_0, [] AS array(int), __result_0))
.command(let_command()
.named_field("doubled", call("if")
.arg(is_null(field_ref("arr")))
.arg(null())
.arg(call("if")
.arg(field_ref("__is_empty_0"))
.arg(cast(array(), Array::new(INT).into()))
.arg(field_ref("__result_0"))))
.build())
// DROP __result_0, __is_empty_0
.command(drop_command()
.field("__result_0")
.field("__is_empty_0")
.build())
.build(),
Struct::default()
.with_str("doubled", Array::new(INT).into())
.with_str("arr", Array::new(INT).into())
)]
fn test_lower_transform(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = type_check(input).output;
let expected_typed = type_check(expected).output;
let mut ctx = StatementTranslationContext::default();
let result = lower_transform(Arc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().as_struct().clone();
assert_eq!(result_schema, expected_output_schema);
}
}