hamelin_translation 0.3.10

Lowering and IR for Hamelin query language
Documentation
//! Pipeline pass: PARSE lowering.
//!
//! Lowers PARSE commands to LET + WHERE:
//!
//! ```text
//! PARSE source "prefix-*-suffix" AS field1
//! ```
//! becomes:
//! ```text
//! LET field1 = regexp_extract(source, "(?s)prefix-(.*?)-suffix", 1)
//! | WHERE regexp_count(source, "(?s)prefix-(.*?)-suffix") > 0
//! ```
//!
//! With NODROP flag (keep non-matching rows):
//! ```text
//! PARSE source "prefix-*-suffix" AS field1 NODROP
//! ```
//! becomes:
//! ```text
//! LET field1 = regexp_extract(source, "(?s)prefix-(.*?)-suffix", 1)
//! ```
//!
//! For multiple capture groups:
//! ```text
//! PARSE source "a-*-b-*-c" AS field1, field2
//! ```
//! becomes:
//! ```text
//! LET field1 = regexp_extract(source, "(?s)a-(.*?)-b-(.*?)-c", 1),
//!     field2 = regexp_extract(source, "(?s)a-(.*?)-b-(.*?)-c", 2)
//! | WHERE regexp_count(source, "(?s)a-(.*?)-b-(.*?)-c") > 0
//! ```
//!
//! Throwaway columns (`_`) are skipped in extraction.

use std::rc::Rc;

use hamelin_lib::err::TranslationError;
use hamelin_lib::tree::{
    ast::command::Command,
    builder::{self, call, gt, let_command, string, where_command, IntoExpressionBuilder},
    typed_ast::{
        command::{TypedCommand, TypedCommandKind, TypedParseCommand},
        context::StatementTranslationContext,
        pipeline::TypedPipeline,
    },
};

/// Lower PARSE commands to LET + WHERE.
///
/// Contract: `Rc<TypedPipeline> -> Result<Rc<TypedPipeline>, ...>`
pub fn lower_parse(
    pipeline: Rc<TypedPipeline>,
    ctx: &mut StatementTranslationContext,
) -> Result<Rc<TypedPipeline>, Rc<TranslationError>> {
    // Check if any PARSE command exists
    if !pipeline
        .valid_ref()?
        .commands
        .iter()
        .any(|cmd| matches!(&cmd.kind, TypedCommandKind::Parse(_)))
    {
        return Ok(pipeline);
    }

    let valid = pipeline.valid_ref()?;

    // Transform commands
    let mut pipe_builder = builder::pipeline();
    for cmd in &valid.commands {
        for c in lower_command(cmd, ctx)? {
            pipe_builder = pipe_builder.command(c);
        }
    }

    let new_ast = pipe_builder.build().at(pipeline.ast.span);

    // Re-typecheck
    Ok(Rc::new(TypedPipeline::from_ast_with_context(
        Rc::new(new_ast),
        ctx,
    )))
}

/// Lower a single command - transforms PARSE, passes others through.
fn lower_command(
    cmd: &Rc<TypedCommand>,
    ctx: &mut StatementTranslationContext,
) -> Result<Vec<Command>, Rc<TranslationError>> {
    let TypedCommandKind::Parse(parse_cmd) = &cmd.kind else {
        return Ok(vec![cmd.ast.as_ref().clone()]);
    };

    lower_parse_command(parse_cmd, cmd, ctx)
}

