use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Default)]
pub struct PivotDetector {
current_counts: HashMap<String, u64>,
prev_distribution: HashMap<String, f64>,
current_total: u64,
history: Vec<f64>,
}
impl PivotDetector {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, event_type: &str) {
*self
.current_counts
.entry(event_type.to_string())
.or_insert(0) += 1;
self.current_total += 1;
}
pub fn end_tick(&mut self) -> f64 {
if self.current_total == 0 {
self.history.push(0.0);
return 0.0;
}
let current_dist: HashMap<String, f64> = self
.current_counts
.iter()
.map(|(k, &v)| (k.clone(), v as f64 / self.current_total as f64))
.collect();
let jsd = if self.prev_distribution.is_empty() {
0.0 } else {
jensen_shannon_divergence(&self.prev_distribution, ¤t_dist)
};
self.history.push(jsd);
self.prev_distribution = current_dist;
self.current_counts.clear();
self.current_total = 0;
jsd
}
pub fn last_pivot(&self) -> f64 {
self.history.last().copied().unwrap_or(0.0)
}
pub fn average_pivot(&self, window: usize) -> f64 {
if self.history.is_empty() || window == 0 {
return 0.0;
}
let start = self.history.len().saturating_sub(window);
let slice = &self.history[start..];
slice.iter().sum::<f64>() / slice.len() as f64
}
pub fn history(&self) -> &[f64] {
&self.history
}
pub fn reset(&mut self) {
self.current_counts.clear();
self.prev_distribution.clear();
self.current_total = 0;
self.history.clear();
}
}
fn jensen_shannon_divergence(p: &HashMap<String, f64>, q: &HashMap<String, f64>) -> f64 {
let all_keys: HashSet<&String> = p.keys().chain(q.keys()).collect();
let mut jsd = 0.0;
for key in all_keys {
let p_val = p.get(key).copied().unwrap_or(0.0);
let q_val = q.get(key).copied().unwrap_or(0.0);
let m_val = (p_val + q_val) / 2.0;
if m_val > 0.0 {
if p_val > 0.0 {
jsd += 0.5 * p_val * (p_val / m_val).log2();
}
if q_val > 0.0 {
jsd += 0.5 * q_val * (q_val / m_val).log2();
}
}
}
jsd.max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identical_distributions_zero_jsd() {
let mut p = PivotDetector::new();
p.push("trade");
p.push("talk");
p.end_tick();
p.push("trade");
p.push("talk");
let jsd = p.end_tick();
assert!(
jsd.abs() < 0.001,
"identical distributions should have JSD ≈ 0, got {}",
jsd
);
}
#[test]
fn completely_different_distributions_high_jsd() {
let mut p = PivotDetector::new();
p.push("peace");
p.push("peace");
p.end_tick();
p.push("war");
p.push("war");
let jsd = p.end_tick();
assert!(
jsd > 0.9,
"completely different distributions should have high JSD, got {}",
jsd
);
}
#[test]
fn first_tick_returns_zero() {
let mut p = PivotDetector::new();
p.push("test");
assert_eq!(p.end_tick(), 0.0);
}
#[test]
fn empty_tick_returns_zero() {
let mut p = PivotDetector::new();
assert_eq!(p.end_tick(), 0.0);
}
#[test]
fn partial_overlap_moderate_jsd() {
let mut p = PivotDetector::new();
p.push("trade");
p.push("trade");
p.push("talk");
p.end_tick();
p.push("trade");
p.push("attack");
p.push("attack");
let jsd = p.end_tick();
assert!(
jsd > 0.1 && jsd < 0.9,
"partial overlap should give moderate JSD, got {}",
jsd
);
}
#[test]
fn average_pivot_over_window() {
let mut p = PivotDetector::new();
for _ in 0..5 {
p.push("same");
p.end_tick();
}
assert!(
p.average_pivot(5) < 0.01,
"stable events should have low average pivot"
);
}
#[test]
fn average_pivot_zero_window_returns_zero() {
let mut p = PivotDetector::new();
p.push("test");
p.end_tick();
let avg = p.average_pivot(0);
assert_eq!(avg, 0.0);
}
#[test]
fn jsd_is_bounded_zero_one() {
let p: HashMap<String, f64> = [("a".into(), 1.0)].into();
let q: HashMap<String, f64> = [("b".into(), 1.0)].into();
let jsd = jensen_shannon_divergence(&p, &q);
assert!(
jsd >= 0.0 && jsd <= 1.0,
"JSD should be in [0, 1], got {}",
jsd
);
}
}