use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
func::defs::{Coalesce, ToJsonString, Transform},
tree::{
ast::{
clause::{SortExpression, SortOrder},
command::Command,
expression::Expression,
identifier::SimpleIdentifier,
node::Span,
},
builder::{
agg_command, array, call, cast, drop_command, eq, explode_command, field, field_ref,
is_null, null, pipeline, set_command, subtract, ExpressionBuilder,
},
typed_ast::{
command::TypedCommand,
context::StatementTranslationContext,
environment::TypeEnvironment,
expression::{
FieldAccess, MapExpressionAlgebra, TypedApply, TypedExpression, TypedExpressionKind,
},
pipeline::TypedPipeline,
},
},
types::{array::Array, Type},
};
use crate::unique::UniqueNameGenerator;
enum TransformReplacement {
Vectorized(Arc<Expression>),
ExplodeAggregate {
result_name: SimpleIdentifier,
is_empty_name: SimpleIdentifier,
array_ast: Arc<Expression>,
result_type: Arc<Type>,
},
}
struct TransformNameGenerators {
row_id: UniqueNameGenerator,
indices: UniqueNameGenerator,
elem: UniqueNameGenerator,
idx: UniqueNameGenerator,
body: UniqueNameGenerator,
result: UniqueNameGenerator,
is_empty: UniqueNameGenerator,
unfuse_param: UniqueNameGenerator,
}
impl TransformNameGenerators {
fn new() -> Self {
Self {
row_id: UniqueNameGenerator::new("__row_id"),
indices: UniqueNameGenerator::new("__indices"),
elem: UniqueNameGenerator::new("__elem"),
idx: UniqueNameGenerator::new("__idx"),
body: UniqueNameGenerator::new("__body"),
result: UniqueNameGenerator::new("__result"),
is_empty: UniqueNameGenerator::new("__is_empty"),
unfuse_param: UniqueNameGenerator::new("__unfuse"),
}
}
}
pub fn lower_transform(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
let mut current = pipeline;
let mut deferred_drops: Vec<Arc<Command>> = Vec::new();
while pipeline_has_transform(¤t)? {
let new_ast = transform_pipeline(¤t, ctx, &mut deferred_drops)?;
current = Arc::new(TypedPipeline::from_ast_with_context(Arc::new(new_ast), ctx));
}
if !deferred_drops.is_empty() {
use hamelin_lib::tree::ast::command::CommandClass;
let mut new_ast = (*current.ast).clone();
let insert_pos = new_ast
.commands
.iter()
.position(|c| c.kind.command_class() == CommandClass::Dml)
.unwrap_or(new_ast.commands.len());
for (i, drop) in deferred_drops.into_iter().enumerate() {
new_ast.commands.insert(insert_pos + i, drop);
}
current = Arc::new(TypedPipeline::from_ast_with_context(Arc::new(new_ast), ctx));
}
Ok(current)
}
fn pipeline_has_transform(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
Ok(valid.commands.iter().any(command_has_transform))
}
fn command_has_transform(cmd: &Arc<TypedCommand>) -> bool {
cmd.find_expression(&mut |expr| is_transform_call(expr))
.is_some()
}
fn is_transform_call(expr: &TypedExpression) -> bool {
if let TypedExpressionKind::Apply(apply) = &expr.kind {
return apply.function_def.type_id() == std::any::TypeId::of::<Transform>();
}
false
}
fn transform_pipeline(
in_pipeline: &TypedPipeline,
ctx: &mut StatementTranslationContext,
deferred_drops: &mut Vec<Arc<Command>>,
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Arc<TranslationError>> {
let valid = in_pipeline.valid_ref()?;
let mut builder = pipeline().at(in_pipeline.ast.span.clone());
let mut name_gens = TransformNameGenerators::new();
for cmd in valid.commands.iter() {
let transformed_cmds = transform_command(cmd, &mut name_gens, ctx, deferred_drops)?;
for c in transformed_cmds {
builder = builder.command(c);
}
}
Ok(builder.build())
}
fn transform_command(
cmd: &Arc<TypedCommand>,
name_gens: &mut TransformNameGenerators,
ctx: &mut StatementTranslationContext,
deferred_drops: &mut Vec<Arc<Command>>,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let mut transforms = collect_transforms(cmd);
if transforms.is_empty() {
return Ok(vec![cmd.ast.clone()]);
}
transforms.reverse();
let mut result: Vec<Arc<Command>> = Vec::new();
let mut replacements: Vec<(Arc<Expression>, TransformReplacement)> = Vec::new();
let mut prior_result_columns: Vec<SimpleIdentifier> = Vec::new();
for transform_expr in transforms {
let TypedExpressionKind::Apply(apply) = &transform_expr.kind else {
continue;
};
let array_expr = apply.parameter_binding.get_by_name("array").map_err(|e| {
ctx.error("transform() missing 'array' parameter")
.at(&*transform_expr.ast)
.with_source_boxed(e.into())
.emit()
})?;
let lambda_expr = apply.parameter_binding.get_by_name("lambda").map_err(|e| {
ctx.error("transform() missing 'lambda' parameter")
.at(&*transform_expr.ast)
.with_source_boxed(e.into())
.emit()
})?;
let TypedExpressionKind::Lambda(lambda) = &lambda_expr.kind else {
return Err(ctx
.error("transform() second argument must be a lambda")
.at(&*transform_expr.ast)
.emit());
};
let lambda_param = lambda.parameters[0].name.clone();
if let Some(replacement_ast) = attempt_vectorize(
array_expr,
&lambda.body,
&lambda_param,
&cmd.input_schema,
name_gens,
ctx,
) {
replacements.push((
transform_expr.ast.clone(),
TransformReplacement::Vectorized(replacement_ast),
));
continue;
}
let row_id_name = name_gens.row_id.next(&cmd.input_schema);
let indices_name = name_gens.indices.next(&cmd.input_schema);
let elem_name = name_gens.elem.next(&cmd.input_schema);
let idx_name = name_gens.idx.next(&cmd.input_schema);
let body_name = name_gens.body.next(&cmd.input_schema);
let result_name = name_gens.result.next(&cmd.input_schema);
let is_empty_name = name_gens.is_empty.next(&cmd.input_schema);
let array_ast = array_expr.ast.as_ref().clone();
let set_row_id = set_command()
.named_field(row_id_name.clone(), call("uuid"))
.at(cmd.ast.span)
.build();
result.push(Arc::new(set_row_id));
let len_expr = call("len").arg(array_expr.ast.clone());
let indices_expr = call("sequence").arg(0).arg(subtract(len_expr, 1)).arg(1);
let set_indices = set_command()
.named_field(indices_name.clone(), indices_expr)
.at(cmd.ast.span)
.build();
result.push(Arc::new(set_indices));
let set_is_empty = set_command()
.named_field(
is_empty_name.clone(),
call("coalesce")
.arg(eq(call("len").arg(field_ref(indices_name.clone())), 0))
.arg(true),
)
.at(cmd.ast.span)
.build();
result.push(Arc::new(set_is_empty));
let sentinel = cast(
array().element(null()),
array_expr.resolved_type.as_ref().clone(),
);
let padded_array = call("if")
.arg(field_ref(is_empty_name.clone()))
.arg(sentinel)
.arg(Arc::new(array_ast));
let padded_indices = call("if")
.arg(field_ref(is_empty_name.clone()))
.arg(array().element(0))
.arg(field_ref(indices_name.clone()));
let explode = explode_command()
.named_field(elem_name.clone(), padded_array)
.named_field(idx_name.clone(), padded_indices)
.at(cmd.ast.span)
.build();
result.push(Arc::new(explode));
let body_with_substitution =
substitute_lambda_param(&lambda.body, &lambda_param, &elem_name);
let set_body = set_command()
.named_field(body_name.clone(), body_with_substitution)
.at(cmd.ast.span)
.build();
result.push(Arc::new(set_body));
let agg = build_agg_command(
&result_name,
&body_name,
&row_id_name,
&idx_name,
&is_empty_name,
&cmd.input_schema,
&[
row_id_name.clone(),
indices_name.clone(),
elem_name.clone(),
idx_name.clone(),
body_name.clone(),
is_empty_name.clone(),
],
&prior_result_columns,
cmd.ast.span,
);
result.push(Arc::new(agg));
let drop = drop_command()
.field(row_id_name)
.field(indices_name)
.field(elem_name)
.field(idx_name)
.field(body_name)
.at(cmd.ast.span)
.build();
result.push(Arc::new(drop));
prior_result_columns.push(result_name.clone());
prior_result_columns.push(is_empty_name.clone());
replacements.push((
transform_expr.ast.clone(),
TransformReplacement::ExplodeAggregate {
result_name,
is_empty_name,
array_ast: array_expr.ast.clone(),
result_type: transform_expr.resolved_type.clone(),
},
));
}
let transformed_cmd = replace_transforms_in_command(cmd, &replacements);
result.push(transformed_cmd);
let drop_fields: Vec<_> = replacements
.iter()
.filter_map(|(_, r)| match r {
TransformReplacement::ExplodeAggregate {
result_name,
is_empty_name,
..
} => Some((result_name.clone(), is_empty_name.clone())),
TransformReplacement::Vectorized(_) => None,
})
.collect();
if !drop_fields.is_empty() {
let mut drop_results = drop_command().at(cmd.ast.span);
for (result_name, is_empty_name) in &drop_fields {
drop_results = drop_results.field(result_name.clone());
drop_results = drop_results.field(is_empty_name.clone());
}
deferred_drops.push(Arc::new(drop_results.build()));
}
Ok(result)
}
fn collect_transforms(cmd: &Arc<TypedCommand>) -> Vec<Arc<TypedExpression>> {
let mut transforms = Vec::new();
let mut visited: Vec<Arc<Expression>> = Vec::new();
loop {
let found = cmd.find_expression(&mut |expr| {
is_transform_call(expr)
&& !visited.iter().any(|v| Arc::ptr_eq(v, &expr.ast))
&& !has_nested_transform(expr)
});
match found {
Some(expr) => {
visited.push(expr.ast.clone());
transforms.push(Arc::new(expr.clone()));
}
None => break,
}
}
transforms
}
fn has_nested_transform(expr: &TypedExpression) -> bool {
let TypedExpressionKind::Apply(apply) = &expr.kind else {
return false;
};
apply
.parameter_binding
.iter()
.any(|arg| arg.find(&mut |e| is_transform_call(e)).is_some())
}
fn substitute_lambda_param(
body: &TypedExpression,
param_name: &SimpleIdentifier,
replacement_name: &SimpleIdentifier,
) -> Arc<Expression> {
struct SubstituteAlgebra<'a> {
param_name: &'a SimpleIdentifier,
replacement_name: &'a SimpleIdentifier,
}
impl MapExpressionAlgebra for SubstituteAlgebra<'_> {
fn field_reference(
&mut self,
node: &hamelin_lib::tree::typed_ast::expression::TypedFieldReference,
expr: &TypedExpression,
) -> Arc<Expression> {
let matches = node
.field_name
.valid_ref()
.map(|name| name == self.param_name)
.unwrap_or(false);
if matches {
Arc::new(field_ref(self.replacement_name.clone()).build())
} else {
expr.ast.clone()
}
}
}
let mut alg = SubstituteAlgebra {
param_name,
replacement_name,
};
body.cata(&mut alg)
}
fn build_agg_command(
result_name: &SimpleIdentifier,
body_name: &SimpleIdentifier,
row_id_name: &SimpleIdentifier,
idx_name: &SimpleIdentifier,
is_empty_name: &SimpleIdentifier,
input_schema: &TypeEnvironment,
temp_columns: &[SimpleIdentifier],
prior_result_columns: &[SimpleIdentifier],
span: Span,
) -> Command {
let mut agg = agg_command()
.named_aggregate(
result_name.clone(),
call("array_agg").arg(field_ref(body_name.clone())),
)
.at(span);
agg = agg.named_aggregate(
is_empty_name.clone(),
call("any_value").arg(field_ref(is_empty_name.clone())),
);
agg = agg.group_by(row_id_name.clone());
for (field_name, _) in input_schema.as_struct().iter() {
let field_name_ast = field_name.clone();
if temp_columns.iter().any(|t| t.as_str() == field_name.name()) {
continue;
}
agg = agg.named_aggregate(
field_name_ast.clone(),
call("any_value").arg(field_ref(field_name_ast)),
);
}
for col in prior_result_columns {
agg = agg.named_aggregate(col.clone(), call("any_value").arg(field_ref(col.clone())));
}
agg = agg.sort_expr(SortExpression {
span: Span::NONE,
expression: Arc::new(field_ref(idx_name.clone()).build()),
order: Some(SortOrder::Asc),
});
agg.build()
}
fn replace_transforms_in_command(
cmd: &Arc<TypedCommand>,
replacements: &[(Arc<Expression>, TransformReplacement)],
) -> Arc<Command> {
struct ReplaceTransformsAlgebra<'a> {
replacements: &'a [(Arc<Expression>, TransformReplacement)],
}
impl MapExpressionAlgebra for ReplaceTransformsAlgebra<'_> {
fn apply(
&mut self,
node: &TypedApply,
expr: &TypedExpression,
children: hamelin_lib::func::def::ParameterBinding<Arc<Expression>>,
) -> Arc<Expression> {
if node.function_def.type_id() == std::any::TypeId::of::<Transform>() {
for (transform_ast, replacement) in self.replacements {
if !Arc::ptr_eq(&expr.ast, transform_ast) {
continue;
}
return match replacement {
TransformReplacement::Vectorized(ast) => ast.clone(),
TransformReplacement::ExplodeAggregate {
result_name,
is_empty_name,
array_ast,
result_type,
} => {
let empty_result = cast(array(), result_type.as_ref().clone());
let inner_if = call("if")
.arg(field_ref(is_empty_name.clone()))
.arg(empty_result)
.arg(field_ref(result_name.clone()));
let if_expr = call("if")
.arg(is_null(array_ast.clone()))
.arg(null())
.arg(inner_if)
.build();
Arc::new(if_expr)
}
};
}
}
node.replace_children_ast(expr, children)
}
}
let mut alg = ReplaceTransformsAlgebra { replacements };
cmd.cata_expressions(&mut alg)
}
type FastPathMatcher =
fn(&TypedExpression, &TypedExpression, &SimpleIdentifier) -> Option<Arc<Expression>>;
const TRANSFORM_FAST_PATHS: &[FastPathMatcher] = &[
vectorize_identity,
vectorize_struct_field_access,
vectorize_tuple_element_access,
vectorize_variant_field_access,
vectorize_array_element_cast,
vectorize_to_json_string_variant,
vectorize_coalesce_with_default,
];
fn attempt_vectorize(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
bindings: &Arc<TypeEnvironment>,
name_gens: &mut TransformNameGenerators,
ctx: &mut StatementTranslationContext,
) -> Option<Arc<Expression>> {
if let Some(replacement) = match_fast_path(array_expr, body, lambda_param) {
return Some(replacement);
}
let (inner_sub_body, rebuild_outer) = decompose_link(body, lambda_param)?;
let inner_replacement_ast = attempt_vectorize(
array_expr,
inner_sub_body,
lambda_param,
bindings,
name_gens,
ctx,
)?;
let typed_inner_array = type_check_in_env(&inner_replacement_ast, bindings, ctx);
let Type::Array(arr) = typed_inner_array.resolved_type.as_ref() else {
return None;
};
let element_type = arr.element_type.as_ref().clone();
let outer_param = name_gens.unfuse_param.next(bindings);
let outer_body_ast = rebuild_outer(Arc::new(field_ref(outer_param.clone()).build()));
let outer_bindings =
Arc::new(TypeEnvironment::clone(bindings).with(outer_param.clone().into(), element_type));
let typed_outer_body = type_check_in_env(&outer_body_ast, &outer_bindings, ctx);
match_fast_path(&typed_inner_array, &typed_outer_body, &outer_param)
}
fn match_fast_path(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
for matcher in TRANSFORM_FAST_PATHS {
if let Some(replacement) = matcher(array_expr, body, lambda_param) {
return Some(replacement);
}
}
None
}
fn type_check_in_env(
ast: &Arc<Expression>,
bindings: &Arc<TypeEnvironment>,
ctx: &mut StatementTranslationContext,
) -> TypedExpression {
let mut expr_ctx = ctx.default_expression_context(bindings);
TypedExpression::from_ast_with_context(ast.clone(), &mut expr_ctx)
}
fn decompose_link<'a>(
body: &'a TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<(
&'a TypedExpression,
Box<dyn FnOnce(Arc<Expression>) -> Arc<Expression> + 'a>,
)> {
match &body.kind {
TypedExpressionKind::FieldLookup(fl) => {
if !subtree_references_lambda_param(&fl.value, lambda_param) {
return None;
}
let field_id = match &body.ast.kind {
hamelin_lib::tree::ast::expression::ExpressionKind::FieldLookup(fl_ast) => {
fl_ast.field_identifier.clone()
}
_ => return None,
};
let span = body.ast.span.clone();
let rebuild: Box<dyn FnOnce(Arc<Expression>) -> Arc<Expression>> =
Box::new(move |slot| {
Arc::new(Expression {
span,
kind: hamelin_lib::tree::ast::expression::FieldLookup {
value: slot,
field_identifier: field_id,
}
.into(),
})
});
Some((fl.value.as_ref(), rebuild))
}
TypedExpressionKind::Cast(cast_node) => {
if !subtree_references_lambda_param(&cast_node.value, lambda_param) {
return None;
}
let target_type = cast_node.target_type.clone();
let span = body.ast.span.clone();
let rebuild: Box<dyn FnOnce(Arc<Expression>) -> Arc<Expression>> =
Box::new(move |slot| {
Arc::new(Expression {
span,
kind: hamelin_lib::tree::ast::expression::Cast {
expression: slot,
target_type: Arc::new(target_type),
}
.into(),
})
});
Some((cast_node.value.as_ref(), rebuild))
}
TypedExpressionKind::Apply(apply) => {
let mut x_slot_index: Option<usize> = None;
for (i, arg) in apply.parameter_binding.iter().enumerate() {
if subtree_references_lambda_param(arg, lambda_param) {
if x_slot_index.is_some() {
return None;
}
x_slot_index = Some(i);
}
}
let x_slot_index = x_slot_index?;
let inner_sub_body = apply.parameter_binding.get_by_index(x_slot_index).ok()?;
let function_name = apply.function_def.name().to_string();
let arg_asts: Vec<Arc<Expression>> = apply
.parameter_binding
.iter()
.map(|a| a.ast.clone())
.collect();
let span = body.ast.span.clone();
let rebuild: Box<dyn FnOnce(Arc<Expression>) -> Arc<Expression>> =
Box::new(move |slot| {
let mut positional = arg_asts;
positional[x_slot_index] = slot;
Arc::new(Expression {
span,
kind: hamelin_lib::tree::ast::expression::FunctionCall {
name: SimpleIdentifier::new(function_name).into(),
positional_args: positional,
named_args: Default::default(),
}
.into(),
})
});
Some((inner_sub_body.as_ref(), rebuild))
}
_ => None,
}
}
fn subtree_references_lambda_param(expr: &TypedExpression, param_name: &SimpleIdentifier) -> bool {
expr.find(&mut |e| {
if let TypedExpressionKind::FieldReference(field_ref) = &e.kind {
return field_ref
.field_name
.valid_ref()
.map(|name| name == param_name)
.unwrap_or(false);
}
false
})
.is_some()
}
fn is_lambda_param_ref(expr: &TypedExpression, param_name: &SimpleIdentifier) -> bool {
if let TypedExpressionKind::FieldReference(fr) = &expr.kind {
return fr
.field_name
.valid_ref()
.map(|name| name == param_name)
.unwrap_or(false);
}
false
}
fn vectorize_identity(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
if is_lambda_param_ref(body, lambda_param) {
Some(array_expr.ast.clone())
} else {
None
}
}
fn vectorize_struct_field_access(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
let TypedExpressionKind::FieldLookup(fl) = &body.kind else {
return None;
};
if !is_lambda_param_ref(&fl.value, lambda_param) {
return None;
}
let FieldAccess::StructField(field_id) = &fl.access else {
return None;
};
let field_name = field_id.valid_ref().ok()?;
Some(Arc::new(
field(array_expr.ast.clone(), field_name.as_str()).build(),
))
}
fn vectorize_tuple_element_access(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
let TypedExpressionKind::FieldLookup(fl) = &body.kind else {
return None;
};
if !is_lambda_param_ref(&fl.value, lambda_param) {
return None;
}
let FieldAccess::TupleElement(idx) = &fl.access else {
return None;
};
let field_name = format!("f{}", idx);
Some(Arc::new(
field(array_expr.ast.clone(), field_name.as_str()).build(),
))
}
fn vectorize_variant_field_access(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
let TypedExpressionKind::FieldLookup(fl) = &body.kind else {
return None;
};
if !is_lambda_param_ref(&fl.value, lambda_param) {
return None;
}
let FieldAccess::VariantField(field_id) = &fl.access else {
return None;
};
let field_name = field_id.valid_ref().ok()?;
Some(Arc::new(
field(array_expr.ast.clone(), field_name.as_str()).build(),
))
}
fn vectorize_array_element_cast(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
let TypedExpressionKind::Cast(cast_node) = &body.kind else {
return None;
};
if !is_lambda_param_ref(&cast_node.value, lambda_param) {
return None;
}
let target_array_type = Array::new(cast_node.target_type.clone()).into();
Some(Arc::new(
cast(array_expr.ast.clone(), target_array_type).build(),
))
}
fn vectorize_to_json_string_variant(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
let TypedExpressionKind::Apply(apply) = &body.kind else {
return None;
};
if apply.function_def.type_id() != std::any::TypeId::of::<ToJsonString>() {
return None;
}
if apply.parameter_binding.len() != 1 {
return None;
}
let only_arg = apply.parameter_binding.get_by_index(0).ok()?;
if !is_lambda_param_ref(only_arg, lambda_param) {
return None;
}
Some(Arc::new(
call("array_variant_to_json")
.arg(array_expr.ast.clone())
.build(),
))
}
fn vectorize_coalesce_with_default(
array_expr: &TypedExpression,
body: &TypedExpression,
lambda_param: &SimpleIdentifier,
) -> Option<Arc<Expression>> {
let TypedExpressionKind::Apply(apply) = &body.kind else {
return None;
};
if apply.function_def.type_id() != std::any::TypeId::of::<Coalesce>() {
return None;
}
if apply.parameter_binding.len() != 2 {
return None;
}
let first = apply.parameter_binding.get_by_index(0).ok()?;
let second = apply.parameter_binding.get_by_index(1).ok()?;
if !is_lambda_param_ref(first, lambda_param) {
return None;
}
if subtree_references_lambda_param(second, lambda_param) {
return None;
}
Some(Arc::new(
call("array_coalesce")
.arg(array_expr.ast.clone())
.arg(second.ast.clone())
.build(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::tree::{
ast::pipeline::Pipeline,
builder::{
array, call, cast, field, field_ref, lambda1, pipeline, set_command, struct_literal,
tuple,
},
};
use hamelin_lib::type_check;
use std::sync::Arc;
#[track_caller]
fn assert_vectorized(input: Pipeline) {
let result = run_lower_transform(input);
let count = count_explode_agg(&result);
assert!(
count == 0,
"expected fully vectorized lowering, but slow path emitted {} EXPLODE/AGG commands.\n\
pipeline:\n{:#?}",
count,
result.ast
);
}
#[track_caller]
fn assert_falls_back(input: Pipeline) {
let result = run_lower_transform(input);
let count = count_explode_agg(&result);
assert!(
count >= 2,
"expected fallback to EXPLODE + AGG, got {} EXPLODE/AGG commands.\n\
pipeline:\n{:#?}",
count,
result.ast.commands
);
}
fn count_explode_agg(pipeline: &TypedPipeline) -> usize {
use hamelin_lib::tree::typed_ast::command::TypedCommandKind;
let valid = pipeline.valid_ref().unwrap();
valid
.commands
.iter()
.filter(|cmd| {
matches!(
cmd.kind,
TypedCommandKind::Explode(_) | TypedCommandKind::Agg(_)
)
})
.count()
}
fn run_lower_transform(input: Pipeline) -> Arc<TypedPipeline> {
let input_typed = type_check(input).output;
let mut ctx = StatementTranslationContext::default();
lower_transform(Arc::new(input_typed), &mut ctx).unwrap()
}
#[test]
fn unfuse_array_element_cast() {
use hamelin_lib::types::STRING;
assert_vectorized(
pipeline()
.command(
set_command()
.named_field("arr", call("sequence").arg(1).arg(3).arg(1))
.build(),
)
.command(
set_command()
.named_field(
"casted",
call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(cast(field_ref("x"), STRING))),
)
.build(),
)
.build(),
);
}
#[test]
fn unfuse_two_sibling_transforms_same_array() {
assert_vectorized(
pipeline()
.command(
set_command()
.named_field(
"arr",
array()
.element(struct_literal().field("k", "a").field("v", 1))
.element(struct_literal().field("k", "b").field("v", 2)),
)
.build(),
)
.command(
set_command()
.named_field(
"keys",
call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(field(field_ref("x"), "k"))),
)
.named_field(
"vals",
call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(field(field_ref("x"), "v"))),
)
.build(),
)
.build(),
);
}
#[test]
fn unfuse_tuple_element_access() {
assert_vectorized(
pipeline()
.command(
set_command()
.named_field(
"arr",
array()
.element(tuple().element(1).element("a"))
.element(tuple().element(2).element("b")),
)
.build(),
)
.command(
set_command()
.named_field(
"firsts",
call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(field(field_ref("x"), "f0"))),
)
.build(),
)
.build(),
);
}
#[test]
fn unfuse_variant_field_access() {
assert_vectorized(
pipeline()
.command(
set_command()
.named_field(
"arr",
array()
.element(call("parse_json").arg(r#"{"foo": 1}"#))
.element(call("parse_json").arg(r#"{"foo": 2}"#)),
)
.build(),
)
.command(
set_command()
.named_field(
"foos",
call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(field(field_ref("x"), "foo"))),
)
.build(),
)
.build(),
);
}
#[test]
fn unfuse_to_json_string_variant() {
assert_vectorized(
pipeline()
.command(
set_command()
.named_field(
"arr",
array()
.element(call("parse_json").arg("1"))
.element(call("parse_json").arg("2")),
)
.build(),
)
.command(
set_command()
.named_field(
"jsons",
call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(call("to_json_string").arg(field_ref("x")))),
)
.build(),
)
.build(),
);
}
#[test]
fn unfuse_coalesce_of_typed_cast_on_struct_field() {
use hamelin_lib::types::STRING;
assert_vectorized(
pipeline()
.command(
set_command()
.named_field(
"arr",
array()
.element(struct_literal().field("q", 1))
.element(struct_literal().field("q", 2)),
)
.build(),
)
.command(
set_command()
.named_field(
"result",
call("transform").arg(field_ref("arr")).arg(
lambda1("x").body(
call("coalesce")
.arg(cast(field(field_ref("x"), "q"), STRING))
.arg("0"),
),
),
)
.build(),
)
.build(),
);
}
#[test]
fn fallback_partial_vectorizable_body() {
assert_falls_back(
pipeline()
.command(
set_command()
.named_field("arr", array().element("a").element("b"))
.build(),
)
.command(
set_command()
.named_field(
"uppered",
call("transform")
.arg(field_ref("arr"))
.arg(lambda1("x").body(call("upper").arg(field_ref("x")))),
)
.build(),
)
.build(),
);
}
}