hamelin_translation 0.3.10

Lowering and IR for Hamelin query language
Documentation
//! Struct expansion helpers.
//!
//! Provides functions for expanding struct expressions to match a target struct type.
//! Used by `expand_array_literals` for array element widening.
//!
//! Expansion transforms a struct expression to match a target struct type by:
//! 1. Adding missing fields as typed NULLs
//! 2. Reordering fields to match target type's order
//! 3. Recursively expanding nested struct fields
//!
//! This module produces AST expressions using builders. Type-checking happens
//! after the pass re-typechecks the transformed pipeline/statement.
//!
//! When a complex expression needs to be accessed multiple times (to avoid
//! duplicate evaluation), hoisting is used: LET commands are returned alongside
//! the expanded expression.

use std::rc::Rc;

use hamelin_lib::tree::{
    ast::{command::Command, expression::Expression},
    builder::{
        cast, column_ref, drop_command, field, let_command, null, struct_literal, ExpressionBuilder,
    },
    typed_ast::expression::{TypedExpression, TypedExpressionKind, TypedStructLiteral},
};
use hamelin_lib::types::{struct_type::Struct, Type};

use super::unique::UniqueNameGenerator;

// ---------------------------------------------------------------------------
// Helper Functions
// ---------------------------------------------------------------------------

/// Check if an expression is a column reference chain (col, col.field, col.field.nested).
///
/// These can be duplicated without performance penalty since they're just
/// lookups, not computations.
pub fn is_column_reference_chain(expr: &TypedExpression) -> bool {
    match &expr.kind {
        TypedExpressionKind::ColumnReference(_) => true,
        TypedExpressionKind::FieldLookup(lookup) => is_column_reference_chain(&lookup.value),
        _ => false,
    }
}

/// Count how many field accesses we'd need to build the expanded struct.
///
/// This counts fields that exist in source and need to be copied to target.
/// Used to determine if hoisting is necessary.
pub fn count_needed_field_accesses(source: &Struct, target: &Struct) -> usize {
    let mut count = 0;
    for (field_name, target_field_type) in target.fields.iter() {
        if let Some(source_field_type) = source.fields.get(field_name) {
            count += 1;
            // Recurse for nested structs
            if let (Type::Struct(s), Type::Struct(t)) = (source_field_type, target_field_type) {
                count += count_needed_field_accesses(s, t);
            }
        }
    }
    count
}

// ---------------------------------------------------------------------------
// Core Expansion Functions
// ---------------------------------------------------------------------------

/// Expand a typed struct expression to match a target struct type.
///
/// Returns `(expanded_expression, before_commands, after_commands)` where:
/// - `expanded_expression` is the AST expression built with builders
/// - `before_commands` are LET commands that must be inserted before the command
///   that uses this expression (to avoid duplicate evaluation of complex expressions)
/// - `after_commands` are DROP commands that must be inserted after the command
///   to clean up the hoisted variables
///
/// The caller is responsible for:
/// - Ensuring `expr` has a struct type matching `source_type`
/// - Re-typechecking after all transformations are complete
pub fn expand_struct_to_type(
    expr: &Rc<TypedExpression>,
    source_type: &Struct,
    target_type: &Struct,
    name_gen: &mut UniqueNameGenerator,
) -> (Rc<Expression>, Vec<Command>, Vec<Command>) {
    // Check if expansion is needed
    if source_type == target_type {
        // No expansion needed - return the original AST (cheap Rc clone)
        return (expr.ast.clone(), Vec::new(), Vec::new());
    }

    // Determine expansion strategy based on expression kind
    match &expr.kind {
        // Case 1: Struct literal - modify in place
        TypedExpressionKind::StructLiteral(lit) => {
            let (expanded, before, after) = expand_struct_literal(lit, target_type, name_gen);
            (Rc::new(expanded), before, after)
        }

        // Case 2: Column reference chain - can duplicate freely
        _ if is_column_reference_chain(expr) => (
            Rc::new(build_expanded_struct(expr, source_type, target_type)),
            Vec::new(),
            Vec::new(),
        ),

        // Case 3: Complex expression - may need hoisting
        _ => {
            let field_access_count = count_needed_field_accesses(source_type, target_type);
            if field_access_count > 1 {
                // Hoist the expression to avoid duplicate evaluation
                let hoisted_name = name_gen.next();
                let let_cmd = let_command()
                    .named_field(hoisted_name.clone(), expr.ast.clone())
                    .build();
                let drop_cmd = drop_command().field(hoisted_name.clone()).build();
                let expanded = build_expanded_struct_from_column(
                    hoisted_name.as_str(),
                    source_type,
                    target_type,
                );
                (Rc::new(expanded), vec![let_cmd], vec![drop_cmd])
            } else if field_access_count == 1 {
                // Single field access, no duplication needed
                (
                    Rc::new(build_expanded_struct(expr, source_type, target_type)),
                    Vec::new(),
                    Vec::new(),
                )
            } else {
                // No field accesses needed (all fields are NULL?) - rare case
                (
                    Rc::new(cast(null(), target_type.clone().into()).build()),
                    Vec::new(),
                    Vec::new(),
                )
            }
        }
    }
}

