hamelin_translation 0.9.3

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: Deduplicate source rows before APPEND DISTINCT.
//!
//! When APPEND has a DISTINCT clause, the MERGE (or anti-join) only deduplicates
//! source rows against the target table. If the source itself contains duplicate keys,
//! Trino's MERGE will fail and DataFusion will insert all duplicates.
//!
//! This pass inserts WINDOW + WHERE + DROP before APPEND to keep exactly one row
//! per distinct key group in the source data.
//!
//! Example:
//! ```text
//! ... | APPEND target DISTINCT id
//! ```
//! becomes:
//! ```text
//! ... | WINDOW __dedup_rn = row_number() BY id
//!     | WHERE __dedup_rn = 1
//!     | DROP __dedup_rn
//!     | APPEND target DISTINCT id
//! ```

use std::sync::Arc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::ast::identifier::Identifier;
use hamelin_lib::tree::ast::pipeline::Pipeline;
use hamelin_lib::tree::builder::{
    self, call, drop_command, eq, field_ref, where_command, window_command,
};
use hamelin_lib::tree::typed_ast::clause::ResolvedSelection;
use hamelin_lib::tree::typed_ast::command::TypedCommandKind;
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;

const DEDUP_RN_FIELD: &str = "__dedup_rn";

/// Deduplicate source rows before APPEND when DISTINCT is present.
///
/// Contract: `Arc<TypedPipeline> -> Result<Arc<TypedPipeline>, Arc<TranslationError>>`
pub fn dedup_append_source(
    pipeline: Arc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
    let Some((append_idx, distinct_by_ids)) = find_append_with_distinct_by(&pipeline)? else {
        return Ok(pipeline);
    };

    let new_pipeline = build_deduped_pipeline(&pipeline, append_idx, &distinct_by_ids)?;

    Ok(Arc::new(TypedPipeline::from_ast_with_context(
        Arc::new(new_pipeline),
        ctx,
    )))
}

/// Find APPEND with non-empty DISTINCT, returning its index and resolved identifiers.
fn find_append_with_distinct_by(
    pipeline: &TypedPipeline,
) -> Result<Option<(usize, Vec<Identifier>)>, Arc<TranslationError>> {
    let valid = pipeline.valid_ref()?;

    for (idx, cmd) in valid.commands.iter().enumerate() {
        if let TypedCommandKind::Append(append_cmd) = &cmd.kind {
            if append_cmd.distinct_by.is_empty() {
                return Ok(None);
            }

            let ids: Vec<Identifier> = append_cmd
                .distinct_by
                .iter()
                .map(|sel| match &sel.resolved {
                    ResolvedSelection::Resolved(id) => Ok(id.clone()),
                    ResolvedSelection::Error(err) => Err(err.clone()),
                })
                .collect::<Result<_, _>>()?;

            return Ok(Some((idx, ids)));
        }
    }

    Ok(None)
}

/// Build a new pipeline with WINDOW/WHERE/DROP inserted before APPEND.
fn build_deduped_pipeline(
    pipeline: &TypedPipeline,
    append_idx: usize,
    distinct_by: &[Identifier],
) -> Result<Pipeline, Arc<TranslationError>> {
    let valid = pipeline.valid_ref()?;
    let mut pipeline_builder = builder::pipeline().at(pipeline.ast.span.clone());

    for (idx, cmd) in valid.commands.iter().enumerate() {
        if idx == append_idx {
            let mut window = window_command().named_field(DEDUP_RN_FIELD, call("row_number"));
            for id in distinct_by {
                window = window.group_by(id.clone(), id.clone());
            }

            pipeline_builder = pipeline_builder.command(window);
            pipeline_builder =
                pipeline_builder.command(where_command(eq(field_ref(DEDUP_RN_FIELD), 1)));
            pipeline_builder = pipeline_builder.command(drop_command().field(DEDUP_RN_FIELD));
        }
        pipeline_builder = pipeline_builder.command(cmd.ast.clone());
    }

    Ok(pipeline_builder.build())
}

#[cfg(test)]
mod tests {
    use super::*;
    use hamelin_lib::{
        func::registry::FunctionRegistry,
        provider::EnvironmentProvider,
        tree::{
            ast::{
                identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
                pipeline::Pipeline,
            },
            builder::{append_command, pipeline, set_command},
            options::TypeCheckOptions,
        },
        type_check_with_options,
        types::{struct_type::Struct, INT, STRING},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::sync::Arc;

    #[derive(Debug)]
    struct TestProvider;

    impl EnvironmentProvider for TestProvider {
        fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
            let target: Identifier = AstSimpleIdentifier::new("target").into();

            if *name == target {
                Ok(Struct::default()
                    .with_str("id", INT)
                    .with_str("name", STRING))
            } else {
                anyhow::bail!("Table not found: {name}")
            }
        }

        fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
            Ok(vec![])
        }
    }

    fn run_test(input: Pipeline, expected: Pipeline) -> Result<(), Arc<TranslationError>> {
        let provider = Arc::new(TestProvider);
        let registry = Arc::new(FunctionRegistry::default());
        let tc_opts = || {
            TypeCheckOptions::builder()
                .registry(registry.clone())
                .provider(provider.clone())
                .build()
        };
        let input_typed = type_check_with_options(input, tc_opts()).output;
        let expected_typed = type_check_with_options(expected, tc_opts()).output;

        let mut ctx = StatementTranslationContext::new(registry, provider);
        let result = dedup_append_source(Arc::new(input_typed), &mut ctx)?;

        assert_eq!(result.ast, expected_typed.ast);

        Ok(())
    }

    #[rstest]
    #[case::no_append_passthrough(
        pipeline()
            .command(set_command().named_field("id", 1).named_field("name", "alice"))
            .build(),
        pipeline()
            .command(set_command().named_field("id", 1).named_field("name", "alice"))
            .build()
    )]
    #[case::append_without_distinct_by_passthrough(
        pipeline()
            .command(set_command().named_field("id", 1).named_field("name", "alice"))
            .command(append_command("target"))
            .build(),
        pipeline()
            .command(set_command().named_field("id", 1).named_field("name", "alice"))
            .command(append_command("target"))
            .build()
    )]
    #[case::append_with_distinct_by_inserts_dedup(
        pipeline()
            .command(set_command().named_field("id", 1).named_field("name", "alice"))
            .command(append_command("target").distinct_by("id"))
            .build(),
        pipeline()
            .command(set_command().named_field("id", 1).named_field("name", "alice"))
            .command(window_command()
                .named_field(DEDUP_RN_FIELD, call("row_number"))
                .group_by("id", field_ref("id")))
            .command(where_command(eq(field_ref(DEDUP_RN_FIELD), 1)))
            .command(drop_command().field(DEDUP_RN_FIELD))
            .command(append_command("target").distinct_by("id"))
            .build()
    )]
    fn test_dedup_append_source(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
    ) -> Result<(), Arc<TranslationError>> {
        run_test(input, expected)
    }
}