use crate::engine::{BoundValue, Match};
use crate::interval::Interval;
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct StuScoredMatch<N: Debug, V: Debug, T: Debug + Clone> {
pub pattern: String,
pub pattern_idx: Option<usize>,
pub bindings: HashMap<String, BoundValue<N, V>>,
pub intervals: HashMap<String, Interval<T>>,
pub metadata: HashMap<String, String>,
pub property_frequencies: Vec<(String, f64)>,
pub stu_score: f64,
}
#[derive(Debug, Clone, Default)]
struct PropertyTable {
total_matches: u64,
property_counts: HashMap<String, u64>,
pair_counts: HashMap<(String, String), u64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum StuAggregation {
#[default]
ArithmeticMean,
TfIdf,
GeometricMean,
Min,
}
#[derive(Debug, Clone)]
pub struct StuScorer {
tables: HashMap<String, PropertyTable>,
aggregation: StuAggregation,
pmi_correction: bool,
}
impl Default for StuScorer {
fn default() -> Self {
Self {
tables: HashMap::new(),
aggregation: StuAggregation::ArithmeticMean,
pmi_correction: false,
}
}
}
impl StuScorer {
pub fn new() -> Self {
Self::default()
}
pub fn with_aggregation(mut self, aggregation: StuAggregation) -> Self {
self.aggregation = aggregation;
self
}
pub fn with_pmi_correction(mut self) -> Self {
self.pmi_correction = true;
self
}
pub fn observe_one(&mut self, pattern: &str, properties: &[impl AsRef<str>]) {
let table = self.tables.entry(pattern.to_string()).or_default();
table.total_matches += 1;
let mut seen = std::collections::HashSet::new();
let mut unique: Vec<String> = Vec::new();
for prop in properties {
if seen.insert(prop.as_ref().to_string()) {
let s = prop.as_ref().to_string();
*table.property_counts.entry(s.clone()).or_insert(0) += 1;
unique.push(s);
}
}
if self.pmi_correction {
for i in 0..unique.len() {
for j in (i + 1)..unique.len() {
let pair = if unique[i] < unique[j] {
(unique[i].clone(), unique[j].clone())
} else {
(unique[j].clone(), unique[i].clone())
};
*table.pair_counts.entry(pair).or_insert(0) += 1;
}
}
}
}
pub fn observe_batch(&mut self, observations: &[(&str, &[String])]) {
for (pattern, props) in observations {
self.observe_one(pattern, props);
}
}
pub fn property_frequency(&self, pattern: &str, property: &str) -> Option<f64> {
let table = self.tables.get(pattern)?;
let count = table.property_counts.get(property).copied().unwrap_or(0);
let vocab_size = table.property_counts.len() as f64;
Some((count as f64 + 1.0) / (table.total_matches as f64 + vocab_size))
}
#[allow(clippy::type_complexity)]
pub fn score<
N: Debug + Clone + PartialEq,
V: Debug + Clone + PartialEq,
T: Debug + Clone + PartialEq,
>(
&self,
matches_with_props: &[(Match<N, V, T>, Vec<String>)],
) -> Vec<StuScoredMatch<N, V, T>> {
matches_with_props
.iter()
.map(|(m, props)| {
let table = self.tables.get(&m.pattern);
if props.is_empty() || table.is_none() {
return StuScoredMatch {
pattern: m.pattern.clone(),
pattern_idx: m.pattern_idx,
bindings: m.bindings.clone(),
intervals: m.intervals.clone(),
metadata: m.metadata.clone(),
property_frequencies: Vec::new(),
stu_score: 1.0,
};
}
let table = table.unwrap();
let vocab_size = table.property_counts.len() as f64;
let unique_props: Vec<&String> = {
let mut seen = std::collections::HashSet::new();
props.iter().filter(|p| seen.insert(p.as_str())).collect()
};
let mut prop_freqs: Vec<(String, f64)> = unique_props
.iter()
.map(|prop| {
let count = table
.property_counts
.get(prop.as_str())
.copied()
.unwrap_or(0);
let freq = (count as f64 + 1.0) / (table.total_matches as f64 + vocab_size);
(prop.to_string(), freq)
})
.collect();
if self.pmi_correction && prop_freqs.len() >= 2 {
let pmi_threshold = 1.0; let n = table.total_matches as f64;
for i in 0..prop_freqs.len() {
for j in (i + 1)..prop_freqs.len() {
let (a, b) = if prop_freqs[i].0 < prop_freqs[j].0 {
(prop_freqs[i].0.clone(), prop_freqs[j].0.clone())
} else {
(prop_freqs[j].0.clone(), prop_freqs[i].0.clone())
};
let pair_count = table.pair_counts.get(&(a, b)).copied().unwrap_or(0);
if pair_count == 0 {
continue;
}
let p_ab = pair_count as f64 / n;
let c_i = table
.property_counts
.get(&prop_freqs[i].0)
.copied()
.unwrap_or(0) as f64;
let c_j = table
.property_counts
.get(&prop_freqs[j].0)
.copied()
.unwrap_or(0) as f64;
let p_i_raw = c_i / n;
let p_j_raw = c_j / n;
if p_i_raw == 0.0 || p_j_raw == 0.0 {
continue;
}
let pmi = (p_ab / (p_i_raw * p_j_raw)).log2();
if pmi > pmi_threshold {
if prop_freqs[i].1 > prop_freqs[j].1 {
prop_freqs[i].1 = (p_ab / p_j_raw).min(1.0);
} else {
prop_freqs[j].1 = (p_ab / p_i_raw).min(1.0);
}
}
}
}
}
prop_freqs
.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let k = prop_freqs.len() as f64;
let stu_score = match self.aggregation {
StuAggregation::ArithmeticMean => {
prop_freqs.iter().map(|(_, f)| f).sum::<f64>() / k
}
StuAggregation::TfIdf => prop_freqs.iter().map(|(_, f)| -f.log2()).sum::<f64>(),
StuAggregation::GeometricMean => {
(prop_freqs.iter().map(|(_, f)| f.ln()).sum::<f64>() / k).exp()
}
StuAggregation::Min => prop_freqs
.iter()
.map(|(_, f)| *f)
.fold(f64::INFINITY, f64::min),
};
let confidence = 1.0 - 1.0 / (table.total_matches as f64 + 1.0);
let stu_score = if matches!(self.aggregation, StuAggregation::TfIdf) {
stu_score * confidence
} else {
1.0 - (1.0 - stu_score) * confidence
};
StuScoredMatch {
pattern: m.pattern.clone(),
pattern_idx: m.pattern_idx,
bindings: m.bindings.clone(),
intervals: m.intervals.clone(),
metadata: m.metadata.clone(),
property_frequencies: prop_freqs,
stu_score,
}
})
.collect()
}
pub fn match_count(&self, pattern: &str) -> u64 {
self.tables
.get(pattern)
.map(|t| t.total_matches)
.unwrap_or(0)
}
pub fn vocabulary_size(&self, pattern: &str) -> usize {
self.tables
.get(pattern)
.map(|t| t.property_counts.len())
.unwrap_or(0)
}
pub fn pmi_for(&self, pattern: &str, pi: &str, pj: &str) -> Option<f64> {
let table = self.tables.get(pattern)?;
if table.total_matches == 0 {
return None;
}
let (a, b) = if pi < pj { (pi, pj) } else { (pj, pi) };
let pair_count = table
.pair_counts
.get(&(a.to_string(), b.to_string()))
.copied()
.unwrap_or(0);
let p_ab = pair_count as f64 / table.total_matches as f64;
let p_a =
table.property_counts.get(a).copied().unwrap_or(0) as f64 / table.total_matches as f64;
let p_b =
table.property_counts.get(b).copied().unwrap_or(0) as f64 / table.total_matches as f64;
if p_a == 0.0 || p_b == 0.0 || p_ab == 0.0 {
return Some(0.0);
}
Some((p_ab / (p_a * p_b)).log2())
}
pub fn reset(&mut self) {
self.tables.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::Match;
fn dummy_match(name: &str) -> Match<String, String, i64> {
Match {
pattern: name.to_string(),
pattern_idx: None,
bindings: HashMap::new(),
intervals: HashMap::new(),
metadata: HashMap::new(),
}
}
#[test]
fn stu_rare_properties_score_lower() {
let mut stu = StuScorer::new();
for i in 0..10 {
let props = if i < 2 {
vec!["actor_trait=ambitious".to_string()]
} else {
vec!["actor_trait=loyal".to_string()]
};
stu.observe_one("betrayal", &props);
}
let m_ambitious = (
dummy_match("betrayal"),
vec!["actor_trait=ambitious".to_string()],
);
let m_loyal = (
dummy_match("betrayal"),
vec!["actor_trait=loyal".to_string()],
);
let scored = stu.score(&[m_ambitious, m_loyal]);
assert!(
scored[0].stu_score < scored[1].stu_score,
"ambitious ({:.3}) should score lower (more surprising) than loyal ({:.3})",
scored[0].stu_score,
scored[1].stu_score
);
}
#[test]
fn stu_laplace_smoothing_for_novel_property() {
let mut stu = StuScorer::new();
for _ in 0..10 {
stu.observe_one("test", &["trait=brave"]);
}
let freq = stu.property_frequency("test", "trait=cowardly");
assert!(freq.is_some());
let f = freq.unwrap();
assert!(
f > 0.0,
"novel property should have non-zero frequency: {}",
f
);
assert!(f < 0.2, "novel property should have low frequency: {}", f);
}
#[test]
fn stu_empty_properties_get_default_score() {
let stu = StuScorer::new();
let m = (dummy_match("test"), vec![]);
let scored = stu.score(&[m]);
assert_eq!(
scored[0].stu_score, 1.0,
"empty properties = maximally unsurprising"
);
}
#[test]
fn stu_unobserved_pattern_gets_default_score() {
let stu = StuScorer::new();
let m = (dummy_match("unknown"), vec!["some_prop".to_string()]);
let scored = stu.score(&[m]);
assert_eq!(scored[0].stu_score, 1.0, "unobserved pattern = no data");
}
#[test]
fn stu_property_frequencies_sorted_ascending() {
let mut stu = StuScorer::new();
for i in 0..10 {
let mut props = vec!["common=yes".to_string()]; if i < 3 {
props.push("rare=yes".to_string()); }
stu.observe_one("test", &props);
}
let m = (
dummy_match("test"),
vec!["rare=yes".to_string(), "common=yes".to_string()],
);
let scored = stu.score(&[m]);
let pf = &scored[0].property_frequencies;
assert_eq!(pf.len(), 2);
assert!(
pf[0].1 <= pf[1].1,
"properties should be sorted ascending: {:?}",
pf
);
assert!(pf[0].0.contains("rare"), "rarest property should be first");
}
#[test]
fn stu_observe_batch() {
let mut stu = StuScorer::new();
let p1_props = vec!["a".to_string(), "b".to_string()];
let p2_props = vec!["c".to_string()];
let batch: Vec<(&str, &[String])> = vec![("p1", &p1_props), ("p2", &p2_props)];
stu.observe_batch(&batch);
assert_eq!(stu.match_count("p1"), 1);
assert_eq!(stu.match_count("p2"), 1);
assert_eq!(stu.vocabulary_size("p1"), 2);
}
#[test]
fn stu_deduplicates_properties_per_match() {
let mut stu = StuScorer::new();
stu.observe_one("test", &["dup", "dup", "dup"]);
assert_eq!(stu.match_count("test"), 1);
let freq = stu.property_frequency("test", "dup").unwrap();
assert!(
(freq - 1.0).abs() < 0.01,
"duplicated property should count once: {}",
freq
);
}
#[test]
fn stu_reset_clears_all() {
let mut stu = StuScorer::new();
stu.observe_one("test", &["a", "b"]);
assert_eq!(stu.match_count("test"), 1);
stu.reset();
assert_eq!(stu.match_count("test"), 0);
assert_eq!(stu.vocabulary_size("test"), 0);
}
fn score_rare_vs_common(agg: StuAggregation) -> (f64, f64) {
let mut stu = StuScorer::new().with_aggregation(agg);
for i in 0..10 {
let props = if i < 2 {
vec!["trait=ambitious".to_string()]
} else {
vec!["trait=loyal".to_string()]
};
stu.observe_one("test", &props);
}
let scored = stu.score(&[
(dummy_match("test"), vec!["trait=ambitious".to_string()]),
(dummy_match("test"), vec!["trait=loyal".to_string()]),
]);
(scored[0].stu_score, scored[1].stu_score)
}
#[test]
fn stu_default_is_arithmetic_mean() {
let (rare_default, common_default) = score_rare_vs_common(StuAggregation::default());
let (rare_explicit, common_explicit) = score_rare_vs_common(StuAggregation::ArithmeticMean);
assert_eq!(rare_default, rare_explicit);
assert_eq!(common_default, common_explicit);
assert!(rare_default < common_default);
}
#[test]
fn stu_tfidf_higher_is_more_surprising() {
let (rare, common) = score_rare_vs_common(StuAggregation::TfIdf);
assert!(
rare > common,
"TfIdf: rare ({:.3}) should score HIGHER than common ({:.3})",
rare,
common
);
}
#[test]
fn stu_geometric_mean_rare_scores_lower() {
let (rare, common) = score_rare_vs_common(StuAggregation::GeometricMean);
assert!(
rare < common,
"GeometricMean: rare ({:.3}) should score lower than common ({:.3})",
rare,
common
);
}
#[test]
fn stu_min_uses_rarest_property() {
let mut stu = StuScorer::new().with_aggregation(StuAggregation::Min);
for i in 0..10 {
let mut props = vec!["common=yes".to_string()];
if i == 0 {
props.push("rare=yes".to_string());
}
stu.observe_one("test", &props);
}
let scored = stu.score(&[(
dummy_match("test"),
vec!["common=yes".to_string(), "rare=yes".to_string()],
)]);
let score = scored[0].stu_score;
let rare_freq = stu.property_frequency("test", "rare=yes").unwrap();
let confidence = 1.0 - 1.0 / (stu.match_count("test") as f64 + 1.0);
let expected = 1.0 - (1.0 - rare_freq) * confidence;
assert!(
(score - expected).abs() < 1e-10,
"Min score ({:.4}) should equal lerped rare freq ({:.4})",
score,
expected
);
}
#[test]
fn stu_cold_start_attenuates_toward_unsurprising() {
let mut stu = StuScorer::new();
stu.observe_one("test", &["common"]);
stu.observe_one("test", &["common"]);
stu.observe_one("test", &["rare"]);
let raw_freq = stu.property_frequency("test", "rare").unwrap();
assert!(raw_freq < 1.0, "rare should have freq < 1.0: {}", raw_freq);
let scored = stu.score(&[(dummy_match("test"), vec!["rare".to_string()])]);
assert!(
scored[0].stu_score > raw_freq,
"cold start should push toward unsurprising: score={:.3}, raw={:.3}",
scored[0].stu_score,
raw_freq
);
let mut stu2 = StuScorer::new();
for _ in 0..50 {
stu2.observe_one("test", &["common"]);
}
for _ in 0..50 {
stu2.observe_one("test", &["rare"]);
}
let scored2 = stu2.score(&[(dummy_match("test"), vec!["rare".to_string()])]);
let raw_freq2 = stu2.property_frequency("test", "rare").unwrap();
assert!(
(scored2[0].stu_score - raw_freq2).abs() < 0.02,
"high-observation score ({:.4}) should be close to raw freq ({:.4})",
scored2[0].stu_score,
raw_freq2
);
}
#[test]
fn stu_cold_start_tfidf_attenuates_toward_zero() {
let mut stu = StuScorer::new().with_aggregation(StuAggregation::TfIdf);
for i in 0..5 {
let props = if i == 0 {
vec!["rare=yes".to_string()]
} else {
vec!["common=yes".to_string()]
};
stu.observe_one("test", &props);
}
let scored_tfidf = stu.score(&[(dummy_match("test"), vec!["rare=yes".to_string()])]);
assert!(scored_tfidf[0].stu_score > 0.0);
}
#[test]
fn stu_with_aggregation_builder() {
let scorer = StuScorer::new().with_aggregation(StuAggregation::TfIdf);
assert_eq!(scorer.aggregation, StuAggregation::TfIdf);
let default = StuScorer::new();
assert_eq!(default.aggregation, StuAggregation::ArithmeticMean);
}
#[test]
fn pmi_pair_counting() {
let mut stu = StuScorer::new().with_pmi_correction();
stu.observe_one("test", &["rebels", "hideout"]);
stu.observe_one("test", &["rebels", "hideout"]);
stu.observe_one("test", &["crown", "castle"]);
stu.observe_one("test", &["crown", "castle"]);
let pmi_rh = stu.pmi_for("test", "rebels", "hideout").unwrap();
assert!(
pmi_rh > 0.0,
"rebels+hideout should have positive PMI: {:.3}",
pmi_rh
);
let pmi_rc = stu.pmi_for("test", "rebels", "castle").unwrap();
assert_eq!(pmi_rc, 0.0, "rebels+castle should have PMI=0");
}
#[test]
fn pmi_correction_reduces_double_counting() {
let mut no_pmi = StuScorer::new();
let mut with_pmi = StuScorer::new().with_pmi_correction();
for _ in 0..20 {
no_pmi.observe_one("test", &["faction=rebels", "location=hideout"]);
with_pmi.observe_one("test", &["faction=rebels", "location=hideout"]);
}
for _ in 0..80 {
no_pmi.observe_one("test", &["faction=crown", "location=castle"]);
with_pmi.observe_one("test", &["faction=crown", "location=castle"]);
}
let props = vec!["faction=rebels".to_string(), "location=hideout".to_string()];
let scored_no = no_pmi.score(&[(dummy_match("test"), props.clone())]);
let scored_with = with_pmi.score(&[(dummy_match("test"), props)]);
assert!(
(scored_no[0].stu_score - scored_with[0].stu_score).abs() > 0.001,
"PMI correction should change the score: no_pmi={:.4}, with_pmi={:.4}",
scored_no[0].stu_score,
scored_with[0].stu_score
);
}
#[test]
fn pmi_no_effect_when_disabled() {
let mut stu = StuScorer::new(); for _ in 0..20 {
stu.observe_one("test", &["a", "b"]);
}
assert!(
stu.pmi_for("test", "a", "b").is_none() || stu.pmi_for("test", "a", "b") == Some(0.0),
"PMI should not be available when disabled"
);
}
#[test]
fn pmi_canonical_order() {
let mut stu = StuScorer::new().with_pmi_correction();
stu.observe_one("test", &["b", "a"]); let pmi = stu.pmi_for("test", "a", "b");
assert!(pmi.is_some());
let pmi_rev = stu.pmi_for("test", "b", "a");
assert_eq!(pmi, pmi_rev);
}
}