1use statrs::distribution::{ContinuousCDF, Normal};
7
8use crate::errors::{DecisionError, Result};
9
10pub trait StatisticalTest {
12 fn test(&self) -> Result<f64>;
14
15 fn is_significant(&self, alpha: f64) -> Result<bool> {
17 Ok(self.test()? < alpha)
18 }
19}
20
21#[derive(Debug, Clone)]
26pub struct ZTest {
27 pub successes_1: u64,
29 pub trials_1: u64,
31 pub successes_2: u64,
33 pub trials_2: u64,
35}
36
37impl ZTest {
38 pub fn new(successes_1: u64, trials_1: u64, successes_2: u64, trials_2: u64) -> Self {
40 Self {
41 successes_1,
42 trials_1,
43 successes_2,
44 trials_2,
45 }
46 }
47
48 pub fn proportions(&self) -> (f64, f64) {
50 let p1 = if self.trials_1 > 0 {
51 self.successes_1 as f64 / self.trials_1 as f64
52 } else {
53 0.0
54 };
55
56 let p2 = if self.trials_2 > 0 {
57 self.successes_2 as f64 / self.trials_2 as f64
58 } else {
59 0.0
60 };
61
62 (p1, p2)
63 }
64
65 pub fn pooled_proportion(&self) -> f64 {
67 let total_successes = self.successes_1 + self.successes_2;
68 let total_trials = self.trials_1 + self.trials_2;
69
70 if total_trials > 0 {
71 total_successes as f64 / total_trials as f64
72 } else {
73 0.0
74 }
75 }
76
77 pub fn z_statistic(&self) -> Result<f64> {
79 let (p1, p2) = self.proportions();
80 let p_pool = self.pooled_proportion();
81
82 let n1 = self.trials_1 as f64;
83 let n2 = self.trials_2 as f64;
84
85 if n1 == 0.0 || n2 == 0.0 {
86 return Err(DecisionError::InsufficientData(
87 "Cannot perform z-test with zero trials".to_string()
88 ));
89 }
90
91 let se = (p_pool * (1.0 - p_pool) * (1.0/n1 + 1.0/n2)).sqrt();
93
94 if se == 0.0 {
95 return Err(DecisionError::StatisticalError(
96 "Standard error is zero, cannot compute z-statistic".to_string()
97 ));
98 }
99
100 Ok((p1 - p2) / se)
102 }
103
104 pub fn confidence_interval(&self, confidence: f64) -> Result<(f64, f64)> {
106 let (p1, p2) = self.proportions();
107 let diff = p1 - p2;
108
109 let n1 = self.trials_1 as f64;
110 let n2 = self.trials_2 as f64;
111
112 if n1 == 0.0 || n2 == 0.0 {
113 return Err(DecisionError::InsufficientData(
114 "Cannot calculate confidence interval with zero trials".to_string()
115 ));
116 }
117
118 let se = ((p1 * (1.0 - p1) / n1) + (p2 * (1.0 - p2) / n2)).sqrt();
120
121 let z = match confidence {
123 c if (c - 0.90).abs() < 0.001 => 1.645,
124 c if (c - 0.95).abs() < 0.001 => 1.96,
125 c if (c - 0.99).abs() < 0.001 => 2.576,
126 _ => {
127 let normal = Normal::new(0.0, 1.0)
128 .map_err(|e| DecisionError::StatisticalError(e.to_string()))?;
129 let alpha = 1.0 - confidence;
130 normal.inverse_cdf(1.0 - alpha / 2.0)
131 }
132 };
133
134 let margin = z * se;
135 Ok((diff - margin, diff + margin))
136 }
137
138 pub fn effect_size(&self) -> f64 {
140 let (p1, p2) = self.proportions();
141
142 2.0 * (p1.sqrt().asin() - p2.sqrt().asin())
144 }
145
146 pub fn power(&self, alpha: f64, effect_size: f64) -> Result<f64> {
148 let n1 = self.trials_1 as f64;
149 let n2 = self.trials_2 as f64;
150
151 if n1 == 0.0 || n2 == 0.0 {
152 return Ok(0.0);
153 }
154
155 let n_harmonic = 2.0 / (1.0/n1 + 1.0/n2);
157 let noncentrality = effect_size * (n_harmonic / 4.0).sqrt();
158
159 let normal = Normal::new(0.0, 1.0)
160 .map_err(|e| DecisionError::StatisticalError(e.to_string()))?;
161
162 let z_alpha = normal.inverse_cdf(1.0 - alpha / 2.0);
163 let power = 1.0 - normal.cdf(z_alpha - noncentrality);
164
165 Ok(power)
166 }
167}
168
169impl StatisticalTest for ZTest {
170 fn test(&self) -> Result<f64> {
172 let z = self.z_statistic()?;
173
174 let normal = Normal::new(0.0, 1.0)
175 .map_err(|e| DecisionError::StatisticalError(e.to_string()))?;
176
177 let p_value = 2.0 * (1.0 - normal.cdf(z.abs()));
179
180 Ok(p_value)
181 }
182}
183
184pub struct SampleSizeCalculator {
186 pub baseline_rate: f64,
188 pub min_effect: f64,
190 pub power: f64,
192 pub alpha: f64,
194}
195
196impl SampleSizeCalculator {
197 pub fn new(baseline_rate: f64, min_effect: f64, power: f64, alpha: f64) -> Result<Self> {
199 if baseline_rate <= 0.0 || baseline_rate >= 1.0 {
200 return Err(DecisionError::InvalidConfig(
201 "Baseline rate must be between 0 and 1".to_string()
202 ));
203 }
204
205 if power <= 0.0 || power >= 1.0 {
206 return Err(DecisionError::InvalidConfig(
207 "Power must be between 0 and 1".to_string()
208 ));
209 }
210
211 if alpha <= 0.0 || alpha >= 1.0 {
212 return Err(DecisionError::InvalidConfig(
213 "Alpha must be between 0 and 1".to_string()
214 ));
215 }
216
217 Ok(Self {
218 baseline_rate,
219 min_effect,
220 power,
221 alpha,
222 })
223 }
224
225 pub fn calculate(&self) -> Result<usize> {
227 let p1 = self.baseline_rate;
228 let p2 = self.baseline_rate * (1.0 + self.min_effect);
229
230 if p2 >= 1.0 {
231 return Err(DecisionError::InvalidConfig(
232 "Effect size too large, treatment rate exceeds 1.0".to_string()
233 ));
234 }
235
236 let normal = Normal::new(0.0, 1.0)
237 .map_err(|e| DecisionError::StatisticalError(e.to_string()))?;
238
239 let z_alpha = normal.inverse_cdf(1.0 - self.alpha / 2.0);
240 let z_beta = normal.inverse_cdf(self.power);
241
242 let p_avg = (p1 + p2) / 2.0;
243 let delta = (p2 - p1).abs();
244
245 let n = ((z_alpha + z_beta).powi(2) * 2.0 * p_avg * (1.0 - p_avg)) / delta.powi(2);
247
248 Ok(n.ceil() as usize)
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use approx::assert_relative_eq;
256
257 #[test]
258 fn test_z_test_proportions() {
259 let test = ZTest::new(50, 100, 60, 100);
260 let (p1, p2) = test.proportions();
261
262 assert_eq!(p1, 0.5);
263 assert_eq!(p2, 0.6);
264 }
265
266 #[test]
267 fn test_pooled_proportion() {
268 let test = ZTest::new(50, 100, 60, 100);
269 let p_pool = test.pooled_proportion();
270
271 assert_eq!(p_pool, 0.55);
273 }
274
275 #[test]
276 fn test_z_statistic() {
277 let test = ZTest::new(50, 100, 60, 100);
278 let z = test.z_statistic().unwrap();
279
280 assert!(z < 0.0);
282
283 assert!(z.abs() > 1.0 && z.abs() < 2.0);
285 }
286
287 #[test]
288 fn test_z_test_significant_difference() {
289 let test = ZTest::new(30, 100, 70, 100);
291 let p_value = test.test().unwrap();
292
293 assert!(p_value < 0.05);
295 }
296
297 #[test]
298 fn test_z_test_no_difference() {
299 let test = ZTest::new(50, 100, 50, 100);
301 let p_value = test.test().unwrap();
302
303 assert!(p_value > 0.05);
305 }
306
307 #[test]
308 fn test_is_significant() {
309 let test = ZTest::new(30, 100, 70, 100);
310
311 assert!(test.is_significant(0.05).unwrap());
313
314 }
317
318 #[test]
319 fn test_confidence_interval() {
320 let test = ZTest::new(50, 100, 60, 100);
321 let (lower, upper) = test.confidence_interval(0.95).unwrap();
322
323 let diff = -0.1;
325
326 assert!(lower < diff && diff < upper);
328
329 assert!(upper - lower < 0.3);
331 }
332
333 #[test]
334 fn test_effect_size() {
335 let test = ZTest::new(30, 100, 70, 100);
336 let h = test.effect_size();
337
338 assert!(h.abs() > 0.5);
340 }
341
342 #[test]
343 fn test_insufficient_data_error() {
344 let test = ZTest::new(5, 10, 0, 0);
345
346 assert!(test.z_statistic().is_err());
348 }
349
350 #[test]
351 fn test_sample_size_calculator() {
352 let calc = SampleSizeCalculator::new(
353 0.1, 0.2, 0.8, 0.05, ).unwrap();
358
359 let n = calc.calculate().unwrap();
360
361 assert!(n > 100);
363 assert!(n < 100000); }
365
366 #[test]
367 fn test_sample_size_larger_effect() {
368 let small_effect = SampleSizeCalculator::new(0.1, 0.1, 0.8, 0.05)
369 .unwrap()
370 .calculate()
371 .unwrap();
372
373 let large_effect = SampleSizeCalculator::new(0.1, 0.5, 0.8, 0.05)
374 .unwrap()
375 .calculate()
376 .unwrap();
377
378 assert!(large_effect < small_effect);
380 }
381
382 #[test]
383 fn test_power_calculation() {
384 let test = ZTest::new(500, 1000, 550, 1000);
385 let power = test.power(0.05, 0.1).unwrap();
386
387 assert!(power > 0.0 && power < 1.0);
389 }
390
391 #[test]
392 fn test_real_world_scenario() {
393 let test = ZTest::new(100, 1000, 150, 1000);
396
397 let (p1, p2) = test.proportions();
398 assert_relative_eq!(p1, 0.1, epsilon = 0.001);
399 assert_relative_eq!(p2, 0.15, epsilon = 0.001);
400
401 let p_value = test.test().unwrap();
402
403 assert!(p_value < 0.05, "p-value {} should be < 0.05", p_value);
405
406 let (lower, upper) = test.confidence_interval(0.95).unwrap();
407
408 assert!(lower < 0.0 && upper < 0.0 || lower > 0.0 && upper > 0.0);
410 }
411}