use std::sync::Arc;
use hamelin_lib::tree::{
ast::{command::Command, expression::Expression},
builder::{
call, cast, column_ref, drop_command, field, lambda1, let_command, null, struct_literal,
ExpressionBuilder,
},
typed_ast::{
environment::TypeEnvironment,
expression::{TypedExpression, TypedExpressionKind, TypedStructLiteral},
},
};
use hamelin_lib::types::{struct_type::Struct, Type};
use super::unique::UniqueNameGenerator;
pub fn is_column_reference_chain(expr: &TypedExpression) -> bool {
match &expr.kind {
TypedExpressionKind::ColumnReference(_) => true,
TypedExpressionKind::FieldLookup(lookup) => is_column_reference_chain(&lookup.value),
_ => false,
}
}
pub fn count_needed_field_accesses(source: &Struct, target: &Struct) -> usize {
let mut count = 0;
for (field_name, target_field_type) in target.fields.iter() {
if let Some(source_field_type) = source.fields.get(field_name) {
count += 1;
if let (Type::Struct(s), Type::Struct(t)) = (source_field_type, target_field_type) {
count += count_needed_field_accesses(s, t);
}
}
}
count
}
pub fn expand_struct_to_type(
expr: &Arc<TypedExpression>,
source_type: &Struct,
target_type: &Struct,
name_gen: &mut UniqueNameGenerator,
schema: &TypeEnvironment,
) -> (Arc<Expression>, Vec<Command>, Vec<Command>) {
if source_type == target_type {
return (expr.ast.clone(), Vec::new(), Vec::new());
}
match &expr.kind {
TypedExpressionKind::StructLiteral(lit) => {
let (expanded, before, after) =
expand_struct_literal(lit, target_type, name_gen, schema);
(Arc::new(expanded), before, after)
}
_ if is_column_reference_chain(expr) => (
Arc::new(build_expanded_struct(expr, source_type, target_type)),
Vec::new(),
Vec::new(),
),
_ => {
let field_access_count = count_needed_field_accesses(source_type, target_type);
if field_access_count > 1 {
let hoisted_name = name_gen.next(schema);
let let_cmd = let_command()
.named_field(hoisted_name.clone(), expr.ast.clone())
.build();
let drop_cmd = drop_command().field(hoisted_name.clone()).build();
let expanded = build_expanded_struct_fields_from_column(
hoisted_name.as_str(),
source_type,
target_type,
);
(Arc::new(expanded), vec![let_cmd], vec![drop_cmd])
} else if field_access_count == 1 {
(
Arc::new(build_expanded_struct(expr, source_type, target_type)),
Vec::new(),
Vec::new(),
)
} else {
(
Arc::new(cast(null(), target_type.clone().into()).build()),
Vec::new(),
Vec::new(),
)
}
}
}
}
fn expand_struct_literal(
lit: &TypedStructLiteral,
target_type: &Struct,
name_gen: &mut UniqueNameGenerator,
schema: &TypeEnvironment,
) -> (Expression, Vec<Command>, Vec<Command>) {
let mut builder = struct_literal();
let mut all_before = Vec::new();
let mut all_after = Vec::new();
for (field_name, field_type) in target_type.fields.iter() {
let existing = lit.fields.iter().find(|(n, _)| {
n.valid_ref()
.map(|s| s.as_str() == field_name.name.as_str())
.unwrap_or(false)
});
if let Some((_, field_expr)) = existing {
if let Type::Struct(target_nested) = field_type {
if let Type::Struct(source_nested) = field_expr.resolved_type.as_ref() {
let (expanded, before, after) = expand_struct_to_type(
field_expr,
source_nested,
target_nested,
name_gen,
schema,
);
all_before.extend(before);
all_after.extend(after);
builder = builder.field(field_name.clone(), expanded);
continue;
}
}
if let Type::Array(target_arr) = field_type {
if let Type::Array(source_arr) = field_expr.resolved_type.as_ref() {
if let (Type::Struct(source_elem), Type::Struct(target_elem)) = (
source_arr.element_type.as_ref(),
target_arr.element_type.as_ref(),
) {
if source_elem != target_elem {
let lambda_body = build_expanded_struct_fields(
column_ref("__item").build(),
source_elem,
target_elem,
);
let transformed = call("transform")
.arg(field_expr.ast.clone())
.arg(lambda1("__item").body(lambda_body))
.build();
builder = builder.field(field_name.clone(), Arc::new(transformed));
continue;
}
}
}
}
builder = builder.field(field_name.clone(), field_expr.ast.clone());
} else {
builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
}
}
(builder.build(), all_before, all_after)
}
fn build_expanded_struct(
source_expr: &Arc<TypedExpression>,
source_type: &Struct,
target_type: &Struct,
) -> Expression {
let mut builder = struct_literal();
for (field_name, field_type) in target_type.fields.iter() {
if let Some(source_field_type) = source_type.fields.get(field_name) {
let field_access = field(source_expr.ast.clone(), field_name.name.as_str());
if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
(source_field_type, field_type)
{
if source_nested != target_nested {
let expanded = build_expanded_struct_fields(
field_access.build(),
source_nested,
target_nested,
);
builder = builder.field(field_name.clone(), expanded);
continue;
}
}
if let (Type::Array(source_arr), Type::Array(target_arr)) =
(source_field_type, field_type)
{
if let (Type::Struct(source_elem), Type::Struct(target_elem)) = (
source_arr.element_type.as_ref(),
target_arr.element_type.as_ref(),
) {
if source_elem != target_elem {
let lambda_body = build_expanded_struct_fields_from_column(
"__item",
source_elem,
target_elem,
);
let transformed = call("transform")
.arg(field_access)
.arg(lambda1("__item").body(lambda_body))
.build();
builder = builder.field(field_name.clone(), transformed);
continue;
}
}
}
builder = builder.field(field_name.clone(), field_access);
} else {
builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
}
}
builder.build()
}
pub fn build_widening_expression(
field_name: &str,
source_type: Option<&Type>,
target_type: &Type,
) -> Expression {
match (source_type, target_type) {
(Some(source_t), target_t) => {
if let (Type::Struct(source_struct), Type::Struct(target_struct)) = (source_t, target_t)
{
if source_struct != target_struct {
return build_expanded_struct_fields(
column_ref(field_name).build(),
source_struct,
target_struct,
);
}
}
if let (Type::Array(source_arr), Type::Array(target_arr)) = (source_t, target_t) {
if let (Type::Struct(source_elem), Type::Struct(target_elem)) = (
source_arr.element_type.as_ref(),
target_arr.element_type.as_ref(),
) {
if source_elem != target_elem {
let lambda_body = build_expanded_struct_fields_from_column(
"__item",
source_elem,
target_elem,
);
return call("transform")
.arg(column_ref(field_name))
.arg(lambda1("__item").body(lambda_body))
.build();
}
}
}
column_ref(field_name).build()
}
(None, target_t) => cast(null(), target_t.clone()).build(),
}
}
fn build_expanded_struct_fields(
source_expr: Expression,
source_type: &Struct,
target_type: &Struct,
) -> Expression {
let mut builder = struct_literal();
for (field_name, field_type) in target_type.fields.iter() {
if let Some(source_field_type) = source_type.fields.get(field_name) {
let field_access = field(source_expr.clone(), field_name.name.as_str());
if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
(source_field_type, field_type)
{
if source_nested != target_nested {
let expanded = build_expanded_struct_fields(
field_access.build(),
source_nested,
target_nested,
);
builder = builder.field(field_name.clone(), expanded);
continue;
}
}
if let (Type::Array(source_arr), Type::Array(target_arr)) =
(source_field_type, field_type)
{
if let (Type::Struct(source_elem), Type::Struct(target_elem)) = (
source_arr.element_type.as_ref(),
target_arr.element_type.as_ref(),
) {
if source_elem != target_elem {
let lambda_body = build_expanded_struct_fields_from_column(
"__item",
source_elem,
target_elem,
);
let transformed = call("transform")
.arg(field_access)
.arg(lambda1("__item").body(lambda_body))
.build();
builder = builder.field(field_name.clone(), transformed);
continue;
}
}
}
builder = builder.field(field_name.clone(), field_access);
} else {
builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
}
}
builder.build()
}
fn build_expanded_struct_fields_from_column(
column_name: &str,
source_type: &Struct,
target_type: &Struct,
) -> Expression {
build_expanded_struct_fields(column_ref(column_name).build(), source_type, target_type)
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::tree::ast::expression::Expression;
use hamelin_lib::tree::ast::{IntoTyped, TypeCheckExecutor};
use hamelin_lib::tree::builder::{
call, cast, column_ref, field, ident, lambda1, null, struct_literal,
};
use hamelin_lib::tree::typed_ast::environment::TypeEnvironment;
use hamelin_lib::tree::typed_ast::expression::TypedExpression;
use hamelin_lib::types::array::Array;
use hamelin_lib::types::INT;
use pretty_assertions::assert_eq;
use rstest::rstest;
#[rstest]
#[case::flat_struct(
Struct::default().with_str("a", INT).with_str("b", INT),
Struct::default().with_str("a", INT).with_str("b", INT).with_str("c", INT),
2 // a and b exist in source, c doesn't
)]
#[case::nested_struct(
Struct::default().with_str("nested", Struct::default().with_str("x", INT).into()),
Struct::default().with_str("nested", Struct::default().with_str("x", INT).with_str("y", INT).into()),
2 // 1 for nested field + 1 for x inside nested
)]
fn test_count_needed_field_accesses(
#[case] source: Struct,
#[case] target: Struct,
#[case] expected: usize,
) {
assert_eq!(count_needed_field_accesses(&source, &target), expected);
}
#[rstest]
#[case::missing_field(
"missing_field",
None,
INT,
cast(null(), INT).build()
)]
#[case::existing_field(
"existing_field",
Some(INT),
INT,
column_ref("existing_field").build()
)]
#[case::array_of_structs(
"items",
Some(Array::new(Struct::default().with_str("a", INT).into()).into()),
Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into(),
call("transform")
.arg(column_ref("items"))
.arg(lambda1("__item").body(
struct_literal()
.field("a", field(column_ref("__item"), "a"))
.field("b", cast(null(), INT)),
))
.build()
)]
fn test_build_widening_expression(
#[case] field_name: &str,
#[case] source_type: Option<Type>,
#[case] target_type: Type,
#[case] expected: Expression,
) {
let result = build_widening_expression(field_name, source_type.as_ref(), &target_type);
assert_eq!(result, expected);
}
#[rstest]
#[case::reorders_fields(
Struct::default().with_str("b", INT).with_str("a", INT),
Struct::default().with_str("a", INT).with_str("b", INT).with_str("c", INT),
struct_literal()
.field("a", field(column_ref("source"), "a"))
.field("b", field(column_ref("source"), "b"))
.field("c", cast(null(), INT))
.build()
)]
#[case::nested_struct(
Struct::default().with_str("nested", Struct::default().with_str("x", INT).into()),
Struct::default().with_str("nested", Struct::default().with_str("x", INT).with_str("y", INT).into()),
struct_literal()
.field(
"nested",
struct_literal()
.field("x", field(field(column_ref("source"), "nested"), "x"))
.field("y", cast(null(), INT)),
)
.build()
)]
#[case::nested_array(
Struct::default().with_str("items", Array::new(Struct::default().with_str("a", INT).into()).into()),
Struct::default().with_str("items", Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into()),
struct_literal()
.field(
"items",
call("transform")
.arg(field(column_ref("source"), "items"))
.arg(lambda1("__item").body(
struct_literal()
.field("a", field(column_ref("__item"), "a"))
.field("b", cast(null(), INT)),
)),
)
.build()
)]
fn test_build_expanded_struct_fields(
#[case] source: Struct,
#[case] target: Struct,
#[case] expected: Expression,
) {
let result = build_expanded_struct_fields(column_ref("source").build(), &source, &target);
assert_eq!(result, expected);
}
fn type_check_expr(expr: Expression, bindings: Arc<TypeEnvironment>) -> TypedExpression {
expr.typed_with().with_bindings(bindings).typed()
}
fn test_bindings() -> Arc<TypeEnvironment> {
let nested_struct: Type = Struct::default().with_str("inner", INT).into();
let outer_struct: Type = Struct::default().with_str("field", nested_struct).into();
Arc::new(
TypeEnvironment::default()
.with_base(ident("col").into(), INT)
.unwrap()
.with_base(ident("s").into(), outer_struct)
.unwrap(),
)
}
#[rstest]
#[case::simple_column(column_ref("col").build(), true)]
#[case::one_field_access(field(column_ref("s"), "field").build(), true)]
#[case::nested_field_access(field(field(column_ref("s"), "field"), "inner").build(), true)]
#[case::function_call(call("coalesce").arg(column_ref("col")).arg(0).build(), false)]
#[case::binary_operation(hamelin_lib::tree::builder::add(column_ref("col"), 1).build(), false)]
fn test_is_column_reference_chain(#[case] expr: Expression, #[case] expected: bool) {
let bindings = test_bindings();
let typed_expr = type_check_expr(expr, bindings);
assert_eq!(is_column_reference_chain(&typed_expr), expected);
}
#[rstest]
#[case::column_ref_array_of_structs(
// Source: {items: Array<{a: INT}>}, column reference path
column_ref("data").build(),
Struct::default().with_str("items", Array::new(Struct::default().with_str("a", INT).into()).into()),
Struct::default().with_str("items", Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into()),
0, // no hoisting for column refs
struct_literal()
.field(
"items",
call("transform")
.arg(field(column_ref("data"), "items"))
.arg(lambda1("__item").body(
struct_literal()
.field("a", field(column_ref("__item"), "a"))
.field("b", cast(null(), INT)),
)),
)
.build()
)]
#[case::hoisted_array_of_structs(
// Source: {x: INT, items: Array<{a: INT}>}, complex expr triggers hoisting
call("coalesce").arg(column_ref("data")).arg(column_ref("data")).build(),
Struct::default().with_str("x", INT).with_str("items", Array::new(Struct::default().with_str("a", INT).into()).into()),
Struct::default().with_str("x", INT).with_str("items", Array::new(Struct::default().with_str("a", INT).with_str("b", INT).into()).into()),
1, // hoisting needed for complex expr with 2+ field accesses
struct_literal()
.field("x", field(column_ref("__test_0"), "x"))
.field(
"items",
call("transform")
.arg(field(column_ref("__test_0"), "items"))
.arg(lambda1("__item").body(
struct_literal()
.field("a", field(column_ref("__item"), "a"))
.field("b", cast(null(), INT)),
)),
)
.build()
)]
fn test_expand_struct_to_type(
#[case] expr: Expression,
#[case] source_type: Struct,
#[case] target_type: Struct,
#[case] expected_hoisted_count: usize,
#[case] expected: Expression,
) {
let bindings = Arc::new(
TypeEnvironment::default()
.with_base(ident("data").into(), source_type.clone().into())
.unwrap(),
);
let typed_expr = Arc::new(expr.typed_with().with_bindings(bindings.clone()).typed());
let mut name_gen = UniqueNameGenerator::new("__test");
let (result, before, after) = expand_struct_to_type(
&typed_expr,
&source_type,
&target_type,
&mut name_gen,
&bindings,
);
assert_eq!(before.len(), expected_hoisted_count);
assert_eq!(after.len(), expected_hoisted_count);
assert_eq!(*result, expected);
}
}