use regex::Regex;
use std::sync::OnceLock;
pub struct CypherValidator;
static PATTERNS: OnceLock<ValidationPatterns> = OnceLock::new();
struct ValidationPatterns {
basic_cypher: Regex,
dangerous_ops: Regex,
match_clause: Regex,
return_clause: Regex,
}
impl ValidationPatterns {
fn get() -> &'static Self {
PATTERNS.get_or_init(|| Self {
basic_cypher: Regex::new(r"(?i)(MATCH|CREATE|MERGE|DELETE|SET|REMOVE|RETURN|WITH|UNWIND|CALL)").unwrap(),
dangerous_ops: Regex::new(r"(?i)(DROP\s|DELETE\s)").unwrap(),
match_clause: Regex::new(r"(?i)MATCH\s+").unwrap(),
return_clause: Regex::new(r"(?i)RETURN\s+").unwrap(),
})
}
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub is_valid: bool,
pub errors: Vec<String>,
pub warnings: Vec<String>,
}
impl CypherValidator {
#[must_use]
pub fn validate(query: &str) -> ValidationResult {
let mut errors = Vec::new();
let mut warnings = Vec::new();
let query = query.trim();
let patterns = ValidationPatterns::get();
if query.is_empty() {
errors.push("Query is empty".to_string());
return ValidationResult {
is_valid: false,
errors,
warnings,
};
}
if !patterns.basic_cypher.is_match(query) {
errors.push("Query does not contain valid Cypher keywords".to_string());
}
if patterns.dangerous_ops.is_match(query) {
errors.push("Query contains potentially dangerous operations (DROP, DELETE ALL)".to_string());
}
let query_upper = query.to_uppercase();
let starts_with_non_match = query_upper.starts_with("CREATE")
|| query_upper.starts_with("MERGE")
|| query_upper.starts_with("CALL")
|| query_upper.starts_with("UNWIND");
if !patterns.match_clause.is_match(query) && !starts_with_non_match {
warnings.push("Query does not contain a MATCH clause".to_string());
}
if !patterns.return_clause.is_match(query) {
warnings.push("Query does not contain a RETURN clause".to_string());
}
if !Self::check_balanced_parentheses(query) {
errors.push("Unbalanced parentheses in query".to_string());
}
if !Self::check_balanced_brackets(query) {
errors.push("Unbalanced brackets in query".to_string());
}
ValidationResult {
is_valid: errors.is_empty(),
errors,
warnings,
}
}
fn check_balanced_parentheses(query: &str) -> bool {
let mut count = 0;
for c in query.chars() {
match c {
'(' => count += 1,
')' => {
count -= 1;
if count < 0 {
return false;
}
}
_ => {}
}
}
count == 0
}
fn check_balanced_brackets(query: &str) -> bool {
let mut count = 0;
for c in query.chars() {
match c {
'[' => count += 1,
']' => {
count -= 1;
if count < 0 {
return false;
}
}
_ => {}
}
}
count == 0
}
#[allow(dead_code)]
pub fn suggest_fix(
query: &str,
error: &str,
) -> Option<String> {
let query = query.trim();
let error_lower = error.to_lowercase();
if error_lower.contains("syntax error") || error_lower.contains("invalid syntax") {
if !query.to_uppercase().contains("RETURN") {
return Some(format!("{query}\nRETURN *"));
}
}
if (error_lower.contains("syntax error") || error_lower.contains("invalid syntax"))
&& query.contains('=')
&& !query.to_uppercase().contains("WHERE")
&& query.to_uppercase().contains("MATCH")
&& let Some(fixed) = Self::try_add_where_clause(query)
{
return Some(fixed);
}
if error_lower.contains("property") && error_lower.contains("not found") {
tracing::info!(
"Property not found error, consider checking schema or using toLower() for case-insensitive matching"
);
}
None
}
fn try_add_where_clause(query: &str) -> Option<String> {
let re = Regex::new(r"(?i)(MATCH\s+\([^)]+\))\s+([a-zA-Z_][a-zA-Z0-9_]*\.[a-zA-Z_][a-zA-Z0-9_]*\s*=)").ok()?;
if let Some(caps) = re.captures(query) {
let _match_part = caps.get(1)?.as_str();
let condition_start = caps.get(2)?.start();
let before = &query[..condition_start];
let after = &query[condition_start..];
return Some(format!("{before}\nWHERE {after}"));
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_query() {
let query = "MATCH (n:Person) WHERE n.name = 'John' RETURN n";
let result = CypherValidator::validate(query);
assert!(result.is_valid, "Query should be valid");
assert!(result.errors.is_empty(), "Should have no errors");
}
#[test]
fn test_empty_query() {
let query = "";
let result = CypherValidator::validate(query);
assert!(!result.is_valid, "Empty query should be invalid");
assert!(!result.errors.is_empty(), "Should have errors");
}
#[test]
fn test_unbalanced_parentheses() {
let query = "MATCH (n:Person WHERE n.name = 'John' RETURN n";
let result = CypherValidator::validate(query);
assert!(!result.is_valid, "Query with unbalanced parentheses should be invalid");
}
#[test]
fn test_dangerous_operations() {
let query = "MATCH (n) DROP n";
let result = CypherValidator::validate(query);
assert!(!result.is_valid, "Query with DROP should be invalid");
}
#[test]
fn test_balanced_parentheses() {
assert!(CypherValidator::check_balanced_parentheses("()"));
assert!(CypherValidator::check_balanced_parentheses("(())"));
assert!(CypherValidator::check_balanced_parentheses("(()())"));
assert!(!CypherValidator::check_balanced_parentheses("(()"));
assert!(!CypherValidator::check_balanced_parentheses("())"));
}
}