hamelin_sql 0.3.10

SQL generation utilities for Hamelin query language
Documentation
//! SQL translations for membership operators

use anyhow::bail;

use hamelin_lib::func::defs::{
    InArray, InMap, InRange, InTimestampInterval, InTimestampTimestamp, InTuple, NotInArray,
    NotInMap, NotInRange, NotInTimestampInterval, NotInTimestampTimestamp, NotInTuple,
    NumericRange, TimestampRange,
};
use hamelin_lib::sql::expression::apply::{
    BinaryOperatorApply, FunctionCallApply, UnaryOperatorApply,
};
use hamelin_lib::sql::expression::literal::TupleLiteral;
use hamelin_lib::sql::expression::operator::Operator;
use hamelin_lib::sql::expression::SQLExpression;
use hamelin_lib::types::Type;

use crate::range_builder::RangeBuilder;
use crate::utils::{
    interval_range_to_timestamp_range, interval_to_range, interval_to_timestamp,
    timestamp_to_range, within_range, within_range_expr,
};
use crate::TranslationRegistry;

/// Register all membership operator translations.
pub fn register(registry: &mut TranslationRegistry) {
    // InArray: left IN right (array) -> contains(right, left)
    registry.register::<InArray>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        Ok(FunctionCallApply::with_two("contains", right.sql, left.sql).into())
    });

    // NotInArray: left NOT IN right (array) -> NOT contains(right, left)
    registry.register::<NotInArray>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        let res = UnaryOperatorApply::new(
            Operator::Not,
            FunctionCallApply::with_two("contains", right.sql, left.sql).into(),
        );
        Ok(res.into())
    });

    // InMap: left IN right (map) -> contains(map_keys(right), left)
    registry.register::<InMap>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        Ok(FunctionCallApply::with_two(
            "contains",
            FunctionCallApply::with_one("map_keys", right.sql).into(),
            left.sql,
        )
        .into())
    });

    // NotInMap: left NOT IN right (map) -> NOT contains(map_keys(right), left)
    registry.register::<NotInMap>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        let res = UnaryOperatorApply::new(
            Operator::Not,
            FunctionCallApply::with_two(
                "contains",
                FunctionCallApply::with_one("map_keys", right.sql).into(),
                left.sql,
            )
            .into(),
        );
        Ok(res.into())
    });

    // NumericRange: left..right -> RANGE(left, right)
    registry.register::<NumericRange>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        let typ = left.typ.merge(right.typ.clone())?;
        let rb = RangeBuilder::default()
            .with_begin(left.sql, typ.clone().to_sql()?)
            .with_end(right.sql, typ.to_sql()?)
            .build();
        Ok(rb)
    });

    // TimestampRange: left..right -> RANGE(left, right) with timestamp/interval handling
    registry.register::<TimestampRange>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;

        let typ = match (&left.typ, &right.typ) {
            (Type::Interval, Type::Timestamp) => Type::Timestamp,
            (Type::CalendarInterval, Type::Timestamp) => Type::Timestamp,
            (Type::Timestamp, Type::Interval) => Type::Timestamp,
            (Type::Timestamp, Type::CalendarInterval) => Type::Timestamp,
            (t, _) => t.clone(),
        };

        let (left_sql, right_sql) = match (&left.typ, &right.typ) {
            (Type::Interval, Type::Timestamp) => (interval_to_timestamp(left.sql), right.sql),
            (Type::CalendarInterval, Type::Timestamp) => {
                (interval_to_timestamp(left.sql), right.sql)
            }
            (Type::Timestamp, Type::Interval) => (left.sql, interval_to_timestamp(right.sql)),
            (Type::Timestamp, Type::CalendarInterval) => {
                (left.sql, interval_to_timestamp(right.sql))
            }
            _ => (left.sql, right.sql),
        };

        let rb = RangeBuilder::default()
            .with_begin(left_sql, typ.clone().to_sql()?)
            .with_end(right_sql, typ.to_sql()?)
            .build();

        Ok(rb)
    });

    // InRange: left IN right (range) -> within_range(left, right)
    registry.register::<InRange>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        match right.typ {
            Type::Range(range) if *range.of == Type::Interval => Ok(within_range(
                left.sql,
                interval_range_to_timestamp_range(right.sql),
            )),
            _ => Ok(RangeBuilder::from_literal(right.sql.clone())
                .map(|rb| within_range(left.sql.clone(), rb))
                .unwrap_or_else(|| within_range_expr(left.sql, right.sql))),
        }
    });

    // NotInRange: left NOT IN right (range) -> NOT within_range(left, right)
    registry.register::<NotInRange>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        let in_result = match right.typ {
            Type::Range(range) if *range.of == Type::Interval => within_range(
                left.sql.clone(),
                interval_range_to_timestamp_range(right.sql),
            ),
            _ => RangeBuilder::from_literal(right.sql.clone())
                .map(|rb| within_range(left.sql.clone(), rb))
                .unwrap_or_else(|| within_range_expr(left.sql.clone(), right.sql)),
        };
        let res = UnaryOperatorApply::new(Operator::Not, in_result);
        Ok(res.into())
    });

    // InTimestampInterval: timestamp IN interval -> within_range(timestamp, interval_to_range(interval))
    registry.register::<InTimestampInterval>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        Ok(within_range(left.sql, interval_to_range(right.sql)))
    });

    // NotInTimestampInterval: timestamp NOT IN interval -> NOT within_range(...)
    registry.register::<NotInTimestampInterval>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        let res = UnaryOperatorApply::new(
            Operator::Not,
            within_range(left.sql, interval_to_range(right.sql)),
        );
        Ok(res.into())
    });

    // InTimestampTimestamp: timestamp IN timestamp -> within_range(left, timestamp_to_range(right))
    registry.register::<InTimestampTimestamp>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        Ok(within_range(left.sql, timestamp_to_range(right.sql)))
    });

    // NotInTimestampTimestamp: timestamp NOT IN timestamp -> NOT within_range(...)
    registry.register::<NotInTimestampTimestamp>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        let res = UnaryOperatorApply::new(
            Operator::Not,
            within_range(left.sql, timestamp_to_range(right.sql)),
        );
        Ok(res.into())
    });

    // InTuple: left IN (a, b, c) -> left IN (a, b, c)
    registry.register::<InTuple>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        match right.sql {
            SQLExpression::RowLiteral(row) => Ok(BinaryOperatorApply::new(
                Operator::In,
                left.sql,
                SQLExpression::TupleLiteral(TupleLiteral::new(row.values)),
            )
            .into()),
            x => bail!("found unexpected expression {x:?}"),
        }
    });

    // NotInTuple: left NOT IN (a, b, c) -> NOT (left IN (a, b, c))
    registry.register::<NotInTuple>(|_, mut bindings| {
        let left = bindings.take()?;
        let right = bindings.take()?;
        match right.sql {
            SQLExpression::RowLiteral(row) => {
                let res = UnaryOperatorApply::new(
                    Operator::Not,
                    BinaryOperatorApply::new(
                        Operator::In,
                        left.sql,
                        SQLExpression::TupleLiteral(TupleLiteral::new(row.values)),
                    )
                    .into(),
                );
                Ok(res.into())
            }
            x => bail!("found unexpected expression {x:?}"),
        }
    });
}