hamelin_translation 0.9.7

Lowering and IR for Hamelin query language
Documentation
//! Pass: FROM alias nesting.
//!
//! Transforms aliased FROM clauses into CTEs with NEST commands.
//!
//! Example:
//! ```text
//! FROM x = events  -- events: {a, b}
//! ```
//! becomes:
//! ```text
//! DEF __alias_0 = FROM events | NEST x
//! FROM __alias_0
//! ```
//!
//! The NEST command is later lowered by `lower_nest` to SELECT with compound
//! identifiers, which `fuse_projections` then packs into struct literals.
//!
//! This pass operates at the statement level, taking a TypedStatement and
//! returning a TypedStatement. It uses builders to construct AST, then
//! runs type-checking to derive types.

use std::sync::Arc;

use hamelin_lib::{
    err::TranslationError,
    tree::{
        ast::query::Query,
        builder::{self, field_ref, from_command, query, select_command, struct_literal},
        typed_ast::{
            clause::TypedFromClause, command::TypedCommandKind,
            context::StatementTranslationContext, pipeline::TypedPipeline, query::TypedStatement,
        },
    },
};

use crate::unique::UniqueNameGenerator;
use hamelin_lib::tree::builder::pipeline as pipeline_builder;

/// Nest aliased FROM clauses into CTEs.
///
/// Transforms:
/// ```text
/// FROM x = events | WHERE x.a > 10
/// ```
/// into:
/// ```text
/// DEF __alias_0 = FROM events | NEST x;
/// FROM __alias_0 | WHERE x.a > 10
/// ```
///
/// The pass walks the statement and generates CTEs for each pipeline that has
/// aliased FROM clauses. The NEST command is later lowered by `lower_nest`.
pub fn nest_from_aliases(
    statement: Arc<TypedStatement>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
    // Check if any pipeline has aliased FROM clauses
    // This lets us early-exit if we actually don't need to perform any changes.
    if !statement_has_aliases(&statement)? {
        return Ok(statement);
    }

    let mut name_gen = UniqueNameGenerator::new("__alias");
    let new_query = transform_statement(&statement, &mut name_gen)?;

    Ok(Arc::new(TypedStatement::from_ast_with_context(
        Arc::new(new_query),
        ctx,
    )))
}

/// Check if the statement has any aliased FROM clauses that need processing.
fn statement_has_aliases(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
    statement
        .iter()
        .try_fold(false, |acc, p| pipeline_has_aliases(p).map(|pa| pa || acc))
}

/// Check if a pipeline has any aliased FROM clauses.
fn pipeline_has_aliases(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
    let res = pipeline
        .valid_ref()?
        .commands
        .iter()
        .flat_map(|c| match &c.kind {
            TypedCommandKind::From(typed_from_command) => typed_from_command.clauses.iter(),
            _ => [].iter(),
        })
        .any(|c| matches!(c, TypedFromClause::Alias(_)));

    Ok(res)
}

/// Transform a full statement, processing all pipelines and returning a new Query.
fn transform_statement(
    statement: &TypedStatement,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
    let mut query_builder = query();

    for sd in &statement.scalar_defs {
        let name = sd.name.valid_ref()?.clone();
        query_builder = query_builder.def_expression(name, sd.expression.ast.clone());
    }

    // Existing tabular DEF pipelines — each may generate additional CTEs
    for pd in &statement.pipeline_defs {
        // Transform pipeline returns a query with CTEs and a main
        let transformed = transform_pipeline(&pd.pipeline, statement, name_gen)?;
        let valid_name = pd.name.clone().valid()?;

        query_builder = query_builder.merge_as_cte(transformed, valid_name);
    }

    // Process main pipeline
    let main_query = transform_pipeline(&statement.pipeline, statement, name_gen)?;
    Ok(query_builder.merge_as_main(main_query))
}

