use std::collections::HashSet;
use std::rc::Rc;
use ordermap::OrderMap;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::{
command::Command,
expression::Expression,
identifier::{Identifier, SimpleIdentifier},
node::Span,
},
builder::{self, column_ref, field, select_command, ExpressionBuilder},
typed_ast::{
clause::Projections,
command::{
TypedCommand, TypedCommandKind, TypedDropCommand, TypedLetCommand,
TypedSelectCommand,
},
context::StatementTranslationContext,
expression::{TypedExpression, TypedExpressionKind},
pipeline::TypedPipeline,
},
},
};
pub fn fuse_projections(
pipeline: Rc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Rc<TypedPipeline>, Rc<TranslationError>> {
let needs_fusion = pipeline
.valid_ref()?
.commands
.iter()
.any(command_needs_fusion);
if !needs_fusion {
return Ok(pipeline);
}
let valid = pipeline.valid_ref()?;
let fused_commands = fuse_commands(&valid.commands)?;
let mut pipe_builder = builder::pipeline().at(pipeline.ast.span.clone());
for cmd in fused_commands {
pipe_builder = pipe_builder.command(cmd);
}
let fused_ast = pipe_builder.build();
Ok(Rc::new(TypedPipeline::from_ast_with_context(
Rc::new(fused_ast),
ctx,
)))
}
fn command_needs_fusion(cmd: &Rc<TypedCommand>) -> bool {
matches!(
&cmd.kind,
TypedCommandKind::Let(_) | TypedCommandKind::Drop(_)
)
}
fn fuse_commands(commands: &[Rc<TypedCommand>]) -> Result<Vec<Command>, Rc<TranslationError>> {
let mut result: Vec<Command> = Vec::new();
let mut pending: Option<PendingSelect> = None;
for command in commands {
match &command.kind {
TypedCommandKind::Select(select_cmd) => {
if let Some(p) = pending.take() {
result.push(p.emit());
}
pending = Some(PendingSelect::from_select(command, select_cmd)?);
}
TypedCommandKind::Let(let_cmd) => {
let refs = extract_column_references_from_projections(&let_cmd.projections);
if let Some(ref mut p) = pending {
if refs
.iter()
.any(|r| p.assigned.iter().any(|a| identifiers_overlap(r, a)))
{
result.push(pending.take().unwrap().emit());
pending = Some(PendingSelect::from_let(command, let_cmd)?);
} else {
p.merge_let(command, let_cmd)?;
}
} else {
pending = Some(PendingSelect::from_let(command, let_cmd)?);
}
}
TypedCommandKind::Drop(drop_cmd) => {
if let Some(ref mut p) = pending {
p.merge_drop(command, drop_cmd);
} else {
pending = Some(PendingSelect::from_drop(command, drop_cmd));
}
}
_ => {
if let Some(p) = pending.take() {
result.push(p.emit());
}
result.push(command.ast.as_ref().clone());
}
}
}
if let Some(p) = pending {
result.push(p.emit());
}
Ok(result)
}
struct PendingSelect {
assignments: OrderMap<Identifier, Rc<Expression>>,
assigned: HashSet<Identifier>,
span: Span,
}
impl PendingSelect {
fn from_select(
command: &TypedCommand,
select_cmd: &TypedSelectCommand,
) -> Result<Self, Rc<TranslationError>> {
let mut assignments = ordermap::OrderMap::new();
let mut assigned = HashSet::new();
for assignment in &select_cmd.projections.assignments {
let id = assignment.identifier.clone().valid()?;
assignments.insert(id.clone(), assignment.expression.ast.clone());
assigned.insert(id);
}
Ok(Self {
assignments,
assigned,
span: command.ast.span,
})
}
fn from_let(
command: &TypedCommand,
let_cmd: &TypedLetCommand,
) -> Result<Self, Rc<TranslationError>> {
let mut assignments = ordermap::OrderMap::new();
let mut assigned = HashSet::new();
for assignment in &let_cmd.projections.assignments {
let id = assignment.identifier.clone().valid()?;
assignments.insert(id.clone(), assignment.expression.ast.clone());
assigned.insert(id);
}
let input_struct = command.input_schema.flatten();
for (field_name, _field_type) in input_struct.fields.iter() {
let identifier: Identifier = SimpleIdentifier::from(field_name.clone()).into();
if !assignments.contains_key(&identifier) {
let passthrough_expr = synthesize_passthrough_ast(&identifier);
assignments.insert(identifier, Rc::new(passthrough_expr));
}
}
Ok(Self {
assignments,
assigned,
span: command.ast.span,
})
}
fn from_drop(command: &TypedCommand, drop_cmd: &TypedDropCommand) -> Self {
let mut assignments = ordermap::OrderMap::new();
let dropped_set: HashSet<_> = drop_cmd.dropped_fields.iter().cloned().collect();
let input_struct = command.input_schema.flatten();
for (field_name, _field_type) in input_struct.fields.iter() {
let identifier: Identifier = SimpleIdentifier::from(field_name.clone()).into();
if dropped_set.contains(&identifier) {
continue;
}
let passthrough_expr = synthesize_passthrough_ast(&identifier);
assignments.insert(identifier, Rc::new(passthrough_expr));
}
Self {
assignments,
assigned: HashSet::new(), span: command.ast.span,
}
}
fn merge_let(
&mut self,
command: &TypedCommand,
let_cmd: &TypedLetCommand,
) -> Result<(), Rc<TranslationError>> {
let mut new_assignments = OrderMap::new();
for assignment in &let_cmd.projections.assignments {
let id = assignment.identifier.clone().valid()?;
new_assignments.insert(id.clone(), assignment.expression.ast.clone());
self.assigned.insert(id);
}
for (id, expr) in self.assignments.drain(..) {
if !new_assignments.contains_key(&id) {
new_assignments.insert(id, expr);
}
}
self.assignments = new_assignments;
self.expand_span(&command.ast.span);
Ok(())
}
fn merge_drop(&mut self, command: &TypedCommand, drop_cmd: &TypedDropCommand) {
for dropped in &drop_cmd.dropped_fields {
self.assignments.remove(dropped);
self.assigned.remove(dropped);
}
self.expand_span(&command.ast.span);
}
fn expand_span(&mut self, other: &Span) {
if other.is_none() {
return;
}
if self.span.is_none() {
self.span = other.clone();
return;
}
let self_start = self.span.start().unwrap();
let self_end = self.span.end().unwrap();
let other_start = other.start().unwrap();
let other_end = other.end().unwrap();
let new_start = self_start.min(other_start);
let new_end = self_end.max(other_end);
self.span = Span::new(new_start, new_end);
}
fn emit(self) -> Command {
let mut builder = select_command();
for (identifier, expr) in self.assignments {
builder = builder.named_field(identifier, expr);
}
builder.at(self.span).build()
}
}
fn synthesize_passthrough_ast(identifier: &Identifier) -> Expression {
match identifier {
Identifier::Simple(simple) => column_ref(simple.as_str()).build(),
Identifier::Compound(compound) => {
let parts = &compound.parts;
assert!(!parts.is_empty());
let mut current: Box<dyn ExpressionBuilder> = Box::new(column_ref(parts[0].as_str()));
for part in &parts[1..] {
current = Box::new(field(current, part.as_str()));
}
current.build()
}
}
}
fn identifiers_overlap(a: &Identifier, b: &Identifier) -> bool {
a == b || a.has_prefix(b) || b.has_prefix(a)
}
fn extract_column_references_from_projections(projections: &Projections) -> HashSet<Identifier> {
let mut refs = HashSet::new();
for assignment in &projections.assignments {
extract_column_references_from_expression(&assignment.expression, &mut refs);
}
refs
}
fn extract_column_references_from_expression(
expr: &TypedExpression,
refs: &mut HashSet<Identifier>,
) {
expr.find(&mut |e| {
if let TypedExpressionKind::ColumnReference(col_ref) = &e.kind {
if let Ok(simple) = col_ref.column_name.clone().valid() {
refs.insert(simple.into());
}
}
false });
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
tree::ast::identifier::CompoundIdentifier,
tree::{
ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
builder::{add, column_ref, drop_command, let_command, pipeline, select_command},
},
types::{struct_type::Struct, INT},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::rc::Rc;
#[rstest]
#[case::no_fusion_needed(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.build(),
Struct::default().with_str("a", INT).with_str("b", INT)
)]
#[case::select_let_fused(
pipeline()
.command(select_command().named_field("a", 1).build())
.command(let_command().named_field("b", 2).build())
.build(),
pipeline()
.command(select_command().named_field("b", 2).named_field("a", 1).build())
.build(),
Struct::default().with_str("b", INT).with_str("a", INT)
)]
#[case::select_drop_fused(
pipeline()
.command(select_command().named_field("a", 1).named_field("b", 2).build())
.command(drop_command().field("b").build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).build())
.build(),
Struct::default().with_str("a", INT)
)]
#[case::select_multiple_lets_fused(
pipeline()
.command(select_command().named_field("a", 1).build())
.command(let_command().named_field("b", 2).build())
.command(let_command().named_field("c", 3).build())
.build(),
pipeline()
.command(select_command()
.named_field("c", 3)
.named_field("b", 2)
.named_field("a", 1)
.build())
.build(),
Struct::default().with_str("c", INT).with_str("b", INT).with_str("a", INT)
)]
#[case::barrier_let_refs_select_field(
pipeline()
.command(select_command().named_field("a", 1).build())
.command(let_command().named_field("b", add(column_ref("a"), 1)).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).build())
.command(select_command()
.named_field("b", add(column_ref("a"), 1))
.named_field("a", column_ref("a"))
.build())
.build(),
Struct::default().with_str("b", INT).with_str("a", INT)
)]
#[case::barrier_let_refs_let_field(
pipeline()
.command(select_command().named_field("a", 1).build())
.command(let_command().named_field("b", 2).build())
.command(let_command().named_field("c", add(column_ref("b"), 1)).build())
.build(),
pipeline()
.command(select_command()
.named_field("b", 2)
.named_field("a", 1)
.build())
.command(select_command()
.named_field("c", add(column_ref("b"), 1))
.named_field("b", column_ref("b"))
.named_field("a", column_ref("a"))
.build())
.build(),
Struct::default().with_str("c", INT).with_str("b", INT).with_str("a", INT)
)]
#[case::barrier_chained_dependencies(
pipeline()
.command(select_command().named_field("a", 1).build())
.command(let_command().named_field("b", add(column_ref("a"), 1)).build())
.command(let_command().named_field("c", add(column_ref("b"), 1)).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).build())
.command(select_command()
.named_field("b", add(column_ref("a"), 1))
.named_field("a", column_ref("a"))
.build())
.command(select_command()
.named_field("c", add(column_ref("b"), 1))
.named_field("b", column_ref("b"))
.named_field("a", column_ref("a"))
.build())
.build(),
Struct::default().with_str("c", INT).with_str("b", INT).with_str("a", INT)
)]
#[case::no_barrier_overwrite_without_ref(
pipeline()
.command(let_command().named_field("a", 1).build())
.command(let_command().named_field("a", 2).build())
.build(),
pipeline()
// Second LET's assignment wins (last write)
.command(select_command().named_field("a", 2).build())
.build(),
Struct::default().with_str("a", INT)
)]
#[case::barrier_self_reference(
pipeline()
.command(select_command().named_field("a", 1).build())
.command(let_command().named_field("a", add(column_ref("a"), 1)).build())
.build(),
pipeline()
.command(select_command().named_field("a", 1).build())
.command(select_command().named_field("a", add(column_ref("a"), 1)).build())
.build(),
Struct::default().with_str("a", INT)
)]
#[case::no_barrier_independent_lets(
pipeline()
.command(select_command().named_field("a", 1).build())
.command(let_command().named_field("b", 2).build())
.command(let_command().named_field("c", 3).build())
.build(),
pipeline()
.command(select_command()
.named_field("c", 3)
.named_field("b", 2)
.named_field("a", 1)
.build())
.build(),
Struct::default().with_str("c", INT).with_str("b", INT).with_str("a", INT)
)]
#[case::three_lets_prepend_order(
pipeline()
.command(let_command().named_field("a", 1).build())
.command(let_command().named_field("b", 2).build())
.command(let_command().named_field("c", 3).build())
.build(),
pipeline()
.command(select_command()
.named_field("c", 3)
.named_field("b", 2)
.named_field("a", 1)
.build())
.build(),
Struct::default().with_str("c", INT).with_str("b", INT).with_str("a", INT)
)]
#[case::compound_let_preserved(
pipeline()
.command(let_command()
.named_field(
CompoundIdentifier::new("x".into(), "a".into(), vec![]),
1,
)
.build())
.command(let_command()
.named_field(
CompoundIdentifier::new("x".into(), "b".into(), vec![]),
2,
)
.build())
.build(),
pipeline()
.command(select_command()
.named_field(
CompoundIdentifier::new("x".into(), "b".into(), vec![]),
2,
)
.named_field(
CompoundIdentifier::new("x".into(), "a".into(), vec![]),
1,
)
.build())
.build(),
Struct::default()
.with_str("x", Struct::default().with_str("b", INT).with_str("a", INT).into())
)]
fn test_fuse_projections(
#[case] input: Pipeline,
#[case] expected: Pipeline,
#[case] expected_output_schema: Struct,
) {
let input_typed = input.typed_with().typed();
let expected_typed = expected.typed_with().typed();
let mut ctx = StatementTranslationContext::default();
let result = fuse_projections(Rc::new(input_typed), &mut ctx).unwrap();
assert_eq!(result.ast, expected_typed.ast);
let result_schema = result.environment().flatten();
assert_eq!(result_schema, expected_output_schema);
}
}