use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::query::Query,
builder::{self, join_table_reference, lookup_table_reference, query, ExpressionBuilder},
typed_ast::{
command::TypedCommandKind, context::StatementTranslationContext,
pipeline::TypedPipeline, query::TypedStatement,
},
},
};
use super::super::unique::UniqueNameGenerator;
use hamelin_lib::tree::builder::pipeline as pipeline_builder;
pub fn lower_joins(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
if !statement_has_joins(&statement)? {
return Ok(statement);
}
let mut name_gen = UniqueNameGenerator::new("__join");
let new_query = transform_statement(&statement, &mut name_gen)?;
Ok(Arc::new(TypedStatement::from_ast_with_context(
Arc::new(new_query),
ctx,
)))
}
fn statement_has_joins(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
statement
.iter()
.try_fold(false, |acc, p| pipeline_has_joins(p).map(|pj| pj || acc))
}
fn pipeline_has_joins(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let res = pipeline.valid_ref()?.commands.iter().any(|c| {
matches!(
&c.kind,
TypedCommandKind::Join(_) | TypedCommandKind::Lookup(_)
)
});
Ok(res)
}
fn transform_statement(
statement: &TypedStatement,
name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
for with_clause in &statement.with_clauses {
let transformed = transform_pipeline(&with_clause.pipeline, statement, name_gen)?;
let valid_name = with_clause.name.clone().valid()?;
query_builder = query_builder.merge_as_cte(transformed, valid_name);
}
let main_query = transform_pipeline(&statement.pipeline, statement, name_gen)?;
Ok(query_builder.merge_as_main(main_query))
}
fn transform_pipeline(
pipeline: &TypedPipeline,
statement: &TypedStatement,
name_gen: &mut UniqueNameGenerator,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
let mut pipe_builder = pipeline_builder().at(pipeline.ast.span.clone());
for cmd in &pipeline.valid_ref()?.commands {
match &cmd.kind {
TypedCommandKind::Join(join_cmd) => {
let alias = join_cmd.right.alias.valid_ref()?;
let table_name = join_cmd
.right
.ast
.table
.identifier
.valid_ref()
.map(|id| id.clone())?;
let cte_name = name_gen.next(statement);
let cte_pipeline = builder::pipeline()
.from(|f| f.table_reference(table_name))
.nest(alias.clone())
.build();
query_builder = query_builder.with(cte_name.clone(), cte_pipeline);
let condition = join_cmd
.condition
.as_ref()
.map(|c| c.ast.as_ref().clone())
.unwrap_or_else(|| builder::boolean(true).build());
pipe_builder =
pipe_builder.command(join_table_reference(cte_name).on(condition).build());
}
TypedCommandKind::Lookup(lookup_cmd) => {
let alias = lookup_cmd.right.alias.valid_ref()?;
let table_name = lookup_cmd
.right
.ast
.table
.identifier
.valid_ref()
.map(|id| id.clone())?;
let cte_name = name_gen.next(statement);
let cte_pipeline = builder::pipeline()
.from(|f| f.table_reference(table_name))
.nest(alias.clone())
.build();
query_builder = query_builder.with(cte_name.clone(), cte_pipeline);
let condition = lookup_cmd
.condition
.as_ref()
.map(|c| c.ast.as_ref().clone())
.unwrap_or_else(|| builder::boolean(true).build());
pipe_builder =
pipe_builder.command(lookup_table_reference(cte_name).on(condition).build());
}
_ => {
pipe_builder = pipe_builder.command(cmd.ast.clone());
}
}
}
Ok(query_builder.main(pipe_builder.build()).build())
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
func::registry::FunctionRegistry,
provider::EnvironmentProvider,
sql::{expression::identifier::Identifier as SqlIdentifier, query::TableReference},
tree::{
ast::{IntoTyped, TypeCheckExecutor},
builder::{column_ref, eq, field, query, HasMain, QueryBuilder},
},
types::{struct_type::Struct, INT},
};
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, table: TableReference) -> anyhow::Result<Struct> {
let mut fields = Struct::default();
let events: SqlIdentifier = "events".parse().unwrap();
let users: SqlIdentifier = "users".parse().unwrap();
if table.name == events {
fields.fields.insert("timestamp".parse().unwrap(), INT);
fields.fields.insert("user_id".parse().unwrap(), INT);
Ok(fields)
} else if table.name == users {
fields.fields.insert("id".parse().unwrap(), INT);
fields.fields.insert("name".parse().unwrap(), INT);
Ok(fields)
} else {
anyhow::bail!("Table not found: {}", table.name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
Ok(vec![])
}
}
fn typed_query(builder: QueryBuilder<HasMain>) -> TypedStatement {
builder
.build()
.typed_with()
.with_provider(Arc::new(MockProvider))
.typed()
}
#[test]
fn test_no_joins_passthrough() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder()
.from(|f| f.table_reference("events"))
.where_cmd(eq(column_ref("timestamp"), 10)),
);
let statement = typed_query(q);
assert!(!statement_has_joins(&statement)?);
Ok(())
}
#[test]
fn test_join_generates_cte() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder()
.from(|f| f.table_reference("events"))
.join(
"users",
eq(column_ref("user_id"), field(column_ref("users"), "id")),
),
);
let statement = typed_query(q);
assert!(statement_has_joins(&statement)?);
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = lower_joins(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.with_clauses.len(), 1);
let cte_name = transformed.with_clauses[0].name.valid_ref().unwrap();
assert_eq!(cte_name.to_string(), "__join_0");
Ok(())
}
#[test]
fn test_lookup_generates_cte() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder()
.from(|f| f.table_reference("events"))
.lookup("users", |l| {
l.on(eq(column_ref("user_id"), field(column_ref("users"), "id")))
}),
);
let statement = typed_query(q);
assert!(statement_has_joins(&statement)?);
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = lower_joins(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.with_clauses.len(), 1);
let cte_name = transformed.with_clauses[0].name.valid_ref().unwrap();
assert_eq!(cte_name.to_string(), "__join_0");
Ok(())
}
#[test]
fn test_multiple_joins_generate_multiple_ctes() -> Result<(), Arc<TranslationError>> {
let q = query().main(
pipeline_builder()
.from(|f| f.table_reference("events"))
.join(
"users",
eq(column_ref("user_id"), field(column_ref("users"), "id")),
)
.join(
"users",
eq(column_ref("user_id"), field(column_ref("users"), "id")),
),
);
let statement = typed_query(q);
let registry = Arc::new(FunctionRegistry::default());
let provider = Arc::new(MockProvider);
let mut ctx = StatementTranslationContext::new(registry, provider);
let transformed = lower_joins(Arc::new(statement), &mut ctx)?;
assert_eq!(transformed.with_clauses.len(), 2);
let cte_name_0 = transformed.with_clauses[0].name.valid_ref().unwrap();
let cte_name_1 = transformed.with_clauses[1].name.valid_ref().unwrap();
assert_eq!(cte_name_0.to_string(), "__join_0");
assert_eq!(cte_name_1.to_string(), "__join_1");
Ok(())
}
}