hamelin_translation 0.4.2

Lowering and IR for Hamelin query language
Documentation
//! Pass: JOIN/LOOKUP right-side hoisting.
//!
//! Transforms JOIN and LOOKUP commands by hoisting the right side into a CTE
//! with NEST, which prepares the nested struct representation.
//!
//! Example:
//! ```text
//! FROM events | JOIN users ON user_id == users.id
//! ```
//! becomes:
//! ```text
//! WITH __join_0 = (FROM users | NEST users)
//! FROM events | JOIN __join_0 ON user_id == users.id
//! ```
//!
//! This pass must run BEFORE `nest_from_aliases` since both emit NEST commands
//! that are later processed by `lower_nest`.
//!
//! The NEST command wraps all right-side fields under the alias name, so after
//! joining, `users.id` correctly refers to the nested struct field.

use std::sync::Arc;

use hamelin_lib::{
    err::TranslationError,
    tree::{
        ast::query::Query,
        builder::{self, join_table_reference, lookup_table_reference, query, ExpressionBuilder},
        typed_ast::{
            command::TypedCommandKind, context::StatementTranslationContext,
            pipeline::TypedPipeline, query::TypedStatement,
        },
    },
};

use super::super::unique::UniqueNameGenerator;
use hamelin_lib::tree::builder::pipeline as pipeline_builder;

/// Lower JOIN and LOOKUP commands by hoisting right sides to CTEs with NEST.
///
/// Transforms:
/// ```text
/// FROM events | JOIN users ON user_id == users.id
/// ```
/// into:
/// ```text
/// WITH __join_0 = (FROM users | NEST users)
/// FROM events | JOIN __join_0 ON user_id == users.id
/// ```
///
/// This pass walks the statement and generates CTEs for each JOIN/LOOKUP command.
/// The NEST command is later lowered by `lower_nest`.
pub fn lower_joins(
    statement: Arc<TypedStatement>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
    // Check if any pipeline has JOIN/LOOKUP commands
    if !statement_has_joins(&statement)? {
        return Ok(statement);
    }

    let mut name_gen = UniqueNameGenerator::new("__join");
    let new_query = transform_statement(&statement, &mut name_gen)?;

    Ok(Arc::new(TypedStatement::from_ast_with_context(
        Arc::new(new_query),
        ctx,
    )))
}

/// Check if the statement has any JOIN or LOOKUP commands that need processing.
fn statement_has_joins(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
    statement
        .iter()
        .try_fold(false, |acc, p| pipeline_has_joins(p).map(|pj| pj || acc))
}

/// Check if a pipeline has any JOIN or LOOKUP commands.
fn pipeline_has_joins(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
    let res = pipeline.valid_ref()?.commands.iter().any(|c| {
        matches!(
            &c.kind,
            TypedCommandKind::Join(_) | TypedCommandKind::Lookup(_)
        )
    });

    Ok(res)
}

/// Transform a full statement, processing all pipelines and returning a new Query.
fn transform_statement(
    statement: &TypedStatement,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
    let mut query_builder = query();

    // Process existing WITH clauses - each may generate additional CTEs
    for with_clause in &statement.with_clauses {
        let transformed = transform_pipeline(&with_clause.pipeline, statement, name_gen)?;
        let valid_name = with_clause.name.clone().valid()?;

        query_builder = query_builder.merge_as_cte(transformed, valid_name);
    }

    // Process main pipeline
    let main_query = transform_pipeline(&statement.pipeline, statement, name_gen)?;
    Ok(query_builder.merge_as_main(main_query))
}

