use std::sync::Arc;
use hamelin_lib::tree::{
ast::{command::Command, expression::Expression, identifier::ParsedSimpleIdentifier},
builder::{cast, drop_command, field_ref, set_command, struct_literal, ExpressionBuilder},
typed_ast::{
environment::TypeEnvironment,
expression::{TypedExpression, TypedExpressionKind, TypedStructLiteral},
},
};
#[cfg(test)]
use hamelin_lib::type_check_expression;
use hamelin_lib::types::{array::Array, struct_type::Struct, Type};
use crate::unique::UniqueNameGenerator;
pub fn is_field_reference_chain(expr: &TypedExpression) -> bool {
match &expr.kind {
TypedExpressionKind::FieldReference(_) => true,
TypedExpressionKind::FieldLookup(lookup) => is_field_reference_chain(&lookup.value),
_ => false,
}
}
fn is_simple_expression(expr: &TypedExpression) -> bool {
match &expr.kind {
TypedExpressionKind::FieldReference(_) => true,
TypedExpressionKind::FieldLookup(lookup) => is_field_reference_chain(&lookup.value),
TypedExpressionKind::Leaf => true,
TypedExpressionKind::Cast(c) => is_simple_expression(&c.value),
TypedExpressionKind::StructLiteral(lit) => lit
.fields
.iter()
.all(|(_, field_expr)| is_simple_expression(field_expr)),
TypedExpressionKind::ArrayLiteral(arr) => {
arr.elements.iter().all(|elem| is_simple_expression(elem))
}
_ => false,
}
}
pub fn expand_struct_to_type_with_ast(
expr: &Arc<TypedExpression>,
transformed_ast: Option<&Arc<Expression>>,
source_type: &Struct,
target_type: &Struct,
name_gen: &mut UniqueNameGenerator,
schema: &TypeEnvironment,
) -> (Arc<Expression>, Vec<Command>, Vec<Command>) {
let ast = transformed_ast.unwrap_or(&expr.ast);
if source_type == target_type {
return (ast.clone(), Vec::new(), Vec::new());
}
match &expr.kind {
TypedExpressionKind::StructLiteral(lit) => {
let transformed_fields = transformed_ast.and_then(|a| match &a.kind {
hamelin_lib::tree::ast::expression::ExpressionKind::StructLiteral(s) => {
Some(&s.fields)
}
_ => None,
});
let (expanded, before, after) =
expand_struct_literal(lit, transformed_fields, target_type, name_gen, schema);
(Arc::new(expanded), before, after)
}
_ if is_field_reference_chain(expr) || is_simple_expression(expr) => {
let cast_expr = cast(ast.clone(), target_type.clone().into()).build();
(Arc::new(cast_expr), Vec::new(), Vec::new())
}
_ => {
let hoisted_name = name_gen.next(schema);
let set_cmd = set_command()
.named_field(hoisted_name.clone(), ast.clone())
.build();
let drop_cmd = drop_command().field(hoisted_name.clone()).build();
let cast_expr =
cast(field_ref(hoisted_name.as_str()), target_type.clone().into()).build();
(Arc::new(cast_expr), vec![set_cmd], vec![drop_cmd])
}
}
}
fn expand_struct_literal(
lit: &TypedStructLiteral,
transformed_fields: Option<&Vec<(ParsedSimpleIdentifier, Arc<Expression>)>>,
target_type: &Struct,
name_gen: &mut UniqueNameGenerator,
schema: &TypeEnvironment,
) -> (Expression, Vec<Command>, Vec<Command>) {
let mut builder = struct_literal();
let mut all_before = Vec::new();
let mut all_after = Vec::new();
for (field_name, field_type) in target_type.iter() {
let existing = lit.fields.iter().enumerate().find(|(_, (n, _))| {
n.valid_ref()
.map(|s| s.as_str() == field_name.name())
.unwrap_or(false)
});
if let Some((idx, (_, field_expr))) = existing {
let transformed_field_ast = transformed_fields.map(|tf| &tf[idx].1);
if let Type::Struct(target_nested) = field_type {
if let Type::Struct(source_nested) = field_expr.resolved_type.as_ref() {
if source_nested != target_nested {
let (expanded, before, after) = expand_struct_to_type_with_ast(
field_expr,
transformed_field_ast,
source_nested,
target_nested,
name_gen,
schema,
);
all_before.extend(before);
all_after.extend(after);
builder = builder.field(field_name.name(), expanded);
continue;
}
}
}
if let Type::Array(target_arr) = field_type {
if let Type::Array(source_arr) = field_expr.resolved_type.as_ref() {
if let (Type::Struct(source_elem), Type::Struct(target_elem)) = (
source_arr.element_type.as_ref(),
target_arr.element_type.as_ref(),
) {
if source_elem != target_elem {
let field_ast =
transformed_field_ast.unwrap_or(&field_expr.ast).clone();
let target_array_type: Type =
Array::new(target_elem.clone().into()).into();
let cast_expr = cast(field_ast, target_array_type).build();
builder = builder.field(field_name.name(), cast_expr);
continue;
}
}
}
}
let field_ast = transformed_field_ast.unwrap_or(&field_expr.ast).clone();
builder = builder.field(field_name.name(), field_ast);
} else {
builder = builder.field(
field_name.name(),
cast(hamelin_lib::tree::builder::null(), field_type.clone()),
);
}
}
(builder.build(), all_before, all_after)
}
pub fn build_widening_expression(
field_name: &str,
source_type: Option<&Type>,
target_type: &Type,
) -> Expression {
match (source_type, target_type) {
(Some(source_t), target_t) => {
if let (Type::Struct(source_struct), Type::Struct(target_struct)) = (source_t, target_t)
{
if source_struct != target_struct {
return cast(field_ref(field_name), target_t.clone()).build();
}
}
if let (Type::Array(source_arr), Type::Array(target_arr)) = (source_t, target_t) {
if let (Type::Struct(source_elem), Type::Struct(target_elem)) = (
source_arr.element_type.as_ref(),
target_arr.element_type.as_ref(),
) {
if source_elem != target_elem {
return cast(field_ref(field_name), target_t.clone()).build();
}
}
}
field_ref(field_name).build()
}
(None, target_t) => cast(hamelin_lib::tree::builder::null(), target_t.clone()).build(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::builder::{call, cast, field_ref, null};
use hamelin_lib::tree::options::ExpressionTypeCheckOptions;
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::expression::TypedExpression;
use hamelin_lib::types::array::Array;
use hamelin_lib::types::INT;
use pretty_assertions::assert_eq;
use rstest::rstest;
#[rstest]
#[case::missing_field(
"missing_field",
None,
INT,
cast(null(), INT).build()
)]
#[case::existing_field(
"existing_field",
Some(INT),
INT,
field_ref("existing_field").build()
)]
#[case::nested_struct_widening(
"nested",
Some(Struct::default().with_str("a", INT).into()),
Struct::default().with_str("a", INT).with_str("b", INT).into(),
cast(
field_ref("nested"),
Struct::default().with_str("a", INT).with_str("b", INT).into(),
).build()
)]
#[case::array_of_structs(
"items",
Some(Array::new(Struct::default().with_str("a", INT).into()).into()),
Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into(),
cast(
field_ref("items"),
Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into(),
).build()
)]
fn test_build_widening_expression(
#[case] field_name: &str,
#[case] source_type: Option<Type>,
#[case] target_type: Type,
#[case] expected: Expression,
) {
let result = build_widening_expression(field_name, source_type.as_ref(), &target_type);
assert_eq!(result, expected);
}
fn type_check_expr(expr: Expression, bindings: Arc<TypeEnvironment>) -> TypedExpression {
type_check_expression(
expr,
ExpressionTypeCheckOptions::builder()
.bindings(bindings)
.build(),
)
.output
}
fn test_bindings() -> Arc<TypeEnvironment> {
let nested_struct: Type = Struct::default().with_str("inner", INT).into();
let outer_struct: Type = Struct::default().with_str("field", nested_struct).into();
Arc::new(
TypeEnvironment::default()
.with(hamelin_lib::tree::builder::ident("col").into(), INT)
.with(hamelin_lib::tree::builder::ident("s").into(), outer_struct),
)
}
#[rstest]
#[case::simple_column(field_ref("col").build(), true)]
#[case::one_field_access(hamelin_lib::tree::builder::field(field_ref("s"), "field").build(), true)]
#[case::nested_field_access(hamelin_lib::tree::builder::field(hamelin_lib::tree::builder::field(field_ref("s"), "field"), "inner").build(), true)]
#[case::function_call(call("coalesce").arg(field_ref("col")).arg(0).build(), false)]
#[case::binary_operation(hamelin_lib::tree::builder::add(field_ref("col"), 1).build(), false)]
fn test_is_field_reference_chain(#[case] expr: Expression, #[case] expected: bool) {
let bindings = test_bindings();
let typed_expr = type_check_expr(expr, bindings);
assert_eq!(is_field_reference_chain(&typed_expr), expected);
}
fn ident(s: &str) -> hamelin_lib::tree::ast::identifier::SimpleIdentifier {
hamelin_lib::tree::ast::identifier::SimpleIdentifier::new(s)
}
#[rstest]
#[case::field_ref_emits_cast(
// Source: {a: INT, b: INT}, column reference path
field_ref("data").build(),
Struct::default().with_str("a", INT).with_str("b", INT),
Struct::default().with_str("a", INT).with_str("b", INT).with_str("c", INT),
0, // no hoisting for column refs
cast(
field_ref("data"),
Struct::default().with_str("a", INT).with_str("b", INT).with_str("c", INT).into(),
).build()
)]
#[case::complex_expr_hoists_then_casts(
// Complex expr triggers hoisting
call("coalesce").arg(field_ref("data")).arg(field_ref("data")).build(),
Struct::default().with_str("a", INT).with_str("b", INT),
Struct::default().with_str("a", INT).with_str("b", INT).with_str("c", INT),
1, // hoisting needed for complex expr
cast(
field_ref("__test_0"),
Struct::default().with_str("a", INT).with_str("b", INT).with_str("c", INT).into(),
).build()
)]
fn test_expand_struct_to_type(
#[case] expr: Expression,
#[case] source_type: Struct,
#[case] target_type: Struct,
#[case] expected_hoisted_count: usize,
#[case] expected: Expression,
) {
let bindings = Arc::new(
TypeEnvironment::default().with(ident("data").into(), source_type.clone().into()),
);
let typed_expr = Arc::new(
type_check_expression(
expr,
ExpressionTypeCheckOptions::builder()
.bindings(bindings.clone())
.build(),
)
.output,
);
let mut name_gen = UniqueNameGenerator::new("__test");
let (result, before, after) = expand_struct_to_type_with_ast(
&typed_expr,
None,
&source_type,
&target_type,
&mut name_gen,
&bindings,
);
assert_eq!(before.len(), expected_hoisted_count);
assert_eq!(after.len(), expected_hoisted_count);
assert_eq!(*result, expected);
}
}