use indexmap::IndexMap;
use once_cell::sync::Lazy;
use serde_json::Value;
use crate::{
database::SchemaProvider,
parser::{
aggregators_helper::AggregateRegistry, analyzer::{AggregateResolver, AnalyzedIdentifier, AnalyzedQuery, AnalyzerError, ColumnKey, ColumnResolver, IdentifierResolver, JoinResolver, OrderByResolver, PredicateResolver, ScalarResolver, TypeInference}, ast::{Collection, Column, Query, ScalarExpr}
}
};
static DEFAULT_REGISTRY: Lazy<AggregateRegistry> = Lazy::new(AggregateRegistry::default_aggregate_registry);
pub struct AnalysisContext<'a> {
pub collections: IndexMap<String, String>,
pub schemas: &'a dyn SchemaProvider,
pub aggregates: &'a AggregateRegistry,
pub parameters: Value,
pub current_param: usize,
}
impl<'a> AnalysisContext<'a> {
pub fn new_with_aggregates(
schemas: &'a dyn SchemaProvider,
aggregates: &'a AggregateRegistry,
) -> Self {
Self {
collections: IndexMap::new(),
schemas,
aggregates,
parameters: Value::Null,
current_param: 0,
}
}
pub fn new(schemas: &'a dyn SchemaProvider) -> Self {
let reg: &AggregateRegistry = &DEFAULT_REGISTRY;
Self::new_with_aggregates(schemas, reg)
}
pub fn add_collection(&mut self, visible: impl Into<String>, backing: impl Into<String>) {
self.collections.insert(visible.into(), backing.into());
}
pub fn add_collection_alias(&mut self, visible: impl Into<String>, backing: impl Into<String>) {
self.collections.insert(visible.into(), backing.into());
}
pub fn build_context_from_query(
q: &Query,
sp: &'a dyn SchemaProvider,
aggregates: &'a AggregateRegistry,
parameters: Value,
) -> Result<Self, AnalyzerError> {
let mut ctx = Self::new_with_aggregates(sp, aggregates);
ctx.parameters = parameters;
for c in &q.collections {
match c {
Collection::Table { name, alias } => {
let visible = alias.clone().unwrap_or_else(|| name.clone());
ctx.add_collection(visible, name.clone());
}
Collection::Query => {
return Err(AnalyzerError::Other("Collection::Query not yet supported in analyzer".into()));
}
}
}
for join in &q.joins {
match &join.collection {
Collection::Table { name, alias } => {
let visible = alias.clone().unwrap_or_else(|| name.clone());
ctx.add_collection(visible, name.clone());
}
Collection::Query => {
return Err(AnalyzerError::Other("Join of subquery not yet supported in analyzer".into()));
}
}
}
Ok(ctx)
}
fn default_name_for_expr_in_analyzer(e: &ScalarExpr) -> String {
match e {
ScalarExpr::Column(Column::WithCollection{ collection, name }) => format!("{}.{}", collection, name),
ScalarExpr::Column(Column::Name{ name }) => name.clone(),
ScalarExpr::Function(f) => f.name.to_ascii_lowercase(),
ScalarExpr::Literal(_) => "_lit".into(),
ScalarExpr::WildCard | ScalarExpr::WildCardWithCollection(_) => "*".into(),
ScalarExpr::Parameter | ScalarExpr::Args(_) => "_param".into(),
}
}
fn assign_output_names(ids: &mut [AnalyzedIdentifier]) {
use std::collections::HashSet;
let mut used: HashSet<String> = HashSet::new();
for id in ids.iter_mut() {
let mut proposed = if let Some(a) = &id.alias {
a.clone()
} else {
match &id.expression {
ScalarExpr::Column(Column::WithCollection{ collection, name }) => {
if !used.contains(name) { name.clone() } else { format!("{}.{}", collection, name) }
}
_ => AnalysisContext::default_name_for_expr_in_analyzer(&id.expression),
}
};
let base = proposed.clone();
let mut k = 1usize;
while used.contains(&proposed) {
proposed = format!("{}_{}", base, k);
k += 1;
}
used.insert(proposed.clone());
id.output_name = proposed;
}
}
pub fn analyze_query(
query: &Query,
schema_provider: &'a dyn SchemaProvider,
aggregates: &'a AggregateRegistry,
parameters: Value,
) -> Result<AnalyzedQuery, AnalyzerError> {
let mut ctx = Self::build_context_from_query(query, schema_provider, aggregates, parameters)?;
let mut from_collections: Vec<(String, String)> = Vec::with_capacity(query.collections.len());
for c in &query.collections {
if let Collection::Table { name, alias } = c {
let visible = alias.clone().unwrap_or_else(|| name.clone());
if let Some(backing) = ctx.collections.get(&visible) {
from_collections.push((visible, backing.clone()));
} else {
return Err(AnalyzerError::UnknownCollection(visible));
}
}
}
let expanded_proj = IdentifierResolver::expand_projection_idents(&query.projection, &ctx)?;
let mut analyzed_proj = Vec::with_capacity(expanded_proj.len());
for id in expanded_proj {
let qexpr = ScalarResolver::qualify_scalar(&id.expression, &mut ctx, false)?;
let fexpr = ScalarResolver::fold_scalar(&qexpr);
let (ty, nullable) = TypeInference::infer_scalar(&fexpr, &ctx)?;
analyzed_proj.push(AnalyzedIdentifier {
expression: fexpr,
alias: id.alias.clone(),
ty,
nullable,
output_name: String::new(),
});
}
AnalysisContext::assign_output_names(&mut analyzed_proj);
let analyzed_joins = JoinResolver::qualify_and_fold_joins(query, &mut ctx)?;
let criteria_qualified = match &query.criteria {
Some(predicate) => Some(PredicateResolver::qualify_predicate(predicate, &mut ctx)?),
None => None
};
let criteria = criteria_qualified.as_ref().map(PredicateResolver::fold_predicate);
let having_qualified = match &query.having {
Some(predicate) => Some(PredicateResolver::qualify_predicate(predicate, &mut ctx)?),
None => None
};
let having = having_qualified.as_ref().map(PredicateResolver::fold_predicate);
let mut group_by = Vec::with_capacity(query.group_by.len());
let mut group_set = std::collections::HashSet::<ColumnKey>::new();
for c in &query.group_by {
let (qc, _) = ColumnResolver::qualify_column(c, &ctx)?;
group_set.insert(ColumnKey::of(&qc));
group_by.push(qc);
}
let is_agg_query = !group_by.is_empty()
|| analyzed_proj.iter().any(|id| AggregateResolver::contains_aggregate(&id.expression))
|| having_qualified.as_ref().is_some_and(AggregateResolver::predicate_contains_aggregate);
if !is_agg_query && having_qualified.is_some() {
return Err(AnalyzerError::Other("HAVING without GROUP BY must reference an aggregate".into()));
}
if let Some(pq) = &criteria_qualified {
if AggregateResolver::predicate_contains_aggregate(pq) {
return Err(AnalyzerError::Other("Aggregates are not allowed in WHERE".into()));
}
}
if is_agg_query {
for id in &analyzed_proj {
if !AggregateResolver::uses_only_group_by(&id.expression, &group_set, false) {
return Err(AnalyzerError::Other("SELECT expression references columns not in GROUP BY and outside aggregates".into()));
}
}
if let Some(hv_q) = &having_qualified {
if !AggregateResolver::predicate_uses_only_group_by_or_agg(hv_q, &group_set) {
return Err(AnalyzerError::Other("HAVING references columns not in GROUP BY and outside aggregates".into()));
}
}
}
let needs_agg =
!group_by.is_empty()
|| analyzed_proj.iter().any(|id| AggregateResolver::contains_aggregate(&id.expression))
|| having.as_ref().map(AggregateResolver::predicate_contains_aggregate).unwrap_or(false);
let order_by = if needs_agg {
OrderByResolver::qualify_order_by(&query.order_by, &analyzed_proj, &mut ctx, &group_set)?
} else {
OrderByResolver::qualify_order_by_non_agg(&query.order_by, &analyzed_proj, &mut ctx)?
};
Ok(AnalyzedQuery {
projection: analyzed_proj,
collections: from_collections,
joins: analyzed_joins,
criteria,
group_by,
having,
order_by,
limit: query.limit,
offset: query.offset,
})
}
}
#[cfg(test)]
mod tests {
use crate::{
database::FieldInfo,
parser::ast::{Column, ComparatorOp, Function, Identifier, Literal, OrderBy, Predicate, ScalarExpr, Truth},
JsonPrimitive, SchemaDict
};
use super::*;
use indexmap::IndexMap;
struct DummySchemas {
by_name: std::collections::HashMap<String, SchemaDict>,
}
impl DummySchemas {
fn new() -> Self { Self { by_name: std::collections::HashMap::new() } }
fn with(mut self, name: &str, fields: Vec<(&str, JsonPrimitive, bool)>) -> Self {
let mut map: IndexMap<String, FieldInfo> = IndexMap::new();
for (k, ty, nullable) in fields {
map.insert(k.to_string(), FieldInfo { ty, nullable });
}
self.by_name.insert(name.to_string(), SchemaDict { fields: map });
self
}
}
impl SchemaProvider for DummySchemas {
fn schema_of(&self, backing_collection: &str) -> Option<SchemaDict> {
self.by_name.get(backing_collection).cloned()
}
}
fn simple_ctx_for<'a>(query: &'a Query, sp: &'a DummySchemas) -> AnalysisContext<'a> {
AnalysisContext::build_context_from_query(query, sp, &DEFAULT_REGISTRY, Value::Null).expect("build context")
}
fn make_query_with_table(
table: &str,
projection: Vec<Identifier>,
criteria: Option<Predicate>,
group_by: Vec<Column>,
having: Option<Predicate>,
order_by: Vec<OrderBy>,
) -> Query {
Query {
projection,
collections: vec![Collection::Table { name: table.to_string(), alias: None }],
joins: vec![],
criteria,
group_by,
having,
order_by,
..Default::default()
}
}
#[test]
fn select_group_by_validation_error() {
let sp = DummySchemas::new().with("t", vec![
("a", JsonPrimitive::Int, false),
("b", JsonPrimitive::Int, false),
]);
let q_ok = make_query_with_table(
"t",
vec![
Identifier { expression: ScalarExpr::Column(Column::Name { name: "a".into() }), alias: None },
Identifier { expression: ScalarExpr::Function(Function { name: "sum".into(), args: vec![ScalarExpr::Column(Column::Name { name: "b".into() })], distinct: false }), alias: None },
],
None,
vec![Column::Name { name: "a".into() }],
None,
vec![],
);
let q_err = make_query_with_table(
"t",
vec![
Identifier { expression: ScalarExpr::Column(Column::Name { name: "a".into() }), alias: None },
Identifier { expression: ScalarExpr::Column(Column::Name { name: "b".into() }), alias: None },
],
None,
vec![Column::Name { name: "a".into() }],
None,
vec![],
);
let analyzed_ok = AnalysisContext::analyze_query(&q_ok, &sp, &DEFAULT_REGISTRY, Value::Null);
assert!(analyzed_ok.is_ok(), "expected OK, got: {:?}", analyzed_ok);
let analyzed_err = AnalysisContext::analyze_query(&q_err, &sp, &DEFAULT_REGISTRY, Value::Null);
assert!(analyzed_err.is_err(), "expected GROUP BY validation error");
let msg = format!("{analyzed_err:?}");
assert!(msg.to_lowercase().contains("group by"), "err msg should mention group by; got: {msg}");
}
#[test]
fn where_rejects_aggregates() {
let sp = DummySchemas::new().with("t", vec![
("a", JsonPrimitive::Int, false),
("b", JsonPrimitive::Int, false),
]);
let crit = Some(Predicate::Compare {
left: ScalarExpr::Function(Function { name: "sum".into(), args: vec![ScalarExpr::Column(Column::Name { name: "b".into() })], distinct: false }),
op: ComparatorOp::Gt,
right: ScalarExpr::Literal(Literal::Int(10)),
});
let q = make_query_with_table(
"t",
vec![ Identifier { expression: ScalarExpr::Column(Column::Name { name: "a".into() }), alias: None } ],
crit,
vec![],
None,
vec![],
);
let res = AnalysisContext::analyze_query(&q, &sp, &DEFAULT_REGISTRY, Value::Null);
assert!(res.is_err(), "aggregates in WHERE should error");
let msg = format!("{res:?}");
assert!(msg.to_lowercase().contains("where"), "err msg should mention WHERE; got: {msg}");
}
#[test]
fn having_allows_aggregates() {
let sp = DummySchemas::new().with("t", vec![
("a", JsonPrimitive::Int, false),
("b", JsonPrimitive::Int, false),
]);
let having = Some(Predicate::Compare {
left: ScalarExpr::Function(Function { name: "count".into(), args: vec![ScalarExpr::WildCard], distinct: false }),
op: ComparatorOp::Gt,
right: ScalarExpr::Literal(Literal::Int(1)),
});
let q = make_query_with_table(
"t",
vec![
Identifier { expression: ScalarExpr::Column(Column::Name { name: "a".into() }), alias: None },
Identifier { expression: ScalarExpr::Function(Function { name: "count".into(), args: vec![ScalarExpr::WildCard], distinct: false }), alias: None },
],
None,
vec![Column::Name { name: "a".into() }],
having,
vec![],
);
let res = AnalysisContext::analyze_query(&q, &sp, &DEFAULT_REGISTRY, Value::Null);
assert!(res.is_ok(), "HAVING with aggregate should be accepted: {:?}", res.err());
}
#[test]
fn order_by_alias_and_positional_and_validation() {
let sp = DummySchemas::new().with("t", vec![
("name", JsonPrimitive::String, false),
("age", JsonPrimitive::Int, false),
]);
let q = make_query_with_table(
"t",
vec![
Identifier { expression: ScalarExpr::Column(Column::Name { name: "name".into() }), alias: Some("n".into()) },
Identifier { expression: ScalarExpr::Column(Column::Name { name: "age".into() }), alias: None },
],
None,
vec![Column::Name { name: "name".into() }, Column::Name { name: "age".into() }],
None,
vec![
OrderBy { expr: ScalarExpr::Column(Column::Name { name: "n".into() }), ascending: true },
OrderBy { expr: ScalarExpr::Literal(Literal::Int(2)), ascending: false },
],
);
let analyzed = AnalysisContext::analyze_query(&q, &sp, &DEFAULT_REGISTRY, Value::Null).expect("analyze");
assert_eq!(analyzed.order_by.len(), 2);
match &analyzed.order_by[0].expr {
ScalarExpr::Column(Column::WithCollection { collection, name }) => {
assert_eq!(name, "name");
assert_eq!(collection, "t");
}
e => panic!("unexpected first order by expr: {e:?}"),
}
match &analyzed.order_by[1].expr {
ScalarExpr::Column(Column::WithCollection { name, .. }) => assert_eq!(name, "age"),
e => panic!("unexpected second order by expr: {e:?}"),
}
let q_bad = make_query_with_table(
"t",
vec![ Identifier { expression: ScalarExpr::Function(Function { name: "count".into(), args: vec![ScalarExpr::WildCard], distinct: false }), alias: None } ],
None,
vec![Column::Name { name: "name".into() }],
None,
vec![ OrderBy { expr: ScalarExpr::Column(Column::Name { name: "age".into() }), ascending: true } ],
);
let err = AnalysisContext::analyze_query(&q_bad, &sp, &DEFAULT_REGISTRY, Value::Null);
assert!(err.is_err(), "ORDER BY should error when referencing non-grouped columns outside aggregates");
let msg = format!("{err:?}");
assert!(msg.to_lowercase().contains("order by"), "err msg should mention ORDER BY; got: {msg}");
}
#[test]
fn wildcard_expansion_is_stable() {
let sp = DummySchemas::new()
.with("t1", vec![
("id", JsonPrimitive::Int, false),
("name",JsonPrimitive::String, false),
])
.with("t2", vec![
("x", JsonPrimitive::Int, false),
]);
let query = Query {
projection: vec![ Identifier { expression: ScalarExpr::WildCard, alias: None } ],
collections: vec![
Collection::Table { name: "t1".into(), alias: None },
Collection::Table { name: "t2".into(), alias: None },
],
joins: vec![],
criteria: None,
group_by: vec![],
having: None,
order_by: vec![],
..Default::default()
};
let analyzed = AnalysisContext::analyze_query(&query, &sp, &DEFAULT_REGISTRY, Value::Null).expect("analyze");
let cols: Vec<(String,String)> = analyzed.projection.iter().filter_map(|id| {
if let ScalarExpr::Column(Column::WithCollection{collection, name}) = &id.expression {
Some((collection.clone(), name.clone()))
} else { None }
}).collect();
assert_eq!(cols, vec![
("t1".into(), "id".into()),
("t1".into(), "name".into()),
("t2".into(), "x".into()),
]);
}
#[test]
fn folding_like_case_insensitive_with_escape_and_in_null_unknown() {
let p1 = Predicate::Like {
expr: ScalarExpr::Literal(Literal::String("Hello".into())),
pattern: ScalarExpr::Literal(Literal::String("he%".into())),
negated: false,
};
match PredicateResolver::fold_predicate(&p1) {
Predicate::Const3(Truth::True) => {},
other => panic!("expected Const3(True), got {other:?}"),
}
let p2 = Predicate::Like {
expr: ScalarExpr::Literal(Literal::String("he%llo".into())),
pattern: ScalarExpr::Literal(Literal::String(r"he\%l%".into())),
negated: false,
};
match PredicateResolver::fold_predicate(&p2) {
Predicate::Const3(Truth::True) => {},
other => panic!("expected Const3(True) for escaped %, got {other:?}"),
}
let p3 = Predicate::InList {
expr: ScalarExpr::Literal(Literal::Int(2)),
list: vec![ScalarExpr::Literal(Literal::Int(1)), ScalarExpr::Literal(Literal::Null)],
negated: false,
};
match PredicateResolver::fold_predicate(&p3) {
Predicate::Const3(Truth::Unknown) => {},
other => panic!("expected Const3(Unknown) for IN with NULL, got {other:?}"),
}
}
#[test]
fn type_inference_for_aggregates() {
let sp = DummySchemas::new().with("t", vec![
("i", JsonPrimitive::Int, false),
("f", JsonPrimitive::Float, false),
("s", JsonPrimitive::String,false),
]);
let q_base = Query {
projection: vec![],
collections: vec![Collection::Table { name: "t".into(), alias: None }],
joins: vec![],
criteria: None,
group_by: vec![],
having: None,
order_by: vec![],
..Default::default()
};
let ctx = simple_ctx_for(&q_base, &sp);
let cnt = ScalarExpr::Function(Function { name: "count".into(), args: vec![ScalarExpr::WildCard], distinct: false });
let (ty, nullable) = TypeInference::infer_scalar(&cnt, &ctx).expect("type");
assert_eq!(ty, JsonPrimitive::Int);
assert!(!nullable);
let sum_i = ScalarExpr::Function(Function { name: "sum".into(), args: vec![ScalarExpr::Column(Column::Name { name: "i".into() })], distinct: false });
let (ty, nullable) = TypeInference::infer_scalar(&sum_i, &ctx).expect("type");
assert_eq!(ty, JsonPrimitive::Int);
assert!(nullable);
let sum_f = ScalarExpr::Function(Function { name: "sum".into(), args: vec![ScalarExpr::Column(Column::Name { name: "f".into() })], distinct: false });
let (ty, nullable) = TypeInference::infer_scalar(&sum_f, &ctx).expect("type");
assert_eq!(ty, JsonPrimitive::Float);
assert!(nullable);
let avg_i = ScalarExpr::Function(Function { name: "avg".into(), args: vec![ScalarExpr::Column(Column::Name { name: "i".into() })], distinct: false });
let (ty, nullable) = TypeInference::infer_scalar(&avg_i, &ctx).expect("type");
assert_eq!(ty, JsonPrimitive::Float);
assert!(nullable);
let min_s = ScalarExpr::Function(Function { name: "min".into(), args: vec![ScalarExpr::Column(Column::Name { name: "s".into() })], distinct: false });
let (ty, nullable) = TypeInference::infer_scalar(&min_s, &ctx).expect("type");
assert_eq!(ty, JsonPrimitive::String);
assert!(nullable);
}
}