use crate::traits::samplers::{Shuffle, test::check_chi_square};
use crate::{error::Fallible, measures::ZeroConcentratedDivergence};
use dashu::rbig;
use std::array::from_fn;
use super::*;
#[test]
fn test_rnm_gumbel_distribution_varied() -> Fallible<()> {
let scores: [_; 10] = from_fn(|i| i);
let trials = 10_000;
let mut observed = [0; 10];
(0..trials).try_for_each(|_| {
observed[noisy_top_k(&scores, 1.0, 1, false, true)?[0]] += 1;
Fallible::Ok(())
})?;
let numer: f64 = (0..10).map(|i| (i as f64).exp()).sum();
let expected: Vec<f64> = (0..10)
.map(|i| (i as f64).exp() / numer * (trials as f64))
.collect();
check_chi_square(&observed[2..], &expected[2..])
}
#[test]
fn test_noisy_top_k_gumbel() -> Fallible<()> {
let input_domain = VectorDomain::new(AtomDomain::new_non_nan());
let input_metric = LInfDistance::new(true);
let de = make_noisy_top_k(
input_domain,
input_metric,
ZeroConcentratedDivergence,
1,
1.,
false,
)?;
let release = de.invoke(&vec![1., 2., 30., 2., 1.])?;
assert_eq!(release, vec![2]);
assert_eq!(de.map(&1.0)?, 0.125);
Ok(())
}
#[test]
fn test_noisy_top_k_exponential() -> Fallible<()> {
let input_domain = VectorDomain::new(AtomDomain::new_non_nan());
let input_metric = LInfDistance::default();
let de = make_noisy_top_k(input_domain, input_metric, MaxDivergence, 1, 1., false)?;
let release = de.invoke(&vec![1., 2., 30., 2., 1.])?;
assert_eq!(release, vec![2]);
assert_eq!(de.map(&1.0)?, 2.0);
Ok(())
}
fn check_top_k_outcome<M: TopKMeasure>(
measure: M,
scale: f64,
negate: bool,
input: Vec<i32>,
expected: Vec<usize>,
) -> Fallible<()> {
let m_rnm = make_noisy_top_k(
VectorDomain::new(AtomDomain::new_non_nan()),
LInfDistance::default(),
measure,
expected.len(),
scale,
negate,
)?;
assert_eq!(m_rnm.invoke(&input)?, expected);
Ok(())
}
#[test]
fn test_max_vs_min_gumbel_top_k() -> Fallible<()> {
check_top_k_outcome(
ZeroConcentratedDivergence,
0.,
false,
vec![1, 2, 3],
vec![2],
)?;
check_top_k_outcome(ZeroConcentratedDivergence, 0., true, vec![1, 2, 3], vec![0])?;
check_top_k_outcome(
ZeroConcentratedDivergence,
1.,
false,
vec![1, 1, 100_000],
vec![2],
)?;
check_top_k_outcome(
ZeroConcentratedDivergence,
1.,
true,
vec![1, 100_000, 100_000],
vec![0],
)?;
Ok(())
}
#[test]
fn test_max_vs_min_exponential_top_k() -> Fallible<()> {
check_top_k_outcome(MaxDivergence, 0., false, vec![1, 2, 3], vec![2])?;
check_top_k_outcome(MaxDivergence, 0., true, vec![1, 2, 3], vec![0])?;
check_top_k_outcome(MaxDivergence, 1., false, vec![1, 1, 100_000], vec![2])?;
check_top_k_outcome(MaxDivergence, 1., true, vec![1, 100_000, 100_000], vec![0])?;
Ok(())
}
fn argsort<T: Ord>(x: &[T]) -> Vec<usize> {
let mut indices = (0..x.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &x[i]);
indices
}
#[test]
fn test_peel_permute_and_flip() {
for len in [0, 1, 2, 3, 4, 5] {
for _trial in 0..len.min(1) {
for scale in [rbig![0], rbig![1]] {
let mut x = vec![rbig![0], rbig![50], rbig![100], rbig![150]];
x.truncate(len.min(x.len()));
x.shuffle().unwrap();
let mut expected = argsort(&x);
expected.reverse();
let observed = peel_permute_and_flip(x, scale, len, false).unwrap();
assert_eq!(expected, observed);
}
}
}
}
#[test]
fn test_permute_and_flip() {
for scale in [rbig![0], rbig![1]] {
for _ in 0..100 {
let x = [rbig![100], rbig![0], rbig![0]];
let selection = permute_and_flip(&x, &scale, false).unwrap();
assert_eq!(selection, 0);
}
assert_eq!(permute_and_flip(&[rbig![0]], &scale, false).unwrap(), 0);
assert!(permute_and_flip(&[], &scale, false).is_err());
}
}
#[test]
fn test_permute_and_flip_distribution_zero() -> Fallible<()> {
let scores = vec![rbig!(0).clone(); 10];
(0..1000).try_for_each(|_| {
if permute_and_flip(&scores, &rbig!(0), false)? != 9 {
panic!("P&F with zero scale should deterministically select last index");
}
Fallible::Ok(())
})
}
#[test]
fn test_permute_and_flip_distribution_uniform() -> Fallible<()> {
let scores = vec![rbig!(0).clone(); 10];
let mut observed = [0; 10];
(0..1000).try_for_each(|_| {
observed[permute_and_flip(&scores, &rbig!(1), false)?] += 1;
Fallible::Ok(())
})?;
check_chi_square(&observed, &[100.0; 10])
}
#[test]
fn test_permute_and_flip_distribution_varied() -> Fallible<()> {
let scores: [_; 10] = from_fn(RBig::from);
let trials = 10000;
let mut observed = [0; 10];
(0..trials).try_for_each(|_| {
observed[permute_and_flip(&scores, &rbig!(1), false)?] += 1;
Fallible::Ok(())
})?;
let expected: Vec<f64> = permute_and_flip_pmf(
&scores
.into_iter()
.map(|r| r.to_f64().value())
.collect::<Vec<_>>(),
1.0,
)
.into_iter()
.map(|p| p * trials as f64)
.collect();
check_chi_square(&observed[3..], &expected[3..])
}
fn permute_and_flip_pmf(scores: &[f64], scale: f64) -> Vec<f64> {
let n = scores.len();
let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let p: Vec<f64> = scores
.iter()
.map(|&s| ((s - max_score) / scale).exp()) .collect();
fn s(k: usize, r: usize, p: &[f64]) -> f64 {
if k == 0 {
return 1.0;
}
if r == 0 {
return 0.0;
}
s(k, r - 1, p) + p[r - 1] * s(k - 1, r - 1, p)
}
fn t(k: usize, r: usize, n: usize, p: &[f64]) -> f64 {
if k == 0 {
return 1.0;
}
s(k, n, p) - p[r - 1] * t(k - 1, r, n, p)
}
fn sign(i: usize) -> f64 {
if i % 2 == 0 { 1.0 } else { -1.0 }
}
let mass = |r: usize| {
let sum = (0..n)
.map(|k| sign(k) / ((k + 1) as f64) * t(k, r, n, &p))
.sum::<f64>();
p[r - 1] * sum
};
(1..=n).map(mass).collect()
}