iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use std::collections::HashSet;

use super::{
    Ast, Catalog, ProjectionItem, QueryError, Result, ScalarComparison, TypedAst, WhereClause,
};

pub(super) fn validate(ast: &Ast, _catalog: &Catalog) -> Result<TypedAst> {
    if ast.match_aliases.is_empty() {
        return Err(QueryError::InvalidSemantic(
            "MATCH must contain at least one alias".to_string(),
        ));
    }

    validate_unique_aliases(&ast.match_aliases)?;
    let scalar_predicate = validate_where_clause(ast.where_clause.as_ref())?;

    if ast.return_items.is_empty() {
        return Err(QueryError::InvalidSemantic(
            "RETURN must contain at least one projection".to_string(),
        ));
    }

    let mut scope_aliases: HashSet<String> = ast.match_aliases.iter().cloned().collect();
    if !ast.with_items.is_empty() {
        validate_projection_items(&ast.with_items, &scope_aliases, true, "WITH")?;
        scope_aliases.clear();
        for item in &ast.with_items {
            let output = item.output_name().ok_or_else(|| {
                QueryError::InvalidSemantic(
                    "WITH function projections must use AS alias".to_string(),
                )
            })?;
            scope_aliases.insert(output);
        }
    }

    validate_projection_items(&ast.return_items, &scope_aliases, false, "RETURN")?;
    Ok(TypedAst {
        ast: ast.clone(),
        scalar_predicate,
    })
}

fn validate_unique_aliases(aliases: &[String]) -> Result<()> {
    let mut seen = HashSet::new();
    for alias in aliases {
        if !seen.insert(alias.clone()) {
            return Err(QueryError::InvalidSemantic(format!(
                "duplicate MATCH alias '{}'",
                alias
            )));
        }
    }
    Ok(())
}

fn validate_where_clause(where_clause: Option<&WhereClause>) -> Result<Option<ScalarComparison>> {
    match where_clause {
        None => Ok(None),
        Some(WhereClause::Function {
            function,
            args,
            operator: _,
            threshold: _,
        }) => {
            if function.starts_with("vector.") {
                if !matches!(function.as_str(), "vector.cosine" | "vector.euclidean") {
                    return Err(QueryError::InvalidSemantic(format!(
                        "unsupported vector function '{}'; supported: vector.cosine, vector.euclidean",
                        function
                    )));
                }
                if args.len() != 2 {
                    return Err(QueryError::InvalidSemantic(
                        "vector predicate requires target and parameter".to_string(),
                    ));
                }
                if !args[1].starts_with('$') {
                    return Err(QueryError::InvalidSemantic(
                        "vector predicate parameter must start with '$'".to_string(),
                    ));
                }
                Ok(None)
            } else if function.starts_with("bitmap.") {
                if args.len() != 2 {
                    return Err(QueryError::InvalidSemantic(
                        "bitmap predicate requires index name and value key".to_string(),
                    ));
                }
                if args[0].trim().is_empty() || args[1].trim().is_empty() {
                    return Err(QueryError::InvalidSemantic(
                        "bitmap predicate requires non-empty index name and value key".to_string(),
                    ));
                }
                Ok(None)
            } else {
                Err(QueryError::InvalidSemantic(
                    "WHERE function predicates must use vector.* or bitmap.*".to_string(),
                ))
            }
        }
        Some(WhereClause::Comparison {
            left,
            operator,
            right,
        }) => {
            let field = left.split('.').nth(1).ok_or_else(|| {
                QueryError::InvalidSemantic(format!(
                    "scalar WHERE left side must be alias.field, got '{}'",
                    left
                ))
            })?;
            const SUPPORTED: &[&str] = &["adjacency_degree", "delta_count", "has_full", "score"];
            if !SUPPORTED.contains(&field) {
                return Err(QueryError::InvalidSemantic(format!(
                    "unsupported scalar field '{}'; supported: adjacency_degree, delta_count, has_full, score",
                    field
                )));
            }
            let value: f64 = right.parse().map_err(|_| {
                QueryError::InvalidSemantic(format!(
                    "scalar WHERE value '{}' is not a valid number",
                    right
                ))
            })?;
            Ok(Some(ScalarComparison {
                field: field.to_string(),
                operator: operator.clone(),
                value,
            }))
        }
    }
}

fn validate_projection_items(
    items: &[ProjectionItem],
    scope_aliases: &HashSet<String>,
    require_function_alias: bool,
    clause_name: &str,
) -> Result<()> {
    for item in items {
        match item {
            ProjectionItem::Identifier(value) => {
                validate_identifier_reference(value, scope_aliases, clause_name)?;
            }
            ProjectionItem::Function {
                name,
                argument,
                alias,
            } => {
                let lower = name.to_ascii_lowercase();
                if !matches!(lower.as_str(), "count" | "sum" | "collect") {
                    return Err(QueryError::InvalidSemantic(format!(
                        "{} function '{}' is not supported",
                        clause_name, name
                    )));
                }
                if lower == "count" && argument == "*" {
                    // allowed
                } else {
                    validate_identifier_reference(argument, scope_aliases, clause_name)?;
                }
                if (lower == "sum" || lower == "collect") && argument == "*" {
                    return Err(QueryError::InvalidSemantic(format!(
                        "{} function '{}' does not support '*'",
                        clause_name, name
                    )));
                }
                if require_function_alias && alias.is_none() {
                    return Err(QueryError::InvalidSemantic(format!(
                        "{} function projections must use AS alias",
                        clause_name
                    )));
                }
                if let Some(alias) = alias {
                    if alias.trim().is_empty() {
                        return Err(QueryError::InvalidSemantic(format!(
                            "{} projection alias must not be empty",
                            clause_name
                        )));
                    }
                }
            }
        }
    }
    Ok(())
}

fn validate_identifier_reference(
    value: &str,
    scope_aliases: &HashSet<String>,
    clause_name: &str,
) -> Result<()> {
    let root = value
        .split('.')
        .next()
        .map(str::trim)
        .filter(|part| !part.is_empty())
        .ok_or_else(|| {
            QueryError::InvalidSemantic(format!(
                "{} projection has invalid identifier '{}'",
                clause_name, value
            ))
        })?;
    if !scope_aliases.contains(root) {
        return Err(QueryError::InvalidSemantic(format!(
            "{} references unknown alias '{}'",
            clause_name, root
        )));
    }
    Ok(())
}