use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
ast::{command::Command, identifier::SimpleIdentifier},
builder::{self, column_ref, drop_command, explode_command, field, let_command},
typed_ast::{
command::{TypedCommand, TypedCommandKind, TypedUnnestCommand},
context::StatementTranslationContext,
pipeline::TypedPipeline,
},
};
use hamelin_lib::types::struct_type::Struct;
use hamelin_lib::types::Type;
use super::super::unique::UniqueNameGenerator;
pub fn lower_unnest(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline
.valid_ref()?
.commands
.iter()
.any(|cmd| matches!(&cmd.kind, TypedCommandKind::Unnest(_)))
{
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut name_gen = UniqueNameGenerator::new("__unnest");
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
for c in lower_command(cmd, &mut name_gen, ctx)? {
pipe_builder = pipe_builder.command(c);
}
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn lower_command(
cmd: &Arc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
ctx: &mut StatementTranslationContext,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let TypedCommandKind::Unnest(unnest_cmd) = &cmd.kind else {
return Ok(vec![cmd.ast.clone()]);
};
lower_unnest_command(unnest_cmd, cmd, name_gen, ctx)
}
fn lower_unnest_command(
unnest_cmd: &TypedUnnestCommand,
cmd: &TypedCommand,
name_gen: &mut UniqueNameGenerator,
ctx: &mut StatementTranslationContext,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let expr = &unnest_cmd.expression;
let expr_type = expr.resolved_type.as_ref();
let (is_array, struct_type) = extract_struct_type(expr_type, ctx, expr.ast.as_ref())?;
let mut result: Vec<Arc<Command>> = Vec::new();
let temp_name = name_gen.next(&cmd.input_schema);
let let_cmd = let_command()
.named_field(temp_name.clone(), expr.ast.as_ref().clone())
.at(cmd.ast.span)
.build();
result.push(Arc::new(let_cmd));
if is_array {
let explode = explode_command()
.named_field(temp_name.clone(), column_ref(temp_name.as_str()))
.at(cmd.ast.span)
.build();
result.push(Arc::new(explode));
}
let mut let_builder = let_command().at(cmd.ast.span);
for (field_name, _field_type) in struct_type.fields.iter() {
let field_expr = field(column_ref(temp_name.as_str()), field_name.name.as_str());
let field_id: SimpleIdentifier = field_name.clone().into();
let_builder = let_builder.named_field(field_id, field_expr);
}
result.push(Arc::new(let_builder.build()));
let drop = drop_command().field(temp_name).at(cmd.ast.span).build();
result.push(Arc::new(drop));
Ok(result)
}
fn extract_struct_type(
expr_type: &Type,
ctx: &mut StatementTranslationContext,
expr_ast: &hamelin_lib::tree::ast::expression::Expression,
) -> Result<(bool, Struct), Arc<TranslationError>> {
match expr_type {
Type::Struct(s) => Ok((false, s.clone())),
Type::Array(arr) => match arr.element_type.as_ref() {
Type::Struct(s) => Ok((true, s.clone())),
other => Err(ctx
.error(format!(
"UNNEST requires STRUCT or ARRAY<STRUCT>, found ARRAY<{}>",
other
))
.at(expr_ast)
.emit()),
},
other => Err(ctx
.error(format!(
"UNNEST requires STRUCT or ARRAY<STRUCT>, found {}",
other
))
.at(expr_ast)
.emit()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
tree::{
ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
builder::{
array, column_ref, drop_command, field, let_command, pipeline, select_command,
struct_literal, unnest_command,
},
},
types::{array::Array, struct_type::Struct, INT, STRING},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::no_unnest_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::unnest_struct_column(
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", "hello"))
.build())
.command(unnest_command(column_ref("x")).build())
.build(),
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", "hello"))
.build())
.command(let_command()
.named_field("__unnest_0", column_ref("x"))
.build())
.command(let_command()
.named_field("a", field(column_ref("__unnest_0"), "a"))
.named_field("b", field(column_ref("__unnest_0"), "b"))
.build())
.command(drop_command().field("__unnest_0").build())
.build(),
Struct::default()
.with_str("a", INT)
.with_str("b", STRING)
.with_str("x", Struct::default().with_str("a", INT).with_str("b", STRING).into())
)]
#[case::unnest_array_struct_column(
pipeline()
.command(select_command()
.named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
.build())
.command(unnest_command(column_ref("arr")).build())
.build(),
pipeline()
.command(select_command()
.named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
.build())
.command(let_command()
.named_field("__unnest_0", column_ref("arr"))
.build())
.command(explode_command().named_field("__unnest_0", column_ref("__unnest_0")).build())
.command(let_command()
.named_field("a", field(column_ref("__unnest_0"), "a"))
.named_field("b", field(column_ref("__unnest_0"), "b"))
.build())
.command(drop_command().field("__unnest_0").build())
.build(),
Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.with_str("arr", Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into())
)]
#[case::unnest_complex_expression(
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", 2))
.build())
.command(unnest_command(struct_literal().field("a", 10).field("b", 20)).build())
.build(),
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", 2))
.build())
.command(let_command()
.named_field("__unnest_0", struct_literal().field("a", 10).field("b", 20))
.build())
.command(let_command()
.named_field("a", field(column_ref("__unnest_0"), "a"))
.named_field("b", field(column_ref("__unnest_0"), "b"))
.build())
.command(drop_command().field("__unnest_0").build())
.build(),
Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.with_str("x", Struct::default().with_str("a", INT).with_str("b", INT).into())
)]
#[case::multiple_unnest(
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1))
.named_field("y", struct_literal().field("b", 2))
.build())
.command(unnest_command(column_ref("x")).build())
.command(unnest_command(column_ref("y")).build())
.build(),
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1))
.named_field("y", struct_literal().field("b", 2))
.build())
.command(let_command().named_field("__unnest_0", column_ref("x")).build())
.command(let_command().named_field("a", field(column_ref("__unnest_0"), "a")).build())
.command(drop_command().field("__unnest_0").build())
.command(let_command().named_field("__unnest_1", column_ref("y")).build())
.command(let_command().named_field("b", field(column_ref("__unnest_1"), "b")).build())
.command(drop_command().field("__unnest_1").build())
.build(),
Struct::default()
.with_str("b", INT)
.with_str("a", INT)
.with_str("x", Struct::default().with_str("a", INT).into())
.with_str("y", Struct::default().with_str("b", INT).into())
)]
fn test_lower_unnest(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = input.typed_with().typed();
let expected_typed = expected.typed_with().typed();
let mut ctx = StatementTranslationContext::default();
let result = lower_unnest(Arc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().flatten();
assert_eq!(result_schema, expected_output_schema);
}
}