dsfb_database/residual/
workload_phase.rs1use super::{ResidualClass, ResidualSample, ResidualStream};
16use std::collections::HashMap;
17
18pub fn js_divergence(p: &HashMap<String, u64>, q: &HashMap<String, u64>) -> f64 {
22 let sp: f64 = p.values().sum::<u64>() as f64;
23 let sq: f64 = q.values().sum::<u64>() as f64;
24 if sp == 0.0 || sq == 0.0 {
25 return 0.0;
26 }
27 let mut keys: Vec<&String> = p.keys().chain(q.keys()).collect();
28 keys.sort();
29 keys.dedup();
30 let mut acc = 0.0;
31 for k in keys {
32 let pi = *p.get(k).unwrap_or(&0) as f64 / sp;
33 let qi = *q.get(k).unwrap_or(&0) as f64 / sq;
34 let mi = 0.5 * (pi + qi);
35 if pi > 0.0 {
36 acc += 0.5 * pi * (pi / mi).log2();
37 }
38 if qi > 0.0 {
39 acc += 0.5 * qi * (qi / mi).log2();
40 }
41 }
42 acc.clamp(0.0, 1.0)
43}
44
45pub fn push_jsd(stream: &mut ResidualStream, t: f64, bucket_id: &str, jsd: f64) {
48 stream.push(ResidualSample::new(t, ResidualClass::WorkloadPhase, jsd).with_channel(bucket_id));
49}
50
51#[cfg(test)]
52mod tests {
53 use super::*;
54 #[test]
55 fn js_zero_when_equal() {
56 let mut p = HashMap::new();
57 p.insert("a".into(), 3);
58 p.insert("b".into(), 7);
59 assert!(js_divergence(&p, &p).abs() < 1e-12);
60 }
61 #[test]
62 fn js_positive_when_disjoint() {
63 let mut p = HashMap::new();
64 p.insert("a".into(), 10);
65 let mut q = HashMap::new();
66 q.insert("b".into(), 10);
67 let d = js_divergence(&p, &q);
68 assert!(d > 0.99);
69 }
70}