use crate::query::datalog::stratification::DependencyGraph;
use crate::query::datalog::types::Rule;
use anyhow::Result;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RuleRegistry {
rules: HashMap<String, Vec<Rule>>,
}
impl Default for RuleRegistry {
fn default() -> Self {
Self::new()
}
}
impl RuleRegistry {
pub fn new() -> Self {
RuleRegistry {
rules: HashMap::new(),
}
}
pub fn register_rule(&mut self, predicate: String, rule: Rule) -> Result<()> {
self.rules.entry(predicate.clone()).or_default().push(rule);
let graph = DependencyGraph::from_rules(self);
if let Err(e) = graph.stratify() {
let rules = self.rules.get_mut(&predicate).unwrap();
rules.pop();
if rules.is_empty() {
self.rules.remove(&predicate);
}
return Err(e);
}
Ok(())
}
pub fn get_rules(&self, predicate: &str) -> Vec<Rule> {
self.rules.get(predicate).cloned().unwrap_or_default()
}
#[allow(dead_code)]
pub fn has_rule(&self, predicate: &str) -> bool {
self.rules.contains_key(predicate)
}
#[allow(dead_code)]
pub fn rule_count(&self) -> usize {
self.rules.values().map(|v| v.len()).sum()
}
#[allow(dead_code)]
pub fn predicate_count(&self) -> usize {
self.rules.len()
}
#[allow(dead_code)]
pub fn clear(&mut self) {
self.rules.clear();
}
#[allow(dead_code)]
pub fn predicate_names(&self) -> Vec<String> {
self.rules.keys().cloned().collect()
}
#[allow(dead_code)]
pub(crate) fn register_rule_unchecked(&mut self, predicate: String, rule: Rule) {
self.rules.entry(predicate).or_default().push(rule);
}
pub fn all_rules(&self) -> impl Iterator<Item = (&str, &[Rule])> {
self.rules.iter().map(|(k, v)| (k.as_str(), v.as_slice()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::datalog::types::EdnValue;
fn create_test_rule(predicate: &str) -> Rule {
use crate::query::datalog::types::{Pattern, WhereClause};
Rule {
head: vec![
EdnValue::Symbol(predicate.to_string()),
EdnValue::Symbol("?x".to_string()),
EdnValue::Symbol("?y".to_string()),
],
body: vec![WhereClause::Pattern(Pattern::new(
EdnValue::Symbol("?x".to_string()),
EdnValue::Keyword(":connected".to_string()),
EdnValue::Symbol("?y".to_string()),
))],
}
}
#[test]
fn test_rule_registry_new() {
let registry = RuleRegistry::new();
assert_eq!(registry.rule_count(), 0);
assert_eq!(registry.predicate_count(), 0);
}
#[test]
fn test_register_single_rule() {
let mut registry = RuleRegistry::new();
let rule = create_test_rule("reachable");
registry
.register_rule("reachable".to_string(), rule)
.unwrap();
assert_eq!(registry.rule_count(), 1);
assert_eq!(registry.predicate_count(), 1);
assert!(registry.has_rule("reachable"));
}
#[test]
fn test_register_multiple_rules_same_predicate() {
let mut registry = RuleRegistry::new();
let base_rule = create_test_rule("reachable");
let recursive_rule = create_test_rule("reachable");
registry
.register_rule("reachable".to_string(), base_rule)
.unwrap();
registry
.register_rule("reachable".to_string(), recursive_rule)
.unwrap();
assert_eq!(registry.rule_count(), 2);
assert_eq!(registry.predicate_count(), 1); assert_eq!(registry.get_rules("reachable").len(), 2);
}
#[test]
fn test_register_rules_different_predicates() {
let mut registry = RuleRegistry::new();
let rule1 = create_test_rule("reachable");
let rule2 = create_test_rule("ancestor");
registry
.register_rule("reachable".to_string(), rule1)
.unwrap();
registry
.register_rule("ancestor".to_string(), rule2)
.unwrap();
assert_eq!(registry.rule_count(), 2);
assert_eq!(registry.predicate_count(), 2);
assert!(registry.has_rule("reachable"));
assert!(registry.has_rule("ancestor"));
}
#[test]
fn test_get_rules_empty() {
let registry = RuleRegistry::new();
let rules = registry.get_rules("nonexistent");
assert_eq!(rules.len(), 0);
}
#[test]
fn test_get_rules_returns_all() {
let mut registry = RuleRegistry::new();
let rule1 = create_test_rule("reachable");
let rule2 = create_test_rule("reachable");
let rule3 = create_test_rule("reachable");
registry
.register_rule("reachable".to_string(), rule1)
.unwrap();
registry
.register_rule("reachable".to_string(), rule2)
.unwrap();
registry
.register_rule("reachable".to_string(), rule3)
.unwrap();
let rules = registry.get_rules("reachable");
assert_eq!(rules.len(), 3);
}
#[test]
fn test_has_rule() {
let mut registry = RuleRegistry::new();
assert!(!registry.has_rule("reachable"));
let rule = create_test_rule("reachable");
registry
.register_rule("reachable".to_string(), rule)
.unwrap();
assert!(registry.has_rule("reachable"));
assert!(!registry.has_rule("ancestor"));
}
#[test]
fn test_clear() {
let mut registry = RuleRegistry::new();
let rule1 = create_test_rule("reachable");
let rule2 = create_test_rule("ancestor");
registry
.register_rule("reachable".to_string(), rule1)
.unwrap();
registry
.register_rule("ancestor".to_string(), rule2)
.unwrap();
assert_eq!(registry.rule_count(), 2);
registry.clear();
assert_eq!(registry.rule_count(), 0);
assert_eq!(registry.predicate_count(), 0);
assert!(!registry.has_rule("reachable"));
assert!(!registry.has_rule("ancestor"));
}
#[test]
fn test_register_rule_rejects_negative_cycle() {
use crate::query::datalog::types::WhereClause;
let mut registry = RuleRegistry::new();
let rule_p = Rule {
head: vec![
EdnValue::Symbol("p".to_string()),
EdnValue::Symbol("?x".to_string()),
],
body: vec![WhereClause::Not(vec![WhereClause::RuleInvocation {
predicate: "q".to_string(),
args: vec![EdnValue::Symbol("?x".to_string())],
}])],
};
let rule_q = Rule {
head: vec![
EdnValue::Symbol("q".to_string()),
EdnValue::Symbol("?x".to_string()),
],
body: vec![WhereClause::Not(vec![WhereClause::RuleInvocation {
predicate: "p".to_string(),
args: vec![EdnValue::Symbol("?x".to_string())],
}])],
};
registry.register_rule("p".to_string(), rule_p).unwrap();
let result = registry.register_rule("q".to_string(), rule_q);
assert!(
result.is_err(),
"Expected stratification error for negative cycle"
);
assert!(registry.get_rules("q").is_empty());
}
#[test]
fn test_register_rule_accepts_stratifiable_negation() {
use crate::query::datalog::types::{Pattern, WhereClause};
let mut registry = RuleRegistry::new();
let rule_eligible = Rule {
head: vec![
EdnValue::Symbol("eligible".to_string()),
EdnValue::Symbol("?x".to_string()),
],
body: vec![
WhereClause::Pattern(Pattern::new(
EdnValue::Symbol("?x".to_string()),
EdnValue::Keyword(":applied".to_string()),
EdnValue::Boolean(true),
)),
WhereClause::Not(vec![WhereClause::RuleInvocation {
predicate: "rejected".to_string(),
args: vec![EdnValue::Symbol("?x".to_string())],
}]),
],
};
registry
.register_rule("eligible".to_string(), rule_eligible)
.unwrap();
assert_eq!(registry.get_rules("eligible").len(), 1);
}
#[test]
fn test_predicate_names() {
let mut registry = RuleRegistry::new();
let rule1 = create_test_rule("reachable");
let rule2 = create_test_rule("ancestor");
let rule3 = create_test_rule("reachable");
registry
.register_rule("reachable".to_string(), rule1)
.unwrap();
registry
.register_rule("ancestor".to_string(), rule2)
.unwrap();
registry
.register_rule("reachable".to_string(), rule3)
.unwrap();
let mut names = registry.predicate_names();
names.sort();
assert_eq!(names.len(), 2);
assert!(names.contains(&"reachable".to_string()));
assert!(names.contains(&"ancestor".to_string()));
}
}