use std::rc::Rc;
use antlr_rust::token::Token;
use hamelin_lib::{
antlr::hamelinparser::{CastContext, CastContextAttrs, ExpressionContextAll},
err::{InvalidCast, TranslationError, TranslationErrors},
operator::Operator,
sql::{
expression::{
apply::{FunctionCallApply, Lambda},
identifier::SimpleIdentifier,
literal::ColumnReference,
Cast, SQLExpression, TryCast,
},
types::SQLBaseType,
},
translation::ExpressionTranslation,
types::{
array::Array,
matcher::{BaseMatcher, Matcher},
Type, CALENDAR_INTERVAL, INTERVAL, TIMESTAMP, UNKNOWN, VARIANT,
},
};
use hamelin_sql::{
range_builder::RangeBuilder,
utils::{interval_range_to_timestamp_range, interval_to_timestamp, wrap_timestamp},
};
use crate::ast::{expression::HamelinExpression, ExpressionTranslationContext};
pub fn translate_hamelin_cast(
ctx: &CastContext<'static>,
expression_translation_context: Rc<ExpressionTranslationContext>,
) -> Result<ExpressionTranslation, TranslationErrors> {
let hmln_expression =
HamelinExpression::new(ctx.expression().unwrap(), expression_translation_context);
let expression = hmln_expression.translate()?;
let hamelintype = ctx.hamelintype().unwrap();
let to = Type::from_parse_tree(hamelintype.as_ref()).map_err(|e| {
TranslationError::msg(
hamelintype.as_ref(),
"cannot cast into this type (AS means *cast* in Hamelin)",
)
.with_source_boxed(e.into())
})?;
if expression.typ == to {
return Ok(expression);
}
let from = expression.typ;
let default_result: SQLExpression = TryCast::new(
expression.sql.clone(),
to.clone()
.to_sql()
.map_err(|e| TranslationError::wrap_box(hamelintype.as_ref(), e.into()).single())?,
)
.into();
let timestamp: Type = TIMESTAMP.into();
let res = match (from, &to) {
(Type::Array(_), Type::Array(Array { element_type })) => {
let inner_type = element_type.clone();
Ok(FunctionCallApply::with_two(
"transform",
default_result,
Lambda::new(
vec![SimpleIdentifier::new("e")],
TryCast::new(
ColumnReference::new(SimpleIdentifier::new("e").into()).into(),
inner_type
.to_sql()
.map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
)
.into(),
)
.into(),
)
.into())
}
(Type::Variant, Type::Array(Array { element_type })) => {
let array_of_variant: Type = Array::new(VARIANT).into();
Ok(FunctionCallApply::with_two(
"transform",
TryCast::new(
expression.sql.clone(),
array_of_variant
.to_sql()
.map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
)
.into(),
Lambda::from_single_argument(
SimpleIdentifier::new("e"),
TryCast::new(
ColumnReference::new(SimpleIdentifier::new("e").into()).into(),
element_type
.clone()
.to_sql()
.map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
)
.into(),
)
.into(),
)
.into())
}
(Type::Variant, Type::Timestamp) => Ok(wrap_timestamp(
Cast::new(expression.sql, SQLBaseType::VarChar.into()).into(),
)),
(Type::Interval | Type::CalendarInterval, Type::Range(range)) if *range.of == TIMESTAMP => {
let builder =
match hmln_expression.tree() {
ExpressionContextAll::UnaryPrefixOperatorContext(ctx)
if ctx.operator.clone().map(|op| op.get_text().to_string())
== Some(Operator::Minus.to_string()) =>
{
RangeBuilder::default()
.with_begin(
interval_to_timestamp(expression.sql.clone()),
timestamp.clone().to_sql().map_err(|e| {
TranslationError::wrap_box(ctx, e.into()).single()
})?,
)
.with_end(
FunctionCallApply::with_no_arguments("now").into(),
timestamp.clone().to_sql().map_err(|e| {
TranslationError::wrap_box(ctx, e.into()).single()
})?,
)
}
_ => {
RangeBuilder::default()
.with_begin(
FunctionCallApply::with_no_arguments("now").into(),
timestamp.clone().to_sql().map_err(|e| {
TranslationError::wrap_box(ctx, e.into()).single()
})?,
)
.with_end(
interval_to_timestamp(expression.sql.clone()),
timestamp.clone().to_sql().map_err(|e| {
TranslationError::wrap_box(ctx, e.into()).single()
})?,
)
}
};
Ok(builder.build())
}
(Type::Timestamp, Type::Range(range)) if *range.of == TIMESTAMP => {
let timestamp: Type = TIMESTAMP.into();
let rb = RangeBuilder::default()
.with_begin(
expression.sql,
timestamp
.clone()
.to_sql()
.map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
)
.with_end(
FunctionCallApply::with_no_arguments("now").into(),
timestamp
.to_sql()
.map_err(|e| TranslationError::wrap_box(ctx, e.into()).single())?,
)
.build()
.into();
Ok(rb)
}
(Type::Range(from_range), Type::Range(to_range))
if (*from_range.of == INTERVAL || *from_range.of == CALENDAR_INTERVAL)
&& *to_range.of == TIMESTAMP =>
{
Ok(interval_range_to_timestamp_range(expression.sql).build())
}
(Type::Range(from_range), Type::Range(_)) if *from_range.of == UNKNOWN => {
Ok(default_result)
}
(Type::Tuple(t), Type::Struct(s))
if t.elements
.iter()
.zip(s.fields.values())
.all(|(l, r)| l == r) =>
{
Ok(default_result)
}
(_, Type::Variant) => Ok(default_result),
(Type::Variant, _) => Ok(default_result),
(Type::Unknown, Type::Struct(_)) => Ok(default_result),
(Type::Unknown, Type::Array(_)) => Ok(default_result),
(Type::Unknown, Type::Map(_)) => Ok(default_result),
(f, t) if BaseMatcher::default().matches(&f) && BaseMatcher::default().matches(&t) => {
Ok(default_result)
}
(f, t) => TranslationError::wrap(ctx, InvalidCast::new(f, t.clone())).single_result(),
};
res.map(|sql| ExpressionTranslation::with_defaults(to, sql))
}