mod compound_lowering;
mod expand_struct;
mod pipeline_normalizers;
mod statement_normalizers;
use std::sync::Arc;
use hamelin_lib::err::{ContextualTranslationErrors, Stage, TranslationError, TranslationErrors};
use hamelin_lib::parse_with_options;
use hamelin_lib::tree::ast::query::DefBody;
use hamelin_lib::tree::ast::query::Query;
use hamelin_lib::tree::builder::query;
use hamelin_lib::tree::options::{ParseOptions, TypeCheckOptions};
use hamelin_lib::tree::typed_ast::context::StatementTranslationContext;
use hamelin_lib::tree::typed_ast::pipeline::TypedPipeline;
use hamelin_lib::tree::typed_ast::query::TypedStatement;
use hamelin_lib::type_check_with_options;
use pipeline_normalizers::{
dedup_append_source, desugar_in_array_literals, desugar_map_from_arrays, expand_array_literals,
extract_window_aggregates, fuse_projections, lower_broadcast_apply, lower_distinct, lower_nest,
lower_parse, lower_transform, lower_ts_trunc_interval, lower_unnest, normalize_agg,
normalize_explode, normalize_window, normalize_within, recognize_transform_values,
rewrite_transform_values,
};
use statement_normalizers::{expand_union_schemas, from_to_union, lower_match, nest_from_aliases};
use pipeline_normalizers::align_append_schema;
use crate::normalize::pipeline_normalizers::lower_filter;
pub fn normalize_statement(
statement: Arc<TypedStatement>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedStatement>, ContextualTranslationErrors> {
let statement = if ctx.skip_statement_passes {
statement
} else {
let statement = with_statement_src(statement, |s| lower_match(s, ctx))?;
let statement = with_statement_src(statement, |s| nest_from_aliases(s, ctx))?;
let statement = with_statement_src(statement, |s| from_to_union(s, ctx))?;
with_statement_src(statement, |s| expand_union_schemas(s, ctx))?
};
let mut builder = query();
match statement.ast.valid_ref() {
Ok(valid) => {
let mut scalars = statement.scalar_defs.iter();
let mut pipelines = statement.pipeline_defs.iter();
for def in &valid.defs {
match &def.body {
DefBody::Expression(_) => {
let sd = scalars
.next()
.ok_or_else(|| {
Arc::new(TranslationError::msg(
statement.ast.as_ref(),
"internal: scalar_defs shorter than expression defs in AST",
))
})
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
let name = sd
.name
.valid_ref()
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
builder = builder.def_expression(name.clone(), sd.expression.ast.clone());
}
DefBody::Pipeline(_) => {
let pd = pipelines
.next()
.ok_or_else(|| {
Arc::new(TranslationError::msg(
statement.ast.as_ref(),
"internal: pipeline_defs shorter than pipeline defs in AST",
))
})
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
let name = pd
.name
.valid_ref()
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
let normalized = normalize_pipeline(pd.pipeline.clone(), ctx)
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
ctx.register_cte(name.clone(), normalized.environment());
builder = builder.def_pipeline(name.clone(), normalized);
}
}
}
if scalars.next().is_some() || pipelines.next().is_some() {
return Err(contextual_from_arc(
statement.ast.to_string(),
Arc::new(TranslationError::msg(
statement.ast.as_ref(),
"internal: leftover DEF entries after rebuilding query",
)),
));
}
}
Err(_) => {
for sd in &statement.scalar_defs {
let name = sd
.name
.valid_ref()
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
builder = builder.def_expression(name.clone(), sd.expression.ast.clone());
}
for pd in &statement.pipeline_defs {
let name = pd
.name
.valid_ref()
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
let normalized = normalize_pipeline(pd.pipeline.clone(), ctx)
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
ctx.register_cte(name.clone(), normalized.environment());
builder = builder.def_pipeline(name.clone(), normalized);
}
}
}
let normalized_main = normalize_pipeline(statement.pipeline.clone(), ctx)
.map_err(|e| contextual_from_arc(statement.ast.to_string(), e))?;
let built = builder.main(normalized_main).build();
let opts = TypeCheckOptions::builder()
.registry(ctx.registry.clone())
.provider(ctx.provider.clone())
.interner(ctx.interner.clone())
.maybe_timestamp_field(Some(ctx.timestamp_field.clone()))
.maybe_message_field(Some(ctx.message_field.clone()))
.build();
let built_arc = Arc::new(built);
let wte = type_check_with_options::<Query>(built_arc.clone(), opts.clone());
if wte.errors.is_empty() {
Ok(Arc::new(wte.output))
} else {
Err(normalize_typecheck_failure((*built_arc).clone(), wte.errors, opts.clone()).take(1))
}
}
fn contextual_from_arc(
hamelin: impl Into<String>,
err: Arc<TranslationError>,
) -> ContextualTranslationErrors {
ContextualTranslationErrors::new(
hamelin.into(),
(*err).clone().with_stage(Stage::Normalization).single(),
)
}
fn with_statement_src(
statement: Arc<TypedStatement>,
f: impl FnOnce(Arc<TypedStatement>) -> Result<Arc<TypedStatement>, Arc<TranslationError>>,
) -> Result<Arc<TypedStatement>, ContextualTranslationErrors> {
let src = statement.ast.to_string();
f(statement).map_err(|e| contextual_from_arc(src, e))
}
fn with_normalization_stage(errors: TranslationErrors) -> TranslationErrors {
TranslationErrors(
errors
.0
.into_iter()
.map(|e| e.with_stage(Stage::Normalization))
.collect(),
)
}
fn normalize_typecheck_failure(
built: Query,
initial_errors: TranslationErrors,
opts: TypeCheckOptions,
) -> ContextualTranslationErrors {
let synthetic_source = built.to_string();
if initial_errors.0.is_empty() {
return ContextualTranslationErrors::new(
synthetic_source,
TranslationError::msg(
&built,
"type check failed after normalization (no error details)",
)
.with_stage(Stage::Normalization)
.single(),
);
}
let parse_opts = match &opts.interner {
Some(interner) => ParseOptions::builder().interner(interner.clone()).build(),
None => ParseOptions::default(),
};
let parse_wte = parse_with_options(&synthetic_source, parse_opts);
if !parse_wte.errors.is_empty() {
let TranslationErrors(mut vec) = initial_errors;
if let Some(first) = vec.first_mut() {
first.add_context(
0..=0,
"could not re-parse printed normalized query; spans may be inaccurate",
);
}
return ContextualTranslationErrors::new(
synthetic_source,
with_normalization_stage(TranslationErrors(vec)),
);
}
let retry = type_check_with_options::<Query>(Arc::new(parse_wte.output), opts);
if retry.errors.is_empty() {
let TranslationErrors(mut vec) = initial_errors;
if let Some(first) = vec.first_mut() {
first.add_context(
0..=0,
"error spans are from the first type check on the built AST, not this printed query; highlighting may be inaccurate",
);
}
return ContextualTranslationErrors::new(
synthetic_source,
with_normalization_stage(TranslationErrors(vec)),
);
}
ContextualTranslationErrors::new(synthetic_source, with_normalization_stage(retry.errors))
}
pub fn normalize_pipeline(
pipeline: Arc<TypedPipeline>,
ctx: &mut StatementTranslationContext,
) -> Result<Arc<TypedPipeline>, Arc<TranslationError>> {
if ctx.skip_pipeline_passes {
return Ok(pipeline);
}
let pipeline = lower_broadcast_apply(pipeline, ctx)?;
let pipeline = normalize_within(pipeline, ctx)?;
let pipeline = lower_distinct(pipeline, ctx)?;
let pipeline = normalize_agg(pipeline, ctx)?;
let pipeline = normalize_window(pipeline, ctx)?;
let pipeline = extract_window_aggregates(pipeline, ctx)?;
let pipeline = normalize_explode(pipeline, ctx)?;
let pipeline = lower_unnest(pipeline, ctx)?;
let pipeline = lower_parse(pipeline, ctx)?;
let pipeline = lower_nest(pipeline, ctx)?;
let pipeline = expand_array_literals(pipeline, ctx)?;
let pipeline = desugar_in_array_literals(pipeline, ctx)?;
let pipeline = desugar_map_from_arrays(pipeline, ctx)?;
let pipeline = lower_ts_trunc_interval(pipeline, ctx)?;
let pipeline = dedup_append_source(pipeline, ctx)?;
let pipeline = align_append_schema(pipeline, ctx)?;
let pipeline = if ctx.lower_filter {
lower_filter(pipeline, ctx)?
} else {
pipeline
};
let pipeline = recognize_transform_values(pipeline, ctx)?;
let pipeline = if ctx.lower_transform {
let pipeline = rewrite_transform_values(pipeline, ctx)?;
let pipeline = lower_transform(pipeline, ctx)?;
normalize_explode(pipeline, ctx)?
} else {
pipeline
};
if ctx.skip_projection_fusion {
return Ok(pipeline);
}
let pipeline = fuse_projections(pipeline, ctx)?;
Ok(pipeline)
}