use std::collections::HashMap;
use std::sync::Arc;
use hamelin_lib::{
err::TranslationError,
func::def::{ParameterBinding, SpecialPosition},
tree::{
ast::{
clause::SortOrder,
expression::{Expression, ExpressionKind},
identifier::Identifier,
pattern::QuantifierKind,
query::Query,
},
builder::{pipeline as pipeline_builder, query, window_command, ExpressionBuilder},
typed_ast::{
command::{TypedCommandKind, TypedMatchCommand},
context::StatementTranslationContext,
expression::{MapExpressionAlgebra, TypedApply, TypedExpression, TypedExpressionKind},
pattern::TypedPattern,
pipeline::TypedPipeline,
query::TypedStatement,
},
},
};
pub fn lower_match(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, Arc<TranslationError>> {
if !statement_has_match(&statement)? {
return Ok(statement);
}
let new_query = transform_statement(&statement, ctx)?;
Ok(Arc::new(TypedStatement::from_ast_with_context(
Arc::new(new_query),
ctx,
)))
}
fn statement_has_match(statement: &TypedStatement) -> Result<bool, Arc<TranslationError>> {
statement
.iter()
.try_fold(false, |acc, p| pipeline_has_match(p).map(|pm| pm || acc))
}
fn pipeline_has_match(pipeline: &TypedPipeline) -> Result<bool, Arc<TranslationError>> {
let res = pipeline
.valid_ref()?
.commands
.iter()
.any(|c| matches!(&c.kind, TypedCommandKind::Match(_)));
Ok(res)
}
fn transform_statement(
statement: &TypedStatement,
ctx: &mut StatementTranslationContext,
) -> Result<Query, Arc<TranslationError>> {
let mut query_builder = query();
for with_clause in &statement.with_clauses {
let transformed = transform_pipeline(&with_clause.pipeline, ctx)?;
let valid_name = with_clause.name.clone().valid()?;
query_builder = query_builder.merge_as_cte(transformed, valid_name);
}
let main_query = transform_pipeline(&statement.pipeline, ctx)?;
Ok(query_builder.merge_as_main(main_query))
}
fn transform_pipeline(
pipeline: &TypedPipeline,
ctx: &mut StatementTranslationContext,
) -> Result<Query, Arc<TranslationError>> {
let commands = &pipeline.valid_ref()?.commands;
if let Some(first_cmd) = commands.first() {
if let TypedCommandKind::Match(match_cmd) = &first_cmd.kind {
let lowered_pipeline = lower_match_command(match_cmd, ctx)?;
let mut pipe_builder = pipeline_builder().at(pipeline.ast.span.clone());
for cmd in lowered_pipeline.commands {
pipe_builder = pipe_builder.command(cmd);
}
for cmd in commands.iter().skip(1) {
pipe_builder = pipe_builder.command(cmd.ast.clone());
}
return Ok(query().main(pipe_builder.build()).build());
}
}
Ok(query().main(pipeline.ast.clone()).build())
}
fn lower_match_command(
match_cmd: &TypedMatchCommand,
_ctx: &mut StatementTranslationContext,
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Arc<TranslationError>> {
let pattern_vars = extract_pattern_variables(&match_cmd.patterns)?;
let regex_pattern = pattern_to_regex(&match_cmd.patterns, &pattern_vars)?;
let starting_labels = find_starting_pattern_labels(&match_cmd.patterns, &pattern_vars)?;
let pipeline =
build_lowered_pipeline(match_cmd, &pattern_vars, ®ex_pattern, &starting_labels)?;
Ok(pipeline)
}
#[derive(Debug, Clone)]
struct PatternVariable {
alias: String,
table: String,
label: String,
quantifier: PatternQuantifier,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum PatternQuantifier {
One, OneOrMore, ZeroOrMore, ZeroOrOne, Exactly(u32), }
impl PatternQuantifier {
fn is_optional(&self) -> bool {
matches!(
self,
PatternQuantifier::ZeroOrMore | PatternQuantifier::ZeroOrOne
)
}
}
fn extract_pattern_variables(
patterns: &[TypedPattern],
) -> Result<Vec<PatternVariable>, Arc<TranslationError>> {
let mut variables = Vec::new();
let mut label_gen = LabelGenerator::new();
for pattern in patterns {
extract_from_pattern(pattern, &mut variables, &mut label_gen)?;
}
Ok(variables)
}
fn extract_from_pattern(
pattern: &TypedPattern,
variables: &mut Vec<PatternVariable>,
label_gen: &mut LabelGenerator,
) -> Result<(), Arc<TranslationError>> {
match pattern {
TypedPattern::Quantified(quant) => {
let (alias, table) = extract_alias_and_table(quant)?;
let quantifier = convert_quantifier(&quant.quantifier);
variables.push(PatternVariable {
alias,
table,
label: label_gen.next(),
quantifier,
});
}
TypedPattern::Nested(nested) => {
for sub_pattern in &nested.patterns {
extract_from_pattern(sub_pattern, variables, label_gen)?;
}
}
TypedPattern::Error(err) => {
return Err(err.clone());
}
}
Ok(())
}
fn extract_alias_and_table(
quant: &hamelin_lib::tree::typed_ast::pattern::TypedQuantifiedPattern,
) -> Result<(String, String), Arc<TranslationError>> {
use hamelin_lib::tree::typed_ast::clause::TypedFromClause;
match &quant.typed_from {
TypedFromClause::Alias(alias_clause) => {
let alias = alias_clause.alias.valid_ref()?.to_string();
let table = alias_clause.ast.table.identifier.valid_ref()?.to_string();
Ok((alias, table))
}
TypedFromClause::Reference(table_ref) => {
let table = table_ref.ast.identifier.valid_ref()?.to_string();
Ok((table.clone(), table))
}
TypedFromClause::Error(err) => Err(err.clone()),
}
}
fn convert_quantifier(
quantifier: &Arc<hamelin_lib::tree::ast::pattern::Quantifier>,
) -> PatternQuantifier {
match &quantifier.kind {
QuantifierKind::AtLeastOne => PatternQuantifier::OneOrMore,
QuantifierKind::AnyNumber => PatternQuantifier::ZeroOrMore,
QuantifierKind::ZeroOrOne => PatternQuantifier::ZeroOrOne,
QuantifierKind::Exactly(n) => {
if let Ok(num) = n.parse::<u32>() {
if num == 1 {
PatternQuantifier::One
} else {
PatternQuantifier::Exactly(num)
}
} else {
PatternQuantifier::One
}
}
QuantifierKind::Error(_) => PatternQuantifier::One,
}
}
struct LabelGenerator {
index: usize,
}
impl LabelGenerator {
fn new() -> Self {
Self { index: 0 }
}
fn next(&mut self) -> String {
let label = if self.index < 26 {
char::from(b'a' + self.index as u8).to_string()
} else if self.index < 52 {
char::from(b'A' + (self.index - 26) as u8).to_string()
} else {
let idx = self.index - 52;
let first = char::from(b'a' + (idx / 26) as u8);
let second = char::from(b'a' + (idx % 26) as u8);
format!("{}{}", first, second)
};
self.index += 1;
label
}
}
fn pattern_to_regex(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Result<String, Arc<TranslationError>> {
let elements = collect_pattern_elements(patterns, variables)?;
Ok(build_regex_from_elements(&elements))
}
struct PatternElement {
label: String,
quantifier: PatternQuantifier,
}
fn collect_pattern_elements(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Result<Vec<PatternElement>, Arc<TranslationError>> {
let mut elements = Vec::new();
let mut var_index = 0;
for pattern in patterns {
collect_from_pattern(pattern, variables, &mut var_index, &mut elements)?;
}
Ok(elements)
}
fn collect_from_pattern(
pattern: &TypedPattern,
variables: &[PatternVariable],
var_index: &mut usize,
elements: &mut Vec<PatternElement>,
) -> Result<(), Arc<TranslationError>> {
match pattern {
TypedPattern::Quantified(_) => {
if let Some(var) = variables.get(*var_index) {
elements.push(PatternElement {
label: var.label.clone(),
quantifier: var.quantifier,
});
*var_index += 1;
}
}
TypedPattern::Nested(_) => {
return Err(TranslationError::fatal(
"lower_match",
"nested pattern groups with quantifiers (e.g., (a b)+) are not yet supported \
— use flat patterns instead (e.g., a+ b+)"
.into(),
)
.into());
}
TypedPattern::Error(_) => {}
}
Ok(())
}
fn build_regex_from_elements(elements: &[PatternElement]) -> String {
if elements.is_empty() {
return String::new();
}
fn all_remaining_optional(elements: &[PatternElement], from_idx: usize) -> bool {
elements[from_idx..]
.iter()
.all(|e| e.quantifier.is_optional())
}
let mut parts = Vec::new();
for (i, elem) in elements.iter().enumerate() {
let is_last = i == elements.len() - 1;
let next_all_optional = !is_last && all_remaining_optional(elements, i + 1);
let label = &elem.label;
let part = match elem.quantifier {
PatternQuantifier::One => {
if is_last || next_all_optional {
label.clone()
} else {
format!("{},", label)
}
}
PatternQuantifier::OneOrMore => {
if is_last || next_all_optional {
format!("{}(,{})*", label, label)
} else {
format!("({},)+", label)
}
}
PatternQuantifier::ZeroOrMore => {
if is_last || next_all_optional {
format!("(,{}(,{})*)?", label, label)
} else {
format!("({},)*", label)
}
}
PatternQuantifier::ZeroOrOne => {
if is_last || next_all_optional {
format!("(,{})?", label)
} else {
format!("({},)?", label)
}
}
PatternQuantifier::Exactly(n) => {
if n == 0 {
String::new()
} else if is_last || next_all_optional {
(0..n)
.map(|i| {
if i == 0 {
label.clone()
} else {
format!(",{}", label)
}
})
.collect::<Vec<_>>()
.join("")
} else {
(0..n)
.map(|_| format!("{},", label))
.collect::<Vec<_>>()
.join("")
}
}
};
if !part.is_empty() {
parts.push(part);
}
}
format!("^{}", parts.join(""))
}
fn find_starting_pattern_labels(
patterns: &[TypedPattern],
variables: &[PatternVariable],
) -> Result<Vec<String>, Arc<TranslationError>> {
let elements = collect_pattern_elements(patterns, variables)?;
let mut labels = Vec::new();
for elem in &elements {
labels.push(elem.label.clone());
if !elem.quantifier.is_optional() {
break;
}
}
Ok(labels)
}
#[derive(Debug, Clone)]
struct SyntheticAggColumn {
name: String,
source_expr: Arc<Expression>,
}
#[derive(Default)]
struct AggLoweringState {
columns: Vec<SyntheticAggColumn>,
seen: HashMap<String, usize>,
counter: usize,
errors: Vec<Arc<TranslationError>>,
}
impl AggLoweringState {
fn column_for(&mut self, arg_ast: &Arc<Expression>) -> String {
let key = format!("{}", arg_ast);
if let Some(&idx) = self.seen.get(&key) {
return self.columns[idx].name.clone();
}
let candidate_suffix = infer_readable_suffix(arg_ast.as_ref());
let suffix = match candidate_suffix {
Some(s) if !self.suffix_taken(&s) => s,
_ => loop {
let candidate = self.counter.to_string();
self.counter += 1;
if !self.suffix_taken(&candidate) {
break candidate;
}
},
};
let name = format!("__agg_values_{}", suffix);
let idx = self.columns.len();
self.columns.push(SyntheticAggColumn {
name: name.clone(),
source_expr: arg_ast.clone(),
});
self.seen.insert(key, idx);
name
}
fn suffix_taken(&self, suffix: &str) -> bool {
let target = format!("__agg_values_{}", suffix);
self.columns.iter().any(|c| c.name == target)
}
}
fn infer_readable_suffix(arg_ast: &Expression) -> Option<String> {
match &arg_ast.kind {
ExpressionKind::FieldReference(_) | ExpressionKind::FieldLookup(_) => {}
_ => return None,
}
let id = Identifier::infer_from_expression(arg_ast)?.valid().ok()?;
Some(match id {
Identifier::Simple(s) => s.as_str().to_string(),
Identifier::Compound(c) => c
.parts()
.iter()
.map(|p| p.as_str().to_string())
.collect::<Vec<_>>()
.join("__"),
})
}
struct AggRewriteAlgebra<'a> {
state: &'a mut AggLoweringState,
}
impl MapExpressionAlgebra for AggRewriteAlgebra<'_> {
fn apply(
&mut self,
node: &TypedApply,
expr: &TypedExpression,
children: ParameterBinding<Arc<Expression>>,
) -> Arc<Expression> {
if !is_match_aggregate(node) {
return node.replace_children_ast(expr, children);
}
if let Ok(arg) = node.parameter_binding.get_by_index(0) {
if arg
.find(&mut |expr: &TypedExpression| {
matches!(
&expr.kind,
TypedExpressionKind::Apply(apply) if is_match_aggregate(apply)
)
})
.is_some()
{
self.state.errors.push(Arc::new(TranslationError::msg(
expr,
"aggregates nested inside another aggregate are not supported in MATCH AGG",
)));
return expr.ast.clone();
}
}
let func_name = node.function_def.name().to_lowercase();
let arg_ast = node
.parameter_binding
.get_by_index(0)
.ok()
.map(|te| te.ast.clone());
match rewrite_aggregate(&func_name, arg_ast, self.state) {
Ok(rewritten) => rewritten,
Err(msg) => {
self.state
.errors
.push(Arc::new(TranslationError::msg(expr, &msg)));
expr.ast.clone()
}
}
}
}
fn is_match_aggregate(node: &TypedApply) -> bool {
matches!(
node.function_def.special_position(),
Some(SpecialPosition::Match) | Some(SpecialPosition::Agg)
)
}
fn rewrite_aggregate(
func_name: &str,
arg_ast: Option<Arc<Expression>>,
state: &mut AggLoweringState,
) -> Result<Arc<Expression>, String> {
use hamelin_lib::tree::builder::{call, cast, field_ref, int, subtract};
use hamelin_lib::types::{array::Array, Type};
let slice_of = |col: &str| {
call("slice")
.arg(field_ref(col))
.arg(int(0))
.arg(field_ref("__match_length"))
};
let agg_col = arg_ast.as_ref().map(|ast| state.column_for(ast));
match (func_name, agg_col.as_deref()) {
("count", None) => Ok(Arc::new(
call("len").arg(slice_of("__labels_array")).build(),
)),
("count", Some(col)) => Ok(Arc::new(
call("len")
.arg(call("filter_null").arg(slice_of(col)))
.build(),
)),
("sum", Some(col)) => {
let slice_double = cast(slice_of(col), Type::Array(Array::new(Type::Double)));
Ok(Arc::new(call("sum").arg(slice_double).build()))
}
("avg", Some(col)) => {
let slice_double = cast(slice_of(col), Type::Array(Array::new(Type::Double)));
Ok(Arc::new(call("avg").arg(slice_double).build()))
}
("max", Some(col)) | ("min", Some(col)) => {
Ok(Arc::new(call(func_name).arg(slice_of(col)).build()))
}
("first", Some(col)) => Ok(Arc::new(call("get").arg(slice_of(col)).arg(int(0)).build())),
("last", Some(col)) => Ok(Arc::new(
call("get")
.arg(slice_of(col))
.arg(subtract(field_ref("__match_length"), int(1)))
.build(),
)),
("array_agg", Some(col)) => Ok(Arc::new(slice_of(col).build())),
("count_distinct", Some(col)) => Ok(Arc::new(
call("len")
.arg(call("array_distinct").arg(call("filter_null").arg(slice_of(col))))
.build(),
)),
(_, None) => Err(format!("{} requires an argument", func_name)),
(name, _) => Err(format!("unsupported MATCH AGG function: {}", name)),
}
}
struct LoweredAgg {
columns: Vec<SyntheticAggColumn>,
assignments: Vec<(Identifier, Arc<Expression>)>,
}
fn lower_agg_assignments(
match_cmd: &TypedMatchCommand,
) -> Result<LoweredAgg, Arc<TranslationError>> {
let mut state = AggLoweringState::default();
let mut assignments = Vec::new();
for assignment in &match_cmd.agg.assignments {
let target = assignment.identifier.clone().valid()?;
let mut alg = AggRewriteAlgebra { state: &mut state };
let rewritten = assignment.expression.cata(&mut alg);
assignments.push((target, rewritten));
}
if let Some(err) = state.errors.into_iter().next() {
return Err(err);
}
Ok(LoweredAgg {
columns: state.columns,
assignments,
})
}
fn build_lowered_pipeline(
match_cmd: &TypedMatchCommand,
pattern_vars: &[PatternVariable],
regex_pattern: &str,
starting_labels: &[String],
) -> Result<hamelin_lib::tree::ast::pipeline::Pipeline, Arc<TranslationError>> {
use hamelin_lib::tree::builder::{
and, call, eq, field_ref, in_, is_not_null, pair, pipeline, sort_command, string, tuple,
};
let lowered_agg = lower_agg_assignments(match_cmd)?;
let has_agg = !lowered_agg.assignments.is_empty();
let mut pipe = pipeline();
pipe = pipe.from(|f| {
let mut from_builder = f;
for var in pattern_vars {
from_builder = from_builder.table_alias(var.alias.as_str(), var.table.as_str());
}
from_builder
});
let mut case_call = call("case");
for var in pattern_vars {
case_call = case_call.arg(pair(
is_not_null(field_ref(var.alias.as_str())),
string(&var.label),
));
}
pipe = pipe.let_cmd(|l| l.named_field("__pattern_label", case_call));
pipe = pipe.where_cmd(is_not_null(field_ref("__pattern_label")));
let mut window_builder = if has_agg {
let labels_array = call("array_agg").arg(field_ref("__pattern_label"));
self::window_command().named_field("__labels_array", labels_array)
} else {
let state_expr = call("array_join")
.arg(call("array_agg").arg(field_ref("__pattern_label")))
.arg(string(","));
self::window_command().named_field("__state", state_expr)
};
if has_agg {
for col in &lowered_agg.columns {
let value_array = call("array_agg").arg(col.source_expr.clone());
window_builder = window_builder.named_field(col.name.as_str(), value_array);
}
}
if let Some(within) = &match_cmd.within {
window_builder = window_builder.within(within.ast.as_ref().clone());
}
for assignment in &match_cmd.group_by.assignments {
if let Ok(id) = assignment.identifier.valid_ref() {
window_builder =
window_builder.group_by(id.clone(), assignment.expression.ast.as_ref().clone());
}
}
if !match_cmd.sort.is_empty() {
let mut sort_builder = sort_command();
for sort_expr in &match_cmd.sort {
let expr = sort_expr.ast.expression.as_ref().clone();
sort_builder = match sort_expr.order {
SortOrder::Asc => sort_builder.asc(expr),
SortOrder::Desc => sort_builder.desc(expr),
};
}
window_builder = window_builder.sort(sort_builder);
}
pipe = pipe.window(|_| window_builder);
if has_agg {
let state_expr = call("array_join")
.arg(field_ref("__labels_array"))
.arg(string(","));
pipe = pipe.let_cmd(|l| l.named_field("__state", state_expr));
let match_length_expr = call("len").arg(
call("split")
.arg(
call("regexp_extract")
.arg(field_ref("__state"))
.arg(string(regex_pattern)),
)
.arg(string(",")),
);
pipe = pipe.let_cmd(|l| l.named_field("__match_length", match_length_expr));
let mut agg_let_builder: Option<hamelin_lib::tree::builder::LetCommandBuilder> = None;
for (target, rewritten) in &lowered_agg.assignments {
agg_let_builder = Some(match agg_let_builder {
None => hamelin_lib::tree::builder::let_command()
.named_field(target.clone(), rewritten.clone()),
Some(builder) => builder.named_field(target.clone(), rewritten.clone()),
});
}
if let Some(builder) = agg_let_builder {
pipe = pipe.let_cmd(|_| builder);
}
}
let regexp_filter = call("regexp_like")
.arg(field_ref("__state"))
.arg(string(regex_pattern));
pipe = if starting_labels.len() == 1 {
pipe.where_cmd(and(
eq(field_ref("__pattern_label"), string(&starting_labels[0])),
regexp_filter,
))
} else if starting_labels.len() > 1 {
let mut tup = tuple();
for label in starting_labels {
tup = tup.element(string(label));
}
pipe.where_cmd(and(in_(field_ref("__pattern_label"), tup), regexp_filter))
} else {
pipe.where_cmd(regexp_filter)
};
let mut drop_builder = pipe.drop(|d| d.field("__pattern_label").field("__state"));
if has_agg {
drop_builder = drop_builder.drop(|d| {
let mut db = d.field("__labels_array").field("__match_length");
for col in &lowered_agg.columns {
db = db.field(col.name.as_str());
}
db
});
}
Ok(drop_builder.build())
}
#[cfg(test)]
mod tests {
use hamelin_lib::type_check_with_provider;
use super::*;
#[test]
fn test_label_generator() {
let mut gen = LabelGenerator::new();
assert_eq!(gen.next(), "a");
assert_eq!(gen.next(), "b");
for _ in 2..26 {
gen.next();
}
assert_eq!(gen.next(), "A");
for _ in 27..52 {
gen.next();
}
assert_eq!(gen.next(), "aa");
assert_eq!(gen.next(), "ab");
}
#[test]
fn test_pattern_to_regex_single() {
let elements = vec![PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::OneOrMore,
}];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^a(,a)*");
}
#[test]
fn test_pattern_to_regex_sequence() {
let elements = vec![
PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::OneOrMore,
},
PatternElement {
label: "b".to_string(),
quantifier: PatternQuantifier::OneOrMore,
},
];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^(a,)+b(,b)*");
}
#[test]
fn test_pattern_to_regex_optional_end() {
let elements = vec![
PatternElement {
label: "a".to_string(),
quantifier: PatternQuantifier::One,
},
PatternElement {
label: "b".to_string(),
quantifier: PatternQuantifier::ZeroOrOne,
},
];
let regex = build_regex_from_elements(&elements);
assert_eq!(regex, "^a(,b)?");
}
#[test]
fn agg_column_counter_suffix_avoids_collision_with_readable_numeric_field() {
use std::sync::Arc;
use hamelin_lib::tree::builder::{field_ref, int};
let mut state = AggLoweringState::default();
let field_named_1 = Arc::new(field_ref("1").build());
assert_eq!(state.column_for(&field_named_1), "__agg_values_1");
let lit0 = Arc::new(int(0).build());
assert_eq!(state.column_for(&lit0), "__agg_values_0");
let lit1 = Arc::new(int(1).build());
assert_eq!(state.column_for(&lit1), "__agg_values_2");
}
#[test]
fn test_no_match_passthrough() -> Result<(), Arc<TranslationError>> {
use hamelin_lib::{
provider::EnvironmentProvider,
tree::{
ast::identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
builder::{eq, field_ref, pipeline, query, QueryBuilderWithMain},
},
types::{struct_type::Struct, INT},
};
use std::sync::Arc;
#[derive(Debug)]
struct MockProvider;
impl EnvironmentProvider for MockProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let events: Identifier = AstSimpleIdentifier::new("events").into();
if name == &events {
Ok(Struct::default()
.with_str("timestamp", INT)
.with_str("value", INT))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
fn typed_query(builder: QueryBuilderWithMain) -> TypedStatement {
type_check_with_provider(builder.build(), Arc::new(MockProvider)).output
}
let q = query().main(
pipeline()
.from(|f| f.table_reference("events"))
.where_cmd(eq(field_ref("value"), 10)),
);
let statement = typed_query(q);
assert!(!statement_has_match(&statement)?);
Ok(())
}
#[test]
fn match_agg_compound_targets_grp_struct_in_output_schema() {
use hamelin_lib::{
func::registry::FunctionRegistry,
parse,
provider::EnvironmentProvider,
tree::ast::identifier::{Identifier, SimpleIdentifier as AstSimpleIdentifier},
type_check_with_provider,
types::{struct_type::Struct, Type, STRING, TIMESTAMP},
};
#[derive(Debug)]
struct MatchAggProvider;
impl EnvironmentProvider for MatchAggProvider {
fn reflect_columns(&self, name: &Identifier) -> anyhow::Result<Struct> {
let test_t: Identifier = AstSimpleIdentifier::new("test_t").into();
if name == &test_t {
Ok(Struct::default()
.with_str("timestamp", TIMESTAMP)
.with_str("host", STRING))
} else {
anyhow::bail!("Table not found: {}", name)
}
}
fn reflect_datasets(&self) -> anyhow::Result<Vec<Identifier>> {
Ok(vec![])
}
}
let src = r#"MATCH a=test_t+ b=test_t+
AGG grp.lo = min(timestamp), grp.hi = max(timestamp)
BY host
WITHIN 5s"#;
let query = parse(src)
.into_result()
.expect("MATCH + compound AGG fixture should parse");
let provider = Arc::new(MatchAggProvider);
let typed = type_check_with_provider(query, provider.clone())
.into_result()
.expect("MATCH + compound AGG fixture should type-check");
let mut ctx =
StatementTranslationContext::new(Arc::new(FunctionRegistry::default()), provider);
let lowered = lower_match(Arc::new(typed), &mut ctx)
.expect("lower_match should succeed for compound AGG targets");
let schema = lowered.pipeline.schema();
let grp_ty = schema
.lookup(&AstSimpleIdentifier::new("grp"))
.unwrap_or_else(|| {
panic!(
"output schema should include struct `grp`; have columns {:?}",
schema
.iter()
.map(|(k, _)| k.as_str().to_string())
.collect::<Vec<_>>()
)
});
let Type::Struct(grp_struct) = grp_ty else {
panic!(
"`grp` should be a struct in the output schema, got {:?}",
grp_ty
);
};
assert!(
grp_struct.lookup(&AstSimpleIdentifier::new("lo")).is_some(),
"`grp` should contain field `lo`, got {:?}",
grp_struct
.iter()
.map(|(k, _)| k.as_str().to_string())
.collect::<Vec<_>>()
);
assert!(
grp_struct.lookup(&AstSimpleIdentifier::new("hi")).is_some(),
"`grp` should contain field `hi`, got {:?}",
grp_struct
.iter()
.map(|(k, _)| k.as_str().to_string())
.collect::<Vec<_>>()
);
}
}