pub fn pointwise_mutual_information(cosine_sim: f32, base_rate: f32, namespace_size: usize) -> f32 {
let p_m = base_rate.max(1e-6_f32);
let _p_q = 1.0_f32 / (namespace_size as f32).max(1.0);
let pmi = (cosine_sim.max(1e-10_f32) / p_m).ln();
pmi.clamp(-2.0, 5.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_high_similarity_low_frequency_is_surprising() {
let pmi = pointwise_mutual_information(0.9, 0.01, 1000);
assert!(
pmi > 0.0,
"Expected PMI > 0.0 for high similarity / low frequency, got {pmi}"
);
}
#[test]
fn test_low_similarity_is_not_surprising() {
let pmi = pointwise_mutual_information(0.1, 0.5, 1000);
assert!(
pmi < 0.5,
"Expected PMI < 0.5 for low similarity / high frequency, got {pmi}"
);
}
#[test]
fn test_common_memory_less_surprising() {
let pmi_rare = pointwise_mutual_information(0.8, 0.01, 1000);
let pmi_common = pointwise_mutual_information(0.8, 0.5, 1000);
assert!(
pmi_rare > pmi_common,
"Expected rare memory (PMI={pmi_rare}) to be more surprising than common memory (PMI={pmi_common})"
);
}
}