/// Lower a PARSE command to LET + WHERE.
fn lower_parse_command(
    parse_cmd: &TypedParseCommand,
    cmd: &TypedCommand,
    ctx: &mut StatementTranslationContext,
) -> Result<Vec<Command>, Rc<TranslationError>> {
    // Convert anchor pattern to regex pattern
    let regex_pattern = anchor_pattern_to_regex(&parse_cmd.pattern);

    // Get the source expression (or default to message_field if not specified)
    let source_expr = match &parse_cmd.source {
        Some(expr) => expr.ast.as_ref().clone(),
        None => ctx.message_field.clone().into_expression_builder().build(),
    };

    // Build LET command with regexp_extract for each non-throwaway identifier
    let mut let_builder = let_command().at(cmd.ast.span);
    let mut has_fields = false;

    for (i, id) in parse_cmd.identifiers.iter().enumerate() {
        let field_id = id.clone().valid()?;

        // Skip throwaway columns
        if field_id.to_string() == "_" {
            continue;
        }

        let group_num = (i + 1) as i64; // 1-indexed capture groups

        // regexp_extract(source, pattern, group)
        let extract_expr = call("regexp_extract")
            .arg(source_expr.clone())
            .arg(string(&regex_pattern))
            .arg(group_num);

        let_builder = let_builder.named_field(field_id, extract_expr);
        has_fields = true;
    }

    // If all identifiers were throwaway, return empty
    if !has_fields {
        return Ok(vec![]);
    }

    let let_cmd = let_builder.build();

    // If NODROP, just return the LET
    if parse_cmd.nodrop {
        return Ok(vec![let_cmd]);
    }

    // Otherwise, add WHERE to filter out non-matching rows
    // Filter: regexp_count(source, pattern) > 0
    let where_cmd = where_command(gt(
        call("regexp_count")
            .arg(source_expr)
            .arg(string(&regex_pattern)),
        0,
    ))
    .at(cmd.ast.span)
    .build();

    Ok(vec![let_cmd, where_cmd])
}

/// Convert an anchor pattern (with * wildcards) to a regex pattern.
///
/// Logic from parse.rs:
/// 1. If pattern ends with single `*` (not `**`), append another `*`
/// 2. Escape special regex chars: `[-\[\]{}()+?.,\\^$|#]`
/// 3. Replace `*` with `(.*?)` (non-greedy)
/// 4. Replace consecutive `(.*?)(.*?)` with `(.*)` (greedy)
/// 5. Prepend `(?s)` for single-line mode (dot matches newlines)
fn anchor_pattern_to_regex(pattern: &str) -> String {
    // Step 1: Handle trailing single *
    let pattern = if pattern.ends_with('*') && !pattern.ends_with("**") {
        format!("{}*", pattern)
    } else {
        pattern.to_string()
    };

    // Step 2: Escape special regex characters
    let mut escaped = String::new();
    for c in pattern.chars() {
        match c {
            '-' | '[' | ']' | '{' | '}' | '(' | ')' | '+' | '?' | '.' | ',' | '\\' | '^' | '$'
            | '|' | '#' => {
                escaped.push('\\');
                escaped.push(c);
            }
            _ => escaped.push(c),
        }
    }

    // Step 3: Replace * with (.*?)
    let with_captures = escaped.replace('*', "(.*?)");

    // Step 4: Replace consecutive (.*?)(.*?) with (.*)
    let collapsed = with_captures.replace("(.*?)(.*?)", "(.*)");

    // Step 5: Prepend (?s) for single-line mode
    format!("(?s){}", collapsed)
}

#[cfg(test)]
mod tests {
    use super::*;

    // Unit tests for anchor_pattern_to_regex helper
    #[test]
    fn test_anchor_pattern_to_regex_simple() {
        assert_eq!(
            anchor_pattern_to_regex("prefix-*-suffix"),
            "(?s)prefix\\-(.*?)\\-suffix"
        );
    }

    #[test]
    fn test_anchor_pattern_to_regex_multiple() {
        // Two consecutive * become (.*)
        assert_eq!(
            anchor_pattern_to_regex("a-*-b-*-c"),
            "(?s)a\\-(.*?)\\-b\\-(.*?)\\-c"
        );
    }

    #[test]
    fn test_anchor_pattern_to_regex_trailing_star() {
        // Trailing * gets doubled, then becomes (.*)
        assert_eq!(anchor_pattern_to_regex("prefix-*"), "(?s)prefix\\-(.*)");
    }

    #[test]
    fn test_anchor_pattern_to_regex_escapes_metacharacters() {
        // Trailing * gets doubled then collapsed to (.*)
        assert_eq!(
            anchor_pattern_to_regex("user.name=*"),
            "(?s)user\\.name=(.*)"
        );
    }

    #[test]
    fn test_anchor_pattern_to_regex_no_wildcards() {
        assert_eq!(anchor_pattern_to_regex("literal"), "(?s)literal");
    }

    #[test]
    fn test_consecutive_captures_collapse() {
        // Two consecutive wildcards should become greedy
        assert_eq!(anchor_pattern_to_regex("**"), "(?s)(.*)");
    }

    // Pipeline pass integration tests
    use hamelin_lib::{
        tree::{
            ast::{pipeline::Pipeline, IntoTyped, TypeCheckExecutor},
            builder::{
                call, column_ref, gt, let_command, parse_command, pipeline, select_command, string,
                where_command,
            },
        },
        types::{struct_type::Struct, INT, STRING},
    };
    use pretty_assertions::assert_eq;
    use rstest::rstest;
    use std::rc::Rc;

