use std::rc::Rc;
use hamelin_lib::{
err::TranslationError,
tree::{
ast::{clause::SortOrder, pattern::QuantifierKind, query::Query},
builder::{pipeline as pipeline_builder, query, window_command},
typed_ast::{
command::{TypedCommandKind, TypedMatchCommand},
context::StatementTranslationContext,
pattern::TypedPattern,
pipeline::TypedPipeline,
query::TypedStatement,
},
},
};
pub fn lower_match(
statement: Rc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Rc<TypedStatement>, Rc<TranslationError>> {
if !statement_has_match(&statement)? {
return Ok(statement);
}
let new_query = transform_statement(&statement, ctx)?;
Ok(Rc::new(TypedStatement::from_ast_with_context(
Rc::new(new_query),
ctx,
)))
}
fn statement_has_match(statement: &TypedStatement) -> Result<bool, Rc<TranslationError>> {
statement
.iter()
.try_fold(false, |acc, p| pipeline_has_match(p).map(|pm| pm || acc))
}
fn pipeline_has_match(pipeline: &TypedPipeline) -> Result<bool, Rc<TranslationError>> {
let res = pipeline
.valid_ref()?
.commands
.iter()
.any(|c| matches!(&c.kind, TypedCommandKind::Match(_)));
Ok(res)
}
fn transform_statement(
statement: &TypedStatement,
ctx: &mut StatementTranslationContext,
) -> Result<Query, Rc<TranslationError>> {
let mut query_builder = query();
for with_clause in &statement.with_clauses {
let transformed = transform_pipeline(&with_clause.pipeline, ctx)?;
let valid_name = with_clause.name.clone().valid()?;
query_builder = query_builder.merge_as_cte(transformed, valid_name);
}
let main_query = transform_pipeline(&statement.pipeline, ctx)?;
Ok(query_builder.merge_as_main(main_query))
}
fn transform_pipeline(
pipeline: &TypedPipeline,
ctx: &mut StatementTranslationContext,
) -> Result<Query, Rc<TranslationError>> {
let commands = &pipeline.valid_ref()?.commands;
if let Some(first_cmd) = commands.first() {
if let TypedCommandKind::Match(match_cmd) = &first_cmd.kind {
let lowered_pipeline = lower_match_command(match_cmd, ctx)?;
let mut pipe_builder = pipeline_builder().at(pipeline.ast.span.clone());
for cmd in lowered_pipeline.commands {
pipe_builder = pipe_builder.command(cmd);
}
for cmd in commands.iter().skip(1) {
pipe_builder = pipe_builder.command(cmd.ast.clone());
}
return Ok(query().main(pipe_builder.build()).build());
}
}
Ok(query().main(pipeline.ast.clone()).build())
}
fn lower_match_command(
match_cmd: &TypedMatchCommand,
_ctx: &mut StatementTranslationContext,
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Rc<TranslationError>> {
let pattern_vars = extract_pattern_variables(&match_cmd.patterns)?;
let regex_pattern = pattern_to_regex(&match_cmd.patterns, &pattern_vars);
let first_required = find_first_required_pattern(&match_cmd.patterns, &pattern_vars);
let pipeline =
build_lowered_pipeline(match_cmd, &pattern_vars, ®ex_pattern, first_required)?;
Ok(pipeline)
}
#[derive(Debug, Clone)]
struct PatternVariable {
alias: String,
table: String,
label: String,
quantifier: PatternQuantifier,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum PatternQuantifier {
One, OneOrMore, ZeroOrMore, ZeroOrOne, Exactly(u32), }
impl PatternQuantifier {
fn is_optional(&self) -> bool {
matches!(
self,
PatternQuantifier::ZeroOrMore
| PatternQuantifier::ZeroOrOne
| PatternQuantifier::Exactly(0)
)
}
}
fn extract_pattern_variables(
patterns: &[TypedPattern],
) -> Result<Vec<PatternVariable>, Rc<TranslationError>> {
let mut variables = Vec::new();
let mut label_gen = LabelGenerator::new();
for pattern in patterns {
extract_from_pattern(pattern, &mut variables, &mut label_gen)?;
}
Ok(variables)
}
fn extract_from_pattern(
pattern: &TypedPattern,
variables: &mut Vec<PatternVariable>,
label_gen: &mut LabelGenerator,
) -> Result<(), Rc<TranslationError>> {
match pattern {
TypedPattern::Quantified(quant) => {
let (alias, table) = extract_alias_and_table(quant)?;
let quantifier = convert_quantifier(&quant.quantifier);
variables.push(PatternVariable {
alias,
table,
label: label_gen.next(),
quantifier,
});
}
TypedPattern::Nested(nested) => {
for sub_pattern in &nested.patterns {
extract_from_pattern(sub_pattern, variables, label_gen)?;
}
}
TypedPattern::Error(err) => {
return Err(err.clone());
}
}
Ok(())
}
fn extract_alias_and_table(
quant: &hamelin_lib::tree::typed_ast::pattern::TypedQuantifiedPattern,
) -> Result<(String, String), Rc<TranslationError>> {
use hamelin_lib::tree::typed_ast::clause::TypedFromClause;
match &quant.typed_from {
TypedFromClause::Alias(alias_clause) => {
let alias = alias_clause.alias.valid_ref()?.to_string();
let table = alias_clause.ast.table.identifier.valid_ref()?.to_string();
Ok((alias, table))
}
TypedFromClause::Reference(table_ref) => {
let table = table_ref.ast.identifier.valid_ref()?.to_string();
Ok((table.clone(), table))
}
TypedFromClause::Error(err) => Err(err.clone()),
}
}
fn convert_quantifier(
quantifier: &Rc<hamelin_lib::tree::ast::pattern::Quantifier>,
) -> PatternQuantifier {
match &quantifier.kind {
QuantifierKind::AtLeastOne => PatternQuantifier::OneOrMore,
QuantifierKind::AnyNumber => PatternQuantifier::ZeroOrMore,
QuantifierKind::ZeroOrOne => PatternQuantifier::ZeroOrOne,
QuantifierKind::Exactly(n) => {
if let Ok(num) = n.parse::<u32>() {
if num == 1 {
PatternQuantifier::One
} else {
PatternQuantifier::Exactly(num)
}
} else {
PatternQuantifier::One
}
}
QuantifierKind::Error(_) => PatternQuantifier::One,
}
}
struct LabelGenerator {
index: usize,
}
impl LabelGenerator {
fn new() -> Self {
Self { index: 0 }
}
fn next(&mut self) -> String {
let label = if self.index < 26 {
char::from(b'a' + self.index as u8).to_string()
} else if self.index < 52 {
char::from(b'A' + (self.index - 26) as u8).to_string()
} else {
let idx = self.index - 52;
let first = char::from(b'a' + (idx / 26) as u8);
let second = char::from(b'a' + (idx % 26) as u8);
format!("{}{}", first, second)
};
self.index += 1;
label
}
}
fn pattern_to_regex(patterns: &[TypedPattern], variables: &[PatternVariable]) -> String {
let elements = collect_pattern_elements(patterns, variables);
build_regex_from_elements(&elements)
}
struct PatternElement {
label: String,
quantifier: PatternQuantifier,
}
fn collect_pattern_elements(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Vec<PatternElement> {
let mut elements = Vec::new();
let mut var_index = 0;
for pattern in patterns {
collect_from_pattern(pattern, variables, &mut var_index, &mut elements);
}
elements
}
fn collect_from_pattern(
pattern: &TypedPattern,
variables: &[PatternVariable],
var_index: &mut usize,
elements: &mut Vec<PatternElement>,
) {
match pattern {
TypedPattern::Quantified(_) => {
if let Some(var) = variables.get(*var_index) {
elements.push(PatternElement {
label: var.label.clone(),
quantifier: var.quantifier,
});
*var_index += 1;
}
}
TypedPattern::Nested(nested) => {
for sub_pattern in &nested.patterns {
collect_from_pattern(sub_pattern, variables, var_index, elements);
}
}
TypedPattern::Error(_) => {}
}
}
fn build_regex_from_elements(elements: &[PatternElement]) -> String {
if elements.is_empty() {
return String::new();
}
fn all_remaining_optional(elements: &[PatternElement], from_idx: usize) -> bool {
elements[from_idx..]
.iter()
.all(|e| e.quantifier.is_optional())
}
let mut parts = Vec::new();
for (i, elem) in elements.iter().enumerate() {
let is_last = i == elements.len() - 1;
let next_all_optional = !is_last && all_remaining_optional(elements, i + 1);
let label = &elem.label;
let part = match elem.quantifier {
PatternQuantifier::One => {
if is_last || next_all_optional {
label.clone()
} else {
format!("{},", label)
}
}
PatternQuantifier::OneOrMore => {
if is_last || next_all_optional {
format!("{}(,{})*", label, label)
} else {
format!("({},)+", label)
}
}
PatternQuantifier::ZeroOrMore => {
if is_last || next_all_optional {
format!("(,{}(,{})*)?", label, label)
} else {
format!("({},)*", label)
}
}
PatternQuantifier::ZeroOrOne => {
if is_last || next_all_optional {
format!("(,{})?", label)
} else {
format!("({},)?", label)
}
}
PatternQuantifier::Exactly(n) => {
if n == 0 {
String::new()
} else if is_last || next_all_optional {
(0..n)
.map(|i| {
if i == 0 {
label.clone()
} else {
format!(",{}", label)
}
})
.collect::<Vec<_>>()
.join("")
} else {
(0..n)
.map(|_| format!("{},", label))
.collect::<Vec<_>>()
.join("")
}
}
};
if !part.is_empty() {
parts.push(part);
}
}
format!("^{}", parts.join(""))
}
fn find_first_required_pattern(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Option<String> {
let elements = collect_pattern_elements(patterns, variables);
for elem in elements {
if !elem.quantifier.is_optional() {
return Some(elem.label);
}
}
None
}
fn build_lowered_pipeline(
match_cmd: &TypedMatchCommand,
pattern_vars: &[PatternVariable],
regex_pattern: &str,
first_required: Option<String>,
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Rc<TranslationError>> {
use hamelin_lib::tree::builder::{
and, call, column_ref, eq, is_not_null, pair, pipeline, sort_command, string,
};
let mut pipe = pipeline();
pipe = pipe.from(|f| {
let mut from_builder = f;
for var in pattern_vars {
from_builder = from_builder.table_alias(var.alias.as_str(), var.table.as_str());
}
from_builder
});
let mut case_call = call("case");
for var in pattern_vars {
case_call = case_call.arg(pair(
is_not_null(column_ref(var.alias.as_str())),
string(&var.label),
));
}
pipe = pipe.let_cmd(|l| l.named_field("__pattern_label", case_call));
pipe = pipe.where_cmd(is_not_null(column_ref("__pattern_label")));
let state_expr = call("array_join")
.arg(call("array_agg").arg(column_ref("__pattern_label")))
.arg(string(","));
let mut window_builder = self::window_command().named_field("__state", state_expr);
if let Some(within) = &match_cmd.within {
window_builder = window_builder.within(within.ast.as_ref().clone());
}
for assignment in &match_cmd.group_by.assignments {
let id = assignment.identifier.valid_ref()?;
window_builder =
window_builder.group_by(id.clone(), assignment.expression.ast.as_ref().clone());
}
if !match_cmd.sort.is_empty() {
let mut sort_builder = sort_command();
for sort_expr in &match_cmd.sort {
let expr = sort_expr.ast.expression.as_ref().clone();
sort_builder = match sort_expr.order {
SortOrder::Asc => sort_builder.asc(expr),
SortOrder::Desc => sort_builder.desc(expr),
};
}
window_builder = window_builder.sort(sort_builder);
}
pipe = pipe.window(|_| window_builder);
let regexp_filter = call("regexp_like")
.arg(column_ref("__state"))
.arg(string(regex_pattern));
pipe = if let Some(first_label) = first_required {
pipe.where_cmd(and(
eq(column_ref("__pattern_label"), string(&first_label)),
regexp_filter,
))
} else {
pipe.where_cmd(regexp_filter)
};
pipe = pipe.drop(|d| d.field("__pattern_label").field("__state"));
Ok(pipe.build())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_label_generator() {
let mut gen = LabelGenerator::new();
assert_eq!(gen.next(), "a");
assert_eq!(gen.next(), "b");
for _ in 2..26 {
gen.next();
}
assert_eq!(gen.next(), "A");
for _ in 27..52 {
gen.next();
}
assert_eq!(gen.next(), "aa");
assert_eq!(gen.next(), "ab");
}
#[test]
fn test_pattern_to_regex_single() {
let elements = vec![PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::OneOrMore,
}];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^a(,a)*");
}
#[test]
fn test_pattern_to_regex_sequence() {
let elements = vec![
PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::OneOrMore,
},
PatternElement {
label: "b".to_string(),
quantifier: PatternQuantifier::OneOrMore,
},
];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^(a,)+b(,b)*");
}
#[test]
fn test_pattern_to_regex_optional_end() {
let elements = vec![
PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::One,
},
PatternElement {
label: "b".to_string(),
quantifier: PatternQuantifier::ZeroOrOne,
},
];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^a(,b)?");
}
#[test]
fn test_no_match_passthrough() -> Result<(), Rc<TranslationError>> {
use hamelin_lib::{
provider::EnvironmentProvider,
sql::{expression::identifier::Identifier as SqlIdentifier, query::TableReference},
tree::{
ast::{IntoTyped, TypeCheckExecutor},
builder::{column_ref, eq, pipeline, query, HasMain, QueryBuilder},
},
types::{struct_type::Struct, INT},
};
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, table: TableReference) -> anyhow::Result<Struct> {
let mut fields = Struct::default();
let events: SqlIdentifier = "events".parse().unwrap();
if table.name == events {
fields.fields.insert("timestamp".parse().unwrap(), INT);
fields.fields.insert("value".parse().unwrap(), INT);
Ok(fields)
} else {
anyhow::bail!("Table not found: {}", table.name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<SqlIdentifier>> {
Ok(vec![])
}
}
fn typed_query(builder: QueryBuilder<HasMain>) -> TypedStatement {
builder
.build()
.typed_with()
.with_provider(Arc::new(MockProvider))
.typed()
}
let q = query().main(
pipeline()
.from(|f| f.table_reference("events"))
.where_cmd(eq(column_ref("value"), 10)),
);
let statement = typed_query(q);
assert!(!statement_has_match(&statement)?);
Ok(())
}
}