hamelin_translation 0.7.10

Lowering and IR for Hamelin query language
Documentation
//! Lowering entry point for converting TypedStatement to IR.
//!
//! This module provides the public API for lowering Hamelin's type-checked AST
//! to the intermediate representation (IR).

use std::sync::Arc;

use hamelin_lib::err::{ContextualTranslationErrors, Stage};
use hamelin_lib::func::registry::FunctionRegistry;
use hamelin_lib::provider::{EnvironmentProvider, NoOpProvider};
use hamelin_lib::tree::ast::identifier::{CompoundIdentifier, Identifier, SimpleIdentifier};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::query::TypedStatement;

use crate::ir::IRStatement;
use crate::normalize::normalize_statement;

/// Lower a TypedStatement to IR with default options.
///
/// This is the simplest entry point for lowering. For custom configuration,
/// use [`normalize_with()`] instead.
pub fn lower(statement: Arc<TypedStatement>) -> Result<IRStatement, ContextualTranslationErrors> {
    NormalizationOptions::default().lower(statement)
}

/// Create a NormalizationOptions builder for custom configuration.
pub fn normalize_with() -> NormalizationOptions {
    NormalizationOptions::default()
}

/// Options for normalizing and lowering a TypedStatement to IR.
///
/// Configure the normalization/lowering process with custom providers, registries,
/// field names, and skip flags.
/// Use [`normalize_with()`] to create an instance with defaults.
#[derive(Clone)]
pub struct NormalizationOptions {
    registry: Arc<FunctionRegistry>,
    provider: Arc<dyn EnvironmentProvider>,
    timestamp_field: Identifier,
    message_field: Identifier,
    lower_transform: bool,
    skip_statement_passes: bool,
    skip_pipeline_passes: bool,
    skip_projection_fusion: bool,
}

impl Default for NormalizationOptions {
    fn default() -> Self {
        Self {
            registry: Arc::new(FunctionRegistry::default()),
            provider: Arc::new(NoOpProvider::default()),
            timestamp_field: SimpleIdentifier::new("timestamp").into(),
            message_field: CompoundIdentifier::new(
                SimpleIdentifier::new("event"),
                SimpleIdentifier::new("original"),
                vec![],
            )
            .into(),
            lower_transform: false,
            skip_statement_passes: false,
            skip_pipeline_passes: false,
            skip_projection_fusion: false,
        }
    }
}

impl NormalizationOptions {
    /// Set a custom function registry.
    pub fn with_registry(mut self, registry: Arc<FunctionRegistry>) -> Self {
        self.registry = registry;
        self
    }

    /// Set a custom environment provider for table schema lookups.
    pub fn with_provider(mut self, provider: Arc<dyn EnvironmentProvider>) -> Self {
        self.provider = provider;
        self
    }

    /// Set the timestamp field used by WITHIN normalization.
    ///
    /// Defaults to "timestamp".
    pub fn with_timestamp_field(mut self, field: impl Into<Identifier>) -> Self {
        self.timestamp_field = field.into();
        self
    }

    /// Set the message field used by PARSE when no source is specified.
    ///
    /// Defaults to "event.original".
    pub fn with_message_field(mut self, field: impl Into<Identifier>) -> Self {
        self.message_field = field.into();
        self
    }

    /// Enable transform() lowering for backends that don't support lambdas.
    ///
    /// When enabled, `transform(arr, x -> body)` is lowered to an EXPLODE + AGG pattern.
    /// This is needed for DataFusion which doesn't support lambda expressions.
    ///
    /// Defaults to false (transform() is passed through for backends like Trino that support it).
    pub fn with_lower_transform(mut self) -> Self {
        self.lower_transform = true;
        self
    }

    /// Skip statement-level normalization passes (lower_match, nest_from_aliases, etc.).
    pub fn skip_statement_passes(mut self) -> Self {
        self.skip_statement_passes = true;
        self
    }

    /// Skip all pipeline-level normalization passes (implies skip_projection_fusion).
    pub fn skip_pipeline_passes(mut self) -> Self {
        self.skip_pipeline_passes = true;
        self
    }

    /// Skip just the fuse_projections pass.
    pub fn skip_projection_fusion(mut self) -> Self {
        self.skip_projection_fusion = true;
        self
    }

    /// Build a StatementTranslationContext from these options.
    fn build_context(&self) -> StatementTranslationContext {
        let mut ctx =
            StatementTranslationContext::new(self.registry.clone(), self.provider.clone())
                .with_timestamp_field(self.timestamp_field.clone())
                .with_message_field(self.message_field.clone());

        if self.lower_transform {
            ctx = ctx.with_lower_transform();
        }

        if self.skip_statement_passes {
            ctx = ctx.with_skip_statement_passes();
        }

        if self.skip_pipeline_passes {
            ctx = ctx.with_skip_pipeline_passes();
        }

        if self.skip_projection_fusion {
            ctx = ctx.with_skip_projection_fusion();
        }

        ctx
    }

    /// Run normalization passes only (without converting to IR).
    ///
    /// Returns the normalized TypedStatement, suitable for printing via Display.
    pub fn normalize(
        self,
        statement: Arc<TypedStatement>,
    ) -> Result<Arc<TypedStatement>, ContextualTranslationErrors> {
        let mut ctx = self.build_context();
        normalize_statement(statement, &mut ctx)
    }

    /// Lower a TypedStatement to IR.
    ///
    /// Runs all normalization passes and converts to IR.
    /// Returns the first error encountered, if any.
    pub fn lower(
        self,
        statement: Arc<TypedStatement>,
    ) -> Result<IRStatement, ContextualTranslationErrors> {
        let mut ctx = self.build_context();

        let normalized = normalize_statement(statement, &mut ctx)?;
        let hamelin = normalized.ast.to_string();

        IRStatement::from_typed(normalized, &mut ctx).map_err(|e| {
            ContextualTranslationErrors::new(
                hamelin,
                (*e).clone().with_stage(Stage::Lowering).single(),
            )
        })
    }
}