/// Transform a pipeline, generating CTEs for JOIN/LOOKUP right sides.
///
/// Returns a Query containing any generated CTEs and the transformed pipeline as main.
fn transform_pipeline(
    pipeline: &TypedPipeline,
    statement: &TypedStatement,
    name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
    let mut query_builder = query();
    let mut pipe_builder = pipeline_builder().at(pipeline.ast.span.clone());

    for cmd in &pipeline.valid_ref()?.commands {
        match &cmd.kind {
            TypedCommandKind::Join(join_cmd) => {
                // Get the alias (always present after type-checking)
                let alias = join_cmd.right.alias.valid_ref()?;

                // Get the table name from the right side
                let table_name = join_cmd
                    .right
                    .ast
                    .table
                    .identifier
                    .valid_ref()
                    .map(|id| id.clone())?;

                // Check against existing CTE names in the statement
                let cte_name = name_gen.next(statement);

                // Build the CTE pipeline: FROM <table> | NEST <alias>
                let cte_pipeline = builder::pipeline()
                    .from(|f| f.table_reference(table_name))
                    .nest(alias.clone())
                    .build();

                // Add CTE to query builder
                query_builder = query_builder.with(cte_name.clone(), cte_pipeline);

                // Build the new JOIN command referencing the CTE
                // Default to `true` for missing ON condition (CROSS JOIN semantics)
                let condition = join_cmd
                    .condition
                    .as_ref()
                    .map(|c| c.ast.as_ref().clone())
                    .unwrap_or_else(|| builder::boolean(true).build());
                pipe_builder =
                    pipe_builder.command(join_table_reference(cte_name).on(condition).build());
            }
            TypedCommandKind::Lookup(lookup_cmd) => {
                // Get the alias (always present after type-checking)
                let alias = lookup_cmd.right.alias.valid_ref()?;

                // Get the table name from the right side
                let table_name = lookup_cmd
                    .right
                    .ast
                    .table
                    .identifier
                    .valid_ref()
                    .map(|id| id.clone())?;

                // Check against existing CTE names in the statement
                let cte_name = name_gen.next(statement);

                // Build the CTE pipeline: FROM <table> | NEST <alias>
                let cte_pipeline = builder::pipeline()
                    .from(|f| f.table_reference(table_name))
                    .nest(alias.clone())
                    .build();

                // Add CTE to query builder
                query_builder = query_builder.with(cte_name.clone(), cte_pipeline);

                // Build the new LOOKUP command referencing the CTE
                // Default to `true` for missing ON condition (CROSS JOIN semantics)
                let condition = lookup_cmd
                    .condition
                    .as_ref()
                    .map(|c| c.ast.as_ref().clone())
                    .unwrap_or_else(|| builder::boolean(true).build());
                pipe_builder =
                    pipe_builder.command(lookup_table_reference(cte_name).on(condition).build());
            }
            _ => {
                // Non-JOIN/LOOKUP command - keep as-is
                pipe_builder = pipe_builder.command(cmd.ast.clone());
            }
        }
    }

    Ok(query_builder.main(pipe_builder.build()).build())
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        func::registry::FunctionRegistry,
        provider::EnvironmentProvider,
        sql::{expression::identifier::Identifier as SqlIdentifier, query::TableReference},
        tree::{
            ast::{IntoTyped, TypeCheckExecutor},
            builder::{column_ref, eq, field, query, HasMain, QueryBuilder},
        },
        types::{struct_type::Struct, INT},
    };
    use std::sync::Arc;

    // Mock provider for tests
    #[derive(Debug)]
    struct MockProvider;

    impl EnvironmentProvider for MockProvider {
        fn reflect_columns(&self, table: TableReference) -> anyhow::Result<Struct> {
            let mut fields = Struct::default();
            let events: SqlIdentifier = "events".parse().unwrap();
            let users: SqlIdentifier = "users".parse().unwrap();

            if table.name == events {
                fields.fields.insert("timestamp".parse().unwrap(), INT);
                fields.fields.insert("user_id".parse().unwrap(), INT);
                Ok(fields)
            } else if table.name == users {
                fields.fields.insert("id".parse().unwrap(), INT);
                fields.fields.insert("name".parse().unwrap(), INT);
                Ok(fields)
            } else {
                anyhow::bail!("Table not found: {}", table.name)
            }
        }

        fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
            Ok(vec![])
        }
    }

    fn typed_query(builder: QueryBuilder<HasMain>) -> TypedStatement {
        builder
            .build()
            .typed_with()
            .with_provider(Arc::new(MockProvider))
            .typed()
    }

    #[test]
    fn test_no_joins_passthrough() -> Result<(), Arc<TranslationError>> {
        // FROM events | WHERE timestamp > 10
        let q = query().main(
            pipeline_builder()
                .from(|f| f.table_reference("events"))
                .where_cmd(eq(column_ref("timestamp"), 10)),
        );

        let statement = typed_query(q);

        // Should not have joins
        assert!(!statement_has_joins(&statement)?);
        Ok(())
    }

    #[test]
    fn test_join_generates_cte() -> Result<(), Arc<TranslationError>> {
        // FROM events | JOIN users ON user_id == users.id
        let q = query().main(
            pipeline_builder()
                .from(|f| f.table_reference("events"))
                .join(
                    "users",
                    eq(column_ref("user_id"), field(column_ref("users"), "id")),
                ),
        );

        let statement = typed_query(q);

        // Should have join
        assert!(statement_has_joins(&statement)?);

        // Transform it
        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(MockProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = lower_joins(Arc::new(statement), &mut ctx)?;

        // Should now have a CTE
        assert_eq!(transformed.with_clauses.len(), 1);

        // CTE name should be __join_0
        let cte_name = transformed.with_clauses[0].name.valid_ref().unwrap();
        assert_eq!(cte_name.to_string(), "__join_0");
        Ok(())
    }

    #[test]
    fn test_lookup_generates_cte() -> Result<(), Arc<TranslationError>> {
        // FROM events | LOOKUP users ON user_id == users.id
        let q = query().main(
            pipeline_builder()
                .from(|f| f.table_reference("events"))
                .lookup("users", |l| {
                    l.on(eq(column_ref("user_id"), field(column_ref("users"), "id")))
                }),
        );

        let statement = typed_query(q);

        // Should have join (LOOKUP is stored as TypedJoinCommand)
        assert!(statement_has_joins(&statement)?);

        // Transform it
        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(MockProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = lower_joins(Arc::new(statement), &mut ctx)?;

        // Should now have a CTE
        assert_eq!(transformed.with_clauses.len(), 1);

        // CTE name should be __join_0
        let cte_name = transformed.with_clauses[0].name.valid_ref().unwrap();
        assert_eq!(cte_name.to_string(), "__join_0");
        Ok(())
    }

    #[test]
    fn test_multiple_joins_generate_multiple_ctes() -> Result<(), Arc<TranslationError>> {
        // FROM events | JOIN users ON ... | JOIN users ON ...
        // (Yes, joining users twice is contrived but tests the CTE generation)
        let q = query().main(
            pipeline_builder()
                .from(|f| f.table_reference("events"))
                .join(
                    "users",
                    eq(column_ref("user_id"), field(column_ref("users"), "id")),
                )
                .join(
                    "users",
                    eq(column_ref("user_id"), field(column_ref("users"), "id")),
                ),
        );

        let statement = typed_query(q);

        // Transform it
        let registry = Arc::new(FunctionRegistry::default());
        let provider = Arc::new(MockProvider);
        let mut ctx = StatementTranslationContext::new(registry, provider);
        let transformed = lower_joins(Arc::new(statement), &mut ctx)?;

        // Should now have 2 CTEs
        assert_eq!(transformed.with_clauses.len(), 2);

        let cte_name_0 = transformed.with_clauses[0].name.valid_ref().unwrap();
        let cte_name_1 = transformed.with_clauses[1].name.valid_ref().unwrap();
        assert_eq!(cte_name_0.to_string(), "__join_0");
        assert_eq!(cte_name_1.to_string(), "__join_1");
        Ok(())
    }
}