hamelin_sql 0.7.1

SQL generation utilities for Hamelin query language
Documentation
//! Translation implementations for aggregate functions

use crate::utils::direct_function_translation;
use crate::TranslationRegistry;
use hamelin_lib::{
    func::defs::{
        AggAll, AggAny, AggMax, AggMin, AnyValue, ApproxDistinct, ApproxPercentile, ArrayAgg, Avg,
        CountAny, CountDistinct, CountIf, CountStar, MapAgg, MultimapAgg, Stddev, Sum,
    },
    sql::expression::{apply::FunctionCallApply, Cast},
    sql::types::SQLBaseType,
    types::Type,
};

/// Register all aggregate function translations.
pub fn register(registry: &mut TranslationRegistry) {
    // count() - pass through
    registry.register::<CountStar>(direct_function_translation);

    // count(x) - pass through
    registry.register::<CountAny>(direct_function_translation);

    // count_distinct(x) - COUNT(DISTINCT x)
    registry.register::<CountDistinct>(|_, mut bindings| {
        let what = bindings.take()?;
        Ok(FunctionCallApply::with_one("count", what.sql)
            .with_distinct()
            .into())
    });

    // approx_distinct(x) - pass through
    registry.register::<ApproxDistinct>(direct_function_translation);

    // count_if(condition) - pass through
    registry.register::<CountIf>(direct_function_translation);

    // sum(x) - pass through
    registry.register::<Sum>(direct_function_translation);

    // avg(x) - cast integer input to double so all backends return double
    // (Trino does integer division for avg(int), DataFusion returns double)
    registry.register::<Avg>(|_, mut bindings| {
        let x = bindings.take()?;
        let arg = if x.typ == Type::Int {
            Cast::new(x.sql, SQLBaseType::Double.into()).into()
        } else {
            x.sql
        };
        Ok(FunctionCallApply::with_one("avg", arg).into())
    });

    // stddev(x) - pass through
    registry.register::<Stddev>(direct_function_translation);

    // approx_percentile(x, percentile) - pass through
    registry.register::<ApproxPercentile>(direct_function_translation);

    // min(x) - pass through
    registry.register::<AggMin>(direct_function_translation);

    // max(x) - pass through
    registry.register::<AggMax>(direct_function_translation);

    // any_value(x) - pass through
    registry.register::<AnyValue>(direct_function_translation);

    // array_agg(x) - pass through
    registry.register::<ArrayAgg>(direct_function_translation);

    // map_agg(key, value) - pass through
    registry.register::<MapAgg>(direct_function_translation);

    // multimap_agg(key, value) - pass through
    registry.register::<MultimapAgg>(direct_function_translation);

    // any(x) - translate to bool_or()
    registry.register::<AggAny>(|_, mut bindings| {
        let x = bindings.take()?;
        Ok(FunctionCallApply::with_one("bool_or", x.sql).into())
    });

    // all(x) - translate to bool_and()
    registry.register::<AggAll>(|_, mut bindings| {
        let x = bindings.take()?;
        Ok(FunctionCallApply::with_one("bool_and", x.sql).into())
    });
}