hamelin_translation 0.9.6

Lowering and IR for Hamelin query language
Documentation
//! Normalization passes for typed AST.
//!
//! Normalization transforms typed AST into a more uniform form for translation to SQL.
//! Each pass transforms AST and re-typechecks to maintain type safety.
//!
//! ## Pass Types
//!
//! **Statement normalizers** (`statement_normalizers/`):
//! Transform a full `TypedStatement`, may generate new CTEs.
//! - `lower_match` - Lowers MATCH to FROM + SET + WHERE + WINDOW + WHERE + DROP (must run first)
//! - `nest_from_aliases` - Converts aliased FROM to CTEs with NEST for alias nesting
//! - `from_to_union` - Converts multi-source FROM to UNION
//! - `expand_union_schemas` - Generates CTEs for UNION with differing schemas
//!
//! **Pipeline normalizers** (`pipeline_normalizers/`):
//! Transform a single pipeline without generating CTEs.
//! - `lower_broadcast_apply` - BroadcastApply → transform(array, lambda)
//! - `normalize_within` - WITHIN → WHERE with explicit timestamp bounds
//! - `lower_distinct` - DISTINCT → AGG (group-only)
//! - `normalize_agg` - AGG compound identifiers → flat AGG + SET/DROP
//! - `extract_agg_aggregates` - AGG nested aggregates → flat AGG + SET/DROP
//! - `normalize_window` - WINDOW compound identifiers → flat WINDOW + SET/DROP
//! - `normalize_explode` - EXPLODE compound identifiers → flat EXPLODE + SET/DROP
//! - `lower_unnest` - UNNEST → EXPLODE (if array) + SET + DROP
//! - `lower_parse` - PARSE → SET + WHERE (regex extraction + filter)
//! - `lower_nest` - NEST → SELECT with compound identifiers (struct packing)
//! - `expand_array_literals` - Expands array literal elements to match element type
//! - `desugar_in_array_literals` - Rewrites `x IN [1,2,3]` to `x IN (1,2,3)` when all elements are literals
//! - `dedup_append_source` - Deduplicates source rows by DISTINCT keys before APPEND
//! - `align_append_schema` - Inserts SELECT before APPEND to align pipeline to target table schema
//! - `lower_transform` - transform(arr, lambda) → EXPLODE + AGG (conditional, for DataFusion)
//! - `fuse_projections` - Fuses SET/DROP/SELECT into minimal SELECT commands (must be last)
//!
//! ## Helpers
//!
//! - `compound_lowering` - Shared logic for AGG/WINDOW compound identifier lowering
//! - `expand_struct` - Struct widening for array literals and FROM schema expansion
//! - `unique` - Unique name generation for synthesized identifiers
//!
//! ## Pass Order
//!
//! 1. Statement normalizers (can generate CTEs)
//! 2. Pipeline normalizers (`fuse_projections` must be last to catch all SET/DROP)
//!
//! ## JOIN/LOOKUP Lowering
//!
//! JOIN/LOOKUP right-side hoisting is done during IR lowering (`ir.rs`), not
//! during normalization. This avoids double-nesting that occurred when the
//! normalization pass emitted NEST and then re-typechecked the JOIN.

mod compound_lowering;
mod expand_struct;
mod pipeline_normalizers;
mod special_function_extraction;
mod statement_normalizers;

use std::sync::Arc;

use hamelin_lib::err::{ContextualTranslationErrors, Stage, TranslationError, TranslationErrors};
use hamelin_lib::parse_with_options;
use hamelin_lib::tree::ast::query::DefBody;
use hamelin_lib::tree::ast::query::Query;
use hamelin_lib::tree::builder::query;
use hamelin_lib::tree::options::{ParseOptions, TypeCheckOptions};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::tree::typed_ast::query::TypedStatement;
use hamelin_lib::type_check_with_options;

