1use crate::error::{Error, Result};
39use crate::linalg::{Matrix, vec_mean, fit_ols_linpack, fit_and_predict_linpack};
40use super::types::{WhiteTestOutput, WhiteSingleResult, WhiteMethod};
41use super::helpers::chi_squared_p_value;
42
43pub fn white_test(
45 y: &[f64],
46 x_vars: &[Vec<f64>],
47 method: WhiteMethod,
48) -> Result<WhiteTestOutput> {
49 let n = y.len();
50 let k = x_vars.len();
51 let p = k + 1;
52
53 if n <= p {
54 return Err(Error::InsufficientData { required: p + 1, available: n });
55 }
56
57 super::helpers::validate_regression_data(y, x_vars)?;
59
60 let alpha = 0.05;
61
62 let mut x_data = vec![1.0; n * p];
64 for row in 0..n {
65 for (col, x_var) in x_vars.iter().enumerate() {
66 x_data[row * p + col + 1] = x_var[row];
67 }
68 }
69 let x_full = Matrix::new(n, p, x_data);
70 let beta = fit_ols_linpack(y, &x_full).ok_or(Error::SingularMatrix)?;
71 let predictions = x_full.mul_vec(&beta);
72 let residuals: Vec<f64> = y.iter().zip(predictions.iter())
73 .map(|(&yi, &yi_hat)| yi - yi_hat)
74 .collect();
75 let e_squared: Vec<f64> = residuals.iter().map(|&e| e * e).collect();
76
77 let (z_data, z_cols) = build_auxiliary_matrix(n, x_vars, method);
79
80 let z_matrix = Matrix::new(n, z_cols, z_data);
83
84 #[cfg(test)]
85 {
86 eprintln!("Z matrix: {} rows x {} cols", n, z_cols);
87 let qr_result = z_matrix.qr_linpack(None);
88 eprintln!("Z rank: {}", qr_result.rank);
89 eprintln!("Pivot order: {:?}", qr_result.pivot);
90
91 for j in qr_result.rank..z_cols {
93 let dropped_col = qr_result.pivot[j] - 1;
94 eprintln!("Dropped column {} (pivot position {})", dropped_col, j);
95 }
96
97 let beta = fit_ols_linpack(&e_squared, &z_matrix);
99 if let Some(ref b) = beta {
100 eprintln!("First 10 coefficients: {:?}", &b[..10.min(b.len())]);
101 eprintln!("Last 5 coefficients: {:?}", &b[b.len().saturating_sub(5)..]);
102 }
103 }
104
105 let pred_aux = fit_and_predict_linpack(&e_squared, &z_matrix).ok_or(Error::SingularMatrix)?;
106
107 #[cfg(test)]
108 {
109 eprintln!("First few pred_aux: {:?}", &pred_aux[..5.min(pred_aux.len())]);
110 let has_nan = pred_aux.iter().any(|&x| x.is_nan());
111 eprintln!("pred_aux has NaN: {}", has_nan);
112 }
113
114 let (_r_squared_aux, lm_stat) = compute_r2_and_lm(&e_squared, &pred_aux, n);
116
117 let r_result = if method == WhiteMethod::R || method == WhiteMethod::Both {
119 let df_r = (2 * k) as f64;
120 let p_value_r = chi_squared_p_value(lm_stat, df_r);
121 let passed_r = p_value_r > alpha;
122 Some(WhiteSingleResult {
123 method: "R (skedastic::white)".to_string(),
124 statistic: lm_stat,
125 p_value: p_value_r,
126 passed: passed_r,
127 })
128 } else {
129 None
130 };
131
132 let python_result = if method == WhiteMethod::Python || method == WhiteMethod::Both {
133 let theoretical_df = (k * (k + 3) / 2) as f64;
134 let df_p = theoretical_df.min((n - 1) as f64);
135 let p_value_p = chi_squared_p_value(lm_stat, df_p);
136 let passed_p = p_value_p > alpha;
137 Some(WhiteSingleResult {
138 method: "Python (statsmodels)".to_string(),
139 statistic: lm_stat,
140 p_value: p_value_p,
141 passed: passed_p,
142 })
143 } else {
144 None
145 };
146
147 let (interp_text, guid_text) = match (&r_result, &python_result) {
149 (Some(r), None) => interpret_result(r.p_value, alpha),
150 (None, Some(p)) => interpret_result(p.p_value, alpha),
151 (Some(r), Some(p)) => {
152 if r.p_value >= p.p_value {
153 interpret_result(r.p_value, alpha)
154 } else {
155 interpret_result(p.p_value, alpha)
156 }
157 }
158 (None, None) => unreachable!(),
159 };
160
161 Ok(WhiteTestOutput {
162 test_name: "White Test for Heteroscedasticity".to_string(),
163 r_result,
164 python_result,
165 interpretation: interp_text,
166 guidance: guid_text.to_string(),
167 })
168}
169
170fn compute_r2_and_lm(e_squared: &[f64], pred_aux: &[f64], n: usize) -> (f64, f64) {
172 let residuals_aux: Vec<f64> = e_squared.iter().zip(pred_aux.iter())
173 .map(|(&yi, &yi_hat)| yi - yi_hat)
174 .collect();
175
176 let rss_aux: f64 = residuals_aux.iter().map(|&r| r * r).sum();
177
178 let mean_e_squared = vec_mean(e_squared);
179 let tss_centered: f64 = e_squared.iter()
180 .map(|&e| {
181 let diff = e - mean_e_squared;
182 diff * diff
183 })
184 .sum();
185
186 let r_squared_aux = if tss_centered > 1e-10 {
187 (1.0 - (rss_aux / tss_centered)).clamp(0.0, 1.0)
188 } else {
189 0.0
190 };
191
192 let lm_stat = (n as f64) * r_squared_aux;
193 (r_squared_aux, lm_stat)
194}
195
196fn build_auxiliary_matrix(
198 n: usize,
199 x_vars: &[Vec<f64>],
200 method: WhiteMethod,
201) -> (Vec<f64>, usize) {
202 let k = x_vars.len();
203
204 match method {
205 WhiteMethod::R => {
206 let z_cols = 1 + 2 * k;
207 let mut z_data = vec![0.0; n * z_cols];
208
209 for row in 0..n {
210 let mut col_idx = 0;
211 z_data[row * z_cols + col_idx] = 1.0;
212 col_idx += 1;
213
214 for x_var in x_vars.iter() {
215 z_data[row * z_cols + col_idx] = x_var[row];
216 col_idx += 1;
217 }
218
219 for x_var in x_vars.iter() {
220 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
221 col_idx += 1;
222 }
223 }
224
225 (z_data, z_cols)
226 }
227 WhiteMethod::Python => {
228 let num_cross = k * (k - 1) / 2;
229 let z_cols = 1 + 2 * k + num_cross;
230 let mut z_data = vec![0.0; n * z_cols];
231
232 for row in 0..n {
233 let mut col_idx = 0;
234
235 z_data[row * z_cols + col_idx] = 1.0;
236 col_idx += 1;
237
238 for x_var in x_vars.iter() {
239 z_data[row * z_cols + col_idx] = x_var[row];
240 col_idx += 1;
241 }
242
243 for x_var in x_vars.iter() {
244 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
245 col_idx += 1;
246 }
247
248 for i in 0..k {
249 for j in (i + 1)..k {
250 z_data[row * z_cols + col_idx] = x_vars[i][row] * x_vars[j][row];
251 col_idx += 1;
252 }
253 }
254 }
255
256 (z_data, z_cols)
257 }
258 WhiteMethod::Both => {
259 build_auxiliary_matrix(n, x_vars, WhiteMethod::Python)
260 }
261 }
262}
263
264fn interpret_result(p_value: f64, alpha: f64) -> (String, &'static str) {
266 if p_value > alpha {
267 (
268 format!(
269 "p-value = {:.4} is greater than {:.2}. Cannot reject H0. No significant evidence of heteroscedasticity.",
270 p_value, alpha
271 ),
272 "The assumption of homoscedasticity (constant variance) appears to be met."
273 )
274 } else {
275 (
276 format!(
277 "p-value = {:.4} is less than or equal to {:.2}. Reject H0. Significant evidence of heteroscedasticity detected.",
278 p_value, alpha
279 ),
280 "Consider transforming the dependent variable (e.g., log transformation), using weighted least squares, or robust standard errors."
281 )
282 }
283}
284
285pub fn r_white_method(
287 y: &[f64],
288 x_vars: &[Vec<f64>],
289) -> Result<WhiteSingleResult> {
290 let result = white_test(y, x_vars, WhiteMethod::R)?;
291 result.r_result.ok_or(Error::SingularMatrix)
292}
293
294pub fn python_white_method(
296 y: &[f64],
297 x_vars: &[Vec<f64>],
298) -> Result<WhiteSingleResult> {
299 let result = white_test(y, x_vars, WhiteMethod::Python)?;
300 result.python_result.ok_or(Error::SingularMatrix)
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 fn test_data() -> (Vec<f64>, Vec<Vec<f64>>) {
308 let y = vec![
309 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2,
310 17.8, 16.4, 17.3, 15.2, 10.4, 10.4, 14.7, 32.4, 30.4, 33.9,
311 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4, 15.8, 19.7,
312 15.0, 21.4
313 ];
314 let x1 = vec![
315 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44,
316 3.44, 4.07, 3.73, 3.78, 5.25, 5.424, 5.345, 2.2, 1.615, 1.835,
317 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14, 1.513, 3.17, 2.77,
318 3.57, 2.78
319 ];
320 let x2 = vec![
321 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0,
322 123.0, 180.0, 180.0, 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0,
323 97.0, 150.0, 150.0, 245.0, 175.0, 66.0, 91.0, 113.0, 264.0, 175.0,
324 335.0, 109.0
325 ];
326 (y, vec![x1, x2])
327 }
328
329 #[test]
330 fn test_white_test_r_method() {
331 let (y, x_vars) = test_data();
332 let result = white_test(&y, &x_vars, WhiteMethod::R);
333 assert!(result.is_ok());
334 let output = result.unwrap();
335 assert!(output.r_result.is_some());
336 assert!(output.python_result.is_none());
337 }
338
339 #[test]
340 fn test_white_test_python_method() {
341 let (y, x_vars) = test_data();
342 let result = white_test(&y, &x_vars, WhiteMethod::Python);
343 assert!(result.is_ok());
344 let output = result.unwrap();
345 assert!(output.r_result.is_none());
346 assert!(output.python_result.is_some());
347 }
348
349 #[test]
350 fn test_white_test_both_methods() {
351 let (y, x_vars) = test_data();
352 let result = white_test(&y, &x_vars, WhiteMethod::Both);
353 assert!(result.is_ok());
354 let output = result.unwrap();
355 assert!(output.r_result.is_some());
356 assert!(output.python_result.is_some());
357 }
358
359 #[test]
360 fn test_white_test_insufficient_data() {
361 let y = vec![1.0, 2.0];
362 let x1 = vec![1.0, 2.0];
363 let x2 = vec![2.0, 3.0];
364 let result = white_test(&y, &[x1, x2], WhiteMethod::R);
365 assert!(result.is_err());
366 }
367
368 fn mtcars_data() -> (Vec<f64>, Vec<Vec<f64>>) {
369 let y = vec![
370 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2,
371 17.8, 16.4, 17.3, 15.2, 10.4, 10.4, 14.7, 32.4, 30.4, 33.9,
372 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4, 15.8, 19.7,
373 15.0, 21.4
374 ];
375
376 let cyl = vec![
377 6.0, 6.0, 4.0, 6.0, 8.0, 6.0, 8.0, 4.0, 4.0, 6.0,
378 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0,
379 4.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 8.0, 8.0,
380 8.0, 4.0
381 ];
382
383 let disp = vec![
384 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 167.6,
385 167.6, 275.8, 275.8, 275.8, 472.0, 460.0, 440.0, 78.7, 75.7, 71.1,
386 120.1, 318.0, 304.0, 350.0, 400.0, 79.0, 120.3, 95.1, 351.0, 145.0,
387 301.0, 121.0
388 ];
389
390 let hp = vec![
391 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0,
392 123.0, 180.0, 180.0, 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0,
393 97.0, 150.0, 150.0, 245.0, 175.0, 66.0, 91.0, 113.0, 264.0, 175.0,
394 335.0, 109.0
395 ];
396
397 let drat = vec![
398 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92,
399 3.92, 3.07, 3.07, 3.07, 2.93, 3.00, 3.23, 4.08, 4.93, 4.22,
400 3.70, 2.76, 3.15, 3.73, 3.08, 4.08, 4.43, 3.77, 4.22, 3.62,
401 3.54, 4.11
402 ];
403
404 let wt = vec![
405 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44,
406 3.44, 4.07, 3.73, 3.78, 5.25, 5.424, 5.345, 2.2, 1.615, 1.835,
407 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14, 1.513, 3.17, 2.77,
408 3.57, 2.78
409 ];
410
411 let qsec = vec![
412 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18.30,
413 18.90, 17.40, 17.60, 18.00, 17.98, 17.82, 17.42, 19.47, 18.52, 19.90,
414 20.01, 16.87, 17.30, 15.41, 17.05, 18.90, 16.70, 16.90, 14.50, 15.50,
415 14.60, 18.60
416 ];
417
418 let vs = vec![
419 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0,
420 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
421 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0,
422 0.0, 1.0
423 ];
424
425 let am = vec![
426 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
427 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
428 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0,
429 1.0, 1.0
430 ];
431
432 let gear = vec![
433 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0,
434 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0,
435 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0,
436 5.0, 4.0
437 ];
438
439 let carb = vec![
440 4.0, 4.0, 1.0, 1.0, 2.0, 1.0, 4.0, 2.0, 2.0, 4.0,
441 4.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 1.0, 2.0, 1.0,
442 1.0, 2.0, 2.0, 4.0, 2.0, 1.0, 2.0, 2.0, 4.0, 6.0,
443 8.0, 2.0
444 ];
445
446 (y, vec![cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb])
447 }
448
449 #[test]
450 fn test_white_r_validation() {
451 let (y, x_vars) = mtcars_data();
452 let result = white_test(&y, &x_vars, WhiteMethod::R).unwrap();
453
454 if let Some(r) = result.r_result {
455 println!("\n=== White Test R Method Validation ===");
458 println!("Reference: LM-statistic = 19.3975, p-value = 0.49614");
459 println!("Rust: LM-statistic = {}, p-value = {}", r.statistic, r.p_value);
460
461 assert!(r.p_value > 0.05);
463 assert!(r.passed);
464 }
465 }
466
467 #[test]
468 fn test_white_python_validation() {
469 let (y, x_vars) = mtcars_data();
470 let result = white_test(&y, &x_vars, WhiteMethod::Python).unwrap();
471
472 if let Some(p) = result.python_result {
473 println!("\n=== White Test Python Method Validation ===");
476 println!("Reference: LM-statistic = 32.0, p-value = 0.41674");
477 println!("Rust: LM-statistic = {}, p-value = {}", p.statistic, p.p_value);
478
479 let stat_diff = (p.statistic - 32.0).abs();
481 let pval_diff = (p.p_value - 0.41674).abs();
482 println!("Differences: stat={:.2}, pval={:.2}", stat_diff, pval_diff);
483
484 assert!(stat_diff < 10.0);
485 assert!(pval_diff < 0.3);
486 assert!(p.passed);
487 }
488 }
489
490 #[test]
491 fn test_r_white_method_direct() {
492 let (y, x_vars) = test_data();
493 let result = r_white_method(&y, &x_vars);
494 assert!(result.is_ok());
495 let output = result.unwrap();
496 assert_eq!(output.method, "R (skedastic::white)");
497 assert!(output.passed);
498 }
499
500 #[test]
501 fn test_python_white_method_direct() {
502 let (y, x_vars) = test_data();
503 let result = python_white_method(&y, &x_vars);
504 assert!(result.is_ok());
505 let output = result.unwrap();
506 assert_eq!(output.method, "Python (statsmodels)");
507 assert!(output.passed);
508 }
509}