use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::expression::Expression,
builder::{call, field_ref, lambda1, pipeline, ExpressionBuilder},
typed_ast::{
command::TypedCommand,
context::StatementTranslationContext,
environment::TypeEnvironment,
expression::{MapExpressionAlgebra, TypedBroadcastApply, TypedExpression},
pipeline::TypedPipeline,
},
},
};
use crate::unique::UniqueNameGenerator;
pub fn lower_broadcast_apply(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline_has_broadcast(&pipeline)? {
return Ok(pipeline);
}
let new_ast = transform_pipeline(&pipeline, ctx)?;
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn pipeline_has_broadcast(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
Ok(valid.commands.iter().any(command_has_broadcast))
}
fn command_has_broadcast(cmd: &Arc<TypedCommand>) -> bool {
use hamelin_lib::tree::typed_ast::expression::TypedExpressionKind;
cmd.find_expression(&mut |expr| matches!(&expr.kind, TypedExpressionKind::BroadcastApply(_)))
.is_some()
}
fn transform_pipeline(
in_pipeline: &TypedPipeline,
ctx: &mut StatementTranslationContext,
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Arc<TranslationError>> {
let valid = in_pipeline.valid_ref()?;
let mut builder = pipeline().at(in_pipeline.ast.span.clone());
let mut name_gen = UniqueNameGenerator::new("__broadcast");
for cmd in valid.commands.iter() {
let transformed_cmd = transform_command(cmd, &mut name_gen, ctx)?;
builder = builder.command(transformed_cmd);
}
Ok(builder.build())
}
fn transform_command(
cmd: &Arc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<hamelin_lib::tree::ast::command::Command>, Arc<TranslationError>> {
let mut alg = BroadcastLoweringAlgebra {
name_gen,
schema: &cmd.input_schema,
error: None,
};
let result = cmd.cata_expressions(&mut alg);
if let Some(err) = alg.error {
return Err(ctx.error(err).at(&*cmd.ast).emit());
}
Ok(result)
}
struct BroadcastLoweringAlgebra<'a> {
name_gen: &'a mut UniqueNameGenerator,
schema: &'a TypeEnvironment,
error: Option<String>,
}
impl MapExpressionAlgebra for BroadcastLoweringAlgebra<'_> {
fn broadcast_apply(
&mut self,
node: &TypedBroadcastApply,
expr: &TypedExpression,
children: hamelin_lib::func::def::ParameterBinding<Arc<Expression>>,
) -> Arc<Expression> {
let param_name = self.name_gen.next(self.schema);
let array_arg = match children.get_by_index(node.broadcast_position) {
Ok(arg) => arg.clone(),
Err(e) => {
self.error = Some(format!(
"broadcast lowering: invalid broadcast_position {}: {e}",
node.broadcast_position
));
return node.replace_children_ast(expr, children);
}
};
let param_ref = Arc::new(field_ref(param_name.clone()).build());
let body_children = match children.replace_by_index(node.broadcast_position, param_ref) {
Ok(c) => c,
Err(e) => {
self.error = Some(format!(
"broadcast lowering: failed to replace at position {}: {e}",
node.broadcast_position
));
return Arc::new(expr.ast.as_ref().clone());
}
};
let body_ast = node.replace_children_ast(expr, body_children);
let lambda_ast = lambda1(param_name.as_str()).body(body_ast).build();
let transform_call = call("transform").arg(array_arg).arg(lambda_ast).build();
Arc::new(transform_call)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
provider::EnvironmentProvider,
tree::{
ast::{
identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
pipeline::Pipeline,
},
builder::{
add, array, call, field_ref, gt, lambda1, multiply, pipeline, query,
QueryBuilderWithMain,
},
},
type_check_with_provider,
types::{array::Array, struct_type::Struct, Type, INT, STRING},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let events: Identifier = AstSimpleIdentifier::new("events").into();
if name == &events {
Ok(Struct::default()
.with_str("x", INT)
.with_str("s", STRING)
.with_str("arr_col", Array::new(INT).into()))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
fn typed_pipeline(builder: QueryBuilderWithMain) -> Arc<TypedPipeline> {
let result = type_check_with_provider(builder.build(), Arc::new(MockProvider));
let statement = result
.into_result()
.expect("test fixture should type-check cleanly");
statement.pipeline.clone()
}
fn make_ctx() -> StatementTranslationContext {
StatementTranslationContext::new(
Arc::new(hamelin_lib::func::registry::FunctionRegistry::default()),
Arc::new(MockProvider),
)
}
fn base_schema() -> Struct {
Struct::default()
.with_str("x", INT)
.with_str("s", STRING)
.with_str("arr_col", Array::new(INT).into())
}
fn schema_with_set(field: &str, typ: Type) -> Struct {
Struct::default()
.with_str(field, typ)
.with_str("x", INT)
.with_str("s", STRING)
.with_str("arr_col", Array::new(INT).into())
}
#[rstest]
#[case::no_broadcast_passthrough(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("y", add(field_ref("x"), 1)))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("y", add(field_ref("x"), 1)))
.build(),
schema_with_set("y", INT)
)]
#[case::left_broadcast(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", multiply(array().element(1).element(2).element(3), 10)))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", call("transform")
.arg(array().element(1).element(2).element(3))
.arg(lambda1("__broadcast_0").body(multiply(field_ref("__broadcast_0"), 10)))))
.build(),
schema_with_set("arr", Array::new(INT).into())
)]
#[case::right_broadcast(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", multiply(10, array().element(1).element(2).element(3))))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", call("transform")
.arg(array().element(1).element(2).element(3))
.arg(lambda1("__broadcast_0").body(multiply(10, field_ref("__broadcast_0"))))))
.build(),
schema_with_set("arr", Array::new(INT).into())
)]
#[case::function_broadcast(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", call("upper").arg(array().element("a").element("b"))))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", call("transform")
.arg(array().element("a").element("b"))
.arg(lambda1("__broadcast_0").body(call("upper").arg(field_ref("__broadcast_0"))))))
.build(),
schema_with_set("arr", Array::new(STRING).into())
)]
#[case::two_broadcasts_same_command(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l
.named_field("a", add(array().element(1).element(2), 5))
.named_field("b", multiply(array().element(10).element(20), 2)))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l
.named_field("a", call("transform")
.arg(array().element(1).element(2))
.arg(lambda1("__broadcast_0").body(add(field_ref("__broadcast_0"), 5))))
.named_field("b", call("transform")
.arg(array().element(10).element(20))
.arg(lambda1("__broadcast_1").body(multiply(field_ref("__broadcast_1"), 2)))))
.build(),
// Two fields prepended: a first, then b
Struct::default()
.with_str("a", Array::new(INT).into())
.with_str("b", Array::new(INT).into())
.with_str("x", INT)
.with_str("s", STRING)
.with_str("arr_col", Array::new(INT).into())
)]
#[case::broadcast_over_column(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("result", multiply(field_ref("arr_col"), 10)))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("result", call("transform")
.arg(field_ref("arr_col"))
.arg(lambda1("__broadcast_0").body(multiply(field_ref("__broadcast_0"), 10)))))
.build(),
schema_with_set("result", Array::new(INT).into())
)]
#[case::broadcast_nested_in_coalesce(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", call("coalesce")
.arg(call("upper").arg(array().element("a").element("b")))
.arg(array().element("default"))))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("arr", call("coalesce")
.arg(call("transform")
.arg(array().element("a").element("b"))
.arg(lambda1("__broadcast_0").body(call("upper").arg(field_ref("__broadcast_0")))))
.arg(array().element("default"))))
.build(),
schema_with_set("arr", Array::new(STRING).into())
)]
#[case::broadcast_in_where(
pipeline()
.from(|f| f.table_reference("events"))
.where_cmd(gt(call("len").arg(multiply(field_ref("arr_col"), 10)), 0))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.where_cmd(gt(
call("len").arg(
call("transform")
.arg(field_ref("arr_col"))
.arg(lambda1("__broadcast_0").body(multiply(field_ref("__broadcast_0"), 10)))),
0))
.build(),
// WHERE doesn't add fields
base_schema()
)]
#[case::broadcasts_across_commands(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("a", multiply(array().element(1).element(2), 5)))
.set_cmd(|l| l.named_field("b", add(array().element(10).element(20), 3)))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("a", call("transform")
.arg(array().element(1).element(2))
.arg(lambda1("__broadcast_0").body(multiply(field_ref("__broadcast_0"), 5)))))
.set_cmd(|l| l.named_field("b", call("transform")
.arg(array().element(10).element(20))
.arg(lambda1("__broadcast_1").body(add(field_ref("__broadcast_1"), 3)))))
.build(),
// First SET adds "a", second SET adds "b" at beginning
Struct::default()
.with_str("b", Array::new(INT).into())
.with_str("a", Array::new(INT).into())
.with_str("x", INT)
.with_str("s", STRING)
.with_str("arr_col", Array::new(INT).into())
)]
#[case::name_collision_avoided(
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("__broadcast_0", 42)) // Create column with generated name
.set_cmd(|l| l.named_field("result", multiply(field_ref("arr_col"), 10)))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.set_cmd(|l| l.named_field("__broadcast_0", 42))
.set_cmd(|l| l.named_field("result", call("transform")
.arg(field_ref("arr_col"))
// Generator skips __broadcast_0 (exists) and uses __broadcast_1
.arg(lambda1("__broadcast_1").body(multiply(field_ref("__broadcast_1"), 10)))))
.build(),
// __broadcast_0 added first, then result
Struct::default()
.with_str("result", Array::new(INT).into())
.with_str("__broadcast_0", INT)
.with_str("x", INT)
.with_str("s", STRING)
.with_str("arr_col", Array::new(INT).into())
)]
fn test_lower_broadcast_apply(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) -> Result<(), Arc<TranslationError>> {
let input_typed = typed_pipeline(query().main(input));
let expected_typed = typed_pipeline(query().main(expected));
let mut ctx = make_ctx();
let result = lower_broadcast_apply(input_typed, &mut ctx)?;
assert_eq!(result.ast, expected_typed.ast);
assert_eq!(
result.environment().as_struct().clone(),
expected_output_schema
);
Ok(())
}
}