use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::{clause::SortOrder, pattern::QuantifierKind, query::Query},
builder::{pipeline as pipeline_builder, query, window_command, ExpressionBuilder},
typed_ast::{
command::{TypedCommandKind, TypedMatchCommand},
context::StatementTranslationContext,
pattern::TypedPattern,
pipeline::TypedPipeline,
query::TypedStatement,
},
},
};
pub fn lower_match(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
if !statement_has_match(&statement)? {
return Ok(statement);
}
let new_query = transform_statement(&statement, ctx)?;
Ok(Arc::new(TypedStatement::from_ast_with_context(
Arc::new(new_query),
ctx,
)))
}
fn statement_has_match(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
statement
.iter()
.try_fold(false, |acc, p| pipeline_has_match(p).map(|pm| pm || acc))
}
fn pipeline_has_match(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let res = pipeline
.valid_ref()?
.commands
.iter()
.any(|c| matches!(&c.kind, TypedCommandKind::Match(_)));
Ok(res)
}
fn transform_statement(
statement: &TypedStatement,
ctx: &mut StatementTranslationContext,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
for with_clause in &statement.with_clauses {
let transformed = transform_pipeline(&with_clause.pipeline, ctx)?;
let valid_name = with_clause.name.clone().valid()?;
query_builder = query_builder.merge_as_cte(transformed, valid_name);
}
let main_query = transform_pipeline(&statement.pipeline, ctx)?;
Ok(query_builder.merge_as_main(main_query))
}
fn transform_pipeline(
pipeline: &TypedPipeline,
ctx: &mut StatementTranslationContext,
) -> Result<Query, Arc<TranslationError>> {
let commands = &pipeline.valid_ref()?.commands;
if let Some(first_cmd) = commands.first() {
if let TypedCommandKind::Match(match_cmd) = &first_cmd.kind {
let lowered_pipeline = lower_match_command(match_cmd, ctx)?;
let mut pipe_builder = pipeline_builder().at(pipeline.ast.span.clone());
for cmd in lowered_pipeline.commands {
pipe_builder = pipe_builder.command(cmd);
}
for cmd in commands.iter().skip(1) {
pipe_builder = pipe_builder.command(cmd.ast.clone());
}
return Ok(query().main(pipe_builder.build()).build());
}
}
Ok(query().main(pipeline.ast.clone()).build())
}
fn lower_match_command(
match_cmd: &TypedMatchCommand,
_ctx: &mut StatementTranslationContext,
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Arc<TranslationError>> {
let pattern_vars = extract_pattern_variables(&match_cmd.patterns)?;
let regex_pattern = pattern_to_regex(&match_cmd.patterns, &pattern_vars)?;
let starting_labels = find_starting_pattern_labels(&match_cmd.patterns, &pattern_vars)?;
let pipeline =
build_lowered_pipeline(match_cmd, &pattern_vars, ®ex_pattern, &starting_labels)?;
Ok(pipeline)
}
#[derive(Debug, Clone)]
struct PatternVariable {
alias: String,
table: String,
label: String,
quantifier: PatternQuantifier,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum PatternQuantifier {
One, OneOrMore, ZeroOrMore, ZeroOrOne, Exactly(u32), }
impl PatternQuantifier {
fn is_optional(&self) -> bool {
matches!(
self,
PatternQuantifier::ZeroOrMore | PatternQuantifier::ZeroOrOne
)
}
}
fn extract_pattern_variables(
patterns: &[TypedPattern],
) -> Result<Vec<PatternVariable>, Arc<TranslationError>> {
let mut variables = Vec::new();
let mut label_gen = LabelGenerator::new();
for pattern in patterns {
extract_from_pattern(pattern, &mut variables, &mut label_gen)?;
}
Ok(variables)
}
fn extract_from_pattern(
pattern: &TypedPattern,
variables: &mut Vec<PatternVariable>,
label_gen: &mut LabelGenerator,
) -> Result<(), Arc<TranslationError>> {
match pattern {
TypedPattern::Quantified(quant) => {
let (alias, table) = extract_alias_and_table(quant)?;
let quantifier = convert_quantifier(&quant.quantifier);
variables.push(PatternVariable {
alias,
table,
label: label_gen.next(),
quantifier,
});
}
TypedPattern::Nested(nested) => {
for sub_pattern in &nested.patterns {
extract_from_pattern(sub_pattern, variables, label_gen)?;
}
}
TypedPattern::Error(err) => {
return Err(err.clone());
}
}
Ok(())
}
fn extract_alias_and_table(
quant: &hamelin_lib::tree::typed_ast::pattern::TypedQuantifiedPattern,
) -> Result<(String, String), Arc<TranslationError>> {
use hamelin_lib::tree::typed_ast::clause::TypedFromClause;
match &quant.typed_from {
TypedFromClause::Alias(alias_clause) => {
let alias = alias_clause.alias.valid_ref()?.to_string();
let table = alias_clause.ast.table.identifier.valid_ref()?.to_string();
Ok((alias, table))
}
TypedFromClause::Reference(table_ref) => {
let table = table_ref.ast.identifier.valid_ref()?.to_string();
Ok((table.clone(), table))
}
TypedFromClause::Error(err) => Err(err.clone()),
}
}
fn convert_quantifier(
quantifier: &Arc<hamelin_lib::tree::ast::pattern::Quantifier>,
) -> PatternQuantifier {
match &quantifier.kind {
QuantifierKind::AtLeastOne => PatternQuantifier::OneOrMore,
QuantifierKind::AnyNumber => PatternQuantifier::ZeroOrMore,
QuantifierKind::ZeroOrOne => PatternQuantifier::ZeroOrOne,
QuantifierKind::Exactly(n) => {
if let Ok(num) = n.parse::<u32>() {
if num == 1 {
PatternQuantifier::One
} else {
PatternQuantifier::Exactly(num)
}
} else {
PatternQuantifier::One
}
}
QuantifierKind::Error(_) => PatternQuantifier::One,
}
}
struct LabelGenerator {
index: usize,
}
impl LabelGenerator {
fn new() -> Self {
Self { index: 0 }
}
fn next(&mut self) -> String {
let label = if self.index < 26 {
char::from(b'a' + self.index as u8).to_string()
} else if self.index < 52 {
char::from(b'A' + (self.index - 26) as u8).to_string()
} else {
let idx = self.index - 52;
let first = char::from(b'a' + (idx / 26) as u8);
let second = char::from(b'a' + (idx % 26) as u8);
format!("{}{}", first, second)
};
self.index += 1;
label
}
}
fn pattern_to_regex(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Result<String, Arc<TranslationError>> {
let elements = collect_pattern_elements(patterns, variables)?;
Ok(build_regex_from_elements(&elements))
}
struct PatternElement {
label: String,
quantifier: PatternQuantifier,
}
fn collect_pattern_elements(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Result<Vec<PatternElement>, Arc<TranslationError>> {
let mut elements = Vec::new();
let mut var_index = 0;
for pattern in patterns {
collect_from_pattern(pattern, variables, &mut var_index, &mut elements)?;
}
Ok(elements)
}
fn collect_from_pattern(
pattern: &TypedPattern,
variables: &[PatternVariable],
var_index: &mut usize,
elements: &mut Vec<PatternElement>,
) -> Result<(), Arc<TranslationError>> {
match pattern {
TypedPattern::Quantified(_) => {
if let Some(var) = variables.get(*var_index) {
elements.push(PatternElement {
label: var.label.clone(),
quantifier: var.quantifier,
});
*var_index += 1;
}
}
TypedPattern::Nested(_) => {
return Err(TranslationError::fatal(
"lower_match",
"nested pattern groups with quantifiers (e.g., (a b)+) are not yet supported \
— use flat patterns instead (e.g., a+ b+)"
.into(),
)
.into());
}
TypedPattern::Error(_) => {}
}
Ok(())
}
fn build_regex_from_elements(elements: &[PatternElement]) -> String {
if elements.is_empty() {
return String::new();
}
fn all_remaining_optional(elements: &[PatternElement], from_idx: usize) -> bool {
elements[from_idx..]
.iter()
.all(|e| e.quantifier.is_optional())
}
let mut parts = Vec::new();
for (i, elem) in elements.iter().enumerate() {
let is_last = i == elements.len() - 1;
let next_all_optional = !is_last && all_remaining_optional(elements, i + 1);
let label = &elem.label;
let part = match elem.quantifier {
PatternQuantifier::One => {
if is_last || next_all_optional {
label.clone()
} else {
format!("{},", label)
}
}
PatternQuantifier::OneOrMore => {
if is_last || next_all_optional {
format!("{}(,{})*", label, label)
} else {
format!("({},)+", label)
}
}
PatternQuantifier::ZeroOrMore => {
if is_last || next_all_optional {
format!("(,{}(,{})*)?", label, label)
} else {
format!("({},)*", label)
}
}
PatternQuantifier::ZeroOrOne => {
if is_last || next_all_optional {
format!("(,{})?", label)
} else {
format!("({},)?", label)
}
}
PatternQuantifier::Exactly(n) => {
if n == 0 {
String::new()
} else if is_last || next_all_optional {
(0..n)
.map(|i| {
if i == 0 {
label.clone()
} else {
format!(",{}", label)
}
})
.collect::<Vec<_>>()
.join("")
} else {
(0..n)
.map(|_| format!("{},", label))
.collect::<Vec<_>>()
.join("")
}
}
};
if !part.is_empty() {
parts.push(part);
}
}
format!("^{}", parts.join(""))
}
fn find_starting_pattern_labels(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Result<Vec<String>, Arc<TranslationError>> {
let elements = collect_pattern_elements(patterns, variables)?;
let mut labels = Vec::new();
for elem in &elements {
labels.push(elem.label.clone());
if !elem.quantifier.is_optional() {
break;
}
}
Ok(labels)
}
struct AggInfo {
name: String,
func_name: String,
column: Option<String>,
}
fn extract_agg_info(match_cmd: &TypedMatchCommand) -> Result<Vec<AggInfo>, Arc<TranslationError>> {
use hamelin_lib::tree::typed_ast::expression::TypedExpressionKind;
let mut agg_infos = Vec::new();
for assignment in &match_cmd.agg.assignments {
let name = assignment.identifier.valid_ref()?.to_string();
let (func_name, column) = match &assignment.expression.kind {
TypedExpressionKind::Apply(apply) => {
let func = apply.function_def.name().to_lowercase();
let col = apply
.parameter_binding
.get_by_index(0)
.ok()
.and_then(|arg| extract_column_name_from_expr(arg));
(func, col)
}
_ => continue, };
agg_infos.push(AggInfo {
name,
func_name,
column,
});
}
Ok(agg_infos)
}
fn extract_column_name_from_expr(
expr: &hamelin_lib::tree::typed_ast::expression::TypedExpression,
) -> Option<String> {
use hamelin_lib::tree::typed_ast::expression::TypedExpressionKind;
match &expr.kind {
TypedExpressionKind::ColumnReference(cr) => {
cr.column_name.valid_ref().ok().map(|s| s.to_string())
}
_ => None,
}
}
fn get_agg_columns(agg_infos: &[AggInfo]) -> Vec<String> {
let mut columns = Vec::new();
let mut seen = std::collections::HashSet::new();
for info in agg_infos {
if let Some(col) = &info.column {
if seen.insert(col.clone()) {
columns.push(col.clone());
}
}
}
columns
}
fn build_lowered_pipeline(
match_cmd: &TypedMatchCommand,
pattern_vars: &[PatternVariable],
regex_pattern: &str,
starting_labels: &[String],
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Arc<TranslationError>> {
use hamelin_lib::tree::builder::{
and, call, cast, column_ref, eq, in_, int, is_not_null, pair, pipeline, sort_command,
string, subtract, tuple,
};
use hamelin_lib::types::{array::Array, Type};
let agg_infos = extract_agg_info(match_cmd)?;
let agg_columns = get_agg_columns(&agg_infos);
let has_agg = !agg_infos.is_empty();
let mut pipe = pipeline();
pipe = pipe.from(|f| {
let mut from_builder = f;
for var in pattern_vars {
from_builder = from_builder.table_alias(var.alias.as_str(), var.table.as_str());
}
from_builder
});
let mut case_call = call("case");
for var in pattern_vars {
case_call = case_call.arg(pair(
is_not_null(column_ref(var.alias.as_str())),
string(&var.label),
));
}
pipe = pipe.let_cmd(|l| l.named_field("__pattern_label", case_call));
pipe = pipe.where_cmd(is_not_null(column_ref("__pattern_label")));
let mut window_builder = if has_agg {
let labels_array = call("array_agg").arg(column_ref("__pattern_label"));
self::window_command().named_field("__labels_array", labels_array)
} else {
let state_expr = call("array_join")
.arg(call("array_agg").arg(column_ref("__pattern_label")))
.arg(string(","));
self::window_command().named_field("__state", state_expr)
};
if has_agg {
for col in &agg_columns {
let array_col_name = format!("__agg_values_{}", col);
let value_array = call("array_agg").arg(column_ref(col.as_str()));
window_builder = window_builder.named_field(array_col_name, value_array);
}
}
if let Some(within) = &match_cmd.within {
window_builder = window_builder.within(within.ast.as_ref().clone());
}
for assignment in &match_cmd.group_by.assignments {
if let Ok(id) = assignment.identifier.valid_ref() {
window_builder =
window_builder.group_by(id.clone(), assignment.expression.ast.as_ref().clone());
}
}
if !match_cmd.sort.is_empty() {
let mut sort_builder = sort_command();
for sort_expr in &match_cmd.sort {
let expr = sort_expr.ast.expression.as_ref().clone();
sort_builder = match sort_expr.order {
SortOrder::Asc => sort_builder.asc(expr),
SortOrder::Desc => sort_builder.desc(expr),
};
}
window_builder = window_builder.sort(sort_builder);
}
pipe = pipe.window(|_| window_builder);
if has_agg {
let state_expr = call("array_join")
.arg(column_ref("__labels_array"))
.arg(string(","));
pipe = pipe.let_cmd(|l| l.named_field("__state", state_expr));
let match_length_expr = call("len").arg(
call("split")
.arg(
call("regexp_extract")
.arg(column_ref("__state"))
.arg(string(regex_pattern)),
)
.arg(string(",")),
);
pipe = pipe.let_cmd(|l| l.named_field("__match_length", match_length_expr));
let mut agg_let_builder: Option<hamelin_lib::tree::builder::LetCommandBuilder> = None;
for info in &agg_infos {
let transformed_expr: Arc<hamelin_lib::tree::ast::expression::Expression> =
match info.func_name.as_str() {
"count" => {
Arc::new(
call("len")
.arg(
call("slice")
.arg(column_ref("__labels_array"))
.arg(int(0))
.arg(column_ref("__match_length")),
)
.build(),
)
}
"sum" => {
let col = info.column.as_ref().expect("sum requires column");
let array_col = format!("__agg_values_{}", col);
let slice_expr = call("slice")
.arg(column_ref(array_col.as_str()))
.arg(int(0))
.arg(column_ref("__match_length"));
let slice_double = cast(slice_expr, Type::Array(Array::new(Type::Double)));
Arc::new(call("sum").arg(slice_double).build())
}
"avg" => {
let col = info.column.as_ref().expect("avg requires column");
let array_col = format!("__agg_values_{}", col);
let slice_expr = call("slice")
.arg(column_ref(array_col.as_str()))
.arg(int(0))
.arg(column_ref("__match_length"));
let slice_double = cast(slice_expr, Type::Array(Array::new(Type::Double)));
Arc::new(call("avg").arg(slice_double).build())
}
"max" | "min" => {
let col = info.column.as_ref().expect("max/min require column");
let array_col = format!("__agg_values_{}", col);
let slice_expr = call("slice")
.arg(column_ref(array_col.as_str()))
.arg(int(0))
.arg(column_ref("__match_length"));
Arc::new(call(&info.func_name).arg(slice_expr).build())
}
"first" => {
let col = info.column.as_ref().expect("first requires column");
let array_col = format!("__agg_values_{}", col);
let slice_expr = call("slice")
.arg(column_ref(array_col.as_str()))
.arg(int(0))
.arg(column_ref("__match_length"));
Arc::new(call("get").arg(slice_expr).arg(int(0)).build())
}
"last" => {
let col = info.column.as_ref().expect("last requires column");
let array_col = format!("__agg_values_{}", col);
let slice_expr = call("slice")
.arg(column_ref(array_col.as_str()))
.arg(int(0))
.arg(column_ref("__match_length"));
Arc::new(
call("get")
.arg(slice_expr)
.arg(subtract(column_ref("__match_length"), int(1)))
.build(),
)
}
_ => {
continue;
}
};
agg_let_builder = Some(match agg_let_builder {
None => hamelin_lib::tree::builder::let_command()
.named_field(info.name.as_str(), transformed_expr),
Some(builder) => builder.named_field(info.name.as_str(), transformed_expr),
});
}
if let Some(builder) = agg_let_builder {
pipe = pipe.let_cmd(|_| builder);
}
}
let regexp_filter = call("regexp_like")
.arg(column_ref("__state"))
.arg(string(regex_pattern));
pipe = if starting_labels.len() == 1 {
pipe.where_cmd(and(
eq(column_ref("__pattern_label"), string(&starting_labels[0])),
regexp_filter,
))
} else if starting_labels.len() > 1 {
let mut tup = tuple();
for label in starting_labels {
tup = tup.element(string(label));
}
pipe.where_cmd(and(in_(column_ref("__pattern_label"), tup), regexp_filter))
} else {
pipe.where_cmd(regexp_filter)
};
let mut drop_builder = pipe.drop(|d| d.field("__pattern_label").field("__state"));
if has_agg {
drop_builder = drop_builder.drop(|d| {
let mut db = d.field("__labels_array").field("__match_length");
for col in &agg_columns {
let array_col_name = format!("__agg_values_{}", col);
db = db.field(array_col_name.as_str());
}
db
});
}
Ok(drop_builder.build())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_label_generator() {
let mut gen = LabelGenerator::new();
assert_eq!(gen.next(), "a");
assert_eq!(gen.next(), "b");
for _ in 2..26 {
gen.next();
}
assert_eq!(gen.next(), "A");
for _ in 27..52 {
gen.next();
}
assert_eq!(gen.next(), "aa");
assert_eq!(gen.next(), "ab");
}
#[test]
fn test_pattern_to_regex_single() {
let elements = vec![PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::OneOrMore,
}];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^a(,a)*");
}
#[test]
fn test_pattern_to_regex_sequence() {
let elements = vec![
PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::OneOrMore,
},
PatternElement {
label: "b".to_string(),
quantifier: PatternQuantifier::OneOrMore,
},
];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^(a,)+b(,b)*");
}
#[test]
fn test_pattern_to_regex_optional_end() {
let elements = vec![
PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::One,
},
PatternElement {
label: "b".to_string(),
quantifier: PatternQuantifier::ZeroOrOne,
},
];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^a(,b)?");
}
#[test]
fn test_no_match_passthrough() -> Result<(), Arc<TranslationError>> {
use hamelin_lib::{
provider::EnvironmentProvider,
sql::{expression::identifier::Identifier as SqlIdentifier, query::TableReference},
tree::{
ast::{IntoTyped, TypeCheckExecutor},
builder::{column_ref, eq, pipeline, query, HasMain, QueryBuilder},
},
types::{struct_type::Struct, INT},
};
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, table: TableReference) -> anyhow::Result<Struct> {
let mut fields = Struct::default();
let events: SqlIdentifier = "events".parse().unwrap();
if table.name == events {
fields.fields.insert("timestamp".parse().unwrap(), INT);
fields.fields.insert("value".parse().unwrap(), INT);
Ok(fields)
} else {
anyhow::bail!("Table not found: {}", table.name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
Ok(vec![])
}
}
fn typed_query(builder: QueryBuilder<HasMain>) -> TypedStatement {
builder
.build()
.typed_with()
.with_provider(Arc::new(MockProvider))
.typed()
}
let q = query().main(
pipeline()
.from(|f| f.table_reference("events"))
.where_cmd(eq(column_ref("value"), 10)),
);
let statement = typed_query(q);
assert!(!statement_has_match(&statement)?);
Ok(())
}
}