use super::{ResidualClass, ResidualSample, ResidualStream};
use std::collections::HashMap;
pub fn js_divergence(p: &HashMap<String, u64>, q: &HashMap<String, u64>) -> f64 {
let sp: f64 = p.values().sum::<u64>() as f64;
let sq: f64 = q.values().sum::<u64>() as f64;
if sp == 0.0 || sq == 0.0 {
return 0.0;
}
let mut keys: Vec<&String> = p.keys().chain(q.keys()).collect();
keys.sort();
keys.dedup();
let mut acc = 0.0;
for k in keys {
let pi = *p.get(k).unwrap_or(&0) as f64 / sp;
let qi = *q.get(k).unwrap_or(&0) as f64 / sq;
let mi = 0.5 * (pi + qi);
if pi > 0.0 {
acc += 0.5 * pi * (pi / mi).log2();
}
if qi > 0.0 {
acc += 0.5 * qi * (qi / mi).log2();
}
}
acc.clamp(0.0, 1.0)
}
pub fn push_jsd(stream: &mut ResidualStream, t: f64, bucket_id: &str, jsd: f64) {
stream.push(ResidualSample::new(t, ResidualClass::WorkloadPhase, jsd).with_channel(bucket_id));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn js_zero_when_equal() {
let mut p = HashMap::new();
p.insert("a".into(), 3);
p.insert("b".into(), 7);
assert!(js_divergence(&p, &p).abs() < 1e-12);
}
#[test]
fn js_positive_when_disjoint() {
let mut p = HashMap::new();
p.insert("a".into(), 10);
let mut q = HashMap::new();
q.insert("b".into(), 10);
let d = js_divergence(&p, &q);
assert!(d > 0.99);
}
}