cubecl_random/
tests_utils.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4#[derive(Default, Copy, Clone, Debug)]
5pub struct BinStats {
6    pub count: usize,
7    pub n_runs: usize, // Number of sequences of same bin
8}
9
10/// Sorts the data into bins for ranges of equal sizes
11pub fn calculate_bin_stats<E: Numeric>(
12    numbers: &[E],
13    number_of_bins: usize,
14    low: f32,
15    high: f32,
16) -> Vec<BinStats> {
17    let range = (high - low) / number_of_bins as f32;
18    let mut output: Vec<BinStats> = (0..number_of_bins).map(|_| Default::default()).collect();
19    let mut initialized = false;
20    let mut current_runs = number_of_bins; // impossible value for starting point
21    for number in numbers {
22        let num = number.to_f32().unwrap();
23        if num < low || num > high {
24            continue;
25        }
26        // When num = high, index should be clamped to `number_of_bins` - 1
27        let index = (f32::floor((num - low) / range) as usize).min(number_of_bins - 1);
28        output[index].count += 1;
29        if initialized && index != current_runs {
30            output[current_runs].n_runs += 1;
31        }
32        initialized = true;
33        current_runs = index;
34    }
35    output[current_runs].n_runs += 1;
36    output
37}
38
39/// Asserts that the mean of a dataset is approximately equal to an expected value,
40/// within 2.5 standard deviations.
41/// There is a very small chance this raises a false negative.
42pub fn assert_mean_approx_equal<E: Numeric>(data: &[E], expected_mean: f32) {
43    let mut sum = 0.;
44    for elem in data {
45        let elem = elem.to_f32().unwrap();
46        sum += elem;
47    }
48    let mean = sum / (data.len() as f32);
49
50    let mut sum = 0.0;
51    for elem in data {
52        let elem = elem.to_f32().unwrap();
53        let d = elem - mean;
54        sum += d * d;
55    }
56    // sample variance
57    let var = sum / ((data.len() - 1) as f32);
58    let std = var.sqrt();
59    // z-score
60    let z = (mean - expected_mean).abs() / std;
61
62    assert!(
63        z < 3.,
64        "Uniform RNG validation failed: mean={mean}, expected mean={expected_mean}, std={std}",
65    );
66}
67
68/// Asserts that the distribution follows the 68-95-99 rule of normal distributions,
69/// following the given mean and standard deviation.
70pub fn assert_normal_respects_68_95_99_rule<E: Numeric>(data: &[E], mu: f32, s: f32) {
71    // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule
72    let stats = calculate_bin_stats(data, 6, mu - 3. * s, mu + 3. * s);
73    let assert_approx_eq = |count, percent| {
74        let expected = percent * data.len() as f32 / 100.;
75        assert!(f32::abs(count as f32 - expected) < 2000.);
76    };
77    assert_approx_eq(stats[0].count, 2.1);
78    assert_approx_eq(stats[1].count, 13.6);
79    assert_approx_eq(stats[2].count, 34.1);
80    assert_approx_eq(stats[3].count, 34.1);
81    assert_approx_eq(stats[4].count, 13.6);
82    assert_approx_eq(stats[5].count, 2.1);
83}
84
85/// For a bernoulli distribution: asserts that the proportion of 1s to 0s is approximately equal
86/// to the expected probability.
87pub fn assert_number_of_1_proportional_to_prob<E: Numeric>(data: &[E], prob: f32) {
88    // High bound slightly over 1 so 1.0 is included in second bin
89    let bin_stats = calculate_bin_stats(data, 2, 0., 1.1);
90    assert!(f32::abs((bin_stats[1].count as f32 / data.len() as f32) - prob) < 0.05);
91}
92
93/// Asserts that the elements of the data, sorted into two bins, are elements of the sequence
94/// are mutually independent.
95/// There is a very small chance it gives a false negative.
96pub fn assert_wald_wolfowitz_runs_test<E: Numeric>(data: &[E], bins_low: f32, bins_high: f32) {
97    //https://en.wikipedia.org/wiki/Wald%E2%80%93Wolfowitz_runs_test
98    let stats = calculate_bin_stats(data, 2, bins_low, bins_high);
99    let n_0 = stats[0].count as f32;
100    let n_1 = stats[1].count as f32;
101    let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32;
102
103    let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0;
104    let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1))
105        / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.));
106    let z = (n_runs - expectation) / f32::sqrt(variance);
107
108    // below 2 means we can have good confidence in the randomness
109    // we put 2.6 to make sure it passes even when very unlucky.
110    // With higher vectorization, adjacent values are more
111    // correlated, which makes this test is more flaky.
112    assert!(z.abs() < 2.6, "z: {z}, var: {variance}");
113}
114
115/// Asserts that there is at least one value per bin
116pub fn assert_at_least_one_value_per_bin<E: Numeric>(
117    data: &[E],
118    number_of_bins: usize,
119    bins_low: f32,
120    bins_high: f32,
121) {
122    let stats = calculate_bin_stats(data, number_of_bins, bins_low, bins_high);
123    assert!(stats[0].count >= 1);
124    assert!(stats[1].count >= 1);
125    assert!(stats[2].count >= 1);
126}