use std::collections::HashSet;
use super::{
Ast, Catalog, ProjectionItem, QueryError, Result, ScalarComparison, TypedAst, WhereClause,
};
pub(super) fn validate(ast: &Ast, _catalog: &Catalog) -> Result<TypedAst> {
if ast.match_aliases.is_empty() {
return Err(QueryError::InvalidSemantic(
"MATCH must contain at least one alias".to_string(),
));
}
validate_unique_aliases(&ast.match_aliases)?;
let scalar_predicate = validate_where_clause(ast.where_clause.as_ref())?;
if ast.return_items.is_empty() {
return Err(QueryError::InvalidSemantic(
"RETURN must contain at least one projection".to_string(),
));
}
let mut scope_aliases: HashSet<String> = ast.match_aliases.iter().cloned().collect();
if !ast.with_items.is_empty() {
validate_projection_items(&ast.with_items, &scope_aliases, true, "WITH")?;
scope_aliases.clear();
for item in &ast.with_items {
let output = item.output_name().ok_or_else(|| {
QueryError::InvalidSemantic(
"WITH function projections must use AS alias".to_string(),
)
})?;
scope_aliases.insert(output);
}
}
validate_projection_items(&ast.return_items, &scope_aliases, false, "RETURN")?;
Ok(TypedAst {
ast: ast.clone(),
scalar_predicate,
})
}
fn validate_unique_aliases(aliases: &[String]) -> Result<()> {
let mut seen = HashSet::new();
for alias in aliases {
if !seen.insert(alias.clone()) {
return Err(QueryError::InvalidSemantic(format!(
"duplicate MATCH alias '{}'",
alias
)));
}
}
Ok(())
}
fn validate_where_clause(where_clause: Option<&WhereClause>) -> Result<Option<ScalarComparison>> {
match where_clause {
None => Ok(None),
Some(WhereClause::Function {
function,
args,
operator: _,
threshold: _,
}) => {
if function.starts_with("vector.") {
if !matches!(function.as_str(), "vector.cosine" | "vector.euclidean") {
return Err(QueryError::InvalidSemantic(format!(
"unsupported vector function '{}'; supported: vector.cosine, vector.euclidean",
function
)));
}
if args.len() != 2 {
return Err(QueryError::InvalidSemantic(
"vector predicate requires target and parameter".to_string(),
));
}
if !args[1].starts_with('$') {
return Err(QueryError::InvalidSemantic(
"vector predicate parameter must start with '$'".to_string(),
));
}
Ok(None)
} else if function.starts_with("bitmap.") {
if args.len() != 2 {
return Err(QueryError::InvalidSemantic(
"bitmap predicate requires index name and value key".to_string(),
));
}
if args[0].trim().is_empty() || args[1].trim().is_empty() {
return Err(QueryError::InvalidSemantic(
"bitmap predicate requires non-empty index name and value key".to_string(),
));
}
Ok(None)
} else {
Err(QueryError::InvalidSemantic(
"WHERE function predicates must use vector.* or bitmap.*".to_string(),
))
}
}
Some(WhereClause::Comparison {
left,
operator,
right,
}) => {
let field = left.split('.').nth(1).ok_or_else(|| {
QueryError::InvalidSemantic(format!(
"scalar WHERE left side must be alias.field, got '{}'",
left
))
})?;
const SUPPORTED: &[&str] = &["adjacency_degree", "delta_count", "has_full", "score"];
if !SUPPORTED.contains(&field) {
return Err(QueryError::InvalidSemantic(format!(
"unsupported scalar field '{}'; supported: adjacency_degree, delta_count, has_full, score",
field
)));
}
let value: f64 = right.parse().map_err(|_| {
QueryError::InvalidSemantic(format!(
"scalar WHERE value '{}' is not a valid number",
right
))
})?;
Ok(Some(ScalarComparison {
field: field.to_string(),
operator: operator.clone(),
value,
}))
}
}
}
fn validate_projection_items(
items: &[ProjectionItem],
scope_aliases: &HashSet<String>,
require_function_alias: bool,
clause_name: &str,
) -> Result<()> {
for item in items {
match item {
ProjectionItem::Identifier(value) => {
validate_identifier_reference(value, scope_aliases, clause_name)?;
}
ProjectionItem::Function {
name,
argument,
alias,
} => {
let lower = name.to_ascii_lowercase();
if !matches!(lower.as_str(), "count" | "sum" | "collect") {
return Err(QueryError::InvalidSemantic(format!(
"{} function '{}' is not supported",
clause_name, name
)));
}
if lower == "count" && argument == "*" {
} else {
validate_identifier_reference(argument, scope_aliases, clause_name)?;
}
if (lower == "sum" || lower == "collect") && argument == "*" {
return Err(QueryError::InvalidSemantic(format!(
"{} function '{}' does not support '*'",
clause_name, name
)));
}
if require_function_alias && alias.is_none() {
return Err(QueryError::InvalidSemantic(format!(
"{} function projections must use AS alias",
clause_name
)));
}
if let Some(alias) = alias {
if alias.trim().is_empty() {
return Err(QueryError::InvalidSemantic(format!(
"{} projection alias must not be empty",
clause_name
)));
}
}
}
}
}
Ok(())
}
fn validate_identifier_reference(
value: &str,
scope_aliases: &HashSet<String>,
clause_name: &str,
) -> Result<()> {
let root = value
.split('.')
.next()
.map(str::trim)
.filter(|part| !part.is_empty())
.ok_or_else(|| {
QueryError::InvalidSemantic(format!(
"{} projection has invalid identifier '{}'",
clause_name, value
))
})?;
if !scope_aliases.contains(root) {
return Err(QueryError::InvalidSemantic(format!(
"{} references unknown alias '{}'",
clause_name, root
)));
}
Ok(())
}