hamelin_translation 0.9.2

Lowering and IR for Hamelin query language
Documentation
//! Shared logic for lifting nested special-position function calls to top-level in a command
//! (used by WINDOW and AGG normalization passes).

use std::sync::Arc;

use hamelin_lib::func::def::ParameterBinding;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::ast::identifier::SimpleIdentifier;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::expression::{
    MapExpressionAlgebra, TypedApply, TypedExpression, TypedExpressionKind,
};

use crate::unique::UniqueNameGenerator;

use hamelin_lib::tree::builder::{field_ref, ExpressionBuilder};

/// Check if an expression is a special-position function call (agg, window, match, …).
pub fn is_special_position_call(expr: &TypedExpression) -> bool {
    matches!(&expr.kind, TypedExpressionKind::Apply(apply)
        if apply.function_def.special_position().is_some())
}

/// Check if an expression has a special-position call that is not at the top level.
/// Returns true if the expression contains such calls but is NOT itself one call at root.
pub fn has_nested_special_function(expr: &Arc<TypedExpression>) -> bool {
    if is_special_position_call(expr) {
        return false;
    }
    expr.find(&mut is_special_position_call).is_some()
}

/// Algebra for extracting special-position calls: each becomes a synthetic column ref.
pub struct ExtractSpecialFunctionsAlgebra<'a> {
    pub name_gen: &'a mut UniqueNameGenerator,
    pub schema: &'a TypeEnvironment,
    pub extractions: Vec<(SimpleIdentifier, Expression)>,
    pub synth_ids: Vec<SimpleIdentifier>,
}

impl MapExpressionAlgebra for ExtractSpecialFunctionsAlgebra<'_> {
    fn apply(
        &mut self,
        node: &TypedApply,
        expr: &TypedExpression,
        children: ParameterBinding<Arc<Expression>>,
    ) -> Arc<Expression> {
        if node.function_def.special_position().is_some() {
            let synth_id = self.name_gen.next(self.schema);
            self.extractions
                .push((synth_id.clone(), expr.ast.as_ref().clone()));
            self.synth_ids.push(synth_id.clone());
            Arc::new(field_ref(synth_id).build())
        } else {
            node.replace_children_ast(expr, children)
        }
    }
}