use crate::query::rewrite::error::RewriteError;
use std::collections::HashMap;
#[derive(Default)]
pub struct RewriteContext {
pub scope: HashMap<String, VariableInfo>,
pub stats: RewriteStats,
pub config: RewriteConfig,
}
impl RewriteContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: RewriteConfig) -> Self {
Self {
scope: HashMap::new(),
stats: RewriteStats::default(),
config,
}
}
pub fn get_variable(&self, name: &str) -> Option<&VariableInfo> {
self.scope.get(name)
}
pub fn add_variable(&mut self, name: String, info: VariableInfo) {
self.scope.insert(name, info);
}
}
#[derive(Debug, Clone)]
pub struct VariableInfo {
pub name: String,
pub label: Option<String>,
pub is_edge: bool,
pub properties: HashMap<String, PropertyType>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PropertyType {
String,
Integer,
Float,
Boolean,
DateTime,
List,
Map,
Unknown,
}
#[derive(Debug, Default, Clone)]
pub struct RewriteStats {
pub functions_visited: usize,
pub functions_rewritten: usize,
pub functions_skipped: usize,
pub errors: Vec<RewriteError>,
pub rule_stats: HashMap<String, RuleStats>,
}
impl RewriteStats {
fn rule_stats_mut(&mut self, rule_name: &str) -> &mut RuleStats {
self.rule_stats.entry(rule_name.to_string()).or_default()
}
pub fn record_success(&mut self, rule_name: &str) {
self.functions_rewritten += 1;
self.rule_stats_mut(rule_name).record_success();
}
pub fn record_failure(&mut self, rule_name: &str, error: RewriteError) {
self.functions_skipped += 1;
self.rule_stats_mut(rule_name).record_failure(error.clone());
self.errors.push(error);
}
pub fn record_visit(&mut self) {
self.functions_visited += 1;
}
}
#[derive(Debug, Default, Clone)]
pub struct RuleStats {
pub attempts: usize,
pub successes: usize,
pub failures: HashMap<String, usize>,
}
impl RuleStats {
fn record_success(&mut self) {
self.attempts += 1;
self.successes += 1;
}
fn record_failure(&mut self, error: RewriteError) {
self.attempts += 1;
let error_key = format!("{error:?}");
*self.failures.entry(error_key).or_default() += 1;
}
}
#[derive(Debug, Clone)]
pub struct RewriteConfig {
pub enable_temporal: bool,
pub enable_spatial: bool,
pub enable_property: bool,
pub fallback_to_scalar: bool,
pub verbose_logging: bool,
}
impl Default for RewriteConfig {
fn default() -> Self {
Self {
enable_temporal: true,
enable_spatial: false,
enable_property: false,
fallback_to_scalar: true,
verbose_logging: false,
}
}
}
impl RewriteConfig {
pub fn all_enabled() -> Self {
Self {
enable_temporal: true,
enable_spatial: true,
enable_property: true,
fallback_to_scalar: true,
verbose_logging: false,
}
}
pub fn all_disabled() -> Self {
Self {
enable_temporal: false,
enable_spatial: false,
enable_property: false,
fallback_to_scalar: true,
verbose_logging: false,
}
}
pub fn with_verbose_logging(mut self) -> Self {
self.verbose_logging = true;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_default() {
let ctx = RewriteContext::default();
assert!(ctx.scope.is_empty());
assert_eq!(ctx.stats.functions_visited, 0);
assert!(ctx.config.enable_temporal);
}
#[test]
fn test_stats_recording() {
let mut stats = RewriteStats::default();
stats.record_success("test.func");
assert_eq!(stats.functions_rewritten, 1);
assert_eq!(stats.rule_stats.get("test.func").unwrap().successes, 1);
stats.record_failure(
"test.func",
RewriteError::NotApplicable {
reason: "test".into(),
},
);
assert_eq!(stats.functions_skipped, 1);
assert_eq!(stats.errors.len(), 1);
}
#[test]
fn test_config_builders() {
let all_enabled = RewriteConfig::all_enabled();
assert!(all_enabled.enable_temporal);
assert!(all_enabled.enable_spatial);
assert!(all_enabled.enable_property);
let all_disabled = RewriteConfig::all_disabled();
assert!(!all_disabled.enable_temporal);
assert!(!all_disabled.enable_spatial);
assert!(!all_disabled.enable_property);
}
}