1use crate::error::{Error, Result};
39use crate::linalg::{Matrix, vec_mean, fit_ols_linpack, fit_and_predict_linpack};
40use super::types::{WhiteTestOutput, WhiteSingleResult, WhiteMethod};
41use super::helpers::chi_squared_p_value;
42
43pub fn white_test(
45 y: &[f64],
46 x_vars: &[Vec<f64>],
47 method: WhiteMethod,
48) -> Result<WhiteTestOutput> {
49 let n = y.len();
50 let k = x_vars.len();
51 let p = k + 1;
52
53 if n <= p {
54 return Err(Error::InsufficientData { required: p + 1, available: n });
55 }
56
57 let alpha = 0.05;
58
59 let mut x_data = vec![1.0; n * p];
61 for row in 0..n {
62 for (col, x_var) in x_vars.iter().enumerate() {
63 x_data[row * p + col + 1] = x_var[row];
64 }
65 }
66 let x_full = Matrix::new(n, p, x_data);
67 let beta = fit_ols_linpack(y, &x_full).ok_or(Error::SingularMatrix)?;
68 let predictions = x_full.mul_vec(&beta);
69 let residuals: Vec<f64> = y.iter().zip(predictions.iter())
70 .map(|(&yi, &yi_hat)| yi - yi_hat)
71 .collect();
72 let e_squared: Vec<f64> = residuals.iter().map(|&e| e * e).collect();
73
74 let (z_data, z_cols) = build_auxiliary_matrix(n, x_vars, method);
76
77 let z_matrix = Matrix::new(n, z_cols, z_data);
80
81 #[cfg(test)]
82 {
83 eprintln!("Z matrix: {} rows x {} cols", n, z_cols);
84 let qr_result = z_matrix.qr_linpack(None);
85 eprintln!("Z rank: {}", qr_result.rank);
86 eprintln!("Pivot order: {:?}", qr_result.pivot);
87
88 for j in qr_result.rank..z_cols {
90 let dropped_col = qr_result.pivot[j] - 1;
91 eprintln!("Dropped column {} (pivot position {})", dropped_col, j);
92 }
93
94 let beta = fit_ols_linpack(&e_squared, &z_matrix);
96 if let Some(ref b) = beta {
97 eprintln!("First 10 coefficients: {:?}", &b[..10.min(b.len())]);
98 eprintln!("Last 5 coefficients: {:?}", &b[b.len().saturating_sub(5)..]);
99 }
100 }
101
102 let pred_aux = fit_and_predict_linpack(&e_squared, &z_matrix).ok_or(Error::SingularMatrix)?;
103
104 #[cfg(test)]
105 {
106 eprintln!("First few pred_aux: {:?}", &pred_aux[..5.min(pred_aux.len())]);
107 let has_nan = pred_aux.iter().any(|&x| x.is_nan());
108 eprintln!("pred_aux has NaN: {}", has_nan);
109 }
110
111 let (_r_squared_aux, lm_stat) = compute_r2_and_lm(&e_squared, &pred_aux, n);
113
114 let r_result = if method == WhiteMethod::R || method == WhiteMethod::Both {
116 let df_r = (2 * k) as f64;
117 let p_value_r = chi_squared_p_value(lm_stat, df_r);
118 let passed_r = p_value_r > alpha;
119 Some(WhiteSingleResult {
120 method: "R (skedastic::white)".to_string(),
121 statistic: lm_stat,
122 p_value: p_value_r,
123 passed: passed_r,
124 })
125 } else {
126 None
127 };
128
129 let python_result = if method == WhiteMethod::Python || method == WhiteMethod::Both {
130 let theoretical_df = (k * (k + 3) / 2) as f64;
131 let df_p = theoretical_df.min((n - 1) as f64);
132 let p_value_p = chi_squared_p_value(lm_stat, df_p);
133 let passed_p = p_value_p > alpha;
134 Some(WhiteSingleResult {
135 method: "Python (statsmodels)".to_string(),
136 statistic: lm_stat,
137 p_value: p_value_p,
138 passed: passed_p,
139 })
140 } else {
141 None
142 };
143
144 let (interp_text, guid_text) = match (&r_result, &python_result) {
146 (Some(r), None) => interpret_result(r.p_value, alpha),
147 (None, Some(p)) => interpret_result(p.p_value, alpha),
148 (Some(r), Some(p)) => {
149 if r.p_value >= p.p_value {
150 interpret_result(r.p_value, alpha)
151 } else {
152 interpret_result(p.p_value, alpha)
153 }
154 }
155 (None, None) => unreachable!(),
156 };
157
158 Ok(WhiteTestOutput {
159 test_name: "White Test for Heteroscedasticity".to_string(),
160 r_result,
161 python_result,
162 interpretation: interp_text,
163 guidance: guid_text.to_string(),
164 })
165}
166
167fn compute_r2_and_lm(e_squared: &[f64], pred_aux: &[f64], n: usize) -> (f64, f64) {
169 let residuals_aux: Vec<f64> = e_squared.iter().zip(pred_aux.iter())
170 .map(|(&yi, &yi_hat)| yi - yi_hat)
171 .collect();
172
173 let rss_aux: f64 = residuals_aux.iter().map(|&r| r * r).sum();
174
175 let mean_e_squared = vec_mean(e_squared);
176 let tss_centered: f64 = e_squared.iter()
177 .map(|&e| {
178 let diff = e - mean_e_squared;
179 diff * diff
180 })
181 .sum();
182
183 let r_squared_aux = if tss_centered > 1e-10 {
184 (1.0 - (rss_aux / tss_centered)).clamp(0.0, 1.0)
185 } else {
186 0.0
187 };
188
189 let lm_stat = (n as f64) * r_squared_aux;
190 (r_squared_aux, lm_stat)
191}
192
193fn build_auxiliary_matrix(
195 n: usize,
196 x_vars: &[Vec<f64>],
197 method: WhiteMethod,
198) -> (Vec<f64>, usize) {
199 let k = x_vars.len();
200
201 match method {
202 WhiteMethod::R => {
203 let z_cols = 1 + 2 * k;
204 let mut z_data = vec![0.0; n * z_cols];
205
206 for row in 0..n {
207 let mut col_idx = 0;
208 z_data[row * z_cols + col_idx] = 1.0;
209 col_idx += 1;
210
211 for x_var in x_vars.iter() {
212 z_data[row * z_cols + col_idx] = x_var[row];
213 col_idx += 1;
214 }
215
216 for x_var in x_vars.iter() {
217 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
218 col_idx += 1;
219 }
220 }
221
222 (z_data, z_cols)
223 }
224 WhiteMethod::Python => {
225 let num_cross = k * (k - 1) / 2;
226 let z_cols = 1 + 2 * k + num_cross;
227 let mut z_data = vec![0.0; n * z_cols];
228
229 for row in 0..n {
230 let mut col_idx = 0;
231
232 z_data[row * z_cols + col_idx] = 1.0;
233 col_idx += 1;
234
235 for x_var in x_vars.iter() {
236 z_data[row * z_cols + col_idx] = x_var[row];
237 col_idx += 1;
238 }
239
240 for x_var in x_vars.iter() {
241 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
242 col_idx += 1;
243 }
244
245 for i in 0..k {
246 for j in (i + 1)..k {
247 z_data[row * z_cols + col_idx] = x_vars[i][row] * x_vars[j][row];
248 col_idx += 1;
249 }
250 }
251 }
252
253 (z_data, z_cols)
254 }
255 WhiteMethod::Both => {
256 build_auxiliary_matrix(n, x_vars, WhiteMethod::Python)
257 }
258 }
259}
260
261fn interpret_result(p_value: f64, alpha: f64) -> (String, &'static str) {
263 if p_value > alpha {
264 (
265 format!(
266 "p-value = {:.4} is greater than {:.2}. Cannot reject H0. No significant evidence of heteroscedasticity.",
267 p_value, alpha
268 ),
269 "The assumption of homoscedasticity (constant variance) appears to be met."
270 )
271 } else {
272 (
273 format!(
274 "p-value = {:.4} is less than or equal to {:.2}. Reject H0. Significant evidence of heteroscedasticity detected.",
275 p_value, alpha
276 ),
277 "Consider transforming the dependent variable (e.g., log transformation), using weighted least squares, or robust standard errors."
278 )
279 }
280}
281
282pub fn r_white_method(
284 y: &[f64],
285 x_vars: &[Vec<f64>],
286) -> Result<WhiteSingleResult> {
287 let result = white_test(y, x_vars, WhiteMethod::R)?;
288 result.r_result.ok_or(Error::SingularMatrix)
289}
290
291pub fn python_white_method(
293 y: &[f64],
294 x_vars: &[Vec<f64>],
295) -> Result<WhiteSingleResult> {
296 let result = white_test(y, x_vars, WhiteMethod::Python)?;
297 result.python_result.ok_or(Error::SingularMatrix)
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 fn test_data() -> (Vec<f64>, Vec<Vec<f64>>) {
305 let y = vec![
306 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2,
307 17.8, 16.4, 17.3, 15.2, 10.4, 10.4, 14.7, 32.4, 30.4, 33.9,
308 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4, 15.8, 19.7,
309 15.0, 21.4
310 ];
311 let x1 = vec![
312 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44,
313 3.44, 4.07, 3.73, 3.78, 5.25, 5.424, 5.345, 2.2, 1.615, 1.835,
314 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14, 1.513, 3.17, 2.77,
315 3.57, 2.78
316 ];
317 let x2 = vec![
318 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0,
319 123.0, 180.0, 180.0, 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0,
320 97.0, 150.0, 150.0, 245.0, 175.0, 66.0, 91.0, 113.0, 264.0, 175.0,
321 335.0, 109.0
322 ];
323 (y, vec![x1, x2])
324 }
325
326 #[test]
327 fn test_white_test_r_method() {
328 let (y, x_vars) = test_data();
329 let result = white_test(&y, &x_vars, WhiteMethod::R);
330 assert!(result.is_ok());
331 let output = result.unwrap();
332 assert!(output.r_result.is_some());
333 assert!(output.python_result.is_none());
334 }
335
336 #[test]
337 fn test_white_test_python_method() {
338 let (y, x_vars) = test_data();
339 let result = white_test(&y, &x_vars, WhiteMethod::Python);
340 assert!(result.is_ok());
341 let output = result.unwrap();
342 assert!(output.r_result.is_none());
343 assert!(output.python_result.is_some());
344 }
345
346 #[test]
347 fn test_white_test_both_methods() {
348 let (y, x_vars) = test_data();
349 let result = white_test(&y, &x_vars, WhiteMethod::Both);
350 assert!(result.is_ok());
351 let output = result.unwrap();
352 assert!(output.r_result.is_some());
353 assert!(output.python_result.is_some());
354 }
355
356 #[test]
357 fn test_white_test_insufficient_data() {
358 let y = vec![1.0, 2.0];
359 let x1 = vec![1.0, 2.0];
360 let x2 = vec![2.0, 3.0];
361 let result = white_test(&y, &[x1, x2], WhiteMethod::R);
362 assert!(result.is_err());
363 }
364
365 fn mtcars_data() -> (Vec<f64>, Vec<Vec<f64>>) {
366 let y = vec![
367 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2,
368 17.8, 16.4, 17.3, 15.2, 10.4, 10.4, 14.7, 32.4, 30.4, 33.9,
369 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4, 15.8, 19.7,
370 15.0, 21.4
371 ];
372
373 let cyl = vec![
374 6.0, 6.0, 4.0, 6.0, 8.0, 6.0, 8.0, 4.0, 4.0, 6.0,
375 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0,
376 4.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 8.0, 8.0,
377 8.0, 4.0
378 ];
379
380 let disp = vec![
381 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 167.6,
382 167.6, 275.8, 275.8, 275.8, 472.0, 460.0, 440.0, 78.7, 75.7, 71.1,
383 120.1, 318.0, 304.0, 350.0, 400.0, 79.0, 120.3, 95.1, 351.0, 145.0,
384 301.0, 121.0
385 ];
386
387 let hp = vec![
388 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0,
389 123.0, 180.0, 180.0, 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0,
390 97.0, 150.0, 150.0, 245.0, 175.0, 66.0, 91.0, 113.0, 264.0, 175.0,
391 335.0, 109.0
392 ];
393
394 let drat = vec![
395 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92,
396 3.92, 3.07, 3.07, 3.07, 2.93, 3.00, 3.23, 4.08, 4.93, 4.22,
397 3.70, 2.76, 3.15, 3.73, 3.08, 4.08, 4.43, 3.77, 4.22, 3.62,
398 3.54, 4.11
399 ];
400
401 let wt = vec![
402 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44,
403 3.44, 4.07, 3.73, 3.78, 5.25, 5.424, 5.345, 2.2, 1.615, 1.835,
404 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14, 1.513, 3.17, 2.77,
405 3.57, 2.78
406 ];
407
408 let qsec = vec![
409 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18.30,
410 18.90, 17.40, 17.60, 18.00, 17.98, 17.82, 17.42, 19.47, 18.52, 19.90,
411 20.01, 16.87, 17.30, 15.41, 17.05, 18.90, 16.70, 16.90, 14.50, 15.50,
412 14.60, 18.60
413 ];
414
415 let vs = vec![
416 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0,
417 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
418 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0,
419 0.0, 1.0
420 ];
421
422 let am = vec![
423 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
424 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
425 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0,
426 1.0, 1.0
427 ];
428
429 let gear = vec![
430 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0,
431 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0,
432 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0,
433 5.0, 4.0
434 ];
435
436 let carb = vec![
437 4.0, 4.0, 1.0, 1.0, 2.0, 1.0, 4.0, 2.0, 2.0, 4.0,
438 4.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 1.0, 2.0, 1.0,
439 1.0, 2.0, 2.0, 4.0, 2.0, 1.0, 2.0, 2.0, 4.0, 6.0,
440 8.0, 2.0
441 ];
442
443 (y, vec![cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb])
444 }
445
446 #[test]
447 fn test_white_r_validation() {
448 let (y, x_vars) = mtcars_data();
449 let result = white_test(&y, &x_vars, WhiteMethod::R).unwrap();
450
451 if let Some(r) = result.r_result {
452 println!("\n=== White Test R Method Validation ===");
455 println!("Reference: LM-statistic = 19.3975, p-value = 0.49614");
456 println!("Rust: LM-statistic = {}, p-value = {}", r.statistic, r.p_value);
457
458 assert!(r.p_value > 0.05);
460 assert!(r.passed);
461 }
462 }
463
464 #[test]
465 fn test_white_python_validation() {
466 let (y, x_vars) = mtcars_data();
467 let result = white_test(&y, &x_vars, WhiteMethod::Python).unwrap();
468
469 if let Some(p) = result.python_result {
470 println!("\n=== White Test Python Method Validation ===");
473 println!("Reference: LM-statistic = 32.0, p-value = 0.41674");
474 println!("Rust: LM-statistic = {}, p-value = {}", p.statistic, p.p_value);
475
476 let stat_diff = (p.statistic - 32.0).abs();
478 let pval_diff = (p.p_value - 0.41674).abs();
479 println!("Differences: stat={:.2}, pval={:.2}", stat_diff, pval_diff);
480
481 assert!(stat_diff < 10.0);
482 assert!(pval_diff < 0.3);
483 assert!(p.passed);
484 }
485 }
486
487 #[test]
488 fn test_r_white_method_direct() {
489 let (y, x_vars) = test_data();
490 let result = r_white_method(&y, &x_vars);
491 assert!(result.is_ok());
492 let output = result.unwrap();
493 assert_eq!(output.method, "R (skedastic::white)");
494 assert!(output.passed);
495 }
496
497 #[test]
498 fn test_python_white_method_direct() {
499 let (y, x_vars) = test_data();
500 let result = python_white_method(&y, &x_vars);
501 assert!(result.is_ok());
502 let output = result.unwrap();
503 assert_eq!(output.method, "Python (statsmodels)");
504 assert!(output.passed);
505 }
506}