use std::sync::Arc;
use hamelin_lib::tree::builder::pipeline as pipeline_builder;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::{identifier::Identifier, pipeline::Pipeline, query::Query},
builder::{self, query, select_command},
typed_ast::{
clause::TypedFromClause,
command::{TypedCommandKind, TypedUnionCommand},
context::StatementTranslationContext,
environment::TypeEnvironment,
pipeline::TypedPipeline,
query::TypedStatement,
},
},
types::struct_type::Struct,
};
use super::super::expand_struct::build_widening_expression;
use crate::unique::UniqueNameGenerator;
pub fn expand_union_schemas(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
if !statement_needs_expansion(&statement)? {
return Ok(statement);
}
let mut name_gen = UniqueNameGenerator::new("__union");
let new_query = transform_statement(&statement, &mut name_gen)?;
Ok(Arc::new(TypedStatement::from_ast_with_context(
Arc::new(new_query),
ctx,
)))
}
fn statement_needs_expansion(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
statement.iter().try_fold(false, |acc, p| {
pipeline_needs_expansion(p).map(|pe| pe || acc)
})
}
fn pipeline_needs_expansion(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let res = pipeline
.valid_ref()?
.commands
.iter()
.any(|c| match &c.kind {
TypedCommandKind::Union(union_cmd) => {
union_needs_expansion(union_cmd, &c.output_schema)
}
_ => false,
});
Ok(res)
}
fn union_needs_expansion(cmd: &TypedUnionCommand, output_schema: &TypeEnvironment) -> bool {
if cmd.clauses.len() <= 1 {
return false;
}
let output_struct = output_schema.as_struct();
cmd.clauses.iter().any(|clause| {
let clause_env = clause.environment();
let clause_struct = clause_env.as_struct();
clause_struct != output_struct
})
}
fn transform_statement(
statement: &TypedStatement,
name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
for sd in &statement.scalar_defs {
let name = sd.name.valid_ref()?.clone();
query_builder = query_builder.def_expression(name, sd.expression.ast.clone());
}
for pd in &statement.pipeline_defs {
let transformed = transform_pipeline(&pd.pipeline, statement, name_gen)?;
let valid_name = pd.name.clone().valid()?;
query_builder = query_builder.merge_as_cte(transformed, valid_name);
}
let main_query = transform_pipeline(&statement.pipeline, statement, name_gen)?;
Ok(query_builder.merge_as_main(main_query))
}
fn transform_pipeline(
pipeline: &TypedPipeline,
statement: &TypedStatement,
name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
let mut pipeline_builder = pipeline_builder().at(pipeline.ast.span.clone());
for cmd in &pipeline.valid_ref()?.commands {
match &cmd.kind {
TypedCommandKind::Union(union_cmd)
if union_needs_expansion(union_cmd, &cmd.output_schema) =>
{
let output_struct = cmd.output_schema.as_struct();
let mut union_builder = builder::union_command().at(cmd.ast.span.clone());
for clause in &union_cmd.clauses {
match clause {
TypedFromClause::Reference(ref_clause) => {
let table_name = ref_clause.ast.identifier.clone().valid()?;
let clause_env = clause.environment();
let clause_struct = clause_env.as_struct();
let cte_name = name_gen.next(statement);
let cte_pipeline = build_widening_pipeline(
table_name.clone(),
&clause_struct,
&output_struct,
);
query_builder =
query_builder.def_pipeline(cte_name.clone(), cte_pipeline);
union_builder = union_builder.table_reference(cte_name);
}
TypedFromClause::Alias(_) => {
continue;
}
TypedFromClause::Error(e) => return Err(e.clone()),
}
}
pipeline_builder = pipeline_builder.command(union_builder);
}
_ => pipeline_builder = pipeline_builder.command(cmd.ast.clone()),
}
}
Ok(query_builder.main(pipeline_builder.build()).build())
}
fn build_widening_pipeline(
table_name: Identifier,
source_struct: &Struct,
target_struct: &Struct,
) -> Pipeline {
let mut select_builder = select_command();
for (field_name, field_type) in target_struct.iter() {
let source_field_type = source_struct.lookup(field_name);
let widened_expr =
build_widening_expression(field_name.name(), source_field_type, field_type);
select_builder = select_builder.named_field(field_name.name(), widened_expr);
}
builder::pipeline()
.from(|f| f.table_reference(table_name))
.command(select_builder)
.build()
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
func::registry::FunctionRegistry,
provider::EnvironmentProvider,
tree::{
ast::identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
builder::{cast, field_ref, query, select_command, QueryBuilderWithMain},
},
type_check_with_provider,
types::{array::Array, struct_type::Struct, Type, INT},
};
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let events: Identifier = AstSimpleIdentifier::new("events").into();
let logs: Identifier = AstSimpleIdentifier::new("logs").into();
if name == &events {
Ok(Struct::default().with_str("a", INT).with_str("b", INT))
} else if name == &logs {
Ok(Struct::default().with_str("a", INT).with_str("c", INT))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
fn typed_query(builder: QueryBuilderWithMain) -> TypedStatement {
type_check_with_provider(builder.build(), Arc::new(MockProvider)).output
}
#[test]
fn test_single_table_no_expansion() -> Result<(), Arc<TranslationError>> {
let q = query().main(pipeline_builder().union(|u| u.table_reference("events")));
let statement = typed_query(q);
assert!(!statement_needs_expansion(&statement)?);
Ok(())
}
#[test]
fn test_identical_schemas_no_expansion() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder().union(|u| u.table_reference("events").table_reference("events")),
);
let statement = typed_query(q);
assert!(!statement_needs_expansion(&statement)?);
Ok(())
}
#[test]
fn test_different_schemas_needs_expansion() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder().union(|u| u.table_reference("events").table_reference("logs")),
);
let statement = typed_query(q);
assert!(statement_needs_expansion(&statement)?);
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = expand_union_schemas(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.pipeline_defs.len(), 2);
let cte_name_0 = transformed.pipeline_defs[0].name.valid_ref().unwrap();
let cte_name_1 = transformed.pipeline_defs[1].name.valid_ref().unwrap();
assert_eq!(cte_name_0.to_string(), "__union_0");
assert_eq!(cte_name_1.to_string(), "__union_1");
Ok(())
}
#[test]
fn test_nested_struct_schema_widening() -> Result<(), Arc<TranslationError>> {
#[derive(Debug)]
struct NestedProvider;
impl EnvironmentProvider for NestedProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let events: Identifier = AstSimpleIdentifier::new("events").into();
let logs: Identifier = AstSimpleIdentifier::new("logs").into();
let nested_events: Type = Struct::default().with_str("a", INT).into();
let nested_logs: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
if name == &events {
Ok(Struct::default().with_str("nested", nested_events))
} else if name == &logs {
Ok(Struct::default().with_str("nested", nested_logs))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
let q = query().main(
pipeline_builder().union(|u| u.table_reference("events").table_reference("logs")),
);
let statement = type_check_with_provider(q.build(), Arc::new(NestedProvider)).output;
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(NestedProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = expand_union_schemas(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.pipeline_defs.len(), 2);
let target_struct: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
let expected_events = query().main(
pipeline_builder()
.from(|f| f.table_reference("events"))
.command(
select_command()
.named_field("nested", cast(field_ref("nested"), target_struct))
.build(),
),
);
let expected_typed =
type_check_with_provider(expected_events.build(), Arc::new(NestedProvider)).output;
assert_eq!(
transformed.pipeline_defs[0].pipeline.ast,
expected_typed.pipeline.ast
);
Ok(())
}
#[test]
fn test_array_of_structs_schema_widening() -> Result<(), Arc<TranslationError>> {
#[derive(Debug)]
struct ArrayProvider;
impl EnvironmentProvider for ArrayProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let events: Identifier = AstSimpleIdentifier::new("events").into();
let logs: Identifier = AstSimpleIdentifier::new("logs").into();
let events_elem: Type = Struct::default().with_str("a", INT).into();
let logs_elem: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
if name == &events {
Ok(Struct::default().with_str("items", Array::new(events_elem).into()))
} else if name == &logs {
Ok(Struct::default().with_str("items", Array::new(logs_elem).into()))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
let q = query().main(
pipeline_builder().union(|u| u.table_reference("events").table_reference("logs")),
);
let statement = type_check_with_provider(q.build(), Arc::new(ArrayProvider)).output;
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(ArrayProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = expand_union_schemas(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.pipeline_defs.len(), 2);
let target_elem: Type = Struct::default()
.with_str("a", INT)
.with_str("b", INT)
.into();
let target_array: Type = Array::new(target_elem).into();
let expected_events = query().main(
pipeline_builder()
.from(|f| f.table_reference("events"))
.command(
select_command()
.named_field("items", cast(field_ref("items"), target_array))
.build(),
),
);
let expected_typed =
type_check_with_provider(expected_events.build(), Arc::new(ArrayProvider)).output;
assert_eq!(
transformed.pipeline_defs[0].pipeline.ast,
expected_typed.pipeline.ast
);
Ok(())
}
}