    #[rstest]
    // Case 1: No PARSE commands - passes through unchanged
    #[case::no_parse_passthrough(
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        pipeline()
            .command(select_command().named_field("a", 1).named_field("b", 2).build())
            .build(),
        Struct::default().with_str("a", INT).with_str("b", INT)
    )]
    // Case 2: PARSE with explicit source → LET + WHERE
    #[case::parse_with_source(
        pipeline()
            .command(select_command().named_field("msg", "prefix-hello-suffix").build())
            .command(parse_command().pattern("prefix-*-suffix").identifier("value").source(column_ref("msg")).build())
            .build(),
        pipeline()
            .command(select_command().named_field("msg", "prefix-hello-suffix").build())
            .command(let_command()
                .named_field("value", call("regexp_extract")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)prefix\\-(.*?)\\-suffix"))
                    .arg(1))
                .build())
            .command(where_command(gt(
                call("regexp_count")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)prefix\\-(.*?)\\-suffix")),
                0
            )).build())
            .build(),
        Struct::default().with_str("value", STRING).with_str("msg", STRING)
    )]
    // Case 3: PARSE with NODROP → LET only (no WHERE)
    #[case::parse_with_nodrop(
        pipeline()
            .command(select_command().named_field("msg", "user=test").build())
            .command(parse_command().pattern("user=*").identifier("user").source(column_ref("msg")).nodrop(true).build())
            .build(),
        pipeline()
            .command(select_command().named_field("msg", "user=test").build())
            .command(let_command()
                .named_field("user", call("regexp_extract")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)user=(.*)"))
                    .arg(1))
                .build())
            .build(),
        Struct::default().with_str("user", STRING).with_str("msg", STRING)
    )]
    // Case 4: PARSE with multiple capture groups
    #[case::parse_multiple_captures(
        pipeline()
            .command(select_command().named_field("msg", "a-val1-b-val2-c").build())
            .command(parse_command().pattern("a-*-b-*-c").identifier("x").identifier("y").source(column_ref("msg")).build())
            .build(),
        pipeline()
            .command(select_command().named_field("msg", "a-val1-b-val2-c").build())
            .command(let_command()
                .named_field("x", call("regexp_extract")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)a\\-(.*?)\\-b\\-(.*?)\\-c"))
                    .arg(1))
                .named_field("y", call("regexp_extract")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)a\\-(.*?)\\-b\\-(.*?)\\-c"))
                    .arg(2))
                .build())
            .command(where_command(gt(
                call("regexp_count")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)a\\-(.*?)\\-b\\-(.*?)\\-c")),
                0
            )).build())
            .build(),
        Struct::default().with_str("x", STRING).with_str("y", STRING).with_str("msg", STRING)
    )]
    // Case 5: PARSE with throwaway column (_) - skipped in extraction
    #[case::parse_throwaway_column(
        pipeline()
            .command(select_command().named_field("msg", "skip-keep-end").build())
            .command(parse_command().pattern("*-*-end").identifier("_").identifier("val").source(column_ref("msg")).build())
            .build(),
        pipeline()
            .command(select_command().named_field("msg", "skip-keep-end").build())
            .command(let_command()
                .named_field("val", call("regexp_extract")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)(.*?)\\-(.*?)\\-end"))
                    .arg(2))
                .build())
            .command(where_command(gt(
                call("regexp_count")
                    .arg(column_ref("msg"))
                    .arg(string("(?s)(.*?)\\-(.*?)\\-end")),
                0
            )).build())
            .build(),
        Struct::default().with_str("val", STRING).with_str("msg", STRING)
    )]
    fn test_lower_parse(
        #[case] input: Pipeline,
        #[case] expected: Pipeline,
        #[case] expected_output_schema: Struct,
    ) {
        let input_typed = input.typed_with().typed();
        let expected_typed = expected.typed_with().typed();

        let mut ctx = StatementTranslationContext::default();
        let result = lower_parse(Rc::new(input_typed), &mut ctx).unwrap();

        // Compare ASTs
        assert_eq!(result.ast, expected_typed.ast);

        // Verify output schema
        let result_schema = result.environment().flatten();
        assert_eq!(result_schema, expected_output_schema);
    }
}