use std::rc::Rc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
ast::{command::Command, identifier::SimpleIdentifier},
builder::{self, column_ref, drop_command, explode_command, field, let_command},
typed_ast::{
command::{TypedCommand, TypedCommandKind, TypedUnnestCommand},
context::StatementTranslationContext,
expression::TypedExpressionKind,
pipeline::TypedPipeline,
},
};
use hamelin_lib::types::struct_type::Struct;
use hamelin_lib::types::Type;
use super::super::unique::UniqueNameGenerator;
pub fn lower_unnest(
pipeline: Rc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Rc<TypedPipeline>, Rc<TranslationError>> {
if !pipeline
.valid_ref()?
.commands
.iter()
.any(|cmd| matches!(&cmd.kind, TypedCommandKind::Unnest(_)))
{
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut name_gen = UniqueNameGenerator::new("__unnest");
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
for c in lower_command(cmd, &mut name_gen, ctx)? {
pipe_builder = pipe_builder.command(c);
}
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Rc::new(TypedPipeline::from_ast_with_context(
Rc::new(new_ast),
ctx,
)))
}
fn lower_command(
cmd: &Rc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
ctx: &mut StatementTranslationContext,
) -> Result<Vec<Rc<Command>>, Rc<TranslationError>> {
let TypedCommandKind::Unnest(unnest_cmd) = &cmd.kind else {
return Ok(vec![cmd.ast.clone()]);
};
lower_unnest_command(unnest_cmd, cmd, name_gen, ctx)
}
fn lower_unnest_command(
unnest_cmd: &TypedUnnestCommand,
cmd: &TypedCommand,
name_gen: &mut UniqueNameGenerator,
ctx: &mut StatementTranslationContext,
) -> Result<Vec<Rc<Command>>, Rc<TranslationError>> {
let expr = &unnest_cmd.expression;
let expr_type = expr.resolved_type.as_ref();
let column_name: Option<SimpleIdentifier> = match &expr.kind {
TypedExpressionKind::ColumnReference(col_ref) => col_ref.column_name.clone().valid().ok(),
_ => None,
};
let (is_array, struct_type) = extract_struct_type(expr_type, ctx, expr.ast.as_ref())?;
let mut result: Vec<Rc<Command>> = Vec::new();
let working_column = match column_name {
Some(name) => name,
None => {
let temp_name = name_gen.next();
let let_cmd = let_command()
.named_field(temp_name.clone(), expr.ast.as_ref().clone())
.at(cmd.ast.span)
.build();
result.push(Rc::new(let_cmd));
temp_name
}
};
if is_array {
let explode = explode_command()
.named_field(working_column.clone(), column_ref(working_column.as_str()))
.at(cmd.ast.span)
.build();
result.push(Rc::new(explode));
}
let mut let_builder = let_command().at(cmd.ast.span);
for (field_name, _field_type) in struct_type.fields.iter() {
let field_expr = field(
column_ref(working_column.as_str()),
field_name.name.as_str(),
);
let field_id: SimpleIdentifier = field_name.clone().into();
let_builder = let_builder.named_field(field_id, field_expr);
}
result.push(Rc::new(let_builder.build()));
let drop = drop_command()
.field(working_column)
.at(cmd.ast.span)
.build();
result.push(Rc::new(drop));
Ok(result)
}
fn extract_struct_type(
expr_type: &Type,
ctx: &mut StatementTranslationContext,
expr_ast: &hamelin_lib::tree::ast::expression::Expression,
) -> Result<(bool, Struct), Rc<TranslationError>> {
match expr_type {
Type::Struct(s) => Ok((false, s.clone())),
Type::Array(arr) => match arr.element_type.as_ref() {
Type::Struct(s) => Ok((true, s.clone())),
other => Err(ctx
.error(format!(
"UNNEST requires STRUCT or ARRAY<STRUCT>, found ARRAY<{}>",
other
))
.at(expr_ast)
.emit()),
},
other => Err(ctx
.error(format!(
"UNNEST requires STRUCT or ARRAY<STRUCT>, found {}",
other
))
.at(expr_ast)
.emit()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
tree::{
ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
builder::{
array, column_ref, drop_command, field, let_command, pipeline, select_command,
struct_literal, unnest_command,
},
},
types::{struct_type::Struct, INT, STRING},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::rc::Rc;
#[rstest]
#[case::no_unnest_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::unnest_struct_column(
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", "hello"))
.build())
.command(unnest_command(column_ref("x")).build())
.build(),
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", "hello"))
.build())
.command(let_command()
.named_field("a", field(column_ref("x"), "a"))
.named_field("b", field(column_ref("x"), "b"))
.build())
.command(drop_command().field("x").build())
.build(),
Struct::default().with_str("a", INT).with_str("b", STRING)
)]
#[case::unnest_array_struct_column(
pipeline()
.command(select_command()
.named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
.build())
.command(unnest_command(column_ref("arr")).build())
.build(),
pipeline()
.command(select_command()
.named_field("arr", array().element(struct_literal().field("a", 1).field("b", 2)))
.build())
.command(explode_command().named_field("arr", column_ref("arr")).build())
.command(let_command()
.named_field("a", field(column_ref("arr"), "a"))
.named_field("b", field(column_ref("arr"), "b"))
.build())
.command(drop_command().field("arr").build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::unnest_complex_expression(
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", 2))
.build())
.command(unnest_command(struct_literal().field("a", 10).field("b", 20)).build())
.build(),
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1).field("b", 2))
.build())
.command(let_command()
.named_field("__unnest_0", struct_literal().field("a", 10).field("b", 20))
.build())
.command(let_command()
.named_field("a", field(column_ref("__unnest_0"), "a"))
.named_field("b", field(column_ref("__unnest_0"), "b"))
.build())
.command(drop_command().field("__unnest_0").build())
.build(),
Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.with_str("x", Struct::default().with_str("a", INT).with_str("b", INT).into())
)]
#[case::multiple_unnest(
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1))
.named_field("y", struct_literal().field("b", 2))
.build())
.command(unnest_command(column_ref("x")).build())
.command(unnest_command(column_ref("y")).build())
.build(),
pipeline()
.command(select_command()
.named_field("x", struct_literal().field("a", 1))
.named_field("y", struct_literal().field("b", 2))
.build())
.command(let_command().named_field("a", field(column_ref("x"), "a")).build())
.command(drop_command().field("x").build())
.command(let_command().named_field("b", field(column_ref("y"), "b")).build())
.command(drop_command().field("y").build())
.build(),
Struct::default()
.with_str("b", INT)
.with_str("a", INT)
)]
fn test_lower_unnest(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = input.typed_with().typed();
let expected_typed = expected.typed_with().typed();
let mut ctx = StatementTranslationContext::default();
let result = lower_unnest(Rc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().flatten();
assert_eq!(result_schema, expected_output_schema);
}
}