use std::rc::Rc;
use std::sync::Arc;
use crate::ast::assignment_clause::HamelinAssignmentClause;
use crate::ast::command::from;
use crate::ast::expression::HamelinExpression;
use crate::ast::from_clause::HamelinFromClause;
use crate::ast::pipeline::HamelinPipeline;
use crate::ast::sort_expression::HamelinSortExpression;
use crate::ast::QueryTranslationContext;
use crate::env::Environment;
use crate::translation::projection_builder::ProjectionBuilder;
use crate::translation::PendingQuery;
use antlr_rust::tree::ParseTree;
use hamelin_lib::antlr::hamelinparser::{
ExactlyContextAttrs, FromClauseContextAll, GroupClauseContextAttrs, MatchCommandContext,
MatchCommandContextAttrs, NestedContextAttrs, PatternContextAll, QuantifiedContextAttrs,
QuantifierContextAll,
};
use hamelin_lib::err::{TranslationError, TranslationErrors};
use hamelin_lib::func::def::{FunctionTranslationContext, SpecialPosition};
use hamelin_lib::func::registry::FunctionRegistry;
use hamelin_lib::sql::expression::apply::{BinaryOperatorApply, FunctionCallApply, Lambda};
use hamelin_lib::sql::expression::identifier::{Identifier, SimpleIdentifier};
use hamelin_lib::sql::expression::literal::{
ColumnReference, IntegerLiteral, NullLiteral, StringLiteral,
};
use hamelin_lib::sql::expression::operator::Operator;
use hamelin_lib::sql::expression::{Case, Leaf};
use hamelin_lib::sql::expression::{Direction, OrderByExpression, SQLExpression};
use hamelin_lib::sql::query::cte::CTE;
use hamelin_lib::sql::query::projection::{Binding, ColumnProjection, Projection};
use hamelin_lib::sql::query::window::{
FrameBoundary, FrameType, WindowExpression, WindowFrame, WindowSpecification,
};
use hamelin_lib::sql::query::{BadPatternVariable, PatternVariable, SQLQuery, TableReference};
use hamelin_lib::sql::types::{SQLBaseType, SQLType};
use hamelin_lib::translation::ExpressionTranslation;
use hamelin_lib::types::{Type, CALENDAR_INTERVAL, INTERVAL, TIMESTAMP};
use ordermap::OrderMap;
use super::from::{FromAliasing, FromAnonymous};
use hamelin_lib::sql::expression::apply::UnaryOperatorApply;
fn contains_negative_interval(sql: &SQLExpression) -> bool {
match sql {
SQLExpression::UnaryOperatorApply(UnaryOperatorApply { operator, .. }) => {
*operator == Operator::Minus
}
_ => false,
}
}
#[derive(Debug, Clone)]
struct AggBinding {
identifier: Identifier,
translation: ExpressionTranslation,
}
#[derive(Debug, Clone)]
pub struct PatternElement {
pub variable: String,
pub quantifier: PatternQuantifier,
}
#[derive(Debug, Clone, PartialEq)]
pub enum PatternQuantifier {
One,
OneOrMore,
ZeroOrMore,
ZeroOrOne,
Exactly(usize),
}
impl PatternElement {
pub fn new(variable: String, quantifier: PatternQuantifier) -> Self {
Self {
variable,
quantifier,
}
}
}
pub fn pattern_to_regex(elements: &[PatternElement]) -> String {
if elements.is_empty() {
return String::new();
}
fn is_optional(q: &PatternQuantifier) -> bool {
matches!(
q,
PatternQuantifier::ZeroOrMore
| PatternQuantifier::ZeroOrOne
| PatternQuantifier::Exactly(0)
)
}
fn all_remaining_optional(elements: &[PatternElement], from_idx: usize) -> bool {
elements[from_idx..]
.iter()
.all(|e| is_optional(&e.quantifier))
}
let mut parts = Vec::new();
let mut all_preceding_optional = true;
for (i, elem) in elements.iter().enumerate() {
let is_first = i == 0;
let is_last = i == elements.len() - 1;
let var_lower = elem.variable.to_lowercase();
let next_all_optional = !is_last && all_remaining_optional(elements, i + 1);
let needs_leading_comma = !is_first && !all_preceding_optional;
let part = match &elem.quantifier {
PatternQuantifier::One => {
all_preceding_optional = false;
if is_last {
var_lower
} else if next_all_optional {
var_lower
} else {
format!("{},", var_lower)
}
}
PatternQuantifier::OneOrMore => {
all_preceding_optional = false;
if is_last {
format!("{}(,{})*", var_lower, var_lower)
} else if next_all_optional {
format!("{}(,{})*", var_lower, var_lower)
} else {
format!("({},)+", var_lower)
}
}
PatternQuantifier::ZeroOrMore => {
if is_last || next_all_optional {
if needs_leading_comma {
format!("(,{}(,{})*)?", var_lower, var_lower)
} else if is_first {
format!("({}(,{})*)?", var_lower, var_lower)
} else {
format!("(,?{}(,{})*)?", var_lower, var_lower)
}
} else {
format!("({},)*", var_lower)
}
}
PatternQuantifier::ZeroOrOne => {
if is_last || next_all_optional {
if needs_leading_comma {
format!("(,{})?", var_lower)
} else if is_first {
format!("({})?", var_lower)
} else {
format!("(,?{})?", var_lower)
}
} else {
format!("({},)?", var_lower)
}
}
PatternQuantifier::Exactly(n) => {
if *n > 0 {
all_preceding_optional = false;
}
if *n == 0 {
String::new()
} else if is_last {
let repeated: Vec<_> = (0..*n).map(|_| var_lower.clone()).collect();
repeated.join(",")
} else if next_all_optional {
let repeated: Vec<_> = (0..*n).map(|_| var_lower.clone()).collect();
repeated.join(",")
} else {
let repeated: Vec<_> = (0..*n).map(|_| format!("{},", var_lower)).collect();
repeated.join("")
}
}
};
if !part.is_empty() {
parts.push(part);
}
}
format!("^{}", parts.join(""))
}
pub fn extract_pattern_elements(
patterns: &[Rc<PatternContextAll<'static>>],
cte_env: &Arc<Environment>,
statement_context: &Rc<QueryTranslationContext>,
) -> Result<Vec<PatternElement>, TranslationErrors> {
let mut result = Vec::new();
for pctx in patterns {
match pctx.as_ref() {
PatternContextAll::QuantifiedContext(qctx) => {
let fctx = TranslationErrors::expect(qctx, qctx.fromClause())?;
let fc = HamelinFromClause::new(
fctx.clone(),
cte_env.clone(),
statement_context.clone(),
);
let maybe_si = fc.table_alias().map(|si| si.to_sql()).transpose()?;
let si = match maybe_si {
Some(s) => s,
None => fc.table_identifier()?.last().clone(),
};
let pattern_var = PatternVariable::try_from(si.clone()).map_err(
|BadPatternVariable { name }| {
TranslationError::msg(
qctx,
&format!(
"Pattern variable '{}' contains special characters. \
Pattern variables must be simple identifiers without backticks.",
name.name
),
)
.single()
},
)?;
let var_name = pattern_var.name.to_uppercase();
let quantifier = if let Some(q) = qctx.quantifier() {
match q.as_ref() {
QuantifierContextAll::AtLeastOneContext(_) => PatternQuantifier::OneOrMore,
QuantifierContextAll::AnyNumberContext(_) => PatternQuantifier::ZeroOrMore,
QuantifierContextAll::ZeroOrOneContext(_) => PatternQuantifier::ZeroOrOne,
QuantifierContextAll::ExactlyContext(ectx) => {
let text =
TranslationErrors::expect(ectx, ectx.INTEGER_VALUE())?.get_text();
let n: usize = text.parse().map_err(|_| {
TranslationError::msg(
ectx,
&format!("Invalid quantifier value '{}': expected a positive integer", text),
)
.single()
})?;
PatternQuantifier::Exactly(n)
}
QuantifierContextAll::Error(e) => {
return Err(TranslationError::msg(e, "parse error").single());
}
}
} else {
PatternQuantifier::One
};
result.push(PatternElement::new(var_name, quantifier));
}
PatternContextAll::NestedContext(nctx) => {
if nctx.quantifier().is_some() {
return Err(TranslationError::msg(
nctx,
"Nested pattern groups with quantifiers (e.g., (a b)+) are not supported. \
The CTE-based pattern matching uses regex on state labels and does not \
support capturing groups. Use flat patterns instead (e.g., a b+ or a+ b+).",
)
.single());
}
let nested =
extract_pattern_elements(&nctx.pattern_all(), cte_env, statement_context)?;
result.extend(nested);
}
PatternContextAll::Error(e) => {
return Err(TranslationError::msg(e, "parse error").single());
}
}
}
Ok(result)
}
fn build_pattern_label_expr(
from_clauses: &OrderMap<SimpleIdentifier, Rc<FromClauseContextAll<'static>>>,
) -> SQLExpression {
let mut result: SQLExpression = NullLiteral::default().into();
for si in from_clauses.keys().rev() {
let var_name = si.name.to_lowercase();
let condition: SQLExpression = BinaryOperatorApply::new(
Operator::IsNot,
ColumnReference::new(si.clone().into()).into(),
NullLiteral::default().into(),
)
.into();
result = FunctionCallApply::with_three(
"IF",
condition,
StringLiteral::new(&var_name).into(),
result,
)
.into();
}
result
}
pub fn translate_with_cte(
ctx: &MatchCommandContext<'static>,
pipeline: &HamelinPipeline,
previous: &PendingQuery,
) -> Result<PendingQuery, TranslationErrors> {
let cte = Arc::new(pipeline.cte.clone());
let from_clauses =
extract_from_clauses_from_pattern(&ctx.pattern_all(), &cte, &pipeline.context)?;
let from = from::uber_from(
from_clauses.values().cloned().collect(),
pipeline,
previous,
FromAliasing::AliasingAllowed,
FromAnonymous::DefaultAlias,
)?;
let previous_env = Arc::new(from.env.clone());
let mut defines_env = previous_env.remove_substructure();
for (si, fctx) in from_clauses.iter() {
defines_env = defines_env.with_pattern_variable(
si.clone().into(),
previous_env
.lookup(&si.clone().into())
.map_err(|e| TranslationError::wrap_box(fctx.as_ref(), e.into()))?
.try_unwrap_struct()
.map_err(|e| TranslationError::wrap_box(fctx.as_ref(), e.into()))?
.remove_substructure()
.into(),
);
}
let pattern_elements = extract_pattern_elements(&ctx.pattern_all(), &cte, &pipeline.context)?;
let regex_pattern = pattern_to_regex(&pattern_elements);
let within_translation = ctx
.within
.as_ref()
.map(|within_ctx| {
HamelinExpression::new(
within_ctx.clone(),
pipeline.context.expression_translation_context(
&defines_env,
FunctionTranslationContext::default()
.with_special_allowed(SpecialPosition::Agg)
.with_special_allowed(SpecialPosition::Match),
),
)
.translate()
})
.transpose()?
.ok_or_else(|| {
TranslationError::msg(
ctx,
"CTE-based MATCH requires a WITHIN clause to define the time window",
)
.single()
})?;
let within_expression = within_translation.sql;
let within_type = within_translation.typ;
if contains_negative_interval(&within_expression) {
return Err(TranslationError::msg(ctx, "WITHIN interval cannot be negative").single());
}
let (order_by, sort_col_name, sort_col_binding, sort_type) = if ctx
.sortExpression_all()
.is_empty()
{
let timestamp_ident: Identifier = SimpleIdentifier::new("timestamp").into();
let ts_type = from.env.lookup(×tamp_ident).ok().ok_or_else(|| {
TranslationError::msg(
ctx,
"No 'timestamp' field in the environment. You need to specify SORT BY.",
)
.single()
})?;
if ts_type != TIMESTAMP {
return Err(TranslationError::msg(
ctx,
&format!(
"Field 'timestamp' has type {:?}, expected TIMESTAMP. You need to specify SORT BY.",
ts_type
),
)
.single());
}
(
vec![OrderByExpression::new(
ColumnReference::new(timestamp_ident.clone()).into(),
Direction::ASC,
)],
SimpleIdentifier::new("timestamp"),
None, TIMESTAMP,
)
} else {
let translated = HamelinSortExpression::new(
ctx.sortExpression_all(),
pipeline
.context
.default_expression_translation_context(&previous_env),
)
.translate()?;
if translated.0.len() != 1 {
return Err(TranslationError::msg(
ctx,
"MATCH with WITHIN requires exactly one SORT expression",
)
.single());
}
if translated.0[0].direction == Direction::DESC {
return Err(TranslationError::msg(
ctx,
"MATCH with WITHIN requires ascending sort order (ASC). \
DESC is not supported because the forward-looking window would look backwards.",
)
.single());
}
let sort_type = translated.1[0].clone();
let is_simple_col_ref = matches!(
translated.0[0].expression.as_ref(),
SQLExpression::Leaf(Leaf::ColumnReference(_))
);
if is_simple_col_ref {
let col_name = if let SQLExpression::Leaf(Leaf::ColumnReference(cr)) =
translated.0[0].expression.as_ref()
{
cr.identifier.last().clone()
} else {
unreachable!()
};
(translated.0, col_name, None, sort_type)
} else {
let sort_alias = SimpleIdentifier::new("__sort_col");
let sort_expr = (*translated.0[0].expression).clone();
let direction = translated.0[0].direction.clone();
let new_order_by = vec![OrderByExpression::new(
ColumnReference::new(sort_alias.clone().into()).into(),
direction,
)];
(new_order_by, sort_alias, Some(sort_expr), sort_type)
}
};
let types_compatible = if sort_type == TIMESTAMP {
within_type == INTERVAL || within_type == CALENDAR_INTERVAL
} else {
sort_type == within_type
};
if !types_compatible {
return Err(TranslationError::msg(
ctx,
&format!(
"MATCH with WITHIN: SORT column type {:?} is incompatible with WITHIN type {:?}",
sort_type, within_type
),
)
.single());
}
let mut partition_exprs = Vec::new();
let mut by_clause_bindings = Vec::new();
for gc in ctx.groupClause_all() {
let clause = TranslationErrors::expect(gc.as_ref(), gc.assignmentClause())?;
let (identifier, translation) = HamelinAssignmentClause::new(
clause.clone(),
pipeline
.context
.default_expression_translation_context(&previous_env),
)
.to_sql()?;
let full_name = identifier
.simples()
.iter()
.map(|si| si.name.as_str())
.collect::<Vec<_>>()
.join("__");
let prefixed_name = SimpleIdentifier::new(&format!("__by_{}", full_name));
partition_exprs.push((prefixed_name.clone(), translation.sql.clone()));
by_clause_bindings.push(ByClauseBinding {
identifier,
prefixed_name,
typ: translation.typ,
});
}
let base_columns: Vec<SimpleIdentifier> = from.env.flatten().fields.keys().cloned().collect();
let agg_bindings: Vec<AggBinding> = ctx
.assignmentClause_all()
.into_iter()
.map(|clause| {
HamelinAssignmentClause::new(
clause.clone(),
pipeline.context.expression_translation_context(
&defines_env,
FunctionTranslationContext::default()
.with_special_allowed(SpecialPosition::Agg)
.with_special_allowed(SpecialPosition::Match),
),
)
.to_sql()
.map(|(identifier, translation)| AggBinding {
identifier,
translation,
})
})
.collect::<Result<Vec<_>, _>>()?;
validate_supported_aggregates(&agg_bindings, ctx, &pipeline.context.registry)?;
let first_pattern_vars: Vec<String> = {
let mut vars = Vec::new();
for elem in &pattern_elements {
let is_zero = matches!(elem.quantifier, PatternQuantifier::Exactly(0));
if !is_zero {
vars.push(elem.variable.to_lowercase());
}
if !matches!(
elem.quantifier,
PatternQuantifier::ZeroOrOne
| PatternQuantifier::ZeroOrMore
| PatternQuantifier::Exactly(0)
) {
break;
}
}
vars
};
if first_pattern_vars.is_empty() {
return Err(TranslationError::msg(ctx, "MATCH pattern cannot be empty").single());
}
let ordered_cte_name = SimpleIdentifier::new("__ordered_events");
let state_cte_name = SimpleIdentifier::new("__with_state");
let pattern_label_expr = build_pattern_label_expr(&from_clauses);
let ordered_query = build_ordered_events_cte(
&from.query,
&base_columns,
pattern_label_expr,
&partition_exprs,
sort_col_binding.as_ref().map(|expr| (&sort_col_name, expr)),
)?;
let has_agg = !agg_bindings.is_empty();
let agg_value_columns = extract_agg_value_columns(&agg_bindings);
let state_query = build_state_cte(
&ordered_cte_name,
&base_columns,
&partition_exprs,
&order_by,
&within_expression,
sort_col_binding.as_ref().map(|_| &sort_col_name),
has_agg,
&agg_value_columns,
)?;
let final_query = build_final_query(
&state_cte_name,
®ex_pattern,
&first_pattern_vars,
&base_columns,
&by_clause_bindings,
&agg_bindings,
&order_by,
&from.env,
)?;
let cte_chain = CTE::default()
.with(ordered_cte_name, ordered_query)
.with(state_cte_name, state_query);
let statement = final_query.cte(cte_chain);
let mut output_builder = ProjectionBuilder::default();
let by_first_names: Vec<&SimpleIdentifier> = by_clause_bindings
.iter()
.map(|b| b.identifier.first())
.collect();
let agg_first_names: Vec<&SimpleIdentifier> =
agg_bindings.iter().map(|b| b.identifier.first()).collect();
for col in &base_columns {
if by_first_names.contains(&col) || agg_first_names.contains(&col) {
continue;
}
let col_type = from
.env
.lookup(&col.clone().into())
.map_err(|e| TranslationError::wrap(ctx, e).single())?;
output_builder.bind(
col.clone().into(),
ColumnReference::new(col.clone().into()).into(),
col_type,
);
}
for binding in &by_clause_bindings {
output_builder.bind(
binding.identifier.clone(),
ColumnReference::new(binding.identifier.clone()).into(),
binding.typ.clone(),
);
}
for binding in &agg_bindings {
let typ = get_agg_output_type(binding);
output_builder.bind(
binding.identifier.clone(),
ColumnReference::new(binding.identifier.clone()).into(),
typ,
);
}
Ok(PendingQuery::new(
statement.into(),
Environment::new(output_builder.build_hamelin_type()),
))
}
fn build_ordered_events_cte(
from_query: &SQLQuery,
base_columns: &[SimpleIdentifier],
pattern_label_expr: SQLExpression,
by_clause_exprs: &[(SimpleIdentifier, SQLExpression)],
sort_col_binding: Option<(&SimpleIdentifier, &SQLExpression)>,
) -> Result<SQLQuery, TranslationErrors> {
let mut projections: Vec<Projection> = base_columns
.iter()
.map(|col| {
ColumnProjection {
identifier: col.clone().into(),
}
.into()
})
.collect();
projections
.push(Binding::new(SimpleIdentifier::new("__pattern_label"), pattern_label_expr).into());
for (prefixed_name, expr) in by_clause_exprs {
projections.push(Binding::new(prefixed_name.clone(), expr.clone()).into());
}
if let Some((sort_name, sort_expr)) = sort_col_binding {
projections.push(Binding::new(sort_name.clone(), sort_expr.clone()).into());
}
let mut base_query = from_query.clone();
base_query.cte = CTE::default();
Ok(base_query.select(projections))
}
fn build_state_cte(
ordered_cte_name: &SimpleIdentifier,
base_columns: &[SimpleIdentifier],
partition_by: &[(SimpleIdentifier, SQLExpression)],
order_by: &[OrderByExpression],
within_expr: &SQLExpression,
sort_col_name: Option<&SimpleIdentifier>,
has_agg: bool,
agg_value_columns: &[SimpleIdentifier],
) -> Result<SQLQuery, TranslationErrors> {
let partition_exprs: Vec<SQLExpression> = partition_by
.iter()
.map(|(prefixed_name, _)| ColumnReference::new(prefixed_name.clone().into()).into())
.collect();
let window_spec = WindowSpecification {
partition_by: partition_exprs,
order_by: order_by.to_vec(),
frame: Some(WindowFrame {
frame_type: FrameType::RANGE,
preceding: FrameBoundary::CurrentRowBoundary,
following: FrameBoundary::BoundaryExpression(Box::new(within_expr.clone())),
}),
};
let labels_array: SQLExpression = WindowExpression::new(
FunctionCallApply::with_one(
"ARRAY_AGG",
ColumnReference::new(SimpleIdentifier::new("__pattern_label").into()).into(),
)
.into(),
window_spec.clone().into(),
)
.into();
let state_expr: SQLExpression = FunctionCallApply::with_two(
"ARRAY_JOIN",
labels_array.clone(),
StringLiteral::new(",").into(),
)
.into();
let mut projections: Vec<Projection> = base_columns
.iter()
.map(|col| {
ColumnProjection {
identifier: col.clone().into(),
}
.into()
})
.collect();
projections.push(
ColumnProjection {
identifier: SimpleIdentifier::new("__pattern_label").into(),
}
.into(),
);
for (prefixed_name, _) in partition_by {
projections.push(
ColumnProjection {
identifier: prefixed_name.clone().into(),
}
.into(),
);
}
if let Some(sort_name) = sort_col_name {
projections.push(
ColumnProjection {
identifier: sort_name.clone().into(),
}
.into(),
);
}
projections.push(Binding::new(SimpleIdentifier::new("__state"), state_expr).into());
if has_agg {
projections
.push(Binding::new(SimpleIdentifier::new("__labels_array"), labels_array).into());
for col in agg_value_columns {
let value_array: SQLExpression = WindowExpression::new(
FunctionCallApply::with_one(
"ARRAY_AGG",
ColumnReference::new(col.clone().into()).into(),
)
.into(),
window_spec.clone().into(),
)
.into();
let array_col_name = agg_array_column_name(col);
projections.push(Binding::new(array_col_name, value_array).into());
}
}
let filter: SQLExpression = BinaryOperatorApply::new(
Operator::IsNot,
ColumnReference::new(SimpleIdentifier::new("__pattern_label").into()).into(),
NullLiteral::default().into(),
)
.into();
Ok(SQLQuery::default()
.from(TableReference::new(ordered_cte_name.clone().into()).into())
.select(projections)
.where_(filter))
}
struct ByClauseBinding {
identifier: Identifier,
prefixed_name: SimpleIdentifier,
typ: Type,
}
fn is_valid_cte_aggregate(func_name: &str, registry: &FunctionRegistry) -> bool {
let name_lower = func_name.to_lowercase();
if let Some(defs) = registry.function_defs.get(&name_lower) {
defs.iter()
.any(|def| matches!(def.special_position(), Some(SpecialPosition::Match)))
} else {
false
}
}
fn validate_supported_aggregates(
agg_bindings: &[AggBinding],
ctx: &MatchCommandContext<'static>,
registry: &FunctionRegistry,
) -> Result<(), TranslationErrors> {
for binding in agg_bindings {
if let SQLExpression::FunctionCallApply(fc) = &binding.translation.sql {
if !is_valid_cte_aggregate(&fc.function_name, registry) {
return Err(TranslationError::msg(
ctx,
&format!(
"CTE-based MATCH does not support {}() in AGG. \
Supported aggregates are those with array function equivalents \
(count, sum, avg, min, max, first, last).",
fc.function_name
),
)
.single());
}
if !fc.arguments.is_empty() {
if let Some(arg) = fc.arguments.first() {
if extract_simple_column_name(arg).is_none() {
return Err(TranslationError::msg(
ctx,
&format!(
"CTE-based MATCH AGG only supports simple column references as arguments. \
Expression '{}' in {}() is not a simple column.",
arg, fc.function_name
),
)
.single());
}
}
}
} else {
return Err(TranslationError::msg(
ctx,
&format!(
"MATCH AGG requires aggregate function calls. '{}' is not an aggregate function.",
binding.translation.sql
),
)
.single());
}
}
Ok(())
}
fn extract_agg_value_columns(agg_bindings: &[AggBinding]) -> Vec<SimpleIdentifier> {
let mut columns = Vec::new();
let mut seen = std::collections::HashSet::new();
for binding in agg_bindings {
if let SQLExpression::FunctionCallApply(fc) = &binding.translation.sql {
if !fc.arguments.is_empty() {
for col_ref in binding.translation.sql.get_column_references() {
let col_name = col_ref.identifier.last().clone();
if seen.insert(col_name.name.clone()) {
columns.push(col_name);
}
}
}
}
}
columns
}
fn agg_array_column_name(col: &SimpleIdentifier) -> SimpleIdentifier {
SimpleIdentifier::new(&format!("__agg_values_{}", col.name))
}
struct TransformedAgg {
sql: SQLExpression,
}
fn get_agg_output_type(binding: &AggBinding) -> Type {
if let SQLExpression::FunctionCallApply(fc) = &binding.translation.sql {
let func_name = fc.function_name.to_uppercase();
if func_name == "SUM" || func_name == "AVG" {
return Type::Double;
}
}
binding.translation.typ.clone()
}
fn transform_agg_for_match_length(
agg_expr: &SQLExpression,
match_length_expr: &SQLExpression,
) -> TransformedAgg {
if let SQLExpression::FunctionCallApply(fc) = agg_expr {
let func_name = fc.function_name.to_uppercase();
match func_name.as_str() {
"COUNT" => {
if fc.arguments.is_empty() {
return TransformedAgg {
sql: FunctionCallApply::with_one(
"CARDINALITY",
FunctionCallApply::with_three(
"SLICE",
ColumnReference::new(
SimpleIdentifier::new("__labels_array").into(),
)
.into(),
IntegerLiteral::from_int(1).into(),
match_length_expr.clone(),
)
.into(),
)
.into(),
};
} else if let Some(arg) = fc.arguments.first() {
if let Some(col_name) = extract_simple_column_name(arg) {
let array_col = agg_array_column_name(&col_name);
let sliced_array: SQLExpression = FunctionCallApply::with_three(
"SLICE",
ColumnReference::new(array_col.into()).into(),
IntegerLiteral::from_int(1).into(),
match_length_expr.clone(),
)
.into();
let e = SimpleIdentifier::new("e");
let filtered_array: SQLExpression = FunctionCallApply::with_two(
"filter",
sliced_array,
Lambda::new(
vec![e.clone()],
BinaryOperatorApply::new(
Operator::IsNot,
ColumnReference::new(e.into()).into(),
NullLiteral::default().into(),
)
.into(),
)
.into(),
)
.into();
return TransformedAgg {
sql: FunctionCallApply::with_one("CARDINALITY", filtered_array).into(),
};
}
}
}
"MAX" | "MIN" => {
if let Some(arg) = fc.arguments.first() {
if let Some(col_name) = extract_simple_column_name(arg) {
let array_col = agg_array_column_name(&col_name);
let sliced_array = FunctionCallApply::with_three(
"SLICE",
ColumnReference::new(array_col.into()).into(),
IntegerLiteral::from_int(1).into(),
match_length_expr.clone(),
);
let s = SimpleIdentifier::new("s");
let n = SimpleIdentifier::new("n");
let cmp_op = if func_name == "MAX" {
Operator::Gt
} else {
Operator::Lt
};
return TransformedAgg {
sql: FunctionCallApply::with_four(
"reduce",
sliced_array.into(),
NullLiteral::default().into(),
Lambda::new(
vec![s.clone(), n.clone()],
Case::new(
&[(
BinaryOperatorApply::new(
Operator::Is,
ColumnReference::new(s.clone().into()).into(),
NullLiteral::default().into(),
)
.into(),
ColumnReference::new(n.clone().into()).into(),
)],
Some(
Case::new(
&[(
BinaryOperatorApply::new(
cmp_op,
ColumnReference::new(n.clone().into())
.into(),
ColumnReference::new(s.clone().into())
.into(),
)
.into(),
ColumnReference::new(n.clone().into()).into(),
)],
Some(ColumnReference::new(s.clone().into()).into()),
)
.into(),
),
)
.into(),
)
.into(),
Lambda::new(
vec![s.clone()],
ColumnReference::new(s.clone().into()).into(),
)
.into(),
)
.into(),
};
}
}
}
"FIRST" => {
if let Some(arg) = fc.arguments.first() {
if let Some(col_name) = extract_simple_column_name(arg) {
let array_col = agg_array_column_name(&col_name);
return TransformedAgg {
sql: FunctionCallApply::with_two(
"element_at",
ColumnReference::new(array_col.into()).into(),
IntegerLiteral::from_int(1).into(),
)
.into(),
};
}
}
}
"LAST" => {
if let Some(arg) = fc.arguments.first() {
if let Some(col_name) = extract_simple_column_name(arg) {
let array_col = agg_array_column_name(&col_name);
return TransformedAgg {
sql: FunctionCallApply::with_two(
"element_at",
ColumnReference::new(array_col.into()).into(),
match_length_expr.clone(),
)
.into(),
};
}
}
}
"SUM" => {
if let Some(arg) = fc.arguments.first() {
if let Some(col_name) = extract_simple_column_name(arg) {
let array_col = agg_array_column_name(&col_name);
let sliced_array = FunctionCallApply::with_three(
"SLICE",
ColumnReference::new(array_col.into()).into(),
IntegerLiteral::from_int(1).into(),
match_length_expr.clone(),
);
let e = SimpleIdentifier::new("e");
let filtered_array: SQLExpression = FunctionCallApply::with_two(
"filter",
sliced_array.into(),
Lambda::new(
vec![e.clone()],
BinaryOperatorApply::new(
Operator::IsNot,
ColumnReference::new(e.into()).into(),
NullLiteral::default().into(),
)
.into(),
)
.into(),
)
.into();
let s = SimpleIdentifier::new("s");
let n = SimpleIdentifier::new("n");
let zero_double: SQLExpression = IntegerLiteral::from_int(0)
.to_sql_expression()
.cast(SQLType::SQLBaseType(SQLBaseType::Double));
let sum_expr: SQLExpression = FunctionCallApply::with_four(
"reduce",
filtered_array.clone(),
zero_double,
Lambda::new(
vec![s.clone(), n.clone()],
BinaryOperatorApply::new(
Operator::Plus,
ColumnReference::new(s.clone().into()).into(),
ColumnReference::new(n.clone().into()).into(),
)
.into(),
)
.into(),
Lambda::new(
vec![s.clone()],
ColumnReference::new(s.clone().into()).into(),
)
.into(),
)
.into();
let count_expr: SQLExpression =
FunctionCallApply::with_one("CARDINALITY", filtered_array).into();
return TransformedAgg {
sql: FunctionCallApply::with_three(
"IF",
BinaryOperatorApply::new(
Operator::Eq,
count_expr,
IntegerLiteral::from_int(0).into(),
)
.into(),
NullLiteral::default().into(),
sum_expr,
)
.into(),
};
}
}
}
"AVG" => {
if let Some(arg) = fc.arguments.first() {
if let Some(col_name) = extract_simple_column_name(arg) {
let array_col = agg_array_column_name(&col_name);
let sliced_array: SQLExpression = FunctionCallApply::with_three(
"SLICE",
ColumnReference::new(array_col.into()).into(),
IntegerLiteral::from_int(1).into(),
match_length_expr.clone(),
)
.into();
let e = SimpleIdentifier::new("e");
let filtered_array: SQLExpression = FunctionCallApply::with_two(
"filter",
sliced_array,
Lambda::new(
vec![e.clone()],
BinaryOperatorApply::new(
Operator::IsNot,
ColumnReference::new(e.into()).into(),
NullLiteral::default().into(),
)
.into(),
)
.into(),
)
.into();
let non_null_count: SQLExpression =
FunctionCallApply::with_one("CARDINALITY", filtered_array.clone())
.into();
let s = SimpleIdentifier::new("s");
let n = SimpleIdentifier::new("n");
let zero_double: SQLExpression = IntegerLiteral::from_int(0)
.to_sql_expression()
.cast(SQLType::SQLBaseType(SQLBaseType::Double));
let sum_expr: SQLExpression = FunctionCallApply::with_four(
"reduce",
filtered_array,
zero_double,
Lambda::new(
vec![s.clone(), n.clone()],
BinaryOperatorApply::new(
Operator::Plus,
ColumnReference::new(s.clone().into()).into(),
ColumnReference::new(n.clone().into()).into(),
)
.into(),
)
.into(),
Lambda::new(
vec![s.clone()],
ColumnReference::new(s.clone().into()).into(),
)
.into(),
)
.into();
return TransformedAgg {
sql: BinaryOperatorApply::new(
Operator::Slash,
sum_expr,
FunctionCallApply::with_two(
"NULLIF",
non_null_count,
IntegerLiteral::from_int(0).into(),
)
.into(),
)
.into(),
};
}
}
}
_ => {}
}
}
TransformedAgg {
sql: agg_expr.clone(),
}
}
fn extract_simple_column_name(expr: &SQLExpression) -> Option<SimpleIdentifier> {
if let SQLExpression::Leaf(Leaf::ColumnReference(cr)) = expr {
match &cr.identifier {
Identifier::Simple(simple) => Some(simple.clone()),
Identifier::Compound(_) => None,
}
} else {
None
}
}
fn build_final_query(
state_cte_name: &SimpleIdentifier,
regex_pattern: &str,
first_pattern_vars: &[String],
base_columns: &[SimpleIdentifier],
by_clause_bindings: &[ByClauseBinding],
agg_bindings: &[AggBinding],
order_by: &[OrderByExpression],
from_env: &Environment,
) -> Result<SQLQuery, TranslationErrors> {
let mut output_builder = ProjectionBuilder::default();
let by_first_names: Vec<&SimpleIdentifier> = by_clause_bindings
.iter()
.map(|b| b.identifier.first())
.collect();
let agg_first_names: Vec<&SimpleIdentifier> =
agg_bindings.iter().map(|b| b.identifier.first()).collect();
for col in base_columns {
if !by_first_names.contains(&col) && !agg_first_names.contains(&col) {
let col_type = from_env
.lookup(&col.clone().into())
.unwrap_or(Type::Unknown);
output_builder.bind(
col.clone().into(),
ColumnReference::new(col.clone().into()).into(),
col_type,
);
}
}
for binding in by_clause_bindings {
output_builder.bind(
binding.identifier.clone(),
ColumnReference::new(binding.prefixed_name.clone().into()).into(),
binding.typ.clone(),
);
}
if !agg_bindings.is_empty() {
let match_length_expr: SQLExpression = FunctionCallApply::with_one(
"CARDINALITY",
FunctionCallApply::with_two(
"SPLIT",
FunctionCallApply::with_two(
"REGEXP_EXTRACT",
ColumnReference::new(SimpleIdentifier::new("__state").into()).into(),
StringLiteral::new(regex_pattern).into(),
)
.into(),
StringLiteral::new(",").into(),
)
.into(),
)
.into();
for binding in agg_bindings {
let transformed =
transform_agg_for_match_length(&binding.translation.sql, &match_length_expr);
let typ = get_agg_output_type(binding);
output_builder.bind(binding.identifier.clone(), transformed.sql, typ);
}
}
let projections = output_builder
.build_projections()
.map_err(|e| TranslationError::fatal("build_final_query", e.into()).single())?;
let pattern_filter: SQLExpression = if first_pattern_vars.len() == 1 {
BinaryOperatorApply::new(
Operator::Eq,
ColumnReference::new(SimpleIdentifier::new("__pattern_label").into()).into(),
StringLiteral::new(&first_pattern_vars[0]).into(),
)
.into()
} else {
BinaryOperatorApply::new(
Operator::In,
ColumnReference::new(SimpleIdentifier::new("__pattern_label").into()).into(),
FunctionCallApply::with_positional(
"",
first_pattern_vars
.iter()
.map(|v| StringLiteral::new(v).into())
.collect(),
)
.into(),
)
.into()
};
let regex_filter: SQLExpression = FunctionCallApply::with_two(
"REGEXP_LIKE",
ColumnReference::new(SimpleIdentifier::new("__state").into()).into(),
StringLiteral::new(regex_pattern).into(),
)
.into();
let filter: SQLExpression =
BinaryOperatorApply::new(Operator::And, pattern_filter, regex_filter).into();
Ok(SQLQuery::default()
.from(TableReference::new(state_cte_name.clone().into()).into())
.select(projections)
.where_(filter)
.order_by(order_by.to_vec()))
}
fn extract_from_clauses_from_pattern(
ctx: &Vec<Rc<PatternContextAll<'static>>>,
cte_env: &Arc<Environment>,
statement_context: &Rc<QueryTranslationContext>,
) -> Result<OrderMap<SimpleIdentifier, Rc<FromClauseContextAll<'static>>>, TranslationErrors> {
let mut result: OrderMap<SimpleIdentifier, Rc<FromClauseContextAll<'static>>> = OrderMap::new();
for pctx in ctx {
match pctx.as_ref() {
PatternContextAll::QuantifiedContext(qctx) => {
let fctx = TranslationErrors::expect(qctx, qctx.fromClause())?;
let fc = HamelinFromClause::new(
fctx.clone(),
cte_env.clone(),
statement_context.clone(),
);
let maybe_si = fc.table_alias().map(|si| si.to_sql()).transpose()?;
let si = maybe_si.unwrap_or(fc.table_identifier()?.last().clone());
if let Some(existing_fctx) = result.get(&si) {
let existing_fc = HamelinFromClause::new(
existing_fctx.clone(),
cte_env.clone(),
statement_context.clone(),
);
let existing_table = existing_fc.table_identifier()?;
let new_table = fc.table_identifier()?;
if existing_table != new_table {
return Err(TranslationError::msg(
qctx,
&format!(
"Duplicate pattern variable '{}' bound to different data sources",
si.name
),
)
.single());
}
} else {
result.insert(si, fctx);
}
}
PatternContextAll::NestedContext(nctx) => {
let nested = extract_from_clauses_from_pattern(
&nctx.pattern_all(),
cte_env,
statement_context,
)?;
for (si, fctx) in nested {
if let Some(existing_fctx) = result.get(&si) {
let existing_fc = HamelinFromClause::new(
existing_fctx.clone(),
cte_env.clone(),
statement_context.clone(),
);
let new_fc = HamelinFromClause::new(
fctx.clone(),
cte_env.clone(),
statement_context.clone(),
);
let existing_table = existing_fc.table_identifier()?;
let new_table = new_fc.table_identifier()?;
if existing_table != new_table {
return Err(TranslationError::msg(
nctx,
&format!(
"Duplicate pattern variable '{}' bound to different data sources",
si.name
),
)
.single());
}
} else {
result.insert(si, fctx);
}
}
}
PatternContextAll::Error(e) => {
return Err(TranslationError::msg(e, "parse error").single())
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_to_regex_simple_sequence() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::One),
];
assert_eq!(pattern_to_regex(&elements), "^a,b");
}
#[test]
fn test_pattern_to_regex_one_or_more_middle() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::OneOrMore),
PatternElement::new("A".to_string(), PatternQuantifier::One),
];
assert_eq!(pattern_to_regex(&elements), "^a,(b,)+a");
}
#[test]
fn test_pattern_to_regex_one_or_more_end() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::OneOrMore),
];
assert_eq!(pattern_to_regex(&elements), "^a,b(,b)*");
}
#[test]
fn test_pattern_to_regex_optional_at_start() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::ZeroOrOne),
PatternElement::new("B".to_string(), PatternQuantifier::One),
];
assert_eq!(pattern_to_regex(&elements), "^(a,)?b");
}
#[test]
fn test_pattern_to_regex_optional_at_end() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::ZeroOrOne),
];
assert_eq!(pattern_to_regex(&elements), "^a(,b)?");
}
#[test]
fn test_pattern_to_regex_zero_or_more_at_end() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::ZeroOrMore),
];
assert_eq!(pattern_to_regex(&elements), "^a(,b(,b)*)?");
}
#[test]
fn test_pattern_to_regex_multiple_optional_at_end() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::ZeroOrOne),
PatternElement::new("C".to_string(), PatternQuantifier::ZeroOrOne),
];
assert_eq!(pattern_to_regex(&elements), "^a(,b)?(,c)?");
}
#[test]
fn test_pattern_to_regex_one_or_more_followed_by_optional() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::OneOrMore),
PatternElement::new("B".to_string(), PatternQuantifier::ZeroOrOne),
];
assert_eq!(pattern_to_regex(&elements), "^a(,a)*(,b)?");
}
#[test]
fn test_pattern_to_regex_complex() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::OneOrMore),
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::One),
];
assert_eq!(pattern_to_regex(&elements), "^a,(b,)+a,b");
}
#[test]
fn test_pattern_to_regex_required_optional_required() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::One),
PatternElement::new("B".to_string(), PatternQuantifier::ZeroOrOne),
PatternElement::new("C".to_string(), PatternQuantifier::One),
];
assert_eq!(pattern_to_regex(&elements), "^a,(b,)?c");
}
#[test]
fn test_pattern_to_regex_all_optional_from_start() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::ZeroOrOne),
PatternElement::new("B".to_string(), PatternQuantifier::ZeroOrOne),
];
assert_eq!(pattern_to_regex(&elements), "^(a)?(,?b)?");
}
#[test]
fn test_pattern_to_regex_zero_or_more_at_start() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::ZeroOrMore),
PatternElement::new("B".to_string(), PatternQuantifier::One),
];
assert_eq!(pattern_to_regex(&elements), "^(a,)*b");
}
#[test]
fn test_pattern_to_regex_all_zero_or_more() {
let elements = vec![
PatternElement::new("A".to_string(), PatternQuantifier::ZeroOrMore),
PatternElement::new("B".to_string(), PatternQuantifier::ZeroOrMore),
];
assert_eq!(pattern_to_regex(&elements), "^(a(,a)*)?(,?b(,b)*)?");
}
}