use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::{command::Command, expression::Expression, pipeline::Pipeline},
builder::{array, pipeline, ExpressionBuilder},
typed_ast::{
command::TypedCommand,
context::StatementTranslationContext,
environment::TypeEnvironment,
expression::{
MapExpressionAlgebra, TypedArrayLiteral, TypedExpression, TypedExpressionKind,
},
pipeline::TypedPipeline,
},
},
types::{struct_type::Struct, Type},
};
use super::super::expand_struct::expand_struct_to_type_with_ast;
use crate::unique::UniqueNameGenerator;
pub fn expand_array_literals(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline_needs_expansion(&pipeline)? {
return Ok(pipeline);
}
let new_ast = transform_pipeline(&pipeline)?;
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn pipeline_needs_expansion(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
Ok(valid.commands.iter().any(command_needs_expansion))
}
fn command_needs_expansion(cmd: &Arc<TypedCommand>) -> bool {
cmd.find_expression(&mut |expr| expression_needs_expansion(expr))
.is_some()
}
fn expression_needs_expansion(expr: &TypedExpression) -> bool {
expr.find(&mut |e| {
let TypedExpressionKind::ArrayLiteral(arr) = &e.kind else {
return false;
};
let Type::Array(arr_type) = e.resolved_type.as_ref() else {
return false;
};
let Type::Struct(element_type) = arr_type.element_type.as_ref() else {
return false;
};
arr.elements
.iter()
.any(|elem| matches!(elem.resolved_type.as_ref(), Type::Struct(s) if s != element_type))
})
.is_some()
}
fn transform_pipeline(in_pipeline: &TypedPipeline) -> Result<Pipeline, Arc<TranslationError>> {
let valid = in_pipeline.valid_ref()?;
let mut builder = pipeline().at(in_pipeline.ast.span.clone());
let mut name_gen = UniqueNameGenerator::new("__expand");
for cmd in valid.commands.iter() {
for expanded_cmd in transform_command(cmd, &mut name_gen) {
builder = builder.command(expanded_cmd);
}
}
Ok(builder.build())
}
fn transform_command(
cmd: &Arc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
) -> Vec<Arc<Command>> {
let mut alg = ArrayExpansionAlgebra {
input_schema: cmd.input_schema.clone(),
name_gen,
before_commands: Vec::new(),
after_commands: Vec::new(),
};
let transformed_cmd_ast = cmd.cata_expressions(&mut alg);
alg.before_commands
.into_iter()
.map(Arc::new)
.chain(std::iter::once(transformed_cmd_ast))
.chain(alg.after_commands.into_iter().map(Arc::new))
.collect()
}
struct ArrayExpansionAlgebra<'a> {
input_schema: Arc<TypeEnvironment>,
name_gen: &'a mut UniqueNameGenerator,
before_commands: Vec<Command>,
after_commands: Vec<Command>,
}
impl MapExpressionAlgebra for ArrayExpansionAlgebra<'_> {
fn array_literal(
&mut self,
node: &TypedArrayLiteral,
expr: &TypedExpression,
children: Vec<Arc<Expression>>,
) -> Arc<Expression> {
let Type::Array(arr_type) = expr.resolved_type.as_ref() else {
return Arc::new(Expression {
span: expr.ast.span.clone(),
kind: hamelin_lib::tree::ast::expression::ArrayLiteral { elements: children }
.into(),
});
};
let Type::Struct(element_type) = arr_type.element_type.as_ref() else {
return Arc::new(Expression {
span: expr.ast.span.clone(),
kind: hamelin_lib::tree::ast::expression::ArrayLiteral { elements: children }
.into(),
});
};
let needs_expansion = node.elements.iter().any(
|elem| matches!(elem.resolved_type.as_ref(), Type::Struct(s) if s != element_type),
);
if !needs_expansion {
return Arc::new(Expression {
span: expr.ast.span.clone(),
kind: hamelin_lib::tree::ast::expression::ArrayLiteral { elements: children }
.into(),
});
}
let (expanded_ast, before, after) = expand_array_literal(
node,
&children,
element_type,
&self.input_schema,
self.name_gen,
);
self.before_commands.extend(before);
self.after_commands.extend(after);
expanded_ast
}
}
fn expand_array_literal(
arr: &TypedArrayLiteral,
children: &[Arc<Expression>],
target_struct: &Struct,
input_schema: &Arc<TypeEnvironment>,
name_gen: &mut UniqueNameGenerator,
) -> (Arc<Expression>, Vec<Command>, Vec<Command>) {
let mut array_builder = array();
let mut all_before = Vec::new();
let mut all_after = Vec::new();
for (elem, child_ast) in arr.elements.iter().zip(children.iter()) {
if let Type::Struct(elem_struct) = elem.resolved_type.as_ref() {
if elem_struct != target_struct {
let (expanded_ast, before, after) = expand_struct_to_type_with_ast(
elem,
Some(child_ast),
elem_struct,
target_struct,
name_gen,
input_schema,
);
all_before.extend(before);
all_after.extend(after);
array_builder = array_builder.element(expanded_ast);
} else {
array_builder = array_builder.element(child_ast.clone());
}
} else {
array_builder = array_builder.element(child_ast.clone());
}
}
let new_ast: Expression = array_builder.build();
(Arc::new(new_ast), all_before, all_after)
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
err::Context,
func::registry::FunctionRegistry,
provider::EnvironmentProvider,
tree::{
ast::identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
builder::{
array, cast, field_ref, let_command, pipeline, query, struct_literal,
NullLiteralBuilder, QueryBuilderWithMain,
},
typed_ast::command::TypedCommandKind,
},
type_check_with_provider,
types::{array::Array, struct_type::Struct, Type, INT, STRING},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let events: Identifier = AstSimpleIdentifier::new("events").into();
if name == &events {
Ok(Struct::default()
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into()))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
fn typed_pipeline(builder: QueryBuilderWithMain) -> Arc<TypedPipeline> {
let statement = type_check_with_provider(builder.build(), Arc::new(MockProvider)).output;
statement.pipeline.clone()
}
fn test_error(message: impl Into<String>) -> Arc<TranslationError> {
Arc::new(TranslationError::new(Context::new(0..=0, &message.into())))
}
fn array_element_types(
pipeline: &TypedPipeline,
field_name: &str,
) -> Result<Vec<Type>, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
for cmd in &valid.commands {
let TypedCommandKind::Let(let_cmd) = &cmd.kind else {
continue;
};
for assignment in &let_cmd.projections.assignments {
let Ok(identifier) = assignment.identifier.valid_ref() else {
continue;
};
let Identifier::Simple(simple) = identifier else {
continue;
};
if simple.as_str() != field_name {
continue;
}
let TypedExpressionKind::ArrayLiteral(arr) = &assignment.expression.kind else {
return Err(test_error("expected array literal assignment"));
};
return Ok(arr
.elements
.iter()
.map(|elem| elem.resolved_type.as_ref().clone())
.collect());
}
}
Err(test_error("array literal assignment not found"))
}
#[test]
fn test_debug_expansion() -> Result<(), Arc<TranslationError>> {
let input = pipeline()
.from(|f| f.table_reference("events"))
.command(
let_command()
.named_field(
"arr",
array()
.element(struct_literal().field("a", 1))
.element(struct_literal().field("a", 2).field("b", 3)),
)
.build(),
)
.build();
let input_typed = typed_pipeline(query().main(input));
let element_types = array_element_types(&input_typed, "arr")?;
println!("Element types: {:?}", element_types);
let valid = input_typed.valid_ref()?;
for cmd in &valid.commands {
let TypedCommandKind::Let(let_cmd) = &cmd.kind else {
continue;
};
for assignment in &let_cmd.projections.assignments {
if let TypedExpressionKind::ArrayLiteral(arr) = &assignment.expression.kind {
let Type::Array(arr_type) = assignment.expression.resolved_type.as_ref() else {
continue;
};
println!("Array element type: {:?}", arr_type.element_type);
for (i, elem) in arr.elements.iter().enumerate() {
println!(
"Element {} type: {:?}, needs expansion: {}",
i,
elem.resolved_type,
elem.resolved_type.as_ref() != arr_type.element_type.as_ref()
);
}
}
}
}
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let result = expand_array_literals(input_typed, &mut ctx)?;
println!("Result AST: {:?}", result.ast);
Ok(())
}
#[test]
fn test_array_literal_element_types() -> Result<(), Arc<TranslationError>> {
let input = pipeline()
.from(|f| f.table_reference("events"))
.command(
let_command()
.named_field(
"arr",
array()
.element(struct_literal().field("a", 1))
.element(struct_literal().field("a", 2).field("b", 3)),
)
.build(),
)
.build();
let input_typed = typed_pipeline(query().main(input));
let element_types = array_element_types(&input_typed, "arr")?;
let expected_first = Struct::default().with_str("a", INT).into();
let expected_second = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
assert_eq!(element_types, vec![expected_first, expected_second]);
Ok(())
}
#[rstest]
#[case::no_expansion_needed(
pipeline()
.from(|f| f.table_reference("events"))
.let_cmd(|l| l.named_field("y", 1))
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.let_cmd(|l| l.named_field("y", 1))
.build(),
Struct::default()
.with_str("y", INT)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
#[case::struct_element_missing_field(
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(struct_literal().field("a", 1))
.element(struct_literal().field("a", 2).field("b", 3)),
)
.build())
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(struct_literal()
.field("a", 1)
.field("b", cast(NullLiteralBuilder::new(), INT)))
.element(struct_literal().field("a", 2).field("b", 3)),
)
.build())
.build(),
Struct::default()
.with_str(
"arr",
Array::new(
Struct::default().with_str("a", INT).with_str("b", INT).into(),
)
.into(),
)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
#[case::nested_struct_widening(
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(
struct_literal().field(
"nested",
struct_literal().field("x", 1),
),
)
.element(
struct_literal().field(
"nested",
struct_literal().field("x", 2).field("y", 3),
),
),
)
.build())
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(
struct_literal().field(
"nested",
struct_literal()
.field("x", 1)
.field("y", cast(NullLiteralBuilder::new(), INT)),
),
)
.element(
struct_literal().field(
"nested",
struct_literal().field("x", 2).field("y", 3),
),
),
)
.build())
.build(),
Struct::default()
.with_str(
"arr",
Array::new(
Struct::default()
.with_str(
"nested",
Struct::default().with_str("x", INT).with_str("y", INT).into(),
)
.into(),
)
.into(),
)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
#[case::simple_cast_struct_literal(
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(cast(
struct_literal().field("a", 1).field("b", 2),
Struct::default().with_str("a", INT).with_str("b", INT).into(),
))
.element(
struct_literal()
.field("a", 3)
.field("b", 4)
.field("c", 5),
),
)
.build())
.build(),
// No hoisting because cast(struct_literal()) is simple
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(
// Simple expr is just cast to target type (no hoisting)
cast(
cast(
struct_literal().field("a", 1).field("b", 2),
Struct::default().with_str("a", INT).with_str("b", INT).into(),
),
Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.with_str("c", INT)
.into(),
),
)
.element(
struct_literal()
.field("a", 3)
.field("b", 4)
.field("c", 5),
),
)
.build())
.build(),
Struct::default()
.with_str(
"arr",
Array::new(
Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.with_str("c", INT)
.into(),
)
.into(),
)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
#[case::field_reference_chain_expansion(
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(field_ref("s")) // s: {a: Int} from table
.element(struct_literal().field("a", 1).field("b", 2)),
)
.build())
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(
// Column ref is cast to target type
cast(
field_ref("s"),
Struct::default().with_str("a", INT).with_str("b", INT).into(),
),
)
.element(struct_literal().field("a", 1).field("b", 2)),
)
.build())
.build(),
Struct::default()
.with_str(
"arr",
Array::new(
Struct::default().with_str("a", INT).with_str("b", INT).into(),
)
.into(),
)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
#[case::simple_cast_no_hoisting(
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(cast(
struct_literal().field("a", 1),
Struct::default().with_str("a", INT).into(),
))
.element(struct_literal().field("a", 2).field("b", 3)),
)
.build())
.build(),
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(
// Cast expression is simple enough, just wrap in another cast
cast(
cast(
struct_literal().field("a", 1),
Struct::default().with_str("a", INT).into(),
),
Struct::default().with_str("a", INT).with_str("b", INT).into(),
),
)
.element(struct_literal().field("a", 2).field("b", 3)),
)
.build())
.build(),
Struct::default()
.with_str(
"arr",
Array::new(
Struct::default().with_str("a", INT).with_str("b", INT).into(),
)
.into(),
)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
#[case::explicit_null_casts_preserved(
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(struct_literal().field("left", "apple").field("right", "zebra"))
.element(struct_literal().field("left", cast(NullLiteralBuilder::new(), STRING)).field("right", "yak"))
.element(struct_literal().field("left", "cherry").field("right", cast(NullLiteralBuilder::new(), STRING)))
.element(struct_literal().field("left", "date").field("right", "walrus")),
)
.build())
.build(),
// Expected: unchanged - all elements already have the same struct type
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(struct_literal().field("left", "apple").field("right", "zebra"))
.element(struct_literal().field("left", cast(NullLiteralBuilder::new(), STRING)).field("right", "yak"))
.element(struct_literal().field("left", "cherry").field("right", cast(NullLiteralBuilder::new(), STRING)))
.element(struct_literal().field("left", "date").field("right", "walrus")),
)
.build())
.build(),
Struct::default()
.with_str(
"arr",
Array::new(
Struct::default().with_str("left", STRING).with_str("right", STRING).into(),
)
.into(),
)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
#[case::nested_array_in_struct_widening(
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
// First element: {items: [{a: 1}]}
.element(struct_literal().field(
"items",
array().element(struct_literal().field("a", 1)),
))
// Second element: {items: [{a: 2, b: 3}]}
.element(struct_literal().field(
"items",
array().element(struct_literal().field("a", 2).field("b", 3)),
)),
)
.build())
.build(),
// Expected: first element's items field is cast to array<{a, b}>
pipeline()
.from(|f| f.table_reference("events"))
.command(let_command()
.named_field(
"arr",
array()
.element(struct_literal().field(
"items",
cast(
array().element(struct_literal().field("a", 1)),
Array::new(
Struct::default().with_str("a", INT).with_str("b", INT).into(),
)
.into(),
),
))
.element(struct_literal().field(
"items",
array().element(struct_literal().field("a", 2).field("b", 3)),
)),
)
.build())
.build(),
Struct::default()
.with_str(
"arr",
Array::new(
Struct::default()
.with_str(
"items",
Array::new(
Struct::default().with_str("a", INT).with_str("b", INT).into(),
)
.into(),
)
.into(),
)
.into(),
)
.with_str("x", INT)
.with_str("s", Struct::default().with_str("a", INT).into())
)]
fn test_expand_array_literals(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) -> Result<(), Arc<TranslationError>> {
let input_typed = typed_pipeline(query().main(input));
let expected_typed = typed_pipeline(query().main(expected));
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let result = expand_array_literals(input_typed, &mut ctx)?;
assert_eq!(result.ast, expected_typed.ast);
assert_eq!(
result.environment().as_struct().clone(),
expected_output_schema
);
Ok(())
}
}