pub mod context;
pub mod error;
pub mod registry;
pub mod rule;
pub mod rules;
pub mod walker;
use context::{RewriteContext, RewriteStats};
use error::RewriteError;
use registry::RewriteRegistry;
use walker::ExpressionWalker;
use std::sync::OnceLock;
use uni_cypher::ast::{Expr, Statement};
static GLOBAL_REGISTRY: OnceLock<RewriteRegistry> = OnceLock::new();
fn get_or_init_registry() -> &'static RewriteRegistry {
GLOBAL_REGISTRY.get_or_init(|| {
tracing::info!("Initializing query rewrite framework");
RewriteRegistry::with_builtin_rules()
})
}
fn log_rewrite_stats(stats: &RewriteStats) {
if stats.functions_visited > 0 {
tracing::info!(
"Rewrite pass complete: {} functions visited, {} rewritten, {} skipped",
stats.functions_visited,
stats.functions_rewritten,
stats.functions_skipped
);
if !stats.errors.is_empty() {
tracing::debug!("Rewrite errors: {:?}", stats.errors);
}
}
}
pub fn rewrite_query(
query: uni_cypher::ast::Query,
) -> Result<uni_cypher::ast::Query, RewriteError> {
let registry = get_or_init_registry();
let context = RewriteContext::default();
let mut walker = ExpressionWalker::new(registry, context);
let rewritten_query = walker.rewrite_query(query);
log_rewrite_stats(&walker.context().stats);
Ok(rewritten_query)
}
pub fn rewrite_statement(stmt: Statement) -> Result<Statement, RewriteError> {
let registry = get_or_init_registry();
let context = RewriteContext::default();
let mut walker = ExpressionWalker::new(registry, context);
let rewritten_stmt = walker.rewrite_statement(stmt);
log_rewrite_stats(&walker.context().stats);
Ok(rewritten_stmt)
}
pub fn rewrite_expr(expr: Expr) -> Result<Expr, RewriteError> {
let registry = get_or_init_registry();
let context = RewriteContext::default();
let mut walker = ExpressionWalker::new(registry, context);
Ok(walker.rewrite_expr(expr))
}
pub fn rewrite_expr_with_context(
expr: Expr,
context: RewriteContext,
) -> Result<(Expr, RewriteContext), RewriteError> {
let registry = get_or_init_registry();
let mut walker = ExpressionWalker::new(registry, context);
let rewritten_expr = walker.rewrite_expr(expr);
let final_context = walker.into_context();
Ok((rewritten_expr, final_context))
}
pub fn get_stats() -> RewriteStats {
RewriteStats::default()
}
pub fn has_rewrite_rule(function_name: &str) -> bool {
let registry = get_or_init_registry();
registry.has_rule(function_name)
}
pub fn registered_functions() -> Vec<String> {
let registry = get_or_init_registry();
registry.registered_functions()
}
#[cfg(test)]
mod tests {
use super::*;
use uni_cypher::ast::CypherLiteral;
#[test]
fn test_rewrite_expr_basic() {
let expr = Expr::Literal(CypherLiteral::Integer(42));
let result = rewrite_expr(expr.clone()).unwrap();
assert_eq!(result, expr);
}
#[test]
fn test_has_rewrite_rule() {
assert!(has_rewrite_rule("uni.temporal.validAt"));
assert!(has_rewrite_rule("uni.temporal.overlaps"));
assert!(has_rewrite_rule("uni.temporal.isOngoing"));
assert!(!has_rewrite_rule("nonexistent.function"));
}
#[test]
fn test_registered_functions() {
let functions = registered_functions();
assert!(functions.len() >= 3);
assert!(functions.contains(&"uni.temporal.validAt".to_string()));
assert!(functions.contains(&"uni.temporal.overlaps".to_string()));
}
#[test]
fn test_rewrite_with_context() {
use context::RewriteConfig;
let expr = Expr::Literal(CypherLiteral::Integer(42));
let config = RewriteConfig::default().with_verbose_logging();
let context = RewriteContext::with_config(config);
let (result, final_context) = rewrite_expr_with_context(expr.clone(), context).unwrap();
assert_eq!(result, expr);
assert_eq!(final_context.stats.functions_visited, 0); }
}