1use numra_core::Scalar;
8
9use crate::descriptive;
10use crate::distributions::{
11 chi_squared::ChiSquared, f_dist::FDist, student_t::StudentT, ContinuousDistribution,
12};
13use crate::error::StatsError;
14
15#[derive(Clone, Debug)]
17pub struct TestResult<S: Scalar> {
18 pub statistic: S,
20 pub p_value: S,
22 pub reject: bool,
24}
25
26pub fn ttest_1samp<S: Scalar>(data: &[S], mu0: S, alpha: S) -> Result<TestResult<S>, StatsError> {
28 if data.len() < 2 {
29 return Err(StatsError::EmptyData);
30 }
31 let n = data.len();
32 let m = descriptive::mean(data)?;
33 let s = descriptive::std_dev(data)?;
34 let sqrt_n = S::from_usize(n).sqrt();
35 let t_stat = (m - mu0) / (s / sqrt_n);
36 let df = S::from_usize(n - 1);
37 let t_dist = StudentT::new(df);
38 let p_value = S::TWO * (S::ONE - t_dist.cdf(t_stat.abs()));
40 Ok(TestResult {
41 statistic: t_stat,
42 p_value,
43 reject: p_value < alpha,
44 })
45}
46
47pub fn ttest_ind<S: Scalar>(
49 data1: &[S],
50 data2: &[S],
51 alpha: S,
52) -> Result<TestResult<S>, StatsError> {
53 if data1.len() < 2 || data2.len() < 2 {
54 return Err(StatsError::EmptyData);
55 }
56 let n1 = data1.len();
57 let n2 = data2.len();
58 let m1 = descriptive::mean(data1)?;
59 let m2 = descriptive::mean(data2)?;
60 let v1 = descriptive::variance(data1)?;
61 let v2 = descriptive::variance(data2)?;
62 let n1s = S::from_usize(n1);
63 let n2s = S::from_usize(n2);
64
65 let se = (v1 / n1s + v2 / n2s).sqrt();
66 let t_stat = (m1 - m2) / se;
67
68 let vn1 = v1 / n1s;
70 let vn2 = v2 / n2s;
71 let num = (vn1 + vn2) * (vn1 + vn2);
72 let denom = vn1 * vn1 / (n1s - S::ONE) + vn2 * vn2 / (n2s - S::ONE);
73 let df = num / denom;
74
75 let t_dist = StudentT::new(df);
76 let p_value = S::TWO * (S::ONE - t_dist.cdf(t_stat.abs()));
77 Ok(TestResult {
78 statistic: t_stat,
79 p_value,
80 reject: p_value < alpha,
81 })
82}
83
84pub fn ttest_rel<S: Scalar>(
86 data1: &[S],
87 data2: &[S],
88 alpha: S,
89) -> Result<TestResult<S>, StatsError> {
90 if data1.len() != data2.len() {
91 return Err(StatsError::LengthMismatch {
92 expected: data1.len(),
93 got: data2.len(),
94 });
95 }
96 let diffs: Vec<S> = data1
97 .iter()
98 .zip(data2.iter())
99 .map(|(&a, &b)| a - b)
100 .collect();
101 ttest_1samp(&diffs, S::ZERO, alpha)
102}
103
104pub fn chi2_test<S: Scalar>(
106 observed: &[S],
107 expected: &[S],
108 alpha: S,
109) -> Result<TestResult<S>, StatsError> {
110 if observed.len() != expected.len() {
111 return Err(StatsError::LengthMismatch {
112 expected: observed.len(),
113 got: expected.len(),
114 });
115 }
116 if observed.len() < 2 {
117 return Err(StatsError::EmptyData);
118 }
119 let chi2_stat: S = observed
120 .iter()
121 .zip(expected.iter())
122 .fold(S::ZERO, |a, (&o, &e)| {
123 let d = o - e;
124 a + d * d / e
125 });
126 let df = S::from_usize(observed.len() - 1);
127 let chi2_dist = ChiSquared::new(df);
128 let p_value = S::ONE - chi2_dist.cdf(chi2_stat);
129 Ok(TestResult {
130 statistic: chi2_stat,
131 p_value,
132 reject: p_value < alpha,
133 })
134}
135
136pub fn ks_test<S: Scalar>(
138 data: &[S],
139 dist: &dyn ContinuousDistribution<S>,
140 alpha: S,
141) -> Result<TestResult<S>, StatsError> {
142 if data.is_empty() {
143 return Err(StatsError::EmptyData);
144 }
145 let n = data.len();
146 let mut sorted: Vec<S> = data.to_vec();
147 sorted.sort_by(|a, b| a.to_f64().partial_cmp(&b.to_f64()).unwrap());
148
149 let mut d_max = S::ZERO;
150 let ns = S::from_usize(n);
151 for (i, &x) in sorted.iter().enumerate() {
152 let f_x = dist.cdf(x);
153 let ecdf_above = S::from_usize(i + 1) / ns;
154 let ecdf_below = S::from_usize(i) / ns;
155 let d1 = (ecdf_above - f_x).abs();
156 let d2 = (ecdf_below - f_x).abs();
157 let d = if d1 > d2 { d1 } else { d2 };
158 if d > d_max {
159 d_max = d;
160 }
161 }
162
163 let sqrt_n = ns.sqrt();
165 let z = (sqrt_n + S::from_f64(0.12) + S::from_f64(0.11) / sqrt_n) * d_max;
166 let p_value = ks_pvalue(z);
167
168 Ok(TestResult {
169 statistic: d_max,
170 p_value,
171 reject: p_value < alpha,
172 })
173}
174
175fn ks_pvalue<S: Scalar>(z: S) -> S {
177 let z_f64 = z.to_f64();
178 if z_f64 < 0.27 {
179 return S::ONE;
180 }
181 if z_f64 > 3.1 {
182 return S::ZERO;
183 }
184 let mut sum = 0.0;
186 for k in 1..=100 {
187 let kf = k as f64;
188 let term = (-2.0 * kf * kf * z_f64 * z_f64).exp();
189 if k % 2 == 1 {
190 sum += term;
191 } else {
192 sum -= term;
193 }
194 if term < 1e-15 {
195 break;
196 }
197 }
198 S::from_f64((2.0 * sum).clamp(0.0, 1.0))
199}
200
201pub fn anova_oneway<S: Scalar>(groups: &[&[S]], alpha: S) -> Result<TestResult<S>, StatsError> {
203 if groups.len() < 2 {
204 return Err(StatsError::InvalidParameter(
205 "ANOVA requires at least 2 groups".into(),
206 ));
207 }
208 let k = groups.len();
209 let mut total_n = 0;
210 let mut grand_sum = S::ZERO;
211 let mut group_means = Vec::with_capacity(k);
212 let mut group_sizes = Vec::with_capacity(k);
213
214 for g in groups {
215 if g.is_empty() {
216 return Err(StatsError::EmptyData);
217 }
218 let m = descriptive::mean(g)?;
219 group_means.push(m);
220 group_sizes.push(g.len());
221 total_n += g.len();
222 grand_sum += g.iter().copied().fold(S::ZERO, |a, b| a + b);
223 }
224 let grand_mean = grand_sum / S::from_usize(total_n);
225
226 let ss_between: S = group_means
228 .iter()
229 .zip(group_sizes.iter())
230 .fold(S::ZERO, |a, (&m, &n)| {
231 a + S::from_usize(n) * (m - grand_mean) * (m - grand_mean)
232 });
233
234 let ss_within: S = groups
236 .iter()
237 .zip(group_means.iter())
238 .fold(S::ZERO, |a, (g, &m)| {
239 a + g
240 .iter()
241 .copied()
242 .fold(S::ZERO, |b, x| b + (x - m) * (x - m))
243 });
244
245 let df_between = S::from_usize(k - 1);
246 let df_within = S::from_usize(total_n - k);
247
248 let ms_between = ss_between / df_between;
249 let ms_within = ss_within / df_within;
250
251 let f_stat = ms_between / ms_within;
252 let f_dist = FDist::new(df_between, df_within);
253 let p_value = S::ONE - f_dist.cdf(f_stat);
254
255 Ok(TestResult {
256 statistic: f_stat,
257 p_value,
258 reject: p_value < alpha,
259 })
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_ttest_1samp_no_effect() {
268 let data = vec![-1.0_f64, -0.5, 0.0, 0.5, 1.0];
270 let result = ttest_1samp(&data, 0.0, 0.05).unwrap();
271 assert!(!result.reject);
272 assert!(result.statistic.abs() < 1e-12);
273 }
274
275 #[test]
276 fn test_ttest_1samp_reject() {
277 let data = vec![10.0_f64, 11.0, 12.0, 10.5, 11.5];
278 let result = ttest_1samp(&data, 0.0, 0.05).unwrap();
279 assert!(result.reject);
280 }
281
282 #[test]
283 fn test_ttest_ind_same_distribution() {
284 let a = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
285 let b = vec![1.5_f64, 2.5, 3.5, 4.5, 5.5];
286 let result = ttest_ind(&a, &b, 0.05).unwrap();
287 assert!(!result.reject);
289 }
290
291 #[test]
292 fn test_ttest_rel() {
293 let before = vec![200.0_f64, 220.0, 190.0, 210.0, 230.0];
294 let after = vec![195.0, 215.0, 185.0, 205.0, 225.0];
295 let result = ttest_rel(&before, &after, 0.05).unwrap();
296 assert!(result.reject);
298 }
299
300 #[test]
301 fn test_chi2_test_uniform() {
302 let obs = vec![18.0_f64, 16.0, 17.0, 15.0, 17.0, 17.0];
304 let exp = vec![16.67_f64; 6];
305 let result = chi2_test(&obs, &exp, 0.05).unwrap();
306 assert!(!result.reject); }
308
309 #[test]
310 fn test_ks_test_normal() {
311 use crate::distributions::normal::Normal;
312 use rand::SeedableRng;
313 let n = Normal::<f64>::standard();
315 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
316 let data = n.sample_n(&mut rng, 200);
317 let result = ks_test(&data, &n, 0.05).unwrap();
318 assert!(
320 !result.reject,
321 "KS test unexpectedly rejected: stat={}, p={}",
322 result.statistic, result.p_value
323 );
324 }
325
326 #[test]
327 fn test_anova_equal_groups() {
328 let g1 = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
329 let g2 = vec![1.5_f64, 2.5, 3.5, 4.5, 5.5];
330 let g3 = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
331 let result = anova_oneway(&[&g1, &g2, &g3], 0.05).unwrap();
332 assert!(!result.reject);
333 }
334
335 #[test]
336 fn test_anova_different_groups() {
337 let g1 = vec![1.0_f64, 2.0, 3.0];
338 let g2 = vec![10.0, 11.0, 12.0];
339 let g3 = vec![20.0, 21.0, 22.0];
340 let result = anova_oneway(&[&g1, &g2, &g3], 0.05).unwrap();
341 assert!(result.reject);
342 }
343}