use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
ast::command::Command,
builder::{self, column_ref, select_command},
typed_ast::{
command::{TypedCommand, TypedCommandKind, TypedNestCommand},
context::StatementTranslationContext,
pipeline::TypedPipeline,
},
};
pub fn lower_nest(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
let has_nest = pipeline
.valid_ref()?
.commands
.iter()
.any(|cmd| matches!(&cmd.kind, TypedCommandKind::Nest(_)));
if !has_nest {
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let mut pipe_builder = builder::pipeline();
for cmd in &valid.commands {
pipe_builder = pipe_builder.command(lower_command(cmd, ctx)?);
}
let new_ast = pipe_builder.build().at(pipeline.ast.span);
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_ast),
ctx,
)))
}
fn lower_command(
cmd: &Arc<TypedCommand>,
ctx: &mut StatementTranslationContext,
) -> Result<Command, Arc<TranslationError>> {
let TypedCommandKind::Nest(nest_cmd) = &cmd.kind else {
return Ok(cmd.ast.as_ref().clone());
};
lower_nest_command(nest_cmd, cmd, ctx)
}
fn lower_nest_command(
nest_cmd: &TypedNestCommand,
cmd: &TypedCommand,
_ctx: &mut StatementTranslationContext,
) -> Result<Command, Arc<TranslationError>> {
use hamelin_lib::tree::ast::identifier::SimpleIdentifier;
let identifier = nest_cmd.identifier.valid_ref()?;
let mut select = select_command().at(cmd.ast.span);
for (field_name, _field_type) in nest_cmd.nested_type.fields.iter() {
let field_id: SimpleIdentifier = field_name.clone().into();
let compound_id = identifier.clone() + field_id.into();
select = select.named_field(compound_id, column_ref(field_name.clone()));
}
Ok(select.build())
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
tree::{
ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
builder::{column_ref, ident, nest_command, pipeline, select_command},
},
types::{struct_type::Struct, INT, STRING},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[rstest]
#[case::no_nest_passthrough(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::simple_nest(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", "hello").build())
.command(nest_command("user").build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", "hello").build())
.command(select_command()
.named_field(ident("user").dot("a"), column_ref("a"))
.named_field(ident("user").dot("b"), column_ref("b"))
.build())
.build(),
Struct::default().with_str("user", Struct::default().with_str("a", INT).with_str("b", STRING).into())
)]
#[case::compound_nest(
pipeline()
.command(select_command().named_field("x", 42).build())
.command(nest_command(ident("user").dot("address")).build())
.build(),
pipeline()
.command(select_command().named_field("x", 42).build())
.command(select_command()
.named_field(ident("user").dot("address").dot("x"), column_ref("x"))
.build())
.build(),
Struct::default().with_str("user",
Struct::default().with_str("address",
Struct::default().with_str("x", INT).into()).into())
)]
fn test_lower_nest(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = input.typed_with().typed();
let expected_typed = expected.typed_with().typed();
let mut ctx = StatementTranslationContext::default();
let result = lower_nest(Arc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().flatten();
assert_eq!(result_schema, expected_output_schema);
}
}