use super::facts::TypedFacts;
use super::network::ReteUlNode;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
fn compute_facts_hash(facts: &TypedFacts) -> u64 {
let mut hasher = DefaultHasher::new();
let mut sorted_facts: Vec<_> = facts.get_all().iter().collect();
sorted_facts.sort_by_key(|(k, _)| *k);
for (key, value) in sorted_facts {
key.hash(&mut hasher);
value.as_str().hash(&mut hasher);
}
hasher.finish()
}
fn compute_node_hash(node: &ReteUlNode) -> u64 {
let mut hasher = DefaultHasher::new();
format!("{:?}", node).hash(&mut hasher);
hasher.finish()
}
pub struct MemoizedEvaluator {
cache: HashMap<(u64, u64), bool>,
hits: usize,
misses: usize,
}
impl MemoizedEvaluator {
pub fn new() -> Self {
Self {
cache: HashMap::new(),
hits: 0,
misses: 0,
}
}
pub fn evaluate(
&mut self,
node: &ReteUlNode,
facts: &TypedFacts,
eval_fn: impl FnOnce(&ReteUlNode, &TypedFacts) -> bool,
) -> bool {
let node_hash = compute_node_hash(node);
let facts_hash = compute_facts_hash(facts);
let key = (node_hash, facts_hash);
if let Some(&result) = self.cache.get(&key) {
self.hits += 1;
return result;
}
self.misses += 1;
let result = eval_fn(node, facts);
self.cache.insert(key, result);
result
}
pub fn stats(&self) -> MemoStats {
MemoStats {
cache_size: self.cache.len(),
hits: self.hits,
misses: self.misses,
hit_rate: if self.hits + self.misses > 0 {
self.hits as f64 / (self.hits + self.misses) as f64
} else {
0.0
},
}
}
pub fn clear(&mut self) {
self.cache.clear();
self.hits = 0;
self.misses = 0;
}
pub fn cache_size(&self) -> usize {
self.cache.len()
}
}
impl Default for MemoizedEvaluator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct MemoStats {
pub cache_size: usize,
pub hits: usize,
pub misses: usize,
pub hit_rate: f64,
}
impl std::fmt::Display for MemoStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Memo Stats: {} entries, {} hits, {} misses, {:.2}% hit rate",
self.cache_size,
self.hits,
self.misses,
self.hit_rate * 100.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rete::alpha::AlphaNode;
use crate::rete::facts::TypedFacts;
use crate::rete::network::ReteUlNode;
#[test]
fn test_memoization() {
let mut evaluator = MemoizedEvaluator::new();
let mut facts = TypedFacts::new();
facts.set("age", 25i64);
let node = ReteUlNode::UlAlpha(AlphaNode {
field: "age".to_string(),
operator: ">".to_string(),
value: "18".to_string(),
});
let mut eval_count = 0;
let result1 = evaluator.evaluate(&node, &facts, |n, f| {
eval_count += 1;
n.evaluate_typed(f)
});
assert!(result1);
assert_eq!(eval_count, 1);
let result2 = evaluator.evaluate(&node, &facts, |n, f| {
eval_count += 1;
n.evaluate_typed(f)
});
assert!(result2);
assert_eq!(eval_count, 1);
let stats = evaluator.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate, 0.5);
}
}