use std::rc::Rc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::pipeline::Pipeline;
use hamelin_lib::tree::builder::{self, select_command};
use hamelin_lib::tree::typed_ast::clause::ResolvedEnvironment;
use hamelin_lib::tree::typed_ast::command::{TypedAppendCommand, TypedCommandKind};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::types::struct_type::Struct;
use super::super::expand_struct::build_widening_expression;
pub fn align_append_schema(
pipeline: Rc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Rc<TypedPipeline>, Rc<TranslationError>> {
let Some((append_cmd, append_idx, input_schema)) = find_append_command(&pipeline)? else {
return Ok(pipeline);
};
let target_schema = match &append_cmd.table.resolved {
ResolvedEnvironment::Resolved(env) => env.flatten(),
ResolvedEnvironment::Error(_) => {
return Ok(pipeline);
}
};
let source_schema = input_schema.flatten();
if source_schema == target_schema {
return Ok(pipeline);
}
let new_pipeline =
build_aligned_pipeline(&pipeline, &source_schema, &target_schema, append_idx)?;
Ok(Rc::new(TypedPipeline::from_ast_with_context(
Rc::new(new_pipeline),
ctx,
)))
}
fn find_append_command(
pipeline: &TypedPipeline,
) -> Result<Option<(&TypedAppendCommand, usize, &Rc<TypeEnvironment>)>, Rc<TranslationError>> {
let valid = pipeline.valid_ref()?;
for (idx, cmd) in valid.commands.iter().enumerate() {
if let TypedCommandKind::Append(append_cmd) = &cmd.kind {
return Ok(Some((append_cmd, idx, &cmd.input_schema)));
}
}
Ok(None)
}
fn build_aligned_pipeline(
pipeline: &TypedPipeline,
source_schema: &Struct,
target_schema: &Struct,
append_idx: usize,
) -> Result<Pipeline, Rc<TranslationError>> {
let valid = pipeline.valid_ref()?;
let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());
for (idx, cmd) in valid.commands.iter().enumerate() {
if idx == append_idx {
let select = build_alignment_select(source_schema, target_schema);
pipeline_builder = pipeline_builder.command(select);
}
pipeline_builder = pipeline_builder.command(cmd.ast.clone());
}
Ok(pipeline_builder.build())
}
fn build_alignment_select(
source_schema: &Struct,
target_schema: &Struct,
) -> builder::SelectCommandBuilder {
let mut select_builder = select_command();
for (field_name, field_type) in target_schema.fields.iter() {
let source_field_type = source_schema.fields.get(field_name);
let expr =
build_widening_expression(field_name.name.as_str(), source_field_type, field_type);
select_builder = select_builder.named_field(field_name.name.as_str(), expr);
}
select_builder
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
provider::EnvironmentProvider,
sql::{
expression::identifier::Identifier as SqlIdentifier,
query::TableReference as SqlTableReference,
},
tree::{
ast::{IntoTyped, TypeCheckExecutor},
builder::{append_command, let_command, pipeline},
},
types::{struct_type::Struct, INT},
};
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, table: SqlTableReference) -> anyhow::Result<Struct> {
let mut fields = Struct::default();
let my_table: SqlIdentifier = "my_table".parse().unwrap();
if table.name == my_table {
fields.fields.insert("b".parse().unwrap(), INT);
fields.fields.insert("c".parse().unwrap(), INT);
fields.fields.insert("a".parse().unwrap(), INT);
Ok(fields)
} else {
anyhow::bail!("Table not found: {}", table.name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
Ok(vec![])
}
}
#[test]
fn test_align_append_schema_inserts_select() -> Result<(), Rc<TranslationError>> {
let p = pipeline()
.command(let_command().named_field("a", 1).named_field("b", 2))
.command(append_command("my_table"))
.build();
let typed = p.typed_with().with_provider(Arc::new(MockProvider)).typed();
let provider = Arc::new(MockProvider);
let registry = Arc::new(hamelin_lib::func::registry::FunctionRegistry::default());
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = align_append_schema(Rc::new(typed), &mut ctx)?;
let valid = transformed.valid_ref()?;
assert_eq!(valid.commands.len(), 3);
assert!(matches!(
&valid.commands[1].kind,
TypedCommandKind::Select(_)
));
assert!(matches!(
&valid.commands[2].kind,
TypedCommandKind::Append(_)
));
Ok(())
}
#[test]
fn test_no_append_no_change() -> Result<(), Rc<TranslationError>> {
let p = pipeline()
.command(let_command().named_field("a", 1))
.build();
let typed = p.typed_with().with_provider(Arc::new(MockProvider)).typed();
let provider = Arc::new(MockProvider);
let registry = Arc::new(hamelin_lib::func::registry::FunctionRegistry::default());
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = align_append_schema(Rc::new(typed.clone()), &mut ctx)?;
assert_eq!(transformed.ast, typed.ast);
Ok(())
}
}