/// Transform a pipeline, generating CTEs for aliased FROM clauses.
///
/// Returns a Query containing any generated CTEs and the transformed pipeline as main.
fn transform_pipeline(
    pipeline: &TypedPipeline,
    statement: &TypedStatement,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
    let mut query_builder = query();
    let mut pipeline_builder = pipeline_builder().at(pipeline.ast.span.clone());

    for cmd in &pipeline.valid_ref()?.commands {
        match &cmd.kind {
            TypedCommandKind::From(from_cmd)
                if from_cmd
                    .clauses
                    .iter()
                    .any(|c| matches!(c, TypedFromClause::Alias(_))) =>
            {
                // Transform aliased clauses into CTEs, building FROM command directly
                let mut from_builder = from_command().at(cmd.ast.span.clone());

                for clause in &from_cmd.clauses {
                    match clause {
                        TypedFromClause::Alias(alias_clause) => {
                            // Get the alias name
                            let alias = alias_clause.alias.valid_ref()?;

                            // Check against existing CTE names in the statement
                            let cte_name = name_gen.next(statement);

                            // Get the table name
                            let table_name = alias_clause
                                .ast
                                .table
                                .identifier
                                .valid_ref()
                                .map(|id| id.clone())?;

                            // Build SELECT with explicit column ordering:
                            // 1. All base fields first (preserving table schema order)
                            // 2. Then the nested alias at the end
                            //
                            // This is verbose but necessary because SET uses with_single_parent
                            // which puts new bindings first when flattened. We need base fields
                            // first to match legacy Environment::merge() behavior.
                            let table_env = alias_clause.resolved.environment();
                            let table_schema = table_env.as_struct();

                            // Build struct literal for the alias: {field1: field1, ...}
                            let mut struct_builder = struct_literal();
                            for (field_name, _) in table_schema.iter() {
                                struct_builder = struct_builder
                                    .field(field_name.name(), field_ref(field_name.name()));
                            }

                            // Build SELECT: base fields first, then alias.
                            // Skip any base field whose name matches the alias
                            // (the alias shadows it).
                            let mut select_builder = select_command();
                            for (field_name, _) in table_schema.iter() {
                                if field_name.name() == alias.as_str() {
                                    continue;
                                }
                                select_builder = select_builder.field(field_name.to_string());
                            }
                            select_builder =
                                select_builder.named_field(alias.clone(), struct_builder);

                            let cte_pipeline = builder::pipeline()
                                .from(|f| f.table_reference(table_name))
                                .command(select_builder)
                                .build();

                            // Add CTE to query builder
                            query_builder =
                                query_builder.def_pipeline(cte_name.clone(), cte_pipeline);

                            // Replace aliased clause with reference to CTE
                            from_builder = from_builder.table_reference(cte_name);
                        }
                        TypedFromClause::Reference(ref_clause) => {
                            // Non-aliased reference - pass through
                            from_builder = from_builder
                                .table_reference(ref_clause.ast.identifier.clone().valid()?);
                        }
                        TypedFromClause::Error(e) => return Err(e.clone()),
                    }
                }

                pipeline_builder = pipeline_builder.command(from_builder);
            }
            _ => {
                // Non-FROM command - keep as-is
                pipeline_builder = pipeline_builder.command(cmd.ast.clone());
            }
        }
    }

    Ok(query_builder.main(pipeline_builder.build()).build())
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        func::registry::FunctionRegistry,
        provider::EnvironmentProvider,
        tree::{
            ast::identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
            builder::{eq, field, field_ref, query, QueryBuilderWithMain},
        },
        type_check_with_provider,
        types::{struct_type::Struct, INT},
    };
    use std::sync::Arc;

    // Mock provider for tests
    #[derive(Debug)]
    struct MockProvider;

    impl EnvironmentProvider for MockProvider {
        fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
            let events: Identifier = AstSimpleIdentifier::new("events").into();

            if name == &events {
                Ok(Struct::default().with_str("a", INT).with_str("b", INT))
            } else {
                anyhow::bail!("Table not found: {}", name)
            }
        }

        fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
            Ok(vec![])
        }
    }

    fn typed_query(builder: QueryBuilderWithMain) -> TypedStatement {
        type_check_with_provider(builder.build(), Arc::new(MockProvider)).output
    }

    #[test]
    fn test_no_aliases_passthrough() -> Result<(), Arc<TranslationError>> {
        // FROM events | WHERE a > 10
        let q = query().main(
            pipeline_builder()
                .from(|f| f.table_reference("events"))
                .where_cmd(eq(field_ref("a"), 10)),
        );

        let statement = typed_query(q);

        // Should not have aliases
        assert!(!statement_has_aliases(&statement)?);
        Ok(())
    }

    #[test]
    fn test_single_alias_generates_cte() -> Result<(), Arc<TranslationError>> {
        // FROM x = events | WHERE x.a > 10
        let q = query().main(
            pipeline_builder()
                .from(|f| f.table_alias("x", "events"))
                .where_cmd(eq(field(field_ref("x"), "a"), 10)),
        );

        let statement = typed_query(q);

        // Should have alias
        assert!(statement_has_aliases(&statement)?);

        // Transform it
        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(MockProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = nest_from_aliases(Arc::new(statement), &mut ctx)?;

        // Should now have a CTE
        assert_eq!(transformed.pipeline_defs.len(), 1);

        // CTE name should be __alias_0
        let cte_name = transformed.pipeline_defs[0].name.valid_ref().unwrap();
        assert_eq!(cte_name.to_string(), "__alias_0");
        Ok(())
    }

    #[test]
    fn test_multiple_aliases_generate_multiple_ctes() -> Result<(), Arc<TranslationError>> {
        // FROM x = events, y = events
        let q = query().main(
            pipeline_builder().from(|f| f.table_alias("x", "events").table_alias("y", "events")),
        );

        let statement = typed_query(q);

        // Transform it
        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(MockProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = nest_from_aliases(Arc::new(statement), &mut ctx)?;

        // Should now have 2 CTEs
        assert_eq!(transformed.pipeline_defs.len(), 2);

        let cte_name_0 = transformed.pipeline_defs[0].name.valid_ref().unwrap();
        let cte_name_1 = transformed.pipeline_defs[1].name.valid_ref().unwrap();
        assert_eq!(cte_name_0.to_string(), "__alias_0");
        assert_eq!(cte_name_1.to_string(), "__alias_1");
        Ok(())
    }
}