hamelin_legacy 0.4.4

Legacy AST translation code for Hamelin (to be deprecated)
Documentation
use std::rc::Rc;

use antlr_rust::token::Token;

use hamelin_lib::{
    antlr::hamelinparser::{CastContext, CastContextAttrs, ExpressionContextAll},
    err::{InvalidCast, TranslationError, TranslationErrors},
    operator::Operator,
    sql::{
        expression::{
            apply::{FunctionCallApply, Lambda},
            identifier::SimpleIdentifier,
            literal::ColumnReference,
            Cast, SQLExpression, TryCast,
        },
        types::SQLBaseType,
    },
    translation::ExpressionTranslation,
    types::{
        array::Array,
        matcher::{BaseMatcher, Matcher},
        Type, CALENDAR_INTERVAL, INTERVAL, TIMESTAMP, UNKNOWN, VARIANT,
    },
};
use hamelin_sql::{
    range_builder::RangeBuilder,
    utils::{interval_range_to_timestamp_range, interval_to_timestamp, wrap_timestamp},
};

use crate::ast::{expression::HamelinExpression, ExpressionTranslationContext};

pub fn translate_hamelin_cast(
    ctx: &CastContext<'static>,
    expression_translation_context: Rc<ExpressionTranslationContext>,
) -> Result<ExpressionTranslation, TranslationErrors> {
    let hmln_expression =
        HamelinExpression::new(ctx.expression().unwrap(), expression_translation_context);
    let expression = hmln_expression.translate()?;
    let hamelintype = ctx.hamelintype().unwrap();
    let to = Type::from_parse_tree(hamelintype.as_ref()).map_err(|e| {
        TranslationError::msg(
            hamelintype.as_ref(),
            "cannot cast into this type (AS means *cast* in Hamelin)",
        )
        .with_source_boxed(e.into())
    })?;

    if expression.typ == to {
        return Ok(expression);
    }
    let from = expression.typ;

    let default_result: SQLExpression = TryCast::new(
        expression.sql.clone(),
        to.clone()
            .to_sql()
            .map_err(|e| TranslationError::wrap_box(hamelintype.as_ref(), e.into()).single())?,
    )
    .into();
    let timestamp: Type = TIMESTAMP.into();

    let res = match (from, &to) {
        (Type::Array(_), Type::Array(Array { element_type })) => {
            let inner_type = element_type.clone();
            Ok(FunctionCallApply::with_two(
                "transform",
                default_result,
                Lambda::new(
                    vec![SimpleIdentifier::new("e")],
                    TryCast::new(
                        ColumnReference::new(SimpleIdentifier::new("e").into()).into(),
                        inner_type
                            .to_sql()
                            .map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
                    )
                    .into(),
                )
                .into(),
            )
            .into())
        }
        (Type::Variant, Type::Array(Array { element_type })) => {
            let array_of_variant: Type = Array::new(VARIANT).into();
            Ok(FunctionCallApply::with_two(
                "transform",
                TryCast::new(
                    expression.sql.clone(),
                    array_of_variant
                        .to_sql()
                        .map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
                )
                .into(),
                Lambda::from_single_argument(
                    SimpleIdentifier::new("e"),
                    TryCast::new(
                        ColumnReference::new(SimpleIdentifier::new("e").into()).into(),
                        element_type
                            .clone()
                            .to_sql()
                            .map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
                    )
                    .into(),
                )
                .into(),
            )
            .into())
        }
        (Type::Variant, Type::Timestamp) => Ok(wrap_timestamp(
            Cast::new(expression.sql, SQLBaseType::VarChar.into()).into(),
        )),
        (Type::Interval | Type::CalendarInterval, Type::Range(range)) if *range.of == TIMESTAMP => {
            let builder =
                match hmln_expression.tree() {
                    ExpressionContextAll::UnaryPrefixOperatorContext(ctx)
                        if ctx.operator.clone().map(|op| op.get_text().to_string())
                            == Some(Operator::Minus.to_string()) =>
                    {
                        RangeBuilder::default()
                            .with_begin(
                                interval_to_timestamp(expression.sql.clone()),
                                timestamp.clone().to_sql().map_err(|e| {
                                    TranslationError::wrap_box(ctx, e.into()).single()
                                })?,
                            )
                            .with_end(
                                FunctionCallApply::with_no_arguments("now").into(),
                                timestamp.clone().to_sql().map_err(|e| {
                                    TranslationError::wrap_box(ctx, e.into()).single()
                                })?,
                            )
                    }
                    _ => {
                        RangeBuilder::default()
                            .with_begin(
                                FunctionCallApply::with_no_arguments("now").into(),
                                timestamp.clone().to_sql().map_err(|e| {
                                    TranslationError::wrap_box(ctx, e.into()).single()
                                })?,
                            )
                            .with_end(
                                interval_to_timestamp(expression.sql.clone()),
                                timestamp.clone().to_sql().map_err(|e| {
                                    TranslationError::wrap_box(ctx, e.into()).single()
                                })?,
                            )
                    }
                };

            Ok(builder.build())
        }
        (Type::Timestamp, Type::Range(range)) if *range.of == TIMESTAMP => {
            let timestamp: Type = TIMESTAMP.into();
            let rb = RangeBuilder::default()
                .with_begin(
                    expression.sql,
                    timestamp
                        .clone()
                        .to_sql()
                        .map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
                )
                .with_end(
                    FunctionCallApply::with_no_arguments("now").into(),
                    timestamp
                        .to_sql()
                        .map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
                )
                .build()
                .into();
            Ok(rb)
        }
        (Type::Range(from_range), Type::Range(to_range))
            if (*from_range.of == INTERVAL || *from_range.of == CALENDAR_INTERVAL)
                && *to_range.of == TIMESTAMP =>
        {
            Ok(interval_range_to_timestamp_range(expression.sql).build())
        }
        (Type::Range(from_range), Type::Range(_)) if *from_range.of == UNKNOWN => {
            Ok(default_result)
        }
        (Type::Tuple(t), Type::Struct(s))
            if t.elements
                .iter()
                .zip(s.fields.values())
                .all(|(l, r)| l == r) =>
        {
            Ok(default_result)
        }
        (_, Type::Variant) => Ok(default_result),
        (Type::Variant, _) => Ok(default_result),
        // Doing it explicitly since we don't allow casting to a map yet?
        (Type::Unknown, Type::Struct(_)) => Ok(default_result),
        (Type::Unknown, Type::Array(_)) => Ok(default_result),
        (Type::Unknown, Type::Map(_)) => Ok(default_result),
        (f, t) if BaseMatcher::default().matches(&f) && BaseMatcher::default().matches(&t) => {
            Ok(default_result)
        }
        (f, t) => TranslationError::wrap(ctx, InvalidCast::new(f, t.clone())).single_result(),
    };

    res.map(|sql| ExpressionTranslation::with_defaults(to, sql))
}