use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
ast::{command::Command, identifier::Identifier, identifier::SimpleIdentifier},
builder::{self, drop_command, explode_command, field_ref, let_command},
typed_ast::{
command::{TypedCommand, TypedCommandKind, TypedExplodeCommand},
context::StatementTranslationContext,
expression::TypedExpressionKind,
pipeline::TypedPipeline,
},
};
use crate::unique::UniqueNameGenerator;
pub fn normalize_explode(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if !pipeline
.valid_ref()?
.commands
.iter()
.any(explode_needs_normalization)
{
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut name_gen = UniqueNameGenerator::new("__explode");
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
for c in normalize_command(cmd, &mut name_gen)? {
pipe_builder = pipe_builder.command(c);
}
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn explode_needs_normalization(cmd: &Arc<TypedCommand>) -> bool {
let TypedCommandKind::Explode(explode_cmd) = &cmd.kind else {
return false;
};
!is_canonical_explode(explode_cmd)
}
fn is_canonical_explode(explode_cmd: &TypedExplodeCommand) -> bool {
explode_cmd.items.iter().all(|item| {
let Ok(Identifier::Simple(simple_id)) = item.assignment.identifier.valid_ref() else {
return false;
};
let TypedExpressionKind::FieldReference(col_ref) = &item.assignment.expression.kind else {
return false;
};
let Ok(col_name) = col_ref.field_name.valid_ref() else {
return false;
};
simple_id.as_str() == col_name.as_str()
})
}
fn normalize_command(
cmd: &Arc<TypedCommand>,
name_gen: &mut UniqueNameGenerator,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let TypedCommandKind::Explode(explode_cmd) = &cmd.kind else {
return Ok(vec![cmd.ast.clone()]);
};
if is_canonical_explode(explode_cmd) {
return Ok(vec![cmd.ast.clone()]);
}
transform_explode(explode_cmd, cmd, name_gen)
}
enum ItemNormalization {
Canonical(SimpleIdentifier),
SimpleNonCanonical { col_name: SimpleIdentifier },
Compound { temp_name: SimpleIdentifier },
}
fn transform_explode(
explode_cmd: &TypedExplodeCommand,
cmd: &TypedCommand,
name_gen: &mut UniqueNameGenerator,
) -> Result<Vec<Arc<Command>>, Arc<TranslationError>> {
let mut pre_lets: Vec<Arc<Command>> = Vec::new();
let mut post_lets: Vec<Arc<Command>> = Vec::new();
let mut post_drops: Vec<Arc<Command>> = Vec::new();
let mut normalizations: Vec<ItemNormalization> = Vec::new();
for item in &explode_cmd.items {
let identifier = item.assignment.identifier.valid_ref()?;
let is_canonical = if let Identifier::Simple(simple_id) = identifier {
if let TypedExpressionKind::FieldReference(col_ref) = &item.assignment.expression.kind {
if let Ok(col_name) = col_ref.field_name.valid_ref() {
simple_id.as_str() == col_name.as_str()
} else {
false
}
} else {
false
}
} else {
false
};
if is_canonical {
let Identifier::Simple(simple_id) = identifier else {
unreachable!()
};
normalizations.push(ItemNormalization::Canonical(simple_id.clone()));
} else {
match identifier {
Identifier::Simple(simple_id) => {
let col_name = simple_id.clone();
pre_lets.push(Arc::new(
let_command()
.named_field(
col_name.clone(),
item.assignment.expression.ast.as_ref().clone(),
)
.at(cmd.ast.span)
.build(),
));
normalizations.push(ItemNormalization::SimpleNonCanonical { col_name });
}
Identifier::Compound(compound) => {
let temp_name: SimpleIdentifier = name_gen.next(&cmd.input_schema);
pre_lets.push(Arc::new(
let_command()
.named_field(
temp_name.clone(),
item.assignment.expression.ast.as_ref().clone(),
)
.at(cmd.ast.span)
.build(),
));
let original: Identifier = compound.clone().into();
post_lets.push(Arc::new(
let_command()
.named_field(original.clone(), field_ref(temp_name.as_str()))
.at(cmd.ast.span)
.build(),
));
post_drops.push(Arc::new(
drop_command()
.field(temp_name.clone())
.at(cmd.ast.span)
.build(),
));
normalizations.push(ItemNormalization::Compound { temp_name });
}
}
}
}
let col_names: Vec<SimpleIdentifier> = normalizations
.iter()
.map(|norm| match norm {
ItemNormalization::Canonical(name) => name.clone(),
ItemNormalization::SimpleNonCanonical { col_name } => col_name.clone(),
ItemNormalization::Compound { temp_name } => temp_name.clone(),
})
.collect();
let first_col = &col_names[0];
let explode_builder =
explode_command().named_field(first_col.clone(), field_ref(first_col.as_str()));
let explode_builder = col_names[1..].iter().fold(explode_builder, |builder, col| {
builder.named_field(col.clone(), field_ref(col.as_str()))
});
let explode = explode_builder.at(cmd.ast.span).build();
let mut result = pre_lets;
result.push(Arc::new(explode));
result.extend(post_lets);
result.extend(post_drops);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::type_check;
use hamelin_lib::{
tree::{
ast::pipeline::Pipeline,
builder::{
array, drop_command, explode_command, field_ref, let_command, pipeline,
select_command,
},
},
types::{array::Array, struct_type::Struct, INT},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::canonical_unchanged(
pipeline()
.command(select_command().named_field("arr", array().element(1).element(2)).build())
.command(explode_command().named_field("arr", field_ref("arr")).build())
.build(),
pipeline()
.command(select_command().named_field("arr", array().element(1).element(2)).build())
.command(explode_command().named_field("arr", field_ref("arr")).build())
.build(),
Struct::default().with_str("arr", INT)
)]
#[case::simple_id_different_expr(
pipeline()
.command(select_command().named_field("arr", array().element(1).element(2)).build())
.command(explode_command().named_field("x", field_ref("arr")).build())
.build(),
pipeline()
.command(select_command().named_field("arr", array().element(1).element(2)).build())
.command(let_command().named_field("x", field_ref("arr")).build())
.command(explode_command().named_field("x", field_ref("x")).build())
.build(),
Struct::default().with_str("x", INT).with_str("arr", Array::new(INT).into())
)]
#[case::no_explode_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::multiple_explodes(
pipeline()
.command(select_command()
.named_field("arr1", array().element(1))
.named_field("arr2", array().element(2))
.build())
.command(explode_command().named_field("x", field_ref("arr1")).build())
.command(explode_command().named_field("y", field_ref("arr2")).build())
.build(),
pipeline()
.command(select_command()
.named_field("arr1", array().element(1))
.named_field("arr2", array().element(2))
.build())
.command(let_command().named_field("x", field_ref("arr1")).build())
.command(explode_command().named_field("x", field_ref("x")).build())
.command(let_command().named_field("y", field_ref("arr2")).build())
.command(explode_command().named_field("y", field_ref("y")).build())
.build(),
Struct::default()
.with_str("y", INT)
.with_str("x", INT)
.with_str("arr1", Array::new(INT).into())
.with_str("arr2", Array::new(INT).into())
)]
#[case::compound_id(
pipeline()
.command(select_command().named_field("arr", array().element(1)).build())
.command(explode_command()
.named_field(
hamelin_lib::tree::ast::identifier::CompoundIdentifier::new("result".into(), "item".into(), vec![]),
field_ref("arr")
)
.build())
.build(),
pipeline()
.command(select_command().named_field("arr", array().element(1)).build())
.command(let_command().named_field("__explode_0", field_ref("arr")).build())
.command(explode_command().named_field("__explode_0", field_ref("__explode_0")).build())
.command(let_command()
.named_field(
hamelin_lib::tree::ast::identifier::CompoundIdentifier::new("result".into(), "item".into(), vec![]),
field_ref("__explode_0")
)
.build())
.command(drop_command().field("__explode_0").build())
.build(),
Struct::default()
.with_str("result", Struct::default().with_str("item", INT).into())
.with_str("arr", Array::new(INT).into())
)]
fn test_normalize_explode(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = type_check(input).output;
let expected_typed = type_check(expected).output;
let mut ctx = StatementTranslationContext::default();
let result = normalize_explode(Arc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().as_struct().clone();
assert_eq!(result_schema, expected_output_schema);
}
}