use crate::rete::FactHandle;
use crate::types::Value;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct FactKey {
pub fact_type: String,
pub field: Option<String>,
pub pattern: String,
}
impl FactKey {
pub fn from_pattern(pattern: &str) -> Self {
if let Some(dot_pos) = pattern.find('.') {
let fact_type = pattern[..dot_pos].trim().to_string();
let rest = &pattern[dot_pos + 1..];
let field = if let Some(op_pos) = rest.find(|c: char| !c.is_alphanumeric() && c != '_')
{
Some(rest[..op_pos].trim().to_string())
} else {
Some(rest.trim().to_string())
};
Self {
fact_type,
field,
pattern: pattern.to_string(),
}
} else {
Self {
fact_type: pattern.to_string(),
field: None,
pattern: pattern.to_string(),
}
}
}
pub fn new(fact_type: String, field: Option<String>, pattern: String) -> Self {
Self {
fact_type,
field,
pattern,
}
}
}
#[derive(Debug, Clone)]
pub struct Justification {
pub rule_name: String,
pub premises: Vec<FactHandle>,
pub premise_keys: Vec<String>,
pub generation: u64,
}
#[derive(Debug, Clone)]
pub struct ProofGraphNode {
pub key: FactKey,
pub handle: Option<FactHandle>,
pub justifications: Vec<Justification>,
pub dependents: HashSet<FactHandle>,
pub valid: bool,
pub generation: u64,
pub bindings: HashMap<String, Value>,
}
impl ProofGraphNode {
pub fn new(key: FactKey) -> Self {
Self {
key,
handle: None,
justifications: Vec::new(),
dependents: HashSet::new(),
valid: true,
generation: 0,
bindings: HashMap::new(),
}
}
pub fn add_justification(
&mut self,
rule_name: String,
premises: Vec<FactHandle>,
premise_keys: Vec<String>,
generation: u64,
) {
self.justifications.push(Justification {
rule_name,
premises,
premise_keys,
generation,
});
self.valid = true;
self.generation = generation;
}
pub fn has_valid_justifications(&self) -> bool {
!self.justifications.is_empty()
}
pub fn remove_justifications_with_premise(&mut self, premise_handle: &FactHandle) -> bool {
let before = self.justifications.len();
self.justifications
.retain(|j| !j.premises.contains(premise_handle));
let after = self.justifications.len();
if self.justifications.is_empty() {
self.valid = false;
}
before != after
}
}
pub struct ProofGraph {
nodes_by_handle: HashMap<FactHandle, ProofGraphNode>,
index_by_key: HashMap<FactKey, Vec<FactHandle>>,
dependencies: HashMap<FactHandle, HashSet<FactHandle>>,
generation: u64,
pub stats: ProofGraphStats,
}
#[derive(Debug, Clone, Default)]
pub struct ProofGraphStats {
pub total_nodes: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub invalidations: usize,
pub justifications_added: usize,
}
impl ProofGraph {
pub fn new() -> Self {
Self {
nodes_by_handle: HashMap::new(),
index_by_key: HashMap::new(),
dependencies: HashMap::new(),
generation: 0,
stats: ProofGraphStats::default(),
}
}
pub fn insert_proof(
&mut self,
handle: FactHandle,
key: FactKey,
rule_name: String,
premises: Vec<FactHandle>,
premise_keys: Vec<String>,
) {
self.generation += 1;
let node = self.nodes_by_handle.entry(handle).or_insert_with(|| {
let mut node = ProofGraphNode::new(key.clone());
node.handle = Some(handle);
self.stats.total_nodes += 1;
node
});
node.add_justification(rule_name, premises.clone(), premise_keys, self.generation);
self.stats.justifications_added += 1;
self.index_by_key
.entry(key.clone())
.or_default()
.push(handle);
for premise in &premises {
self.dependencies
.entry(*premise)
.or_default()
.insert(handle);
if let Some(premise_node) = self.nodes_by_handle.get_mut(premise) {
premise_node.dependents.insert(handle);
}
}
}
pub fn lookup_by_key(&mut self, key: &FactKey) -> Option<Vec<&ProofGraphNode>> {
if let Some(handles) = self.index_by_key.get(key) {
let nodes: Vec<&ProofGraphNode> = handles
.iter()
.filter_map(|h| self.nodes_by_handle.get(h))
.filter(|n| n.valid)
.collect();
if !nodes.is_empty() {
self.stats.cache_hits += 1;
Some(nodes)
} else {
self.stats.cache_misses += 1;
None
}
} else {
self.stats.cache_misses += 1;
None
}
}
pub fn is_proven(&mut self, key: &FactKey) -> bool {
self.lookup_by_key(key).is_some()
}
pub fn invalidate_handle(&mut self, handle: &FactHandle) {
self.stats.invalidations += 1;
let dependents = self.dependencies.get(handle).cloned();
if let Some(node) = self.nodes_by_handle.get_mut(handle) {
node.valid = false;
}
if let Some(deps) = dependents {
for dep_handle in deps {
self.propagate_invalidation(&dep_handle, handle);
}
}
}
fn propagate_invalidation(
&mut self,
dependent_handle: &FactHandle,
premise_handle: &FactHandle,
) {
if let Some(node) = self.nodes_by_handle.get_mut(dependent_handle) {
let changed = node.remove_justifications_with_premise(premise_handle);
if changed && !node.valid {
self.stats.invalidations += 1;
let further_deps = node.dependents.clone();
for further_dep in further_deps {
self.propagate_invalidation(&further_dep, dependent_handle);
}
}
}
}
pub fn get_node(&self, handle: &FactHandle) -> Option<&ProofGraphNode> {
self.nodes_by_handle.get(handle)
}
pub fn clear(&mut self) {
self.nodes_by_handle.clear();
self.index_by_key.clear();
self.dependencies.clear();
self.generation = 0;
self.stats = ProofGraphStats::default();
}
pub fn generation(&self) -> u64 {
self.generation
}
pub fn print_stats(&self) {
println!("ProofGraph Statistics:");
println!(" Total nodes: {}", self.stats.total_nodes);
println!(" Cache hits: {}", self.stats.cache_hits);
println!(" Cache misses: {}", self.stats.cache_misses);
println!(" Invalidations: {}", self.stats.invalidations);
println!(
" Justifications added: {}",
self.stats.justifications_added
);
if self.stats.cache_hits + self.stats.cache_misses > 0 {
let hit_rate = (self.stats.cache_hits as f64)
/ ((self.stats.cache_hits + self.stats.cache_misses) as f64)
* 100.0;
println!(" Cache hit rate: {:.1}%", hit_rate);
}
}
}
impl Default for ProofGraph {
fn default() -> Self {
Self::new()
}
}
pub type SharedProofGraph = Arc<std::sync::Mutex<ProofGraph>>;
pub fn new_shared() -> SharedProofGraph {
Arc::new(std::sync::Mutex::new(ProofGraph::new()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fact_key_from_pattern() {
let key = FactKey::from_pattern("User.Score >= 80");
assert_eq!(key.fact_type, "User");
assert_eq!(key.field, Some("Score".to_string()));
assert_eq!(key.pattern, "User.Score >= 80");
}
#[test]
fn test_proof_graph_insert_and_lookup() {
let mut graph = ProofGraph::new();
let handle = FactHandle::new(1);
let key = FactKey::from_pattern("User.Score >= 80");
graph.insert_proof(handle, key.clone(), "ScoreRule".to_string(), vec![], vec![]);
assert!(graph.is_proven(&key));
assert_eq!(graph.stats.total_nodes, 1);
}
#[test]
fn test_dependency_tracking() {
let mut graph = ProofGraph::new();
let premise_handle = FactHandle::new(1);
let conclusion_handle = FactHandle::new(2);
let premise_key = FactKey::from_pattern("User.Age >= 18");
let conclusion_key = FactKey::from_pattern("User.CanVote == true");
graph.insert_proof(
premise_handle,
premise_key.clone(),
"AgeRule".to_string(),
vec![],
vec![],
);
graph.insert_proof(
conclusion_handle,
conclusion_key.clone(),
"VotingRule".to_string(),
vec![premise_handle],
vec!["User.Age >= 18".to_string()],
);
assert!(graph.is_proven(&premise_key));
assert!(graph.is_proven(&conclusion_key));
graph.invalidate_handle(&premise_handle);
let conclusion_node = graph.get_node(&conclusion_handle).unwrap();
assert!(!conclusion_node.valid);
assert_eq!(graph.stats.invalidations, 2); }
#[test]
fn test_multiple_justifications() {
let mut graph = ProofGraph::new();
let handle = FactHandle::new(1);
let key = FactKey::from_pattern("User.IsVIP == true");
graph.insert_proof(
handle,
key.clone(),
"HighSpenderRule".to_string(),
vec![],
vec![],
);
graph.insert_proof(
handle,
key.clone(),
"LoyaltyRule".to_string(),
vec![],
vec![],
);
let node = graph.get_node(&handle).unwrap();
assert_eq!(node.justifications.len(), 2);
assert!(node.valid);
}
#[test]
fn test_cache_statistics() {
let mut graph = ProofGraph::new();
let key = FactKey::from_pattern("User.Active == true");
assert!(!graph.is_proven(&key));
assert_eq!(graph.stats.cache_misses, 1);
let handle = FactHandle::new(1);
graph.insert_proof(
handle,
key.clone(),
"ActiveRule".to_string(),
vec![],
vec![],
);
assert!(graph.is_proven(&key));
assert_eq!(graph.stats.cache_hits, 1);
}
#[test]
fn test_clear() {
let mut graph = ProofGraph::new();
let handle = FactHandle::new(1);
let key = FactKey::from_pattern("Test.Value == 42");
graph.insert_proof(handle, key.clone(), "TestRule".to_string(), vec![], vec![]);
assert!(graph.is_proven(&key));
graph.clear();
assert!(!graph.is_proven(&key));
assert_eq!(graph.stats.total_nodes, 0);
}
}