/// Expand a struct literal to match target type.
///
/// Modifies fields in place where possible, recurses for nested expansion.
fn expand_struct_literal(
    lit: &TypedStructLiteral,
    target_type: &Struct,
    name_gen: &mut UniqueNameGenerator,
) -> (Expression, Vec<Command>, Vec<Command>) {
    let mut builder = struct_literal();
    let mut all_before = Vec::new();
    let mut all_after = Vec::new();

    // Build fields in target type's order
    for (field_name, field_type) in target_type.fields.iter() {
        // Try to find this field in the literal
        let existing = lit.fields.iter().find(|(n, _)| {
            n.valid_ref()
                .map(|s| s.as_str() == field_name.name.as_str())
                .unwrap_or(false)
        });

        if let Some((_, field_expr)) = existing {
            // Field exists in literal
            // Check if nested expansion needed
            if let Type::Struct(target_nested) = field_type {
                if let Type::Struct(source_nested) = field_expr.resolved_type.as_ref() {
                    // Recursively expand
                    let (expanded, before, after) =
                        expand_struct_to_type(field_expr, source_nested, target_nested, name_gen);
                    all_before.extend(before);
                    all_after.extend(after);
                    builder = builder.field(field_name.clone(), expanded);
                    continue;
                }
            }

            // No nested expansion, keep original AST
            builder = builder.field(field_name.clone(), field_expr.ast.clone());
        } else {
            // Field missing - insert typed NULL
            builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
        }
    }

    (builder.build(), all_before, all_after)
}

/// Build a struct literal that extracts fields from a typed expression.
///
/// Used when the source is a column reference or has been hoisted.
fn build_expanded_struct(
    source_expr: &Rc<TypedExpression>,
    source_type: &Struct,
    target_type: &Struct,
) -> Expression {
    let mut builder = struct_literal();

    // Build fields in target type's order
    for (field_name, field_type) in target_type.fields.iter() {
        if let Some(source_field_type) = source_type.fields.get(field_name) {
            // Field exists in source - create field access on original AST
            let field_access = field(source_expr.ast.clone(), field_name.name.as_str());

            // Check if nested expansion needed
            if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
                (source_field_type, field_type)
            {
                if source_nested != target_nested {
                    let expanded = build_expanded_struct_fields(
                        field_access.build(),
                        source_nested,
                        target_nested,
                    );
                    builder = builder.field(field_name.clone(), expanded);
                    continue;
                }
            }

            builder = builder.field(field_name.clone(), field_access);
        } else {
            // Field doesn't exist in source - insert typed NULL
            builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
        }
    }

    builder.build()
}

/// Build a struct literal from a column reference name.
///
/// Used when the expression has been hoisted to a LET.
fn build_expanded_struct_from_column(
    column_name: &str,
    source_type: &Struct,
    target_type: &Struct,
) -> Expression {
    let mut builder = struct_literal();

    // Build fields in target type's order
    for (field_name, field_type) in target_type.fields.iter() {
        if let Some(source_field_type) = source_type.fields.get(field_name) {
            // Field exists in source - create field access
            let field_access = field(column_ref(column_name), field_name.name.as_str());

            // Check if nested expansion needed
            if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
                (source_field_type, field_type)
            {
                if source_nested != target_nested {
                    let expanded = build_expanded_struct_fields(
                        field_access.build(),
                        source_nested,
                        target_nested,
                    );
                    builder = builder.field(field_name.clone(), expanded);
                    continue;
                }
            }

            builder = builder.field(field_name.clone(), field_access);
        } else {
            // Field doesn't exist in source - insert typed NULL
            builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
        }
    }

    builder.build()
}

