Skip to main content

ad_ess/
utils.rs

1use std::ops::{Add, Sub};
2
3pub fn kl_divergence(p_1: &[f32], p_2: &Vec<f32>) -> f32 {
4    p_1.iter().zip(p_2).fold(0.0, |total, (pi_1, pi_2)| {
5        total + pi_1 * (pi_1 / pi_2).log2()
6    })
7}
8
9pub fn entropy(p: &[f32]) -> f32 {
10    p.iter().map(|pi| -pi * pi.log2()).sum()
11}
12
13pub fn information(p: &[f32]) -> Vec<f32> {
14    p.iter().map(|pi| -pi.log2()).collect()
15}
16
17pub fn cumsum<T>(list: &[T]) -> Vec<T>
18where
19    T: Clone,
20    T: From<u8>,
21    for<'a> &'a T: Add<&'a T, Output = T>,
22{
23    list.iter().fold(vec![T::from(0u8)], |mut acc, val| {
24        acc.push(acc.last().unwrap() + val);
25        acc
26    })
27}
28
29pub fn differeniate<T>(list: &[T]) -> Vec<T>
30where
31    T: Copy,
32    T: Sub<Output = T>,
33{
34    let mut result = Vec::with_capacity(list.len() - 1);
35    for idx in 1..list.len() {
36        result.push(list[idx] - list[idx - 1]);
37    }
38    result
39}
40
41pub fn distribution_from_weights(weights: &[usize], res_factor: f32) -> Vec<f32> {
42    let exps: Vec<f32> = weights
43        .iter()
44        .map(|weight| (*weight as f32 / -res_factor).exp2())
45        .collect();
46    let exps_sum = exps.iter().sum::<f32>();
47    let p_goal: Vec<f32> = exps.iter().map(|exp| exp / exps_sum).collect();
48
49    p_goal
50}