1use super::helpers::chi_squared_p_value;
39use super::types::{WhiteMethod, WhiteSingleResult, WhiteTestOutput};
40use crate::error::{Error, Result};
41use crate::linalg::{fit_and_predict_linpack, fit_ols_linpack, vec_mean, Matrix};
42
43pub fn white_test(y: &[f64], x_vars: &[Vec<f64>], method: WhiteMethod) -> Result<WhiteTestOutput> {
79 let n = y.len();
80 let k = x_vars.len();
81 let p = k + 1;
82
83 if n <= p {
84 return Err(Error::InsufficientData {
85 required: p + 1,
86 available: n,
87 });
88 }
89
90 super::helpers::validate_regression_data(y, x_vars)?;
92
93 let alpha = 0.05;
94
95 let mut x_data = vec![1.0; n * p];
97 for row in 0..n {
98 for (col, x_var) in x_vars.iter().enumerate() {
99 x_data[row * p + col + 1] = x_var[row];
100 }
101 }
102 let x_full = Matrix::new(n, p, x_data);
103 let beta = fit_ols_linpack(y, &x_full).ok_or(Error::SingularMatrix)?;
104 let predictions = x_full.mul_vec(&beta);
105 let residuals: Vec<f64> = y
106 .iter()
107 .zip(predictions.iter())
108 .map(|(&yi, &yi_hat)| yi - yi_hat)
109 .collect();
110 let e_squared: Vec<f64> = residuals.iter().map(|&e| e * e).collect();
111
112 let (z_data, z_cols) = build_auxiliary_matrix(n, x_vars, method);
114
115 let z_matrix = Matrix::new(n, z_cols, z_data);
118
119 #[cfg(test)]
120 {
121 eprintln!("Z matrix: {} rows x {} cols", n, z_cols);
122 let qr_result = z_matrix.qr_linpack(None);
123 eprintln!("Z rank: {}", qr_result.rank);
124 eprintln!("Pivot order: {:?}", qr_result.pivot);
125
126 for j in qr_result.rank..z_cols {
128 let dropped_col = qr_result.pivot[j] - 1;
129 eprintln!("Dropped column {} (pivot position {})", dropped_col, j);
130 }
131
132 let beta = fit_ols_linpack(&e_squared, &z_matrix);
134 if let Some(ref b) = beta {
135 eprintln!("First 10 coefficients: {:?}", &b[..10.min(b.len())]);
136 eprintln!("Last 5 coefficients: {:?}", &b[b.len().saturating_sub(5)..]);
137 }
138 }
139
140 let pred_aux = fit_and_predict_linpack(&e_squared, &z_matrix).ok_or(Error::SingularMatrix)?;
141
142 #[cfg(test)]
143 {
144 eprintln!(
145 "First few pred_aux: {:?}",
146 &pred_aux[..5.min(pred_aux.len())]
147 );
148 let has_nan = pred_aux.iter().any(|&x| x.is_nan());
149 eprintln!("pred_aux has NaN: {}", has_nan);
150 }
151
152 let (_r_squared_aux, lm_stat) = compute_r2_and_lm(&e_squared, &pred_aux, n);
154
155 let r_result = if method == WhiteMethod::R || method == WhiteMethod::Both {
157 let df_r = (2 * k) as f64;
158 let p_value_r = chi_squared_p_value(lm_stat, df_r);
159 let passed_r = p_value_r > alpha;
160 Some(WhiteSingleResult {
161 method: "R (skedastic::white)".to_string(),
162 statistic: lm_stat,
163 p_value: p_value_r,
164 passed: passed_r,
165 })
166 } else {
167 None
168 };
169
170 let python_result = if method == WhiteMethod::Python || method == WhiteMethod::Both {
171 let theoretical_df = (k * (k + 3) / 2) as f64;
172 let df_p = theoretical_df.min((n - 1) as f64);
173 let p_value_p = chi_squared_p_value(lm_stat, df_p);
174 let passed_p = p_value_p > alpha;
175 Some(WhiteSingleResult {
176 method: "Python (statsmodels)".to_string(),
177 statistic: lm_stat,
178 p_value: p_value_p,
179 passed: passed_p,
180 })
181 } else {
182 None
183 };
184
185 let (interp_text, guid_text) = match (&r_result, &python_result) {
187 (Some(r), None) => interpret_result(r.p_value, alpha),
188 (None, Some(p)) => interpret_result(p.p_value, alpha),
189 (Some(r), Some(p)) => {
190 if r.p_value >= p.p_value {
191 interpret_result(r.p_value, alpha)
192 } else {
193 interpret_result(p.p_value, alpha)
194 }
195 },
196 (None, None) => unreachable!(),
197 };
198
199 Ok(WhiteTestOutput {
200 test_name: "White Test for Heteroscedasticity".to_string(),
201 r_result,
202 python_result,
203 interpretation: interp_text,
204 guidance: guid_text.to_string(),
205 })
206}
207
208fn compute_r2_and_lm(e_squared: &[f64], pred_aux: &[f64], n: usize) -> (f64, f64) {
210 let residuals_aux: Vec<f64> = e_squared
211 .iter()
212 .zip(pred_aux.iter())
213 .map(|(&yi, &yi_hat)| yi - yi_hat)
214 .collect();
215
216 let rss_aux: f64 = residuals_aux.iter().map(|&r| r * r).sum();
217
218 let mean_e_squared = vec_mean(e_squared);
219 let tss_centered: f64 = e_squared
220 .iter()
221 .map(|&e| {
222 let diff = e - mean_e_squared;
223 diff * diff
224 })
225 .sum();
226
227 let r_squared_aux = if tss_centered > 1e-10 {
228 (1.0 - (rss_aux / tss_centered)).clamp(0.0, 1.0)
229 } else {
230 0.0
231 };
232
233 let lm_stat = (n as f64) * r_squared_aux;
234 (r_squared_aux, lm_stat)
235}
236
237fn build_auxiliary_matrix(n: usize, x_vars: &[Vec<f64>], method: WhiteMethod) -> (Vec<f64>, usize) {
239 let k = x_vars.len();
240
241 match method {
242 WhiteMethod::R => {
243 let z_cols = 1 + 2 * k;
244 let mut z_data = vec![0.0; n * z_cols];
245
246 for row in 0..n {
247 let mut col_idx = 0;
248 z_data[row * z_cols + col_idx] = 1.0;
249 col_idx += 1;
250
251 for x_var in x_vars.iter() {
252 z_data[row * z_cols + col_idx] = x_var[row];
253 col_idx += 1;
254 }
255
256 for x_var in x_vars.iter() {
257 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
258 col_idx += 1;
259 }
260 }
261
262 (z_data, z_cols)
263 },
264 WhiteMethod::Python => {
265 let num_cross = k * (k - 1) / 2;
266 let z_cols = 1 + 2 * k + num_cross;
267 let mut z_data = vec![0.0; n * z_cols];
268
269 for row in 0..n {
270 let mut col_idx = 0;
271
272 z_data[row * z_cols + col_idx] = 1.0;
273 col_idx += 1;
274
275 for x_var in x_vars.iter() {
276 z_data[row * z_cols + col_idx] = x_var[row];
277 col_idx += 1;
278 }
279
280 for x_var in x_vars.iter() {
281 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
282 col_idx += 1;
283 }
284
285 for i in 0..k {
286 for j in (i + 1)..k {
287 z_data[row * z_cols + col_idx] = x_vars[i][row] * x_vars[j][row];
288 col_idx += 1;
289 }
290 }
291 }
292
293 (z_data, z_cols)
294 },
295 WhiteMethod::Both => build_auxiliary_matrix(n, x_vars, WhiteMethod::Python),
296 }
297}
298
299fn interpret_result(p_value: f64, alpha: f64) -> (String, &'static str) {
301 if p_value > alpha {
302 (
303 format!(
304 "p-value = {:.4} is greater than {:.2}. Cannot reject H0. No significant evidence of heteroscedasticity.",
305 p_value, alpha
306 ),
307 "The assumption of homoscedasticity (constant variance) appears to be met."
308 )
309 } else {
310 (
311 format!(
312 "p-value = {:.4} is less than or equal to {:.2}. Reject H0. Significant evidence of heteroscedasticity detected.",
313 p_value, alpha
314 ),
315 "Consider transforming the dependent variable (e.g., log transformation), using weighted least squares, or robust standard errors."
316 )
317 }
318}
319
320pub fn r_white_method(y: &[f64], x_vars: &[Vec<f64>]) -> Result<WhiteSingleResult> {
350 let result = white_test(y, x_vars, WhiteMethod::R)?;
351 result.r_result.ok_or(Error::SingularMatrix)
352}
353
354pub fn python_white_method(y: &[f64], x_vars: &[Vec<f64>]) -> Result<WhiteSingleResult> {
384 let result = white_test(y, x_vars, WhiteMethod::Python)?;
385 result.python_result.ok_or(Error::SingularMatrix)
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 fn test_data() -> (Vec<f64>, Vec<Vec<f64>>) {
393 let y = vec![
394 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8, 16.4, 17.3, 15.2,
395 10.4, 10.4, 14.7, 32.4, 30.4, 33.9, 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4,
396 15.8, 19.7, 15.0, 21.4,
397 ];
398 let x1 = vec![
399 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44, 3.44, 4.07, 3.73, 3.78,
400 5.25, 5.424, 5.345, 2.2, 1.615, 1.835, 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14,
401 1.513, 3.17, 2.77, 3.57, 2.78,
402 ];
403 let x2 = vec![
404 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0, 123.0, 180.0, 180.0,
405 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0, 97.0, 150.0, 150.0, 245.0, 175.0, 66.0,
406 91.0, 113.0, 264.0, 175.0, 335.0, 109.0,
407 ];
408 (y, vec![x1, x2])
409 }
410
411 #[test]
412 fn test_white_test_r_method() {
413 let (y, x_vars) = test_data();
414 let result = white_test(&y, &x_vars, WhiteMethod::R);
415 assert!(result.is_ok());
416 let output = result.unwrap();
417 assert!(output.r_result.is_some());
418 assert!(output.python_result.is_none());
419 }
420
421 #[test]
422 fn test_white_test_python_method() {
423 let (y, x_vars) = test_data();
424 let result = white_test(&y, &x_vars, WhiteMethod::Python);
425 assert!(result.is_ok());
426 let output = result.unwrap();
427 assert!(output.r_result.is_none());
428 assert!(output.python_result.is_some());
429 }
430
431 #[test]
432 fn test_white_test_both_methods() {
433 let (y, x_vars) = test_data();
434 let result = white_test(&y, &x_vars, WhiteMethod::Both);
435 assert!(result.is_ok());
436 let output = result.unwrap();
437 assert!(output.r_result.is_some());
438 assert!(output.python_result.is_some());
439 }
440
441 #[test]
442 fn test_white_test_insufficient_data() {
443 let y = vec![1.0, 2.0];
444 let x1 = vec![1.0, 2.0];
445 let x2 = vec![2.0, 3.0];
446 let result = white_test(&y, &[x1, x2], WhiteMethod::R);
447 assert!(result.is_err());
448 }
449
450 fn mtcars_data() -> (Vec<f64>, Vec<Vec<f64>>) {
451 let y = vec![
452 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8, 16.4, 17.3, 15.2,
453 10.4, 10.4, 14.7, 32.4, 30.4, 33.9, 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4,
454 15.8, 19.7, 15.0, 21.4,
455 ];
456
457 let cyl = vec![
458 6.0, 6.0, 4.0, 6.0, 8.0, 6.0, 8.0, 4.0, 4.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
459 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 4.0,
460 ];
461
462 let disp = vec![
463 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 167.6, 167.6, 275.8,
464 275.8, 275.8, 472.0, 460.0, 440.0, 78.7, 75.7, 71.1, 120.1, 318.0, 304.0, 350.0, 400.0,
465 79.0, 120.3, 95.1, 351.0, 145.0, 301.0, 121.0,
466 ];
467
468 let hp = vec![
469 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0, 123.0, 180.0, 180.0,
470 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0, 97.0, 150.0, 150.0, 245.0, 175.0, 66.0,
471 91.0, 113.0, 264.0, 175.0, 335.0, 109.0,
472 ];
473
474 let drat = vec![
475 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92, 3.92, 3.07, 3.07, 3.07,
476 2.93, 3.00, 3.23, 4.08, 4.93, 4.22, 3.70, 2.76, 3.15, 3.73, 3.08, 4.08, 4.43, 3.77,
477 4.22, 3.62, 3.54, 4.11,
478 ];
479
480 let wt = vec![
481 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44, 3.44, 4.07, 3.73, 3.78,
482 5.25, 5.424, 5.345, 2.2, 1.615, 1.835, 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14,
483 1.513, 3.17, 2.77, 3.57, 2.78,
484 ];
485
486 let qsec = vec![
487 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18.30, 18.90, 17.40,
488 17.60, 18.00, 17.98, 17.82, 17.42, 19.47, 18.52, 19.90, 20.01, 16.87, 17.30, 15.41,
489 17.05, 18.90, 16.70, 16.90, 14.50, 15.50, 14.60, 18.60,
490 ];
491
492 let vs = vec![
493 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
494 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0,
495 ];
496
497 let am = vec![
498 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
499 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
500 ];
501
502 let gear = vec![
503 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
504 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0,
505 ];
506
507 let carb = vec![
508 4.0, 4.0, 1.0, 1.0, 2.0, 1.0, 4.0, 2.0, 2.0, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0,
509 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 4.0, 2.0, 1.0, 2.0, 2.0, 4.0, 6.0, 8.0, 2.0,
510 ];
511
512 (y, vec![cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb])
513 }
514
515 #[test]
516 fn test_white_r_validation() {
517 let (y, x_vars) = mtcars_data();
518 let result = white_test(&y, &x_vars, WhiteMethod::R).unwrap();
519
520 if let Some(r) = result.r_result {
521 println!("\n=== White Test R Method Validation ===");
524 println!("Reference: LM-statistic = 19.3975, p-value = 0.49614");
525 println!(
526 "Rust: LM-statistic = {}, p-value = {}",
527 r.statistic, r.p_value
528 );
529
530 assert!(r.p_value > 0.05);
532 assert!(r.passed);
533 }
534 }
535
536 #[test]
537 fn test_white_python_validation() {
538 let (y, x_vars) = mtcars_data();
539 let result = white_test(&y, &x_vars, WhiteMethod::Python).unwrap();
540
541 if let Some(p) = result.python_result {
542 println!("\n=== White Test Python Method Validation ===");
545 println!("Reference: LM-statistic = 32.0, p-value = 0.41674");
546 println!(
547 "Rust: LM-statistic = {}, p-value = {}",
548 p.statistic, p.p_value
549 );
550
551 let stat_diff = (p.statistic - 32.0).abs();
553 let pval_diff = (p.p_value - 0.41674).abs();
554 println!("Differences: stat={:.2}, pval={:.2}", stat_diff, pval_diff);
555
556 assert!(stat_diff < 10.0);
557 assert!(pval_diff < 0.3);
558 assert!(p.passed);
559 }
560 }
561
562 #[test]
563 fn test_r_white_method_direct() {
564 let (y, x_vars) = test_data();
565 let result = r_white_method(&y, &x_vars);
566 assert!(result.is_ok());
567 let output = result.unwrap();
568 assert_eq!(output.method, "R (skedastic::white)");
569 assert!(output.passed);
570 }
571
572 #[test]
573 fn test_python_white_method_direct() {
574 let (y, x_vars) = test_data();
575 let result = python_white_method(&y, &x_vars);
576 assert!(result.is_ok());
577 let output = result.unwrap();
578 assert_eq!(output.method, "Python (statsmodels)");
579 assert!(output.passed);
580 }
581}