/// Build an expression that widens a source type to a target type.
///
/// Used by `expand_union_schemas` to build SELECT projections that widen
/// individual fields from source schema to target schema.
///
/// For scalar fields: returns `column_ref(field_name)` or `cast(null(), type)`
/// For nested structs: recursively builds struct literal with widened fields
pub fn build_widening_expression(
    field_name: &str,
    source_type: Option<&Type>,
    target_type: &Type,
) -> Expression {
    match (source_type, target_type) {
        // Field exists in source
        (Some(source_t), target_t) => {
            // Check if nested struct expansion needed
            if let (Type::Struct(source_struct), Type::Struct(target_struct)) = (source_t, target_t)
            {
                if source_struct != target_struct {
                    // Need to expand nested struct
                    return build_expanded_struct_fields(
                        column_ref(field_name).build(),
                        source_struct,
                        target_struct,
                    );
                }
            }
            // Simple field reference
            column_ref(field_name).build()
        }
        // Field doesn't exist in source - use typed NULL
        (None, target_t) => cast(null(), target_t.clone()).build(),
    }
}

/// Build expanded struct fields from an expression (for nested expansion).
///
/// This is used when we don't have a TypedExpression to work with (e.g., when
/// building nested field accesses).
fn build_expanded_struct_fields(
    source_expr: Expression,
    source_type: &Struct,
    target_type: &Struct,
) -> Expression {
    let mut builder = struct_literal();

    for (field_name, field_type) in target_type.fields.iter() {
        if let Some(source_field_type) = source_type.fields.get(field_name) {
            let field_access = field(source_expr.clone(), field_name.name.as_str());

            // Check if nested expansion needed
            if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
                (source_field_type, field_type)
            {
                if source_nested != target_nested {
                    let expanded = build_expanded_struct_fields(
                        field_access.build(),
                        source_nested,
                        target_nested,
                    );
                    builder = builder.field(field_name.clone(), expanded);
                    continue;
                }
            }

            builder = builder.field(field_name.clone(), field_access);
        } else {
            builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
        }
    }

    builder.build()
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::sql::expression::identifier::SimpleIdentifier;
    use hamelin_lib::types::INT;

    fn make_struct(fields: &[(&str, Type)]) -> Struct {
        Struct::new(
            fields
                .iter()
                .map(|(name, typ)| (SimpleIdentifier::new(name), typ.clone())),
        )
    }

    #[test]
    fn test_types_match_exactly_same_order() {
        let s1 = make_struct(&[("a", INT), ("b", INT)]);
        let s2 = make_struct(&[("a", INT), ("b", INT)]);
        assert_eq!(s1, s2);
    }

    #[test]
    fn test_types_match_exactly_different_order() {
        let s1 = make_struct(&[("a", INT), ("b", INT)]);
        let s2 = make_struct(&[("b", INT), ("a", INT)]);
        assert_ne!(s1, s2);
    }

    #[test]
    fn test_types_match_exactly_missing_field() {
        let s1 = make_struct(&[("a", INT)]);
        let s2 = make_struct(&[("a", INT), ("b", INT)]);
        assert_ne!(s1, s2);
    }

    #[test]
    fn test_count_needed_field_accesses() {
        let source = make_struct(&[("a", INT), ("b", INT)]);
        let target = make_struct(&[("a", INT), ("b", INT), ("c", INT)]);

        // a and b exist in source, c doesn't
        assert_eq!(count_needed_field_accesses(&source, &target), 2);
    }

    #[test]
    fn test_count_nested_field_accesses() {
        let inner_source = make_struct(&[("x", INT)]);
        let inner_target = make_struct(&[("x", INT), ("y", INT)]);

        let source = make_struct(&[("nested", Type::Struct(inner_source.clone()))]);
        let target = make_struct(&[("nested", Type::Struct(inner_target.clone()))]);

        // 1 for nested field + 1 for x inside nested = 2
        assert_eq!(count_needed_field_accesses(&source, &target), 2);
    }
}