use std::sync::Arc;
use crate::udf::{cast_to_variant_udf, from_variant_udf, variant_get_udf};
use datafusion::common::ScalarValue;
use datafusion::logical_expr::{ident, lit, BinaryExpr, Cast, Expr, Operator as DFOperator};
use datafusion_functions::core::expr_fn as core_fn;
use datafusion_functions::datetime::expr_fn as datetime_fn;
use datafusion_functions_nested::expr_fn as array_fn;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::expression::{
ExpressionKind, IntervalLiteral, IntervalUnit, TruncUnit,
};
use hamelin_lib::tree::ast::identifier::SimpleIdentifier;
use hamelin_lib::tree::ast::node::Spannable;
use hamelin_lib::tree::typed_ast::expression::{
CastKind, FieldAccess, TypedApply, TypedArrayLiteral, TypedCast, TypedExpression,
TypedExpressionKind, TypedFieldLookup, TypedFieldReference, TypedStructLiteral, TypedTsTrunc,
TypedTupleLiteral, TypedVariantIndexAccess,
};
use hamelin_lib::types::Type;
use crate::func::{DFTranslation, DataFusionTranslationRegistry};
use crate::struct_expansion::{hamelin_type_to_arrow, typed_null_scalar};
use crate::udf::{array_cast_udf, CastDescriptor};
pub struct ExprTranslationContext {
pub registry: Arc<DataFusionTranslationRegistry>,
}
impl Default for ExprTranslationContext {
fn default() -> Self {
Self {
registry: Arc::new(DataFusionTranslationRegistry::default()),
}
}
}
pub fn translate_expr_with_ctx(
expr: &TypedExpression,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
match &expr.kind {
TypedExpressionKind::Leaf => translate_leaf(expr),
TypedExpressionKind::FieldReference(col_ref) => translate_field_reference(col_ref),
TypedExpressionKind::ArrayLiteral(arr_lit) => {
translate_array_literal(arr_lit, &expr.resolved_type, ctx)
}
TypedExpressionKind::TupleLiteral(tuple_lit) => translate_tuple_literal(tuple_lit, ctx),
TypedExpressionKind::StructLiteral(struct_lit) => translate_struct_literal(struct_lit, ctx),
TypedExpressionKind::Apply(apply) => translate_apply(apply, expr, ctx),
TypedExpressionKind::BroadcastApply(_) => Err(TranslationError::msg(
expr,
"BroadcastApply should be lowered before DataFusion translation",
)
.into()),
TypedExpressionKind::VariantIndexAccess(via) => {
translate_variant_index_access(via, expr, ctx)
}
TypedExpressionKind::FieldLookup(field_lookup) => {
translate_field_lookup(field_lookup, expr, ctx)
}
TypedExpressionKind::Cast(cast) => translate_cast(cast, expr, ctx),
TypedExpressionKind::TsTrunc(ts_trunc) => translate_ts_trunc(ts_trunc, expr, ctx),
TypedExpressionKind::Lambda(_) => Err(TranslationError::msg(
expr,
"Lambda expressions should be lowered before DataFusion translation",
)
.into()),
TypedExpressionKind::Error(err) => Err((*err.error).clone().into()),
}
}
#[cfg(test)]
fn translate_expr(expr: &TypedExpression) -> Result<Expr, Arc<TranslationError>> {
translate_expr_with_ctx(expr, &ExprTranslationContext::default())
}
fn translate_leaf(expr: &TypedExpression) -> Result<Expr, Arc<TranslationError>> {
match &expr.ast.kind {
ExpressionKind::IntLiteral(int_lit) => Ok(lit(ScalarValue::Int64(Some(int_lit.int)))),
ExpressionKind::ScientificLiteral(sci_lit) => {
Ok(lit(ScalarValue::Float64(Some(sci_lit.value))))
}
ExpressionKind::BooleanLiteral(bool_lit) => {
Ok(lit(ScalarValue::Boolean(Some(bool_lit.value))))
}
ExpressionKind::StringLiteral(str_lit) => {
Ok(lit(ScalarValue::Utf8(Some(str_lit.value.clone()))))
}
ExpressionKind::NullLiteral(_) => Ok(lit(typed_null(&expr.resolved_type))),
ExpressionKind::DecimalLiteral(dec_lit) => {
Ok(lit(ScalarValue::Decimal128(
Some(dec_lit.unscaled_value),
dec_lit.precision as u8,
dec_lit.scale as i8,
)))
}
ExpressionKind::IntervalLiteral(interval_lit) => translate_interval(interval_lit),
ExpressionKind::BinaryLiteral(bin_lit) => {
Ok(lit(ScalarValue::Binary(Some(bin_lit.value.clone()))))
}
ExpressionKind::ArrayLiteral(_)
| ExpressionKind::TupleLiteral(_)
| ExpressionKind::PairLiteral(_)
| ExpressionKind::StructLiteral(_) => Err(TranslationError::msg(
expr,
"Complex literals should not be Leaf expressions",
)
.into()),
ExpressionKind::RowsLiteral(_) => Err(TranslationError::msg(
expr,
"ROWS literals cannot be translated to expressions",
)
.into()),
ExpressionKind::UnboundRangeLiteral(_) => {
Ok(datafusion_functions::core::expr_fn::named_struct(vec![
lit("begin"),
lit(ScalarValue::Null),
lit("end"),
lit(ScalarValue::Null),
]))
}
ExpressionKind::FieldReference(_)
| ExpressionKind::UnaryPrefixOperator(_)
| ExpressionKind::UnaryPostfixOperator(_)
| ExpressionKind::BinaryOperator(_)
| ExpressionKind::FunctionCall(_)
| ExpressionKind::IndexAccess(_)
| ExpressionKind::FieldLookup(_)
| ExpressionKind::Cast(_)
| ExpressionKind::TsTrunc(_)
| ExpressionKind::Lambda(_)
| ExpressionKind::Error(_) => {
Err(TranslationError::msg(expr, "Unexpected AST kind for Leaf expression").into())
}
}
}
fn translate_interval(interval: &IntervalLiteral) -> Result<Expr, Arc<TranslationError>> {
match interval.unit {
IntervalUnit::Nanosecond => {
let total_millis = interval.value / 1_000_000;
Ok(lit(interval_day_time_from_millis(total_millis)))
}
IntervalUnit::Microsecond => {
let total_millis = interval.value / 1_000;
Ok(lit(interval_day_time_from_millis(total_millis)))
}
IntervalUnit::Millisecond => Ok(lit(interval_day_time_from_millis(interval.value))),
IntervalUnit::Second => {
let total_millis = interval.value * 1_000;
Ok(lit(interval_day_time_from_millis(total_millis)))
}
IntervalUnit::Minute => {
let total_millis = interval.value * 60_000;
Ok(lit(interval_day_time_from_millis(total_millis)))
}
IntervalUnit::Hour => {
let total_millis = interval.value * 3_600_000;
Ok(lit(interval_day_time_from_millis(total_millis)))
}
IntervalUnit::Day => Ok(lit(ScalarValue::IntervalDayTime(Some(
datafusion::arrow::datatypes::IntervalDayTime::new(interval.value as i32, 0),
)))),
IntervalUnit::Week => {
let days = interval.value * 7;
Ok(lit(ScalarValue::IntervalDayTime(Some(
datafusion::arrow::datatypes::IntervalDayTime::new(days as i32, 0),
))))
}
IntervalUnit::Month => {
Ok(lit(ScalarValue::IntervalYearMonth(Some(
interval.value as i32,
))))
}
IntervalUnit::Quarter => {
let months = interval.value * 3;
Ok(lit(ScalarValue::IntervalYearMonth(Some(months as i32))))
}
IntervalUnit::Year => {
let months = interval.value * 12;
Ok(lit(ScalarValue::IntervalYearMonth(Some(months as i32))))
}
}
}
const MILLIS_PER_DAY: i64 = 86_400_000;
fn interval_day_time_from_millis(total_millis: i64) -> ScalarValue {
let days = (total_millis / MILLIS_PER_DAY) as i32;
let millis = (total_millis % MILLIS_PER_DAY) as i32;
ScalarValue::IntervalDayTime(Some(datafusion::arrow::datatypes::IntervalDayTime::new(
days, millis,
)))
}
fn translate_field_reference(col_ref: &TypedFieldReference) -> Result<Expr, Arc<TranslationError>> {
let id = col_ref.field_name.valid_ref()?;
Ok(ident(id.as_str()))
}
fn translate_tuple_literal(
tuple_lit: &TypedTupleLiteral,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let elements: Result<Vec<Expr>, Arc<TranslationError>> = tuple_lit
.elements
.iter()
.map(|elem| translate_expr_with_ctx(elem, ctx))
.collect();
Ok(core_fn::r#struct(elements?))
}
fn translate_array_literal(
arr_lit: &TypedArrayLiteral,
_resolved_type: &Type,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let elements: Result<Vec<Expr>, Arc<TranslationError>> = arr_lit
.elements
.iter()
.map(|elem| translate_expr_with_ctx(elem, ctx))
.collect();
Ok(array_fn::make_array(elements?))
}
fn translate_struct_literal(
struct_lit: &TypedStructLiteral,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let mut args = Vec::with_capacity(struct_lit.fields.len() * 2);
for (name, expr) in &struct_lit.fields {
let field_name = match name {
hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier::Valid(id) => {
id.as_str().to_string()
}
hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier::Error(_) => {
continue;
}
};
args.push(lit(field_name));
args.push(translate_expr_with_ctx(expr, ctx)?);
}
Ok(core_fn::named_struct(args))
}
fn translate_apply(
apply: &TypedApply,
span: &impl Spannable,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let df_binding = apply.parameter_binding.clone().try_map(|expr| {
let df_expr = translate_expr_with_ctx(&expr, ctx)?;
let typ = expr.resolved_type.clone();
Ok::<_, Arc<TranslationError>>(DFTranslation::new(df_expr, typ))
})?;
ctx.registry
.translate(apply.function_def.as_ref(), df_binding)
.map_err(|e| Arc::new(TranslationError::wrap(span, e)))
}
fn translate_field_lookup(
field_lookup: &TypedFieldLookup,
span: &impl Spannable,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
match &field_lookup.access {
FieldAccess::StructField(field_name) => {
let base_expr = translate_expr_with_ctx(&field_lookup.value, ctx)?;
let name = match field_name {
hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier::Valid(id) => {
id.as_str().to_string()
}
hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier::Error(_) => {
return Err(Arc::new(TranslationError::msg(
span,
"Invalid field name in field lookup",
)));
}
};
Ok(core_fn::get_field(base_expr, name))
}
FieldAccess::TupleElement(index) => {
let base_expr = translate_expr_with_ctx(&field_lookup.value, ctx)?;
let field_name = format!("c{}", index);
Ok(core_fn::get_field(base_expr, field_name))
}
FieldAccess::VariantField(field_name) => {
let name = match field_name {
hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier::Valid(id) => {
id.as_str().to_string()
}
hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier::Error(_) => {
return Err(Arc::new(TranslationError::msg(
span,
"Invalid field name in variant field lookup",
)));
}
};
let (base_expr, path) =
collect_variant_path(&field_lookup.value, vec![PathSegment::Field(name)]);
let base_df_expr = translate_expr_with_ctx(base_expr, ctx)?;
let path_string = build_path_string(&path);
Ok(variant_get_udf().call(vec![base_df_expr, lit(path_string)]))
}
FieldAccess::RangeBegin => {
let base_expr = translate_expr_with_ctx(&field_lookup.value, ctx)?;
Ok(core_fn::get_field(base_expr, "begin"))
}
FieldAccess::RangeEnd => {
let base_expr = translate_expr_with_ctx(&field_lookup.value, ctx)?;
Ok(core_fn::get_field(base_expr, "end"))
}
}
}
fn translate_variant_index_access(
via: &TypedVariantIndexAccess,
_span: &impl Spannable,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let (base_expr, path) =
collect_variant_path(&via.value, vec![PathSegment::Index(via.variant_index)]);
let base_df_expr = translate_expr_with_ctx(base_expr, ctx)?;
let path_string = build_path_string(&path);
Ok(variant_get_udf().call(vec![base_df_expr, lit(path_string)]))
}
enum PathSegment {
Field(String),
Index(usize),
}
fn collect_variant_path(
expr: &TypedExpression,
mut path_segments: Vec<PathSegment>,
) -> (&TypedExpression, Vec<PathSegment>) {
match &expr.kind {
TypedExpressionKind::FieldLookup(field_lookup) => {
if let FieldAccess::VariantField(field_name) = &field_lookup.access {
if let hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier::Valid(id) =
field_name
{
path_segments.insert(0, PathSegment::Field(id.as_str().to_string()));
return collect_variant_path(&field_lookup.value, path_segments);
}
}
(expr, path_segments)
}
TypedExpressionKind::VariantIndexAccess(via) => {
path_segments.insert(0, PathSegment::Index(via.variant_index));
collect_variant_path(&via.value, path_segments)
}
_ => (expr, path_segments),
}
}
fn build_path_string(segments: &[PathSegment]) -> String {
let mut result = String::new();
for (i, segment) in segments.iter().enumerate() {
match segment {
PathSegment::Field(name) => {
if i > 0 && !result.ends_with(']') {
result.push('.');
} else if i > 0 && result.ends_with(']') {
result.push('.');
}
result.push_str(name);
}
PathSegment::Index(idx) => {
result.push('[');
result.push_str(&idx.to_string());
result.push(']');
}
}
}
result
}
fn translate_cast(
cast: &TypedCast,
span: &impl Spannable,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let target_type = hamelin_type_to_arrow(&cast.target_type);
match &cast.cast_kind {
CastKind::Identity => translate_expr_with_ctx(&cast.value, ctx),
CastKind::NullToType => Ok(lit(typed_null_scalar(&cast.target_type))),
CastKind::IntToDouble
| CastKind::IntToDecimal
| CastKind::DoubleToInt
| CastKind::DoubleToDecimal
| CastKind::DecimalToInt
| CastKind::DecimalToDouble
| CastKind::DecimalToDecimal => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(Expr::Cast(Cast::new(Box::new(value_expr), target_type)))
}
CastKind::IntToBoolean | CastKind::BooleanToInt => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(Expr::Cast(Cast::new(Box::new(value_expr), target_type)))
}
CastKind::ToStringFromInt
| CastKind::ToStringFromDouble
| CastKind::ToStringFromBoolean
| CastKind::ToStringFromTimestamp
| CastKind::ToStringFromBinary
| CastKind::ToStringFromDecimal
| CastKind::ToStringFromInterval
| CastKind::ToStringFromCalendarInterval => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(Expr::Cast(Cast::new(Box::new(value_expr), target_type)))
}
CastKind::StringToInt
| CastKind::StringToDouble
| CastKind::StringToBoolean
| CastKind::StringToTimestamp
| CastKind::StringToDecimal => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(Expr::Cast(Cast::new(Box::new(value_expr), target_type)))
}
CastKind::ToVariant(_) => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(cast_to_variant_udf().call(vec![value_expr]))
}
CastKind::FromVariant(_) => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(from_variant_udf(target_type).call(vec![value_expr]))
}
CastKind::ArrayElementCast(inner) => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
if needs_array_cast_udf(inner) {
let descriptor = cast_kind_to_descriptor(&cast.cast_kind, &cast.target_type, span)?;
let udf = array_cast_udf(target_type.clone(), descriptor);
Ok(udf.call(vec![value_expr]))
} else {
Ok(Expr::Cast(Cast::new(Box::new(value_expr), target_type)))
}
}
CastKind::TupleToStruct(_) => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(Expr::Cast(Cast::new(Box::new(value_expr), target_type)))
}
CastKind::RangeElementCast(_) => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(Expr::Cast(Cast::new(Box::new(value_expr), target_type)))
}
CastKind::IntervalToTimestampRange => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
let now = datetime_fn::now();
let end = Expr::BinaryExpr(BinaryExpr::new(
Box::new(now.clone()),
DFOperator::Plus,
Box::new(value_expr),
));
Ok(core_fn::named_struct(vec![
lit("begin"),
now,
lit("end"),
end,
]))
}
CastKind::TimestampToTimestampRange => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
Ok(core_fn::named_struct(vec![
lit("begin"),
value_expr,
lit("end"),
datetime_fn::now(),
]))
}
CastKind::IntervalRangeToTimestampRange => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
let begin_interval = core_fn::get_field(value_expr.clone(), "begin");
let end_interval = core_fn::get_field(value_expr, "end");
let now = datetime_fn::now();
let begin_ts = Expr::BinaryExpr(BinaryExpr::new(
Box::new(now.clone()),
DFOperator::Plus,
Box::new(begin_interval),
));
let end_ts = Expr::BinaryExpr(BinaryExpr::new(
Box::new(now),
DFOperator::Plus,
Box::new(end_interval),
));
Ok(core_fn::named_struct(vec![
lit("begin"),
begin_ts,
lit("end"),
end_ts,
]))
}
CastKind::StructExpansion(field_casts) => {
let value_expr = translate_expr_with_ctx(&cast.value, ctx)?;
translate_struct_expansion(value_expr, field_casts, &cast.target_type, span)
}
}
}
fn translate_struct_expansion(
source_expr: Expr,
field_casts: &[(SimpleIdentifier, CastKind)],
target_type: &Type,
span: &impl Spannable,
) -> Result<Expr, Arc<TranslationError>> {
let target_struct = match target_type {
Type::Struct(s) => s,
_ => {
return Err(Arc::new(TranslationError::msg(
span,
&format!(
"StructExpansion target type is not a struct: {}",
target_type
),
)))
}
};
let mut args = Vec::new();
for (field_name, cast_kind) in field_casts {
let field_type = target_struct.lookup(field_name).ok_or_else(|| {
Arc::new(TranslationError::msg(
span,
&format!("Field '{}' not found in target struct", field_name.name()),
))
})?;
args.push(lit(field_name.name()));
let field_value = match cast_kind {
CastKind::NullToType => {
lit(typed_null_scalar(field_type))
}
CastKind::Identity => {
core_fn::get_field(source_expr.clone(), field_name.name())
}
CastKind::StructExpansion(nested_field_casts) => {
let field_ref = core_fn::get_field(source_expr.clone(), field_name.name());
translate_struct_expansion(field_ref, nested_field_casts, field_type, span)?
}
CastKind::ArrayElementCast(inner) if needs_array_cast_udf(inner) => {
let field_ref = core_fn::get_field(source_expr.clone(), field_name.name());
let descriptor = cast_kind_to_descriptor(cast_kind, field_type, span)?;
let arrow_type = hamelin_type_to_arrow(field_type);
let udf = array_cast_udf(arrow_type, descriptor);
udf.call(vec![field_ref])
}
CastKind::ToVariant(_) => {
let field_ref = core_fn::get_field(source_expr.clone(), field_name.name());
cast_to_variant_udf().call(vec![field_ref])
}
CastKind::FromVariant(_) => {
let field_ref = core_fn::get_field(source_expr.clone(), field_name.name());
let arrow_type = hamelin_type_to_arrow(field_type);
from_variant_udf(arrow_type).call(vec![field_ref])
}
_ => {
let field_ref = core_fn::get_field(source_expr.clone(), field_name.name());
let arrow_type = hamelin_type_to_arrow(field_type);
Expr::Cast(Cast::new(Box::new(field_ref), arrow_type))
}
};
args.push(field_value);
}
Ok(core_fn::named_struct(args))
}
fn translate_ts_trunc(
ts_trunc: &TypedTsTrunc,
_span: &impl Spannable,
ctx: &ExprTranslationContext,
) -> Result<Expr, Arc<TranslationError>> {
let timestamp_expr = translate_expr_with_ctx(&ts_trunc.expression, ctx)?;
let precision = match ts_trunc.unit {
TruncUnit::Second => "second",
TruncUnit::Minute => "minute",
TruncUnit::Hour => "hour",
TruncUnit::Day => "day",
TruncUnit::Week => "week",
TruncUnit::Month => "month",
TruncUnit::Quarter => "quarter",
TruncUnit::Year => "year",
};
Ok(datetime_fn::date_trunc(lit(precision), timestamp_expr))
}
fn typed_null(hamelin_type: &Type) -> ScalarValue {
typed_null_scalar(hamelin_type)
}
fn needs_array_cast_udf(cast_kind: &CastKind) -> bool {
match cast_kind {
CastKind::StructExpansion(_) => true,
CastKind::ToVariant(_) | CastKind::FromVariant(_) => true,
CastKind::ArrayElementCast(inner) => needs_array_cast_udf(inner),
CastKind::RangeElementCast(inner) => needs_array_cast_udf(inner),
CastKind::Identity
| CastKind::NullToType
| CastKind::IntToDouble
| CastKind::IntToDecimal
| CastKind::DoubleToInt
| CastKind::DoubleToDecimal
| CastKind::DecimalToInt
| CastKind::DecimalToDouble
| CastKind::DecimalToDecimal
| CastKind::IntToBoolean
| CastKind::BooleanToInt
| CastKind::ToStringFromInt
| CastKind::ToStringFromDouble
| CastKind::ToStringFromBoolean
| CastKind::ToStringFromTimestamp
| CastKind::ToStringFromBinary
| CastKind::ToStringFromDecimal
| CastKind::ToStringFromInterval
| CastKind::ToStringFromCalendarInterval
| CastKind::StringToInt
| CastKind::StringToDouble
| CastKind::StringToBoolean
| CastKind::StringToTimestamp
| CastKind::StringToDecimal
| CastKind::TupleToStruct(_)
| CastKind::IntervalToTimestampRange
| CastKind::TimestampToTimestampRange
| CastKind::IntervalRangeToTimestampRange => false,
}
}
fn cast_kind_to_descriptor(
cast_kind: &CastKind,
target_type: &Type,
span: &impl Spannable,
) -> Result<CastDescriptor, Arc<TranslationError>> {
Ok(match cast_kind {
CastKind::Identity => CastDescriptor::Identity,
CastKind::NullToType => CastDescriptor::NullToType,
CastKind::IntToDouble
| CastKind::IntToDecimal
| CastKind::DoubleToInt
| CastKind::DoubleToDecimal
| CastKind::DecimalToInt
| CastKind::DecimalToDouble
| CastKind::DecimalToDecimal
| CastKind::IntToBoolean
| CastKind::BooleanToInt
| CastKind::ToStringFromInt
| CastKind::ToStringFromDouble
| CastKind::ToStringFromBoolean
| CastKind::ToStringFromTimestamp
| CastKind::ToStringFromBinary
| CastKind::ToStringFromDecimal
| CastKind::ToStringFromInterval
| CastKind::ToStringFromCalendarInterval
| CastKind::StringToInt
| CastKind::StringToDouble
| CastKind::StringToBoolean
| CastKind::StringToTimestamp
| CastKind::StringToDecimal => CastDescriptor::ArrowCast,
CastKind::ToVariant(_) => CastDescriptor::ToVariant,
CastKind::FromVariant(_) => {
let arrow_type = hamelin_type_to_arrow(target_type);
CastDescriptor::FromVariant(arrow_type)
}
CastKind::TupleToStruct(_) => CastDescriptor::ArrowCast,
CastKind::ArrayElementCast(inner) => {
let inner_type = match target_type {
Type::Array(arr) => arr.element_type.as_ref(),
_ => target_type,
};
CastDescriptor::ArrayElementCast(Box::new(cast_kind_to_descriptor(
inner, inner_type, span,
)?))
}
CastKind::StructExpansion(field_casts) => {
let target_struct = match target_type {
Type::Struct(s) => s,
_ => {
return Err(Arc::new(TranslationError::msg(
span,
&format!(
"StructExpansion target type is not a struct: {:?}",
target_type
),
)))
}
};
let field_descriptors = field_casts
.iter()
.map(|(name, field_cast)| {
let field_type = target_struct.lookup(name).ok_or_else(|| {
Arc::new(TranslationError::msg(
span,
&format!("Field '{}' not found in target struct", name.name()),
))
})?;
let arrow_type = hamelin_type_to_arrow(field_type);
let descriptor = cast_kind_to_descriptor(field_cast, field_type, span)?;
Ok((name.name().to_owned(), arrow_type, descriptor))
})
.collect::<Result<Vec<_>, Arc<TranslationError>>>()?;
CastDescriptor::StructExpansion(field_descriptors)
}
CastKind::RangeElementCast(inner) => {
let inner_type = match target_type {
Type::Range(range) => range.of.as_ref(),
_ => target_type,
};
CastDescriptor::RangeElementCast(Box::new(cast_kind_to_descriptor(
inner, inner_type, span,
)?))
}
CastKind::IntervalToTimestampRange
| CastKind::TimestampToTimestampRange
| CastKind::IntervalRangeToTimestampRange => CastDescriptor::ArrowCast,
})
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::ast::ParseWithErrors;
use hamelin_lib::tree::options::ExpressionTypeCheckOptions;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::type_check_expression;
fn parse_expr(s: &str) -> TypedExpression {
type_check_expression(
Expression::parse(s),
ExpressionTypeCheckOptions::builder().build(),
)
.output
}
fn parse_expr_with_bindings(s: &str, bindings: Arc<TypeEnvironment>) -> TypedExpression {
type_check_expression(
Expression::parse(s),
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output
}
#[test]
fn test_int_literal() {
let expr = parse_expr("42");
let result = translate_expr(&expr).unwrap();
assert_eq!(result, lit(ScalarValue::Int64(Some(42))));
}
#[test]
fn test_negative_int_literal() {
use datafusion::logical_expr::Expr as DFExpr;
let expr = parse_expr("-123");
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::Negative(_)));
}
#[test]
fn test_float_literal() {
let expr = parse_expr("3.14e0");
let result = translate_expr(&expr).unwrap();
assert_eq!(result, lit(ScalarValue::Float64(Some(3.14))));
}
#[test]
fn test_boolean_literal() {
let true_expr = parse_expr("true");
let false_expr = parse_expr("false");
assert_eq!(
translate_expr(&true_expr).unwrap(),
lit(ScalarValue::Boolean(Some(true)))
);
assert_eq!(
translate_expr(&false_expr).unwrap(),
lit(ScalarValue::Boolean(Some(false)))
);
}
#[test]
fn test_string_literal() {
let expr = parse_expr("'hello world'");
let result = translate_expr(&expr).unwrap();
assert_eq!(
result,
lit(ScalarValue::Utf8(Some("hello world".to_string())))
);
}
#[test]
fn test_null_literal() {
let expr = parse_expr("null");
let result = translate_expr(&expr).unwrap();
assert_eq!(result, lit(ScalarValue::Null));
}
#[test]
fn test_field_reference() {
let mut bindings = TypeEnvironment::default();
bindings.bind("my_column".into(), hamelin_lib::types::INT);
let expr = parse_expr_with_bindings("my_column", Arc::new(bindings));
let result = translate_expr(&expr).unwrap();
assert_eq!(result, ident("my_column"));
}
#[test]
fn test_arithmetic_operators() {
use datafusion::logical_expr::Expr as DFExpr;
let mut bindings = TypeEnvironment::default();
bindings.bind("a".into(), hamelin_lib::types::INT);
bindings.bind("b".into(), hamelin_lib::types::INT);
let bindings = Arc::new(bindings);
let expr = parse_expr_with_bindings("a + b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a - b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a * b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a / b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a % b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
}
#[test]
fn test_comparison_operators() {
use datafusion::logical_expr::Expr as DFExpr;
let mut bindings = TypeEnvironment::default();
bindings.bind("a".into(), hamelin_lib::types::INT);
bindings.bind("b".into(), hamelin_lib::types::INT);
let bindings = Arc::new(bindings);
let expr = parse_expr_with_bindings("a == b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a != b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a < b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a <= b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a > b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("a >= b", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
}
#[test]
fn test_logical_operators() {
use datafusion::logical_expr::Expr as DFExpr;
let mut bindings = TypeEnvironment::default();
bindings.bind("p".into(), hamelin_lib::types::BOOLEAN);
bindings.bind("q".into(), hamelin_lib::types::BOOLEAN);
let bindings = Arc::new(bindings);
let expr = parse_expr_with_bindings("p AND q", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("p OR q", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::BinaryExpr(_)));
let expr = parse_expr_with_bindings("NOT p", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::Not(_)));
}
#[test]
fn test_is_null_operators() {
use datafusion::logical_expr::Expr as DFExpr;
let mut bindings = TypeEnvironment::default();
bindings.bind("x".into(), hamelin_lib::types::INT);
let bindings = Arc::new(bindings);
let expr = parse_expr_with_bindings("x IS NULL", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::IsNull(_)));
let expr = parse_expr_with_bindings("x IS NOT NULL", bindings.clone());
let result = translate_expr(&expr).unwrap();
assert!(matches!(result, DFExpr::IsNotNull(_)));
}
#[test]
fn test_sum_aggregate() {
use datafusion::logical_expr::Expr as DFExpr;
use hamelin_lib::func::def::{FunctionTranslationContext, SpecialPosition};
let mut bindings = TypeEnvironment::default();
bindings.bind("x".into(), hamelin_lib::types::INT);
let bindings = Arc::new(bindings);
let fctx = FunctionTranslationContext::default()
.with_special_allowed(SpecialPosition::Agg)
.with_special_allowed(SpecialPosition::Window);
let parsed = Expression::parse("sum(x)");
let expr = type_check_expression(
parsed,
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.fctx(fctx)
.build(),
)
.output;
let result = translate_expr(&expr).unwrap();
assert!(
matches!(result, DFExpr::AggregateFunction(_)),
"Expected AggregateFunction, got {:?}",
result
);
}
}