use pipeline_normalizers::{
    dedup_append_source, desugar_in_array_literals, desugar_map_from_arrays, expand_array_literals,
    extract_agg_aggregates, extract_window_aggregates, fuse_projections, lower_broadcast_apply,
    lower_distinct, lower_nest, lower_parse, lower_transform, lower_ts_trunc_interval,
    lower_unnest, normalize_agg, normalize_explode, normalize_window, normalize_within,
    recognize_transform_values, rewrite_transform_values,
};
use statement_normalizers::{expand_union_schemas, from_to_union, lower_match, nest_from_aliases};

use pipeline_normalizers::align_append_schema;

use crate::normalize::pipeline_normalizers::lower_filter;

/// Normalize a full statement through all passes.
///
/// This is the top-level entry point for normalization. It applies passes in order:
/// 1. Statement normalizers: `lower_match`, `nest_from_aliases`, `from_to_union`, `expand_union_schemas`
/// 2. Pipeline normalizers (for each CTE + main pipeline)
pub fn normalize_statement(
    statement: Arc<TypedStatement>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, ContextualTranslationErrors> {
    // Statement-level passes (can generate CTEs)
    let statement = if ctx.skip_statement_passes {
        statement
    } else {
        // lower_match must run first since it generates FROM with aliases
        let statement = with_statement_src(statement, |s| lower_match(s, ctx))?;
        let statement = with_statement_src(statement, |s| nest_from_aliases(s, ctx))?;
        let statement = with_statement_src(statement, |s| from_to_union(s, ctx))?;
        with_statement_src(statement, |s| expand_union_schemas(s, ctx))?
    };

    // Apply pipeline passes to each CTE + main pipeline. Rebuild DEFs in source order:
    // scalar and tabular DEFs may be interleaved in the original AST.
    let mut builder = query();

    match statement.ast.valid_ref() {
        Ok(valid) => {
            let mut scalars = statement.scalar_defs.iter();
            let mut pipelines = statement.pipeline_defs.iter();

            for def in &valid.defs {
                match &def.body {
                    DefBody::Expression(_) => {
                        let sd = scalars
                            .next()
                            .ok_or_else(|| {
                                Arc::new(TranslationError::msg(
                                    statement.ast.as_ref(),
                                    "internal: scalar_defs shorter than expression defs in AST",
                                ))
                            })
                            .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                        let name = sd
                            .name
                            .valid_ref()
                            .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                        builder = builder.def_expression(name.clone(), sd.expression.ast.clone());
                    }
                    DefBody::Pipeline(_) => {
                        let pd = pipelines
                            .next()
                            .ok_or_else(|| {
                                Arc::new(TranslationError::msg(
                                    statement.ast.as_ref(),
                                    "internal: pipeline_defs shorter than pipeline defs in AST",
                                ))
                            })
                            .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                        let name = pd
                            .name
                            .valid_ref()
                            .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                        let normalized = normalize_pipeline(pd.pipeline.clone(), ctx)
                            .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                        ctx.register_cte(name.clone(), normalized.environment());
                        builder = builder.def_pipeline(name.clone(), normalized);
                    }
                }
            }

            if scalars.next().is_some() || pipelines.next().is_some() {
                return Err(contextual_from_arc(
                    statement.ast.to_string(),
                    Arc::new(TranslationError::msg(
                        statement.ast.as_ref(),
                        "internal: leftover DEF entries after rebuilding query",
                    )),
                ));
            }
        }
        Err(_) => {
            for sd in &statement.scalar_defs {
                let name = sd
                    .name
                    .valid_ref()
                    .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                builder = builder.def_expression(name.clone(), sd.expression.ast.clone());
            }
            for pd in &statement.pipeline_defs {
                let name = pd
                    .name
                    .valid_ref()
                    .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                let normalized = normalize_pipeline(pd.pipeline.clone(), ctx)
                    .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
                ctx.register_cte(name.clone(), normalized.environment());
                builder = builder.def_pipeline(name.clone(), normalized);
            }
        }
    }

    let normalized_main = normalize_pipeline(statement.pipeline.clone(), ctx)
        .map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
    let built = builder.main(normalized_main).build();

    let opts = TypeCheckOptions::builder()
        .registry(ctx.registry.clone())
        .provider(ctx.provider.clone())
        .interner(ctx.interner.clone())
        .maybe_timestamp_field(Some(ctx.timestamp_field.clone()))
        .maybe_message_field(Some(ctx.message_field.clone()))
        .build();
    let built_arc = Arc::new(built);
    let wte = type_check_with_options::<Query>(built_arc.clone(), opts.clone());

    if wte.errors.is_empty() {
        Ok(Arc::new(wte.output))
    } else {
        Err(normalize_typecheck_failure((*built_arc).clone(), wte.errors, opts.clone()).take(1))
    }
}

fn contextual_from_arc(
    hamelin: impl Into<String>,
    err: Arc<TranslationError>,
) -> ContextualTranslationErrors {
    ContextualTranslationErrors::new(
        hamelin.into(),
        (*err).clone().with_stage(Stage::Normalization).single(),
    )
}

fn with_statement_src(
    statement: Arc<TypedStatement>,
    f: impl FnOnce(Arc<TypedStatement>) -> Result<Arc<TypedStatement>, Arc<TranslationError>>,
) -> Result<Arc<TypedStatement>, ContextualTranslationErrors> {
    let src = statement.ast.to_string();
    f(statement).map_err(|e| contextual_from_arc(src, e))
}

fn with_normalization_stage(errors: TranslationErrors) -> TranslationErrors {
    TranslationErrors(
        errors
            .0
            .into_iter()
            .map(|e| e.with_stage(Stage::Normalization))
            .collect(),
    )
}

/// When the final re-typecheck fails, remap diagnostics by printing the normalized `Query`,
/// re-parsing, and type-checking again so spans match `synthetic_source`.
fn normalize_typecheck_failure(
    built: Query,
    initial_errors: TranslationErrors,
    opts: TypeCheckOptions,
) -> ContextualTranslationErrors {
    let synthetic_source = built.to_string();

    if initial_errors.0.is_empty() {
        return ContextualTranslationErrors::new(
            synthetic_source,
            TranslationError::msg(
                &built,
                "type check failed after normalization (no error details)",
            )
            .with_stage(Stage::Normalization)
            .single(),
        );
    }

    let parse_opts = match &opts.interner {
        Some(interner) => ParseOptions::builder().interner(interner.clone()).build(),
        None => ParseOptions::default(),
    };
    let parse_wte = parse_with_options(&synthetic_source, parse_opts);

    if !parse_wte.errors.is_empty() {
        let TranslationErrors(mut vec) = initial_errors;
        if let Some(first) = vec.first_mut() {
            first.add_context(
                0..=0,
                "could not re-parse printed normalized query; spans may be inaccurate",
            );
        }
        return ContextualTranslationErrors::new(
            synthetic_source,
            with_normalization_stage(TranslationErrors(vec)),
        );
    }

    let retry = type_check_with_options::<Query>(Arc::new(parse_wte.output), opts);
    if retry.errors.is_empty() {
        let TranslationErrors(mut vec) = initial_errors;
        if let Some(first) = vec.first_mut() {
            first.add_context(
                0..=0,
                "error spans are from the first type check on the built AST, not this printed query; highlighting may be inaccurate",
            );
        }
        return ContextualTranslationErrors::new(
            synthetic_source,
            with_normalization_stage(TranslationErrors(vec)),
        );
    }

    ContextualTranslationErrors::new(synthetic_source, with_normalization_stage(retry.errors))
}

/// Normalize a single pipeline through all pipeline passes.
///
/// Pipeline-level pass contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, ...>`
pub fn normalize_pipeline(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    if ctx.skip_pipeline_passes {
        return Ok(pipeline);
    }

    let pipeline = lower_broadcast_apply(pipeline, ctx)?;
    let pipeline = normalize_within(pipeline, ctx)?;
    let pipeline = lower_distinct(pipeline, ctx)?;
    let pipeline = normalize_agg(pipeline, ctx)?;
    let pipeline = extract_agg_aggregates(pipeline, ctx)?;
    let pipeline = normalize_window(pipeline, ctx)?;
    let pipeline = extract_window_aggregates(pipeline, ctx)?;
    let pipeline = normalize_explode(pipeline, ctx)?;
    let pipeline = lower_unnest(pipeline, ctx)?;
    let pipeline = lower_parse(pipeline, ctx)?;
    let pipeline = lower_nest(pipeline, ctx)?;
    // expand_array_literals can generate transform() calls for struct widening
    // so it must run BEFORE lower_transform
    let pipeline = expand_array_literals(pipeline, ctx)?;
    // Desugar IN/NOT IN with array literals to tuple form for better SQL optimization
    let pipeline = desugar_in_array_literals(pipeline, ctx)?;
    // Desugar map([k1,k2,...], [v1,v2,...]) to map(k1:v1, k2:v2, ...) when keys are literals,
    // avoiding a DataFusion scalar/columnar mismatch in the built-in map() UDF
    let pipeline = desugar_map_from_arrays(pipeline, ctx)?;
    // Rewrite TsTrunc(interval) → TsTrunc(now() + interval) so backends always see timestamps
    let pipeline = lower_ts_trunc_interval(pipeline, ctx)?;
    // Deduplicate source rows by DISTINCT keys before APPEND,
    // must run BEFORE align_append_schema (which also inserts commands before APPEND)
    let pipeline = dedup_append_source(pipeline, ctx)?;
    // align_append_schema can generate transform() calls for struct widening
    // so it must run BEFORE lower_transform
    let pipeline = align_append_schema(pipeline, ctx)?;
    // Conditionally lower filter() for backends without a native array filter function.
    // lower_filter rewrites filter(arr, x -> pred) to
    // flatten(transform(arr, x -> if(pred, [x], cast([] AS array(T))))),
    // so it must run BEFORE lower_transform.
    let pipeline = if ctx.lower_filter {
        lower_filter(pipeline, ctx)?
    } else {
        pipeline
    };

    // Recognize `map(map_keys(m), transform(map_values(m), λ))` as
    // `transform_values(m, λ)`. Narrow pattern match (same simple column on
    // both sides). This produces only the internal `transform_values`
    // operator, which any backend can consume — the DataFusion-specific
    // lowering lives in `rewrite_transform_values` below. Backends that
    // natively support `transform_values` can translate it directly; the
    // DataFusion path runs the rewrite pass to dispatch to a fast path or
    // unfold back to the generic form.
    let pipeline = recognize_transform_values(pipeline, ctx)?;

    // Conditionally lower transform() for backends that don't support lambdas.
    // Must run AFTER expand_array_literals and align_append_schema (which can
    // generate transform calls) and BEFORE fuse_projections.
    //
    // `rewrite_transform_values` is gated alongside `lower_transform` because
    // its fast paths emit DataFusion-specific UDFs (`map_variant_to_json_values`,
    // `map_variant_get_values`). Backends that natively support lambdas (e.g.
    // Trino) should keep `transform_values` intact for their own translator.
    let pipeline = if ctx.lower_transform {
        // Replace every `transform_values(m, λ)` with either a direct
        // vectorized map operator (fast path) or the generic
        // `map(map_keys(m), transform(map_values(m), λ))` form — the residual
        // `transform()` is then lowered below.
        let pipeline = rewrite_transform_values(pipeline, ctx)?;
        let pipeline = lower_transform(pipeline, ctx)?;
        // Re-run normalize_explode since lower_transform generates EXPLODE commands
        normalize_explode(pipeline, ctx)?
    } else {
        pipeline
    };

    if ctx.skip_projection_fusion {
        return Ok(pipeline);
    }

    let pipeline = fuse_projections(pipeline, ctx)?;
    Ok(pipeline)
}