use hamelin_lib::err::{TranslationError, TranslationErrors};
use hamelin_lib::sql::expression::apply::Lambda as SQLLambda;
use hamelin_lib::sql::expression::apply::{BinaryOperatorApply, FunctionCallApply};
use hamelin_lib::sql::expression::identifier::{
Identifier as SQLIdentifier, SimpleIdentifier as SQLSimpleIdentifier,
};
use hamelin_lib::sql::expression::literal::StringLiteral;
use hamelin_lib::sql::expression::literal::{
ArrayLiteral as SQLArrayLiteral, RowLiteral as SQLRowLiteral,
};
use hamelin_lib::sql::expression::literal::{
BinaryLiteral as SQLBinaryLiteral, BooleanLiteral as SQLBooleanLiteral,
ColumnReference as SQLColumnReference, DecimalLiteral as SQLDecimalLiteral,
IntegerLiteral as SQLIntegerLiteral, IntervalLiteral as SQLIntervalLiteral,
NullLiteral as SQLNullLiteral, ScientificLiteral as SQLScientificLiteral,
StringLiteral as SQLStringLiteral, Unit as SQLIntervalUnit,
};
use hamelin_lib::sql::expression::operator::Operator;
use hamelin_lib::sql::expression::Cast as SQLCast;
use hamelin_lib::sql::expression::Dot;
use hamelin_lib::sql::expression::IndexLookup;
use hamelin_lib::sql::expression::{SQLExpression, TryCast};
use hamelin_lib::sql::types::{
SQLBaseType, SQLDecimalType, SQLRowType, SQLTimestampTzType, SQLType,
};
use hamelin_lib::translation::ExpressionTranslation;
use hamelin_lib::tree::ast::expression::{ExpressionKind, IntervalUnit, TruncUnit};
use hamelin_lib::tree::ast::identifier::ParsedSimpleIdentifier;
use hamelin_lib::tree::ast::identifier::SimpleIdentifier;
use hamelin_lib::tree::ast::node::Spannable;
use hamelin_lib::tree::typed_ast::expression::{
CastKind, FieldAccess, TypedApply, TypedArrayLiteral, TypedCast, TypedExpressionKind,
TypedFieldLookup, TypedFieldReference, TypedLambda, TypedStructLiteral, TypedTsTrunc,
TypedTupleLiteral, TypedVariantIndexAccess, VariantCastKind,
};
use hamelin_lib::types::Type;
use hamelin_translation::IRExpression;
use crate::context::TranslationContext;
pub type ExprTranslationResult = Result<SQLExpression, TranslationErrors>;
pub fn translate_expression(
ctx: &TranslationContext,
expr: &IRExpression,
) -> ExprTranslationResult {
let typed_expr = expr.inner();
match &typed_expr.kind {
TypedExpressionKind::Leaf => translate_leaf(ctx, expr),
TypedExpressionKind::FieldReference(col_ref) => translate_field_reference(col_ref, expr),
TypedExpressionKind::Cast(cast) => translate_cast(ctx, cast, expr),
TypedExpressionKind::TsTrunc(ts_trunc) => translate_ts_trunc(ctx, ts_trunc),
TypedExpressionKind::ArrayLiteral(arr) => translate_array_literal(ctx, arr),
TypedExpressionKind::TupleLiteral(tuple) => translate_tuple_literal(ctx, tuple),
TypedExpressionKind::StructLiteral(strct) => {
translate_struct_literal(ctx, strct, expr.resolved_type(), expr)
}
TypedExpressionKind::FieldLookup(field_lookup) => {
translate_field_lookup(ctx, field_lookup, expr)
}
TypedExpressionKind::VariantIndexAccess(variant_access) => {
translate_variant_index_access(ctx, variant_access, expr)
}
TypedExpressionKind::Lambda(lambda) => translate_lambda(ctx, lambda),
TypedExpressionKind::Apply(apply) => {
translate_apply(ctx, apply, expr.resolved_type(), expr)
}
TypedExpressionKind::BroadcastApply(_) => {
Err(TranslationError::msg(
expr,
"BroadcastApply should be lowered before translation (run lower_broadcast_apply pass)",
)
.single())
}
TypedExpressionKind::Error(err) => Err(err.error.as_ref().clone().single()),
}
}
fn translate_field_reference(
col_ref: &TypedFieldReference,
_span: &impl Spannable,
) -> ExprTranslationResult {
match &col_ref.field_name {
ParsedSimpleIdentifier::Valid(simple) => {
let sql_simple = SQLSimpleIdentifier::new(simple.as_str());
let sql_ident: SQLIdentifier = sql_simple.into();
Ok(SQLColumnReference::new(sql_ident).into())
}
ParsedSimpleIdentifier::Error(err) => Err(err.as_ref().clone().single()),
}
}
fn translate_cast(
ctx: &TranslationContext,
cast: &TypedCast,
span: &impl Spannable,
) -> ExprTranslationResult {
let inner_expr = IRExpression::new(cast.value.clone());
let inner_sql = translate_expression(ctx, &inner_expr)?;
match &cast.cast_kind {
CastKind::Identity => Ok(inner_sql),
CastKind::IntToDouble => Ok(try_cast(inner_sql, SQLBaseType::Double.into())),
CastKind::IntToDecimal => {
let sql_type = hamelin_type_to_sql_type(&cast.target_type, span)?;
Ok(try_cast(inner_sql, sql_type))
}
CastKind::DoubleToInt => Ok(try_cast(inner_sql, SQLBaseType::BigInt.into())),
CastKind::DoubleToDecimal => {
let sql_type = hamelin_type_to_sql_type(&cast.target_type, span)?;
Ok(try_cast(inner_sql, sql_type))
}
CastKind::DecimalToInt => Ok(try_cast(inner_sql, SQLBaseType::BigInt.into())),
CastKind::DecimalToDouble => Ok(try_cast(inner_sql, SQLBaseType::Double.into())),
CastKind::DecimalToDecimal => {
let sql_type = hamelin_type_to_sql_type(&cast.target_type, span)?;
Ok(try_cast(inner_sql, sql_type))
}
CastKind::IntToBoolean => Ok(try_cast(inner_sql, SQLBaseType::Boolean.into())),
CastKind::BooleanToInt => Ok(try_cast(inner_sql, SQLBaseType::BigInt.into())),
CastKind::ToStringFromInt
| CastKind::ToStringFromDouble
| CastKind::ToStringFromBoolean
| CastKind::ToStringFromTimestamp
| CastKind::ToStringFromBinary
| CastKind::ToStringFromDecimal
| CastKind::ToStringFromInterval
| CastKind::ToStringFromCalendarInterval => {
Ok(try_cast(inner_sql, SQLBaseType::VarChar.into()))
}
CastKind::StringToInt => Ok(try_cast(inner_sql, SQLBaseType::BigInt.into())),
CastKind::StringToDouble => Ok(try_cast(inner_sql, SQLBaseType::Double.into())),
CastKind::StringToBoolean => Ok(try_cast(inner_sql, SQLBaseType::Boolean.into())),
CastKind::StringToTimestamp => Ok(try_cast(inner_sql, SQLTimestampTzType::new(6).into())),
CastKind::StringToDecimal => {
let sql_type = hamelin_type_to_sql_type(&cast.target_type, span)?;
Ok(try_cast(inner_sql, sql_type))
}
CastKind::NullToType => {
let sql_type = hamelin_type_to_sql_type(&cast.target_type, span)?;
Ok(try_cast(inner_sql, sql_type))
}
CastKind::ToVariant(_) => {
Ok(try_cast(inner_sql, SQLBaseType::Json.into()))
}
CastKind::FromVariant(variant_kind) => {
translate_from_variant(inner_sql, variant_kind, span)
}
CastKind::ArrayElementCast(_)
| CastKind::TupleToStruct(_)
| CastKind::RangeElementCast(_)
| CastKind::IntervalToTimestampRange
| CastKind::TimestampToTimestampRange
| CastKind::IntervalRangeToTimestampRange
| CastKind::StructExpansion(_) => {
apply_cast_kind(inner_sql, &cast.cast_kind, &cast.target_type, span)
}
}
}
fn try_cast(expr: SQLExpression, to: SQLType) -> SQLExpression {
TryCast::new(expr, to).into()
}
fn translate_from_variant(
inner_sql: SQLExpression,
variant_kind: &VariantCastKind,
span: &impl Spannable,
) -> ExprTranslationResult {
match variant_kind {
VariantCastKind::Int => {
let scalar = json_extract_scalar(inner_sql);
Ok(try_cast(scalar, SQLBaseType::BigInt.into()))
}
VariantCastKind::Double => {
let scalar = json_extract_scalar(inner_sql);
Ok(try_cast(scalar, SQLBaseType::Double.into()))
}
VariantCastKind::Decimal => {
let scalar = json_extract_scalar(inner_sql);
Ok(try_cast(scalar, SQLBaseType::Double.into()))
}
VariantCastKind::String => {
Ok(json_extract_scalar(inner_sql))
}
VariantCastKind::Boolean => {
let scalar = json_extract_scalar(inner_sql);
Ok(try_cast(scalar, SQLBaseType::Boolean.into()))
}
VariantCastKind::Array(element_kind) => {
translate_variant_to_array(inner_sql, element_kind, span)
}
VariantCastKind::Struct(fields) => translate_variant_to_struct(inner_sql, fields, span),
VariantCastKind::Map(key_kind, value_kind) => {
translate_variant_to_map(inner_sql, key_kind, value_kind, span)
}
VariantCastKind::Unknown => {
Ok(SQLNullLiteral {}.into())
}
VariantCastKind::Variant => {
Ok(inner_sql)
}
}
}
fn json_extract_scalar(expr: SQLExpression) -> SQLExpression {
if let Some((base, path)) = try_extract_json_extract_args(&expr) {
FunctionCallApply::with_two("json_extract_scalar", base, path).into()
} else {
FunctionCallApply::with_two("json_extract_scalar", expr, StringLiteral::new("$").into())
.into()
}
}
fn try_extract_json_extract_args(expr: &SQLExpression) -> Option<(SQLExpression, SQLExpression)> {
if let SQLExpression::FunctionCallApply(func) = expr {
if func.function_name == "json_extract" && func.arguments.len() == 2 {
return Some((func.arguments[0].clone(), func.arguments[1].clone()));
}
}
None
}
fn translate_variant_to_array(
inner_sql: SQLExpression,
element_kind: &VariantCastKind,
span: &impl Spannable,
) -> ExprTranslationResult {
let array_of_json: SQLType =
hamelin_lib::sql::types::SQLArrayType::new(SQLBaseType::Json.into()).into();
let cast_to_array_json: SQLExpression = TryCast::new(inner_sql, array_of_json).into();
let lambda_param = SQLSimpleIdentifier::new("e");
let param_ref: SQLExpression = SQLColumnReference::new(lambda_param.clone().into()).into();
let element_cast = translate_variant_element_cast(param_ref, element_kind, span)?;
let lambda: SQLExpression = SQLLambda::new(vec![lambda_param], element_cast).into();
Ok(FunctionCallApply::with_two("transform", cast_to_array_json, lambda).into())
}
fn translate_variant_element_cast(
element_expr: SQLExpression,
element_kind: &VariantCastKind,
span: &impl Spannable,
) -> ExprTranslationResult {
match element_kind {
VariantCastKind::Int => {
let scalar = json_extract_scalar(element_expr);
Ok(try_cast(scalar, SQLBaseType::BigInt.into()))
}
VariantCastKind::Double => {
let scalar = json_extract_scalar(element_expr);
Ok(try_cast(scalar, SQLBaseType::Double.into()))
}
VariantCastKind::Decimal => {
let scalar = json_extract_scalar(element_expr);
Ok(try_cast(scalar, SQLBaseType::Double.into()))
}
VariantCastKind::String => Ok(json_extract_scalar(element_expr)),
VariantCastKind::Boolean => {
let scalar = json_extract_scalar(element_expr);
Ok(try_cast(scalar, SQLBaseType::Boolean.into()))
}
VariantCastKind::Array(inner_element_kind) => {
translate_variant_to_array(element_expr, inner_element_kind, span)
}
VariantCastKind::Struct(fields) => translate_variant_to_struct(element_expr, fields, span),
VariantCastKind::Map(key_kind, value_kind) => {
translate_variant_to_map(element_expr, key_kind, value_kind, span)
}
VariantCastKind::Unknown => Ok(SQLNullLiteral {}.into()),
VariantCastKind::Variant => Ok(element_expr),
}
}
fn translate_variant_to_struct(
inner_sql: SQLExpression,
fields: &[(String, VariantCastKind)],
span: &impl Spannable,
) -> ExprTranslationResult {
let mut row_type_bindings = ordermap::OrderMap::new();
for (field_name, field_kind) in fields {
let sql_name = SQLSimpleIdentifier::new(field_name);
let sql_type = variant_cast_kind_to_sql_type(field_kind, span)?;
row_type_bindings.insert(sql_name, sql_type);
}
let row_type: SQLType = SQLRowType::new(row_type_bindings).into();
Ok(try_cast(inner_sql, row_type))
}
fn translate_variant_to_map(
inner_sql: SQLExpression,
key_kind: &VariantCastKind,
value_kind: &VariantCastKind,
span: &impl Spannable,
) -> ExprTranslationResult {
let key_type = variant_cast_kind_to_sql_type(key_kind, span)?;
let value_type = variant_cast_kind_to_sql_type(value_kind, span)?;
let map_type: SQLType = hamelin_lib::sql::types::SQLMapType::new(key_type, value_type).into();
Ok(try_cast(inner_sql, map_type))
}
fn is_primitive_cast(cast_kind: &CastKind) -> bool {
matches!(
cast_kind,
CastKind::Identity
| CastKind::IntToDouble
| CastKind::IntToDecimal
| CastKind::DoubleToInt
| CastKind::DoubleToDecimal
| CastKind::DecimalToInt
| CastKind::DecimalToDouble
| CastKind::DecimalToDecimal
| CastKind::IntToBoolean
| CastKind::BooleanToInt
| CastKind::ToStringFromInt
| CastKind::ToStringFromDouble
| CastKind::ToStringFromBoolean
| CastKind::ToStringFromTimestamp
| CastKind::ToStringFromBinary
| CastKind::ToStringFromDecimal
| CastKind::ToStringFromInterval
| CastKind::ToStringFromCalendarInterval
| CastKind::StringToInt
| CastKind::StringToDouble
| CastKind::StringToBoolean
| CastKind::StringToTimestamp
| CastKind::StringToDecimal
| CastKind::NullToType
)
}
fn translate_array_element_cast(
inner_sql: SQLExpression,
inner_cast_kind: &CastKind,
target_type: &Type,
span: &impl Spannable,
) -> ExprTranslationResult {
let element_type = match target_type {
Type::Array(arr) => &arr.element_type,
_ => {
return Err(
TranslationError::msg(span, "ArrayElementCast target type is not an array")
.single(),
)
}
};
if is_primitive_cast(inner_cast_kind) {
let sql_type = hamelin_type_to_sql_type(target_type, span)?;
return Ok(try_cast(inner_sql, sql_type));
}
let lambda_param = SQLSimpleIdentifier::new("x");
let param_ref: SQLExpression = SQLColumnReference::new(lambda_param.clone().into()).into();
let element_cast = apply_cast_kind(param_ref, inner_cast_kind, element_type, span)?;
let lambda: SQLExpression = SQLLambda::new(vec![lambda_param], element_cast).into();
Ok(FunctionCallApply::with_two("transform", inner_sql, lambda).into())
}
fn apply_cast_kind(
expr: SQLExpression,
cast_kind: &CastKind,
target_type: &Type,
span: &impl Spannable,
) -> ExprTranslationResult {
match cast_kind {
CastKind::Identity => Ok(expr),
CastKind::IntToDouble
| CastKind::IntToDecimal
| CastKind::DoubleToInt
| CastKind::DoubleToDecimal
| CastKind::DecimalToInt
| CastKind::DecimalToDouble
| CastKind::DecimalToDecimal
| CastKind::IntToBoolean
| CastKind::BooleanToInt
| CastKind::ToStringFromInt
| CastKind::ToStringFromDouble
| CastKind::ToStringFromBoolean
| CastKind::ToStringFromTimestamp
| CastKind::ToStringFromBinary
| CastKind::ToStringFromDecimal
| CastKind::ToStringFromInterval
| CastKind::ToStringFromCalendarInterval
| CastKind::StringToInt
| CastKind::StringToDouble
| CastKind::StringToBoolean
| CastKind::StringToTimestamp
| CastKind::StringToDecimal
| CastKind::NullToType => {
let sql_type = hamelin_type_to_sql_type(target_type, span)?;
Ok(try_cast(expr, sql_type))
}
CastKind::ArrayElementCast(inner) => {
translate_array_element_cast(expr, inner, target_type, span)
}
CastKind::StructExpansion(field_casts) => {
apply_struct_expansion(expr, field_casts, target_type, span)
}
CastKind::TupleToStruct(_) => {
let row_type = hamelin_type_to_sql_type(target_type, span)?;
Ok(SQLCast::new(expr, row_type).into())
}
CastKind::RangeElementCast(inner) => {
let element_type = match target_type {
Type::Range(range) => &range.of,
_ => {
return Err(TranslationError::msg(
span,
"RangeElementCast target type is not a range",
)
.single())
}
};
let element_sql_type = hamelin_type_to_sql_type(element_type, span)?;
let begin_ref: SQLExpression =
Dot::new(expr.clone(), SQLSimpleIdentifier::new("begin").into()).into();
let end_ref: SQLExpression =
Dot::new(expr, SQLSimpleIdentifier::new("end").into()).into();
let begin_cast = apply_cast_kind(begin_ref, inner, element_type, span)?;
let end_cast = apply_cast_kind(end_ref, inner, element_type, span)?;
let row_literal: SQLExpression = SQLRowLiteral::new(vec![begin_cast, end_cast]).into();
let mut row_type_bindings = ordermap::OrderMap::new();
row_type_bindings.insert(SQLSimpleIdentifier::new("begin"), element_sql_type.clone());
row_type_bindings.insert(SQLSimpleIdentifier::new("end"), element_sql_type);
let row_type: SQLType = SQLRowType::new(row_type_bindings).into();
Ok(SQLCast::new(row_literal, row_type).into())
}
CastKind::ToVariant(_) => Ok(try_cast(expr, SQLBaseType::Json.into())),
CastKind::FromVariant(variant_kind) => translate_from_variant(expr, variant_kind, span),
CastKind::IntervalToTimestampRange => translate_interval_to_timestamp_range(expr),
CastKind::TimestampToTimestampRange => translate_timestamp_to_timestamp_range(expr),
CastKind::IntervalRangeToTimestampRange => {
translate_interval_range_to_timestamp_range(expr)
}
}
}
fn apply_struct_expansion(
source_expr: SQLExpression,
field_casts: &[(SimpleIdentifier, CastKind)],
target_type: &Type,
span: &impl Spannable,
) -> ExprTranslationResult {
let target_struct = match target_type {
Type::Struct(s) => s,
_ => {
return Err(TranslationError::msg(
span,
&format!(
"StructExpansion target type is not a struct: {:?}",
target_type
),
)
.single())
}
};
let mut row_elements = Vec::new();
let mut row_type_bindings = ordermap::OrderMap::new();
for (field_name, cast_kind) in field_casts {
let sql_name = SQLSimpleIdentifier::new(field_name.name());
let field_type = target_struct.lookup(field_name).ok_or_else(|| {
TranslationError::msg(
span,
&format!("Field '{}' must exist in target struct", field_name),
)
.single()
})?;
let sql_type = hamelin_type_to_sql_type(field_type, span)?;
let field_expr = match cast_kind {
CastKind::NullToType => {
let null_expr: SQLExpression = SQLNullLiteral {}.into();
SQLCast::new(null_expr, sql_type.clone()).into()
}
CastKind::ArrayElementCast(inner_cast)
if matches!(inner_cast.as_ref(), CastKind::StructExpansion(_)) =>
{
let field_ref: SQLExpression =
Dot::new(source_expr.clone(), sql_name.clone().into()).into();
translate_array_element_cast_with_anonymous_struct(
field_ref, inner_cast, field_type, span,
)?
}
_ => {
let field_ref: SQLExpression =
Dot::new(source_expr.clone(), sql_name.clone().into()).into();
apply_cast_kind(field_ref, cast_kind, field_type, span)?
}
};
row_elements.push(field_expr);
row_type_bindings.insert(sql_name, sql_type);
}
let row_literal: SQLExpression = SQLRowLiteral::new(row_elements).into();
let row_type: SQLType = SQLRowType::new(row_type_bindings).into();
Ok(SQLCast::new(row_literal, row_type).into())
}
fn translate_array_element_cast_with_anonymous_struct(
inner_sql: SQLExpression,
inner_cast_kind: &CastKind,
target_type: &Type,
span: &impl Spannable,
) -> ExprTranslationResult {
let element_type = match target_type {
Type::Array(arr) => &arr.element_type,
_ => {
return Err(
TranslationError::msg(span, "ArrayElementCast target type is not an array")
.single(),
)
}
};
let field_casts = match inner_cast_kind {
CastKind::StructExpansion(fc) => fc,
_ => return apply_cast_kind(inner_sql, inner_cast_kind, target_type, span),
};
let lambda_param = SQLSimpleIdentifier::new("x");
let param_ref: SQLExpression = SQLColumnReference::new(lambda_param.clone().into()).into();
let element_cast =
build_anonymous_struct_expansion(param_ref, field_casts, element_type, span)?;
let lambda: SQLExpression = SQLLambda::new(vec![lambda_param], element_cast).into();
Ok(FunctionCallApply::with_two("transform", inner_sql, lambda).into())
}
fn build_anonymous_struct_expansion(
source_expr: SQLExpression,
field_casts: &[(SimpleIdentifier, CastKind)],
target_type: &Type,
span: &impl Spannable,
) -> ExprTranslationResult {
let target_struct = match target_type {
Type::Struct(s) => s,
_ => {
return Err(TranslationError::msg(
span,
&format!(
"StructExpansion target type is not a struct: {:?}",
target_type
),
)
.single())
}
};
let mut row_elements = Vec::new();
for (field_name, cast_kind) in field_casts {
let sql_name = SQLSimpleIdentifier::new(field_name.name());
let field_type = target_struct.lookup(field_name).ok_or_else(|| {
TranslationError::msg(
span,
&format!("Field '{}' must exist in target struct", field_name),
)
.single()
})?;
let field_expr = match cast_kind {
CastKind::NullToType => {
let sql_type = hamelin_type_to_sql_type(field_type, span)?;
let null_expr: SQLExpression = SQLNullLiteral {}.into();
SQLCast::new(null_expr, sql_type).into()
}
_ => {
let field_ref: SQLExpression =
Dot::new(source_expr.clone(), sql_name.clone().into()).into();
apply_cast_kind(field_ref, cast_kind, field_type, span)?
}
};
row_elements.push(field_expr);
}
Ok(SQLRowLiteral::new(row_elements).into())
}
fn translate_interval_to_timestamp_range(inner_sql: SQLExpression) -> ExprTranslationResult {
let current_ts: SQLExpression = FunctionCallApply::with_no_arguments("now").into();
let end_expr: SQLExpression =
BinaryOperatorApply::new(Operator::Plus, current_ts.clone(), inner_sql).into();
let row_literal: SQLExpression = SQLRowLiteral::new(vec![current_ts, end_expr]).into();
let mut row_type_bindings = ordermap::OrderMap::new();
row_type_bindings.insert(
SQLSimpleIdentifier::new("begin"),
SQLTimestampTzType::new(6).into(),
);
row_type_bindings.insert(
SQLSimpleIdentifier::new("end"),
SQLTimestampTzType::new(6).into(),
);
let row_type: SQLType = SQLRowType::new(row_type_bindings).into();
Ok(SQLCast::new(row_literal, row_type).into())
}
fn translate_timestamp_to_timestamp_range(inner_sql: SQLExpression) -> ExprTranslationResult {
let row_literal: SQLExpression = SQLRowLiteral::new(vec![inner_sql.clone(), inner_sql]).into();
let mut row_type_bindings = ordermap::OrderMap::new();
row_type_bindings.insert(
SQLSimpleIdentifier::new("begin"),
SQLTimestampTzType::new(6).into(),
);
row_type_bindings.insert(
SQLSimpleIdentifier::new("end"),
SQLTimestampTzType::new(6).into(),
);
let row_type: SQLType = SQLRowType::new(row_type_bindings).into();
Ok(SQLCast::new(row_literal, row_type).into())
}
fn translate_interval_range_to_timestamp_range(inner_sql: SQLExpression) -> ExprTranslationResult {
let current_ts: SQLExpression = FunctionCallApply::with_no_arguments("now").into();
let begin_interval: SQLExpression =
Dot::new(inner_sql.clone(), SQLSimpleIdentifier::new("begin").into()).into();
let end_interval: SQLExpression =
Dot::new(inner_sql, SQLSimpleIdentifier::new("end").into()).into();
let begin_ts: SQLExpression =
BinaryOperatorApply::new(Operator::Plus, current_ts.clone(), begin_interval).into();
let end_ts: SQLExpression =
BinaryOperatorApply::new(Operator::Plus, current_ts, end_interval).into();
let row_literal: SQLExpression = SQLRowLiteral::new(vec![begin_ts, end_ts]).into();
let mut row_type_bindings = ordermap::OrderMap::new();
row_type_bindings.insert(
SQLSimpleIdentifier::new("begin"),
SQLTimestampTzType::new(6).into(),
);
row_type_bindings.insert(
SQLSimpleIdentifier::new("end"),
SQLTimestampTzType::new(6).into(),
);
let row_type: SQLType = SQLRowType::new(row_type_bindings).into();
Ok(SQLCast::new(row_literal, row_type).into())
}
fn variant_cast_kind_to_sql_type(
kind: &VariantCastKind,
span: &impl Spannable,
) -> Result<SQLType, TranslationErrors> {
match kind {
VariantCastKind::Int => Ok(SQLBaseType::BigInt.into()),
VariantCastKind::Double => Ok(SQLBaseType::Double.into()),
VariantCastKind::Decimal => Ok(SQLBaseType::Double.into()), VariantCastKind::String => Ok(SQLBaseType::VarChar.into()),
VariantCastKind::Boolean => Ok(SQLBaseType::Boolean.into()),
VariantCastKind::Array(element_kind) => {
let element_type = variant_cast_kind_to_sql_type(element_kind, span)?;
Ok(hamelin_lib::sql::types::SQLArrayType::new(element_type).into())
}
VariantCastKind::Struct(fields) => {
let mut bindings = ordermap::OrderMap::new();
for (name, field_kind) in fields {
let sql_name = SQLSimpleIdentifier::new(name);
let sql_type = variant_cast_kind_to_sql_type(field_kind, span)?;
bindings.insert(sql_name, sql_type);
}
Ok(SQLRowType::new(bindings).into())
}
VariantCastKind::Map(key_kind, value_kind) => {
let key_type = variant_cast_kind_to_sql_type(key_kind, span)?;
let value_type = variant_cast_kind_to_sql_type(value_kind, span)?;
Ok(hamelin_lib::sql::types::SQLMapType::new(key_type, value_type).into())
}
VariantCastKind::Unknown => Ok(SQLBaseType::Json.into()),
VariantCastKind::Variant => Ok(SQLBaseType::Json.into()),
}
}
fn translate_ts_trunc(ctx: &TranslationContext, ts_trunc: &TypedTsTrunc) -> ExprTranslationResult {
let inner_expr = IRExpression::new(ts_trunc.expression.clone());
let inner_sql = translate_expression(ctx, &inner_expr)?;
let unit_str = match ts_trunc.unit {
TruncUnit::Second => "second",
TruncUnit::Minute => "minute",
TruncUnit::Hour => "hour",
TruncUnit::Day => "day",
TruncUnit::Week => "week",
TruncUnit::Month => "month",
TruncUnit::Quarter => "quarter",
TruncUnit::Year => "year",
};
Ok(
FunctionCallApply::with_two("date_trunc", StringLiteral::new(unit_str).into(), inner_sql)
.into(),
)
}
fn translate_lambda(ctx: &TranslationContext, lambda: &TypedLambda) -> ExprTranslationResult {
let params: Vec<SQLSimpleIdentifier> = lambda
.parameters
.iter()
.map(|p| p.name.clone().into())
.collect();
let body_expr = IRExpression::new(lambda.body.clone());
let body_sql = translate_expression(ctx, &body_expr)?;
Ok(SQLLambda::new(params, body_sql).into())
}
fn translate_apply(
ctx: &TranslationContext,
apply: &TypedApply,
resolved_type: &Type,
span: &impl Spannable,
) -> ExprTranslationResult {
let translated_binding = apply
.parameter_binding
.clone()
.try_map(|typed_expr| {
let ir_expr = IRExpression::new(typed_expr.clone());
let sql = translate_expression(ctx, &ir_expr)?;
let typ = (*typed_expr.resolved_type).clone();
let expr_span = typed_expr.ast.span.to_range();
Ok(ExpressionTranslation {
sql,
typ,
span: expr_span,
special: None,
nested_special: Vec::new(),
})
})
.map_err(|e: TranslationErrors| e)?;
ctx.registry
.translate(
apply.function_def.as_ref(),
apply.function_def.name(),
&ctx.fctx,
translated_binding,
resolved_type.clone(),
)
.map(|et| et.sql)
.map_err(|e| {
TranslationError::msg(span, &format!("Function translation failed: {}", e)).single()
})
}
fn translate_field_lookup(
ctx: &TranslationContext,
field_lookup: &TypedFieldLookup,
_span: &impl Spannable,
) -> ExprTranslationResult {
let inner_expr = IRExpression::new(field_lookup.value.clone());
let inner_sql = translate_expression(ctx, &inner_expr)?;
match &field_lookup.access {
FieldAccess::StructField(parsed_id) => match parsed_id {
ParsedSimpleIdentifier::Valid(id) => {
let sql_ident: SQLIdentifier = SQLSimpleIdentifier::new(id.as_str()).into();
Ok(Dot::new(inner_sql, sql_ident).into())
}
ParsedSimpleIdentifier::Error(err) => Err(err.as_ref().clone().single()),
},
FieldAccess::TupleElement(index) => {
Ok(IndexLookup::new(
inner_sql,
SQLIntegerLiteral::from_int((index + 1) as i64).into(),
)
.to_sql_expression())
}
FieldAccess::VariantField(parsed_id) => match parsed_id {
ParsedSimpleIdentifier::Valid(id) => {
let path = format!("$[\"{}\"]", id.as_str());
Ok(FunctionCallApply::with_two(
"json_extract",
inner_sql,
StringLiteral::new(&path).into(),
)
.into())
}
ParsedSimpleIdentifier::Error(err) => Err(err.as_ref().clone().single()),
},
FieldAccess::RangeBegin => {
let sql_ident: SQLIdentifier = SQLSimpleIdentifier::new("begin").into();
Ok(Dot::new(inner_sql, sql_ident).into())
}
FieldAccess::RangeEnd => {
let sql_ident: SQLIdentifier = SQLSimpleIdentifier::new("end").into();
Ok(Dot::new(inner_sql, sql_ident).into())
}
}
}
fn translate_variant_index_access(
ctx: &TranslationContext,
variant_access: &TypedVariantIndexAccess,
_span: &impl Spannable,
) -> ExprTranslationResult {
let inner_expr = IRExpression::new(variant_access.value.clone());
let inner_sql = translate_expression(ctx, &inner_expr)?;
let path = format!("$[{}]", variant_access.variant_index);
Ok(
FunctionCallApply::with_two("json_extract", inner_sql, StringLiteral::new(&path).into())
.into(),
)
}
fn translate_array_literal(
ctx: &TranslationContext,
arr: &TypedArrayLiteral,
) -> ExprTranslationResult {
let mut elements = Vec::with_capacity(arr.elements.len());
for elem in &arr.elements {
let ir_elem = IRExpression::new(elem.clone());
elements.push(translate_expression(ctx, &ir_elem)?);
}
Ok(SQLArrayLiteral::new(elements).into())
}
fn translate_tuple_literal(
ctx: &TranslationContext,
tuple: &TypedTupleLiteral,
) -> ExprTranslationResult {
let mut values = Vec::with_capacity(tuple.elements.len());
for elem in &tuple.elements {
let ir_elem = IRExpression::new(elem.clone());
values.push(translate_expression(ctx, &ir_elem)?);
}
Ok(SQLRowLiteral::new(values).into())
}
fn translate_struct_literal(
ctx: &TranslationContext,
strct: &TypedStructLiteral,
resolved_type: &Type,
span: &impl Spannable,
) -> ExprTranslationResult {
let mut values = Vec::with_capacity(strct.fields.len());
for (_, field_expr) in &strct.fields {
let ir_elem = IRExpression::new(field_expr.clone());
values.push(translate_expression(ctx, &ir_elem)?);
}
let row_literal: SQLExpression = SQLRowLiteral::new(values).into();
let sql_type = hamelin_type_to_sql_type(resolved_type, span)?;
Ok(SQLCast::new(row_literal, sql_type).into())
}
fn hamelin_type_to_sql_type(
typ: &Type,
span: &impl Spannable,
) -> Result<SQLType, TranslationErrors> {
match typ {
Type::Int => Ok(SQLBaseType::BigInt.into()),
Type::Double => Ok(SQLBaseType::Double.into()),
Type::String => Ok(SQLBaseType::VarChar.into()),
Type::Boolean => Ok(SQLBaseType::Boolean.into()),
Type::Timestamp => Ok(SQLTimestampTzType::new(6).into()),
Type::Binary => Ok(SQLBaseType::VarBinary.into()),
Type::Interval => Ok(SQLBaseType::IntervalDayToSecond.into()),
Type::CalendarInterval => Ok(SQLBaseType::IntervalYearToMonth.into()),
Type::Variant => Ok(SQLBaseType::Json.into()),
Type::Decimal(decimal) => {
Ok(SQLDecimalType::new(decimal.precision as u64, decimal.scale as u64).into())
}
Type::Array(arr) => {
let elem_type = hamelin_type_to_sql_type(&arr.element_type, span)?;
Ok(hamelin_lib::sql::types::SQLArrayType::new(elem_type).into())
}
Type::Struct(s) => {
let mut bindings = ordermap::OrderMap::new();
for (name, field_type) in s.iter() {
let sql_name = SQLSimpleIdentifier::new(name.name());
let sql_type = hamelin_type_to_sql_type(field_type, span)?;
bindings.insert(sql_name, sql_type);
}
Ok(SQLRowType::new(bindings).into())
}
Type::Map(map) => {
let key_type = hamelin_type_to_sql_type(&map.key_type, span)?;
let value_type = hamelin_type_to_sql_type(&map.value_type, span)?;
Ok(hamelin_lib::sql::types::SQLMapType::new(key_type, value_type).into())
}
Type::Range(range) => {
let elem_type = hamelin_type_to_sql_type(&range.of, span)?;
let bindings = ordermap::OrderMap::from_iter([
(SQLSimpleIdentifier::new("begin"), elem_type.clone()),
(SQLSimpleIdentifier::new("end"), elem_type),
]);
Ok(SQLRowType::new(bindings).into())
}
Type::Tuple(tuple) => {
let mut elements = Vec::new();
for elem_type in tuple.elements.iter() {
elements.push(hamelin_type_to_sql_type(elem_type, span)?);
}
Ok(hamelin_lib::sql::types::SQLAnonRowType::new(elements).into())
}
Type::Unknown | Type::Rows | Type::Function(_) => Err(TranslationError::msg(
span,
&format!("Cannot convert type {} to SQL", typ),
)
.single()),
}
}
fn translate_leaf(_ctx: &TranslationContext, expr: &IRExpression) -> ExprTranslationResult {
let typed_expr = expr.inner();
match &typed_expr.ast.kind {
ExpressionKind::IntLiteral(lit) => Ok(SQLIntegerLiteral::from_int(lit.int).into()),
ExpressionKind::DecimalLiteral(lit) => {
let formatted = format_decimal(lit.unscaled_value, lit.scale);
Ok(SQLDecimalLiteral::new(&formatted).into())
}
ExpressionKind::ScientificLiteral(lit) => {
Ok(SQLScientificLiteral::new(&lit.value.to_string()).into())
}
ExpressionKind::BooleanLiteral(lit) => Ok(SQLBooleanLiteral::new(lit.value).into()),
ExpressionKind::StringLiteral(lit) => Ok(SQLStringLiteral::new(&lit.value).into()),
ExpressionKind::BinaryLiteral(lit) => {
let hex = lit
.value
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>();
Ok(SQLBinaryLiteral::new(&format!("'{}'", hex)).into())
}
ExpressionKind::NullLiteral(_) => Ok(SQLNullLiteral::default().into()),
ExpressionKind::IntervalLiteral(lit) => {
Ok(translate_interval_literal(lit.value, &lit.unit))
}
ExpressionKind::ArrayLiteral(_)
| ExpressionKind::TupleLiteral(_)
| ExpressionKind::PairLiteral(_)
| ExpressionKind::StructLiteral(_)
| ExpressionKind::FieldReference(_)
| ExpressionKind::UnaryPrefixOperator(_)
| ExpressionKind::UnaryPostfixOperator(_)
| ExpressionKind::BinaryOperator(_)
| ExpressionKind::FunctionCall(_)
| ExpressionKind::IndexAccess(_)
| ExpressionKind::FieldLookup(_)
| ExpressionKind::Cast(_)
| ExpressionKind::TsTrunc(_)
| ExpressionKind::Lambda(_) => Err(TranslationError::msg(
expr,
&format!(
"Expected leaf literal, got {:?}",
std::mem::discriminant(&typed_expr.ast.kind)
),
)
.single()),
ExpressionKind::RowsLiteral(_) => Err(TranslationError::msg(
expr,
"RowsLiteral should have been lowered before translation",
)
.single()),
ExpressionKind::UnboundRangeLiteral(_) => {
let range_type = match typed_expr.resolved_type.as_ref() {
Type::Range(range) => range,
_ => {
return Err(TranslationError::msg(
expr,
"UnboundRangeLiteral should have Range type",
)
.single())
}
};
let elem_sql_type: SQLType = if *range_type.of == Type::Unknown {
SQLBaseType::Unknown.into()
} else {
hamelin_type_to_sql_type(&range_type.of, expr)?
};
let row_literal: SQLExpression = SQLRowLiteral::new(vec![
SQLNullLiteral::default().into(),
SQLNullLiteral::default().into(),
])
.into();
let mut row_type_bindings = ordermap::OrderMap::new();
row_type_bindings.insert(SQLSimpleIdentifier::new("begin"), elem_sql_type.clone());
row_type_bindings.insert(SQLSimpleIdentifier::new("end"), elem_sql_type);
let row_type: SQLType = SQLRowType::new(row_type_bindings).into();
Ok(SQLCast::new(row_literal, row_type).into())
}
ExpressionKind::Error(err) => Err(err.error.as_ref().clone().single()),
}
}
fn format_decimal(unscaled_value: i128, scale: u32) -> String {
if scale == 0 {
return format!("{}.0", unscaled_value);
}
let is_negative = unscaled_value < 0;
let abs_value = unscaled_value.unsigned_abs();
let str_value = abs_value.to_string();
let scale_usize = scale as usize;
if str_value.len() <= scale_usize {
let leading_zeros = scale_usize - str_value.len();
let sign = if is_negative { "-" } else { "" };
format!("{}0.{}{}", sign, "0".repeat(leading_zeros), str_value)
} else {
let decimal_pos = str_value.len() - scale_usize;
let sign = if is_negative { "-" } else { "" };
format!(
"{}{}.{}",
sign,
&str_value[..decimal_pos],
&str_value[decimal_pos..]
)
}
}
fn translate_interval_literal(value: i64, unit: &IntervalUnit) -> SQLExpression {
match unit {
IntervalUnit::Nanosecond => parse_duration(value, "ns"),
IntervalUnit::Microsecond => parse_duration(value, "us"),
IntervalUnit::Millisecond => parse_duration(value, "ms"),
IntervalUnit::Second => parse_duration(value, "s"),
IntervalUnit::Minute => parse_duration(value, "m"),
IntervalUnit::Hour => parse_duration(value, "h"),
IntervalUnit::Day => parse_duration(value, "d"),
IntervalUnit::Week => BinaryOperatorApply::new(
Operator::Asterisk,
SQLIntegerLiteral::new("7").into(),
SQLIntervalLiteral::new(value, SQLIntervalUnit::Day).into(),
)
.into(),
IntervalUnit::Month => SQLIntervalLiteral::new(value, SQLIntervalUnit::Month).into(),
IntervalUnit::Quarter => BinaryOperatorApply::new(
Operator::Asterisk,
SQLIntegerLiteral::new("3").into(),
SQLIntervalLiteral::new(value, SQLIntervalUnit::Month).into(),
)
.into(),
IntervalUnit::Year => SQLIntervalLiteral::new(value, SQLIntervalUnit::Year).into(),
}
}
fn parse_duration(value: i64, unit: &str) -> SQLExpression {
FunctionCallApply::with_one(
"parse_duration",
SQLStringLiteral::new(&format!("{}{}", value, unit)).into(),
)
.into()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_decimal() {
assert_eq!(format_decimal(12345, 2), "123.45");
assert_eq!(format_decimal(12345, 0), "12345.0");
assert_eq!(format_decimal(123, 5), "0.00123");
assert_eq!(format_decimal(-12345, 2), "-123.45");
assert_eq!(format_decimal(0, 2), "0.00");
}
}