cubek_random/
tests_utils.rs1use cubecl::prelude::*;
2
3#[derive(Default, Copy, Clone, Debug)]
4pub struct BinStats {
5 pub count: usize,
6 pub n_runs: usize, }
8
9pub fn calculate_bin_stats<E: Numeric>(
11 numbers: &[E],
12 number_of_bins: usize,
13 low: f32,
14 high: f32,
15) -> Vec<BinStats> {
16 let range = (high - low) / number_of_bins as f32;
17 let mut output: Vec<BinStats> = (0..number_of_bins).map(|_| Default::default()).collect();
18 let mut initialized = false;
19 let mut current_runs = number_of_bins; for number in numbers {
21 let num = number.to_f32().unwrap();
22 if num < low || num > high {
23 continue;
24 }
25 let index = Ord::min(f32::floor((num - low) / range) as usize, number_of_bins - 1);
27 output[index].count += 1;
28 if initialized && index != current_runs {
29 output[current_runs].n_runs += 1;
30 }
31 initialized = true;
32 current_runs = index;
33 }
34 output[current_runs].n_runs += 1;
35 output
36}
37
38pub fn assert_mean_approx_equal<E: Numeric>(data: &[E], expected_mean: f32) {
42 let mut sum = 0.;
43 for elem in data {
44 let elem = elem.to_f32().unwrap();
45 sum += elem;
46 }
47 let mean = sum / (data.len() as f32);
48
49 let mut sum = 0.0;
50 for elem in data {
51 let elem = elem.to_f32().unwrap();
52 let d = elem - mean;
53 sum += d * d;
54 }
55 let var = sum / ((data.len() - 1) as f32);
57 let std = var.sqrt();
58 let z = (mean - expected_mean).abs() / std;
60
61 assert!(
62 z < 3.,
63 "Uniform RNG validation failed: mean={mean}, expected mean={expected_mean}, std={std}",
64 );
65}
66
67pub fn assert_normal_respects_68_95_99_rule<E: Numeric>(data: &[E], mu: f32, s: f32) {
70 let stats = calculate_bin_stats(data, 6, mu - 3. * s, mu + 3. * s);
72 let assert_approx_eq = |count, percent| {
73 let expected = percent * data.len() as f32 / 100.;
74 assert!(f32::abs(count as f32 - expected) < 2000.);
75 };
76 assert_approx_eq(stats[0].count, 2.1);
77 assert_approx_eq(stats[1].count, 13.6);
78 assert_approx_eq(stats[2].count, 34.1);
79 assert_approx_eq(stats[3].count, 34.1);
80 assert_approx_eq(stats[4].count, 13.6);
81 assert_approx_eq(stats[5].count, 2.1);
82}
83
84pub fn assert_number_of_1_proportional_to_prob<E: Numeric>(data: &[E], prob: f32) {
87 let bin_stats = calculate_bin_stats(data, 2, 0., 1.1);
89 assert!(f32::abs((bin_stats[1].count as f32 / data.len() as f32) - prob) < 0.05);
90}
91
92pub fn assert_wald_wolfowitz_runs_test<E: Numeric>(data: &[E], bins_low: f32, bins_high: f32) {
96 let stats = calculate_bin_stats(data, 2, bins_low, bins_high);
98 let n_0 = stats[0].count as f32;
99 let n_1 = stats[1].count as f32;
100 let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32;
101
102 let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0;
103 let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1))
104 / ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.));
105 let z = (n_runs - expectation) / f32::sqrt(variance);
106
107 assert!(z.abs() < 2.6, "z: {z}, var: {variance}");
112}
113
114pub fn assert_at_least_one_value_per_bin<E: Numeric>(
116 data: &[E],
117 number_of_bins: usize,
118 bins_low: f32,
119 bins_high: f32,
120) {
121 let stats = calculate_bin_stats(data, number_of_bins, bins_low, bins_high);
122 assert!(stats[0].count >= 1);
123 assert!(stats[1].count >= 1);
124 assert!(stats[2].count >= 1);
125}