use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::pipeline::Pipeline;
use hamelin_lib::tree::builder::{self, select_command};
use hamelin_lib::tree::typed_ast::clause::ResolvedEnvironment;
use hamelin_lib::tree::typed_ast::command::{TypedAppendCommand, TypedCommandKind};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::types::struct_type::Struct;
use super::super::expand_struct::build_widening_expression;
pub fn align_append_schema(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
let Some((append_cmd, append_idx, input_schema)) = find_append_command(&pipeline)? else {
return Ok(pipeline);
};
let target_schema = match &append_cmd.table.resolved {
ResolvedEnvironment::Resolved(env) => env.as_struct(),
ResolvedEnvironment::Error(_) => {
return Ok(pipeline);
}
};
let source_schema = input_schema.as_struct();
if source_schema == target_schema {
return Ok(pipeline);
}
let new_pipeline =
build_aligned_pipeline(&pipeline, &source_schema, &target_schema, append_idx)?;
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_pipeline),
ctx,
)))
}
fn find_append_command(
pipeline: &TypedPipeline,
) -> Result<Option<(&TypedAppendCommand, usize, &Arc<TypeEnvironment>)>, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
for (idx, cmd) in valid.commands.iter().enumerate() {
if let TypedCommandKind::Append(append_cmd) = &cmd.kind {
return Ok(Some((append_cmd, idx, &cmd.input_schema)));
}
}
Ok(None)
}
fn build_aligned_pipeline(
pipeline: &TypedPipeline,
source_schema: &Struct,
target_schema: &Struct,
append_idx: usize,
) -> Result<Pipeline, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());
for (idx, cmd) in valid.commands.iter().enumerate() {
if idx == append_idx {
let select = build_alignment_select(source_schema, target_schema);
pipeline_builder = pipeline_builder.command(select);
}
pipeline_builder = pipeline_builder.command(cmd.ast.clone());
}
Ok(pipeline_builder.build())
}
fn build_alignment_select(
source_schema: &Struct,
target_schema: &Struct,
) -> builder::SelectCommandBuilder {
let mut select_builder = select_command();
for (field_name, field_type) in target_schema.iter() {
let source_field_type = source_schema.lookup(field_name);
let expr = build_widening_expression(field_name.name(), source_field_type, field_type);
select_builder = select_builder.named_field(field_name.name(), expr);
}
select_builder
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
func::registry::FunctionRegistry,
provider::EnvironmentProvider,
tree::{
ast::{
identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
pipeline::Pipeline,
},
builder::{
append_command, array, cast, field_ref, null, pipeline, select_command,
set_command, struct_literal,
},
options::TypeCheckOptions,
},
type_check_with_options,
types::{array::Array, struct_type::Struct, Type, INT},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[derive(Debug)]
struct TestProvider;
impl EnvironmentProvider for TestProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let my_table: Identifier = AstSimpleIdentifier::new("my_table").into();
let nested_table: Identifier = AstSimpleIdentifier::new("nested_table").into();
let array_table: Identifier = AstSimpleIdentifier::new("array_table").into();
if name == &my_table {
Ok(Struct::default()
.with_str("b", INT)
.with_str("c", INT)
.with_str("a", INT))
} else if name == &nested_table {
let nested: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
Ok(Struct::default().with_str("nested", nested))
} else if name == &array_table {
let elem: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
Ok(Struct::default().with_str("items", Array::new(elem).into()))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
fn run_test(input: Pipeline, expected: Pipeline) -> Result<(), Arc<TranslationError>> {
let provider = Arc::new(TestProvider);
let registry = Arc::new(FunctionRegistry::default());
let tc_opts = || {
TypeCheckOptions::builder()
.registry(registry.clone())
.provider(provider.clone())
.build()
};
let input_typed = type_check_with_options(input, tc_opts()).output;
let expected_typed = type_check_with_options(expected, tc_opts()).output;
let mut ctx = StatementTranslationContext::new(registry, provider);
let result = align_append_schema(Arc::new(input_typed), &mut ctx)?;
assert_eq!(result.ast, expected_typed.ast);
Ok(())
}
#[rstest]
#[case::no_append_passthrough(
pipeline().command(set_command().named_field("a", 1)).build(),
pipeline().command(set_command().named_field("a", 1)).build()
)]
#[case::reorder_and_missing_field(
pipeline()
.command(set_command().named_field("a", 1).named_field("b", 2))
.command(append_command("my_table"))
.build(),
pipeline()
.command(set_command().named_field("a", 1).named_field("b", 2))
.command(select_command()
.named_field("b", field_ref("b"))
.named_field("c", cast(null(), INT))
.named_field("a", field_ref("a")))
.command(append_command("my_table"))
.build()
)]
#[case::nested_struct_widening(
pipeline()
.command(set_command().named_field("nested", struct_literal().field("a", 1)))
.command(append_command("nested_table"))
.build(),
{
// The normalizer now generates a cast to the target struct type
// CastKind::StructExpansion handles adding null fields
let target_struct: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
pipeline()
.command(set_command().named_field("nested", struct_literal().field("a", 1)))
.command(select_command()
.named_field("nested", cast(field_ref("nested"), target_struct)))
.command(append_command("nested_table"))
.build()
}
)]
#[case::array_of_structs_widening(
pipeline()
.command(set_command().named_field("items", array().element(struct_literal().field("a", 1))))
.command(append_command("array_table"))
.build(),
{
// The normalizer now generates a cast to the target array type
// CastKind::ArrayElementCast(StructExpansion) handles widening array elements
let target_elem: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
let target_array: Type = Array::new(target_elem).into();
pipeline()
.command(set_command().named_field("items", array().element(struct_literal().field("a", 1))))
.command(select_command()
.named_field("items", cast(field_ref("items"), target_array)))
.command(append_command("array_table"))
.build()
}
)]
fn test_align_append_schema(
#[case] input: Pipeline,
#[case] expected: Pipeline,
) -> Result<(), Arc<TranslationError>> {
run_test(input, expected)
}
}