use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::Rng;
use rand::SeedableRng;
use statrs::statistics::Statistics;
use crate::errors::ElinorError;
#[derive(Debug, Clone)]
pub struct RandomizedTukeyHsdTest {
n_systems: usize,
n_topics: usize,
n_iters: usize,
random_state: u64,
p_values: Vec<Vec<f64>>,
}
impl RandomizedTukeyHsdTest {
pub fn from_tupled_samples<I, S>(samples: I, n_systems: usize) -> Result<Self, ElinorError>
where
I: IntoIterator<Item = S>,
S: AsRef<[f64]>,
{
RandomizedTukeyHsdTester::new(n_systems).test(samples)
}
pub const fn n_systems(&self) -> usize {
self.n_systems
}
pub const fn n_topics(&self) -> usize {
self.n_topics
}
pub const fn n_iters(&self) -> usize {
self.n_iters
}
pub const fn random_state(&self) -> u64 {
self.random_state
}
pub fn p_values(&self) -> Vec<Vec<f64>> {
self.p_values.clone()
}
}
#[derive(Debug, Clone)]
pub struct RandomizedTukeyHsdTester {
n_systems: usize,
n_iters: usize,
random_state: Option<u64>,
}
impl RandomizedTukeyHsdTester {
pub const fn new(n_systems: usize) -> Self {
Self {
n_systems,
n_iters: 10000,
random_state: None,
}
}
pub fn with_n_iters(mut self, n_iters: usize) -> Self {
self.n_iters = n_iters.max(1);
self
}
pub const fn with_random_state(mut self, random_state: u64) -> Self {
self.random_state = Some(random_state);
self
}
pub fn test<I, S>(&self, samples: I) -> Result<RandomizedTukeyHsdTest, ElinorError>
where
I: IntoIterator<Item = S>,
S: AsRef<[f64]>,
{
let samples: Vec<Vec<f64>> = samples
.into_iter()
.map(|topic| {
let topic = topic.as_ref();
if topic.len() != self.n_systems {
return Err(ElinorError::InvalidArgument(
"The length of each sample must be equal to the number of systems."
.to_string(),
));
}
Ok(topic.to_vec())
})
.collect::<Result<_, _>>()?;
if samples.is_empty() {
return Err(ElinorError::InvalidArgument(
"The input must have at least one sample.".to_string(),
));
}
let n_samples = samples.len() as f64;
let random_state = self
.random_state
.map_or_else(|| rand::thread_rng().gen(), |seed| seed);
let mut rng = StdRng::seed_from_u64(random_state);
let means: Vec<_> = (0..self.n_systems)
.map(|i| samples.iter().map(|sample| sample[i]).sum::<f64>() / n_samples)
.collect();
let mut diffs = vec![vec![0_f64; self.n_systems]; self.n_systems];
for i in 0..self.n_systems {
for j in (i + 1)..self.n_systems {
diffs[i][j] = means[i] - means[j];
}
}
let mut counts = vec![vec![0_usize; self.n_systems]; self.n_systems];
for _ in 0..self.n_iters {
let mut shuffled_samples = Vec::with_capacity(samples.len());
for sample in &samples {
let mut shuffled_sample = sample.clone();
shuffled_sample.shuffle(&mut rng);
shuffled_samples.push(shuffled_sample);
}
let shuffled_means: Vec<_> = (0..self.n_systems)
.map(|i| shuffled_samples.iter().map(|sample| sample[i]).sum::<f64>() / n_samples)
.collect();
let shuffled_diff = shuffled_means.as_slice().max() - shuffled_means.as_slice().min();
for i in 0..self.n_systems {
for j in (i + 1)..self.n_systems {
if shuffled_diff >= diffs[i][j].abs() {
counts[i][j] += 1;
}
}
}
}
let mut p_values = vec![vec![1_f64; self.n_systems]; self.n_systems];
for i in 0..self.n_systems {
for j in (i + 1)..self.n_systems {
p_values[i][j] = counts[i][j] as f64 / self.n_iters as f64;
p_values[j][i] = p_values[i][j];
}
}
Ok(RandomizedTukeyHsdTest {
n_systems: self.n_systems,
n_topics: samples.len(),
n_iters: self.n_iters,
random_state,
p_values,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_randomized_tukey_hsd_test_from_tupled_samples_empty() {
let samples: Vec<[f64; 2]> = vec![];
let result = RandomizedTukeyHsdTest::from_tupled_samples(samples, 2);
assert_eq!(
result.unwrap_err(),
ElinorError::InvalidArgument("The input must have at least one sample.".to_string())
);
}
#[test]
fn test_randomized_tukey_hsd_test_from_tupled_samples_single() {
let samples = vec![[1.0, 2.0]];
let result = RandomizedTukeyHsdTest::from_tupled_samples(samples, 2).unwrap();
assert_eq!(result.n_systems(), 2);
}
#[test]
fn test_randomized_tukey_hsd_test_from_tupled_samples_invalid_length() {
let samples = vec![vec![1.0, 2.0], vec![3.0]];
let result = RandomizedTukeyHsdTest::from_tupled_samples(samples, 2);
assert_eq!(
result.unwrap_err(),
ElinorError::InvalidArgument(
"The length of each sample must be equal to the number of systems.".to_string()
)
);
}
}