use std::rc::Rc;
use hamelin_lib::tree::{
ast::{command::Command, expression::Expression},
builder::{
cast, column_ref, drop_command, field, let_command, null, struct_literal, ExpressionBuilder,
},
typed_ast::expression::{TypedExpression, TypedExpressionKind, TypedStructLiteral},
};
use hamelin_lib::types::{struct_type::Struct, Type};
use super::unique::UniqueNameGenerator;
pub fn is_column_reference_chain(expr: &TypedExpression) -> bool {
match &expr.kind {
TypedExpressionKind::ColumnReference(_) => true,
TypedExpressionKind::FieldLookup(lookup) => is_column_reference_chain(&lookup.value),
_ => false,
}
}
pub fn count_needed_field_accesses(source: &Struct, target: &Struct) -> usize {
let mut count = 0;
for (field_name, target_field_type) in target.fields.iter() {
if let Some(source_field_type) = source.fields.get(field_name) {
count += 1;
if let (Type::Struct(s), Type::Struct(t)) = (source_field_type, target_field_type) {
count += count_needed_field_accesses(s, t);
}
}
}
count
}
pub fn expand_struct_to_type(
expr: &Rc<TypedExpression>,
source_type: &Struct,
target_type: &Struct,
name_gen: &mut UniqueNameGenerator,
) -> (Rc<Expression>, Vec<Command>, Vec<Command>) {
if source_type == target_type {
return (expr.ast.clone(), Vec::new(), Vec::new());
}
match &expr.kind {
TypedExpressionKind::StructLiteral(lit) => {
let (expanded, before, after) = expand_struct_literal(lit, target_type, name_gen);
(Rc::new(expanded), before, after)
}
_ if is_column_reference_chain(expr) => (
Rc::new(build_expanded_struct(expr, source_type, target_type)),
Vec::new(),
Vec::new(),
),
_ => {
let field_access_count = count_needed_field_accesses(source_type, target_type);
if field_access_count > 1 {
let hoisted_name = name_gen.next();
let let_cmd = let_command()
.named_field(hoisted_name.clone(), expr.ast.clone())
.build();
let drop_cmd = drop_command().field(hoisted_name.clone()).build();
let expanded = build_expanded_struct_from_column(
hoisted_name.as_str(),
source_type,
target_type,
);
(Rc::new(expanded), vec![let_cmd], vec![drop_cmd])
} else if field_access_count == 1 {
(
Rc::new(build_expanded_struct(expr, source_type, target_type)),
Vec::new(),
Vec::new(),
)
} else {
(
Rc::new(cast(null(), target_type.clone().into()).build()),
Vec::new(),
Vec::new(),
)
}
}
}
}
fn expand_struct_literal(
lit: &TypedStructLiteral,
target_type: &Struct,
name_gen: &mut UniqueNameGenerator,
) -> (Expression, Vec<Command>, Vec<Command>) {
let mut builder = struct_literal();
let mut all_before = Vec::new();
let mut all_after = Vec::new();
for (field_name, field_type) in target_type.fields.iter() {
let existing = lit.fields.iter().find(|(n, _)| {
n.valid_ref()
.map(|s| s.as_str() == field_name.name.as_str())
.unwrap_or(false)
});
if let Some((_, field_expr)) = existing {
if let Type::Struct(target_nested) = field_type {
if let Type::Struct(source_nested) = field_expr.resolved_type.as_ref() {
let (expanded, before, after) =
expand_struct_to_type(field_expr, source_nested, target_nested, name_gen);
all_before.extend(before);
all_after.extend(after);
builder = builder.field(field_name.clone(), expanded);
continue;
}
}
builder = builder.field(field_name.clone(), field_expr.ast.clone());
} else {
builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
}
}
(builder.build(), all_before, all_after)
}
fn build_expanded_struct(
source_expr: &Rc<TypedExpression>,
source_type: &Struct,
target_type: &Struct,
) -> Expression {
let mut builder = struct_literal();
for (field_name, field_type) in target_type.fields.iter() {
if let Some(source_field_type) = source_type.fields.get(field_name) {
let field_access = field(source_expr.ast.clone(), field_name.name.as_str());
if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
(source_field_type, field_type)
{
if source_nested != target_nested {
let expanded = build_expanded_struct_fields(
field_access.build(),
source_nested,
target_nested,
);
builder = builder.field(field_name.clone(), expanded);
continue;
}
}
builder = builder.field(field_name.clone(), field_access);
} else {
builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
}
}
builder.build()
}
fn build_expanded_struct_from_column(
column_name: &str,
source_type: &Struct,
target_type: &Struct,
) -> Expression {
let mut builder = struct_literal();
for (field_name, field_type) in target_type.fields.iter() {
if let Some(source_field_type) = source_type.fields.get(field_name) {
let field_access = field(column_ref(column_name), field_name.name.as_str());
if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
(source_field_type, field_type)
{
if source_nested != target_nested {
let expanded = build_expanded_struct_fields(
field_access.build(),
source_nested,
target_nested,
);
builder = builder.field(field_name.clone(), expanded);
continue;
}
}
builder = builder.field(field_name.clone(), field_access);
} else {
builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
}
}
builder.build()
}
pub fn build_widening_expression(
field_name: &str,
source_type: Option<&Type>,
target_type: &Type,
) -> Expression {
match (source_type, target_type) {
(Some(source_t), target_t) => {
if let (Type::Struct(source_struct), Type::Struct(target_struct)) = (source_t, target_t)
{
if source_struct != target_struct {
return build_expanded_struct_fields(
column_ref(field_name).build(),
source_struct,
target_struct,
);
}
}
column_ref(field_name).build()
}
(None, target_t) => cast(null(), target_t.clone()).build(),
}
}
fn build_expanded_struct_fields(
source_expr: Expression,
source_type: &Struct,
target_type: &Struct,
) -> Expression {
let mut builder = struct_literal();
for (field_name, field_type) in target_type.fields.iter() {
if let Some(source_field_type) = source_type.fields.get(field_name) {
let field_access = field(source_expr.clone(), field_name.name.as_str());
if let (Type::Struct(source_nested), Type::Struct(target_nested)) =
(source_field_type, field_type)
{
if source_nested != target_nested {
let expanded = build_expanded_struct_fields(
field_access.build(),
source_nested,
target_nested,
);
builder = builder.field(field_name.clone(), expanded);
continue;
}
}
builder = builder.field(field_name.clone(), field_access);
} else {
builder = builder.field(field_name.clone(), cast(null(), field_type.clone()));
}
}
builder.build()
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::sql::expression::identifier::SimpleIdentifier;
use hamelin_lib::types::INT;
fn make_struct(fields: &[(&str, Type)]) -> Struct {
Struct::new(
fields
.iter()
.map(|(name, typ)| (SimpleIdentifier::new(name), typ.clone())),
)
}
#[test]
fn test_types_match_exactly_same_order() {
let s1 = make_struct(&[("a", INT), ("b", INT)]);
let s2 = make_struct(&[("a", INT), ("b", INT)]);
assert_eq!(s1, s2);
}
#[test]
fn test_types_match_exactly_different_order() {
let s1 = make_struct(&[("a", INT), ("b", INT)]);
let s2 = make_struct(&[("b", INT), ("a", INT)]);
assert_ne!(s1, s2);
}
#[test]
fn test_types_match_exactly_missing_field() {
let s1 = make_struct(&[("a", INT)]);
let s2 = make_struct(&[("a", INT), ("b", INT)]);
assert_ne!(s1, s2);
}
#[test]
fn test_count_needed_field_accesses() {
let source = make_struct(&[("a", INT), ("b", INT)]);
let target = make_struct(&[("a", INT), ("b", INT), ("c", INT)]);
assert_eq!(count_needed_field_accesses(&source, &target), 2);
}
#[test]
fn test_count_nested_field_accesses() {
let inner_source = make_struct(&[("x", INT)]);
let inner_target = make_struct(&[("x", INT), ("y", INT)]);
let source = make_struct(&[("nested", Type::Struct(inner_source.clone()))]);
let target = make_struct(&[("nested", Type::Struct(inner_target.clone()))]);
assert_eq!(count_needed_field_accesses(&source, &target), 2);
}
}