use crate::engine::{BoundValue, Match, SiftEvent};
use crate::interval::Interval;
use crate::pattern::Pattern;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct ScoredMatch<N: Debug, V: Debug, T: Debug + Clone> {
pub pattern: String,
pub pattern_idx: Option<usize>,
pub bindings: HashMap<String, BoundValue<N, V>>,
pub intervals: HashMap<String, Interval<T>>,
pub metadata: HashMap<String, String>,
pub surprise: f64,
}
#[derive(Debug, Clone, Default)]
pub struct SurpriseScorer {
baselines: HashMap<usize, f64>,
counts: HashMap<usize, u64>,
total_rounds: u64,
}
impl SurpriseScorer {
pub fn new() -> Self {
Self::default()
}
pub fn set_baseline(&mut self, pattern_idx: usize, baseline: f64) {
assert!(
baseline > 0.0 && baseline <= 1.0,
"baseline must be in (0, 1], got {}",
baseline
);
self.baselines.insert(pattern_idx, baseline);
}
pub fn observe<
N: Debug + PartialEq,
V: Debug + PartialEq,
T: Debug + Clone + PartialEq,
L,
VV,
>(
&mut self,
matches: &[Match<N, V, T>],
patterns: &[Pattern<L, VV>],
) {
self.total_rounds += 1;
let mut seen_this_round = std::collections::HashSet::new();
for m in matches {
if let Some(idx) = patterns.iter().position(|p| p.name == m.pattern) {
if seen_this_round.insert(idx) {
*self.counts.entry(idx).or_insert(0) += 1;
}
}
}
}
pub fn observe_events<N: Debug, V: Debug, L, VV>(
&mut self,
events: &[SiftEvent<N, V>],
patterns: &[Pattern<L, VV>],
) {
for event in events {
if let SiftEvent::Completed { pattern, .. } = event {
if let Some(idx) = patterns.iter().position(|p| p.name == *pattern) {
*self.counts.entry(idx).or_insert(0) += 1;
}
}
}
}
pub fn tick(&mut self) {
self.total_rounds += 1;
}
pub fn score<
N: Debug + Clone + PartialEq,
V: Debug + Clone + PartialEq,
T: Debug + Clone + PartialEq,
L,
VV,
>(
&self,
matches: &[Match<N, V, T>],
patterns: &[Pattern<L, VV>],
) -> Vec<ScoredMatch<N, V, T>> {
matches
.iter()
.map(|m| {
let idx = patterns.iter().position(|p| p.name == m.pattern);
let surprise = idx.and_then(|i| self.surprise_for(i)).unwrap_or(0.0);
ScoredMatch {
pattern: m.pattern.clone(),
pattern_idx: m.pattern_idx,
bindings: m.bindings.clone(),
intervals: m.intervals.clone(),
metadata: m.metadata.clone(),
surprise,
}
})
.collect()
}
pub fn surprise_for(&self, pattern_idx: usize) -> Option<f64> {
let baseline = *self.baselines.get(&pattern_idx)?;
let count = self.counts.get(&pattern_idx).copied().unwrap_or(0);
let p = (count as f64 + 1.0) / (self.total_rounds as f64 + 1.0);
Some(-(p / baseline).log2())
}
pub fn reset_counts(&mut self) {
self.counts.clear();
self.total_rounds = 0;
}
pub fn total_rounds(&self) -> u64 {
self.total_rounds
}
pub fn count_for(&self, pattern_idx: usize) -> u64 {
self.counts.get(&pattern_idx).copied().unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::PatternBuilder;
fn dummy_pattern(name: &str) -> Pattern<String, String> {
PatternBuilder::<String, String>::new(name)
.stage("e", |s| s.edge("e", "type".into(), "x".into()))
.build()
}
fn dummy_match(name: &str) -> Match<String, String, i64> {
Match {
pattern: name.to_string(),
pattern_idx: None,
bindings: HashMap::new(),
intervals: HashMap::new(),
metadata: HashMap::new(),
}
}
#[test]
fn surprise_high_for_rare_pattern() {
let patterns = vec![dummy_pattern("common"), dummy_pattern("rare")];
let mut scorer = SurpriseScorer::new();
scorer.set_baseline(0, 0.5);
scorer.set_baseline(1, 0.5);
for i in 0..10 {
let mut matches = vec![dummy_match("common")];
if i == 5 {
matches.push(dummy_match("rare"));
}
scorer.observe(&matches, &patterns);
}
let rare_surprise = scorer.surprise_for(1).unwrap();
let common_surprise = scorer.surprise_for(0).unwrap();
assert!(
rare_surprise > common_surprise,
"rare ({:.2}) should be more surprising than common ({:.2})",
rare_surprise,
common_surprise
);
}
#[test]
fn surprise_near_zero_at_baseline() {
let patterns = vec![dummy_pattern("normal")];
let mut scorer = SurpriseScorer::new();
scorer.set_baseline(0, 0.5);
for i in 0..10 {
let matches = if i % 2 == 0 {
vec![dummy_match("normal")]
} else {
vec![]
};
scorer.observe(&matches, &patterns);
}
let surprise = scorer.surprise_for(0).unwrap();
assert!(
surprise.abs() < 0.5,
"surprise should be near zero, got {:.2}",
surprise
);
}
#[test]
fn surprise_negative_for_common_pattern() {
let patterns = vec![dummy_pattern("frequent")];
let mut scorer = SurpriseScorer::new();
scorer.set_baseline(0, 0.1);
for _ in 0..10 {
scorer.observe(&[dummy_match("frequent")], &patterns);
}
let surprise = scorer.surprise_for(0).unwrap();
assert!(
surprise < 0.0,
"over-represented pattern should have negative surprise, got {:.2}",
surprise
);
}
#[test]
fn surprise_high_for_never_matched() {
let patterns = vec![dummy_pattern("ghost")];
let mut scorer = SurpriseScorer::new();
scorer.set_baseline(0, 0.5);
let no_matches: Vec<Match<String, String, i64>> = vec![];
for _ in 0..20 {
scorer.observe(&no_matches, &patterns);
}
let surprise = scorer.surprise_for(0).unwrap();
assert!(
surprise > 2.0,
"never-matched pattern should have high surprise, got {:.2}",
surprise
);
}
#[test]
fn observe_events_counts_completions_only() {
let patterns = vec![dummy_pattern("test")];
let mut scorer = SurpriseScorer::new();
scorer.set_baseline(0, 0.5);
let events: Vec<SiftEvent<String, String>> = vec![
SiftEvent::Advanced {
pattern: "test".into(),
match_id: 0,
stage_index: 0,
metadata: HashMap::new(),
},
SiftEvent::Completed {
pattern: "test".into(),
match_id: 1,
bindings: HashMap::new(),
metadata: HashMap::new(),
},
SiftEvent::Negated {
pattern: "test".into(),
match_id: 2,
clause_label: "x".into(),
trigger_source: "src".into(),
metadata: HashMap::new(),
},
SiftEvent::Expired {
pattern: "test".into(),
match_id: 3,
bindings: HashMap::new(),
stage_reached: 0,
ticks_elapsed: 10,
metadata: HashMap::new(),
},
];
scorer.observe_events(&events, &patterns);
assert_eq!(scorer.count_for(0), 1, "only Completed should be counted");
}
#[test]
fn score_returns_scored_matches() {
let patterns = vec![dummy_pattern("a"), dummy_pattern("b")];
let mut scorer = SurpriseScorer::new();
scorer.set_baseline(0, 0.5);
scorer.set_baseline(1, 0.5);
for i in 0..10 {
let mut matches = vec![dummy_match("a")];
if i == 0 {
matches.push(dummy_match("b"));
}
scorer.observe(&matches, &patterns);
}
let to_score = vec![dummy_match("a"), dummy_match("b")];
let scored = scorer.score(&to_score, &patterns);
assert_eq!(scored.len(), 2);
assert!(
scored[1].surprise > scored[0].surprise,
"b should be more surprising than a"
);
}
#[test]
fn no_baseline_returns_zero_surprise() {
let patterns = vec![dummy_pattern("unscored")];
let scorer = SurpriseScorer::new();
let scored = scorer.score(&[dummy_match("unscored")], &patterns);
assert_eq!(scored[0].surprise, 0.0);
}
#[test]
fn reset_clears_counts_preserves_baselines() {
let patterns = vec![dummy_pattern("test")];
let mut scorer = SurpriseScorer::new();
scorer.set_baseline(0, 0.5);
for _ in 0..5 {
scorer.observe(&[dummy_match("test")], &patterns);
}
assert_eq!(scorer.count_for(0), 5);
assert_eq!(scorer.total_rounds(), 5);
scorer.reset_counts();
assert_eq!(scorer.count_for(0), 0);
assert_eq!(scorer.total_rounds(), 0);
assert!(scorer.surprise_for(0).is_some());
}
#[test]
fn tick_increments_rounds() {
let mut scorer = SurpriseScorer::new();
assert_eq!(scorer.total_rounds(), 0);
scorer.tick();
scorer.tick();
scorer.tick();
assert_eq!(scorer.total_rounds(), 3);
}
}