use std::sync::Arc;
use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::identifier::Identifier;
use hamelin_lib::tree::ast::pipeline::Pipeline;
use hamelin_lib::tree::builder::{
self, call, drop_command, eq, field_ref, where_command, window_command,
};
use hamelin_lib::tree::typed_ast::clause::ResolvedSelection;
use hamelin_lib::tree::typed_ast::command::TypedCommandKind;
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
const DEDUP_RN_FIELD: &str = "__dedup_rn";
pub fn dedup_append_source(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
let Some((append_idx, distinct_by_ids)) = find_append_with_distinct_by(&pipeline)? else {
return Ok(pipeline);
};
let new_pipeline = build_deduped_pipeline(&pipeline, append_idx, &distinct_by_ids)?;
Ok(Arc::new(TypedPipeline::from_ast_with_context(
Arc::new(new_pipeline),
ctx,
)))
}
fn find_append_with_distinct_by(
pipeline: &TypedPipeline,
) -> Result<Option<(usize, Vec<Identifier>)>, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
for (idx, cmd) in valid.commands.iter().enumerate() {
if let TypedCommandKind::Append(append_cmd) = &cmd.kind {
if append_cmd.distinct_by.is_empty() {
return Ok(None);
}
let ids: Vec<Identifier> = append_cmd
.distinct_by
.iter()
.map(|sel| match &sel.resolved {
ResolvedSelection::Resolved(id) => Ok(id.clone()),
ResolvedSelection::Error(err) => Err(err.clone()),
})
.collect::<Result<_, _>>()?;
return Ok(Some((idx, ids)));
}
}
Ok(None)
}
fn build_deduped_pipeline(
pipeline: &TypedPipeline,
append_idx: usize,
distinct_by: &[Identifier],
) -> Result<Pipeline, Arc<TranslationError>> {
let valid = pipeline.valid_ref()?;
let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());
for (idx, cmd) in valid.commands.iter().enumerate() {
if idx == append_idx {
let mut window = window_command().named_field(DEDUP_RN_FIELD, call("row_number"));
for id in distinct_by {
window = window.group_by(id.clone(), id.clone());
}
pipeline_builder = pipeline_builder.command(window);
pipeline_builder =
pipeline_builder.command(where_command(eq(field_ref(DEDUP_RN_FIELD), 1)));
pipeline_builder = pipeline_builder.command(drop_command().field(DEDUP_RN_FIELD));
}
pipeline_builder = pipeline_builder.command(cmd.ast.clone());
}
Ok(pipeline_builder.build())
}
#[cfg(test)]
mod tests {
use super::*;
use hamelin_lib::{
func::registry::FunctionRegistry,
provider::EnvironmentProvider,
tree::{
ast::{
identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
pipeline::Pipeline,
},
builder::{append_command, let_command, pipeline},
options::TypeCheckOptions,
},
type_check_with_options,
types::{struct_type::Struct, INT, STRING},
};
use pretty_assertions::assert_eq;
use rstest::rstest;
use std::sync::Arc;
#[derive(Debug)]
struct TestProvider;
impl EnvironmentProvider for TestProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let target: Identifier = AstSimpleIdentifier::new("target").into();
if *name == target {
Ok(Struct::default()
.with_str("id", INT)
.with_str("name", STRING))
} else {
anyhow::bail!("Table not found: {name}")
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
fn run_test(input: Pipeline, expected: Pipeline) -> Result<(), Arc<TranslationError>> {
let provider = Arc::new(TestProvider);
let registry = Arc::new(FunctionRegistry::default());
let tc_opts = || {
TypeCheckOptions::builder()
.registry(registry.clone())
.provider(provider.clone())
.build()
};
let input_typed = type_check_with_options(input, tc_opts()).output;
let expected_typed = type_check_with_options(expected, tc_opts()).output;
let mut ctx = StatementTranslationContext::new(registry, provider);
let result = dedup_append_source(Arc::new(input_typed), &mut ctx)?;
assert_eq!(result.ast, expected_typed.ast);
Ok(())
}
#[rstest]
#[case::no_append_passthrough(
pipeline()
.command(let_command().named_field("id", 1).named_field("name", "alice"))
.build(),
pipeline()
.command(let_command().named_field("id", 1).named_field("name", "alice"))
.build()
)]
#[case::append_without_distinct_by_passthrough(
pipeline()
.command(let_command().named_field("id", 1).named_field("name", "alice"))
.command(append_command("target"))
.build(),
pipeline()
.command(let_command().named_field("id", 1).named_field("name", "alice"))
.command(append_command("target"))
.build()
)]
#[case::append_with_distinct_by_inserts_dedup(
pipeline()
.command(let_command().named_field("id", 1).named_field("name", "alice"))
.command(append_command("target").distinct_by("id"))
.build(),
pipeline()
.command(let_command().named_field("id", 1).named_field("name", "alice"))
.command(window_command()
.named_field(DEDUP_RN_FIELD, call("row_number"))
.group_by("id", field_ref("id")))
.command(where_command(eq(field_ref(DEDUP_RN_FIELD), 1)))
.command(drop_command().field(DEDUP_RN_FIELD))
.command(append_command("target").distinct_by("id"))
.build()
)]
fn test_dedup_append_source(
#[case] input: Pipeline,
#[case] expected: Pipeline,
) -> Result<(), Arc<TranslationError>> {
run_test(input, expected)
}
}