1use super::helpers::chi_squared_p_value;
39use super::types::{WhiteMethod, WhiteSingleResult, WhiteTestOutput};
40use crate::error::{Error, Result};
41use crate::linalg::{fit_and_predict_linpack, fit_ols_linpack, vec_mean, Matrix};
42
43pub fn white_test(y: &[f64], x_vars: &[Vec<f64>], method: WhiteMethod) -> Result<WhiteTestOutput> {
45 let n = y.len();
46 let k = x_vars.len();
47 let p = k + 1;
48
49 if n <= p {
50 return Err(Error::InsufficientData {
51 required: p + 1,
52 available: n,
53 });
54 }
55
56 super::helpers::validate_regression_data(y, x_vars)?;
58
59 let alpha = 0.05;
60
61 let mut x_data = vec![1.0; n * p];
63 for row in 0..n {
64 for (col, x_var) in x_vars.iter().enumerate() {
65 x_data[row * p + col + 1] = x_var[row];
66 }
67 }
68 let x_full = Matrix::new(n, p, x_data);
69 let beta = fit_ols_linpack(y, &x_full).ok_or(Error::SingularMatrix)?;
70 let predictions = x_full.mul_vec(&beta);
71 let residuals: Vec<f64> = y
72 .iter()
73 .zip(predictions.iter())
74 .map(|(&yi, &yi_hat)| yi - yi_hat)
75 .collect();
76 let e_squared: Vec<f64> = residuals.iter().map(|&e| e * e).collect();
77
78 let (z_data, z_cols) = build_auxiliary_matrix(n, x_vars, method);
80
81 let z_matrix = Matrix::new(n, z_cols, z_data);
84
85 #[cfg(test)]
86 {
87 eprintln!("Z matrix: {} rows x {} cols", n, z_cols);
88 let qr_result = z_matrix.qr_linpack(None);
89 eprintln!("Z rank: {}", qr_result.rank);
90 eprintln!("Pivot order: {:?}", qr_result.pivot);
91
92 for j in qr_result.rank..z_cols {
94 let dropped_col = qr_result.pivot[j] - 1;
95 eprintln!("Dropped column {} (pivot position {})", dropped_col, j);
96 }
97
98 let beta = fit_ols_linpack(&e_squared, &z_matrix);
100 if let Some(ref b) = beta {
101 eprintln!("First 10 coefficients: {:?}", &b[..10.min(b.len())]);
102 eprintln!("Last 5 coefficients: {:?}", &b[b.len().saturating_sub(5)..]);
103 }
104 }
105
106 let pred_aux = fit_and_predict_linpack(&e_squared, &z_matrix).ok_or(Error::SingularMatrix)?;
107
108 #[cfg(test)]
109 {
110 eprintln!(
111 "First few pred_aux: {:?}",
112 &pred_aux[..5.min(pred_aux.len())]
113 );
114 let has_nan = pred_aux.iter().any(|&x| x.is_nan());
115 eprintln!("pred_aux has NaN: {}", has_nan);
116 }
117
118 let (_r_squared_aux, lm_stat) = compute_r2_and_lm(&e_squared, &pred_aux, n);
120
121 let r_result = if method == WhiteMethod::R || method == WhiteMethod::Both {
123 let df_r = (2 * k) as f64;
124 let p_value_r = chi_squared_p_value(lm_stat, df_r);
125 let passed_r = p_value_r > alpha;
126 Some(WhiteSingleResult {
127 method: "R (skedastic::white)".to_string(),
128 statistic: lm_stat,
129 p_value: p_value_r,
130 passed: passed_r,
131 })
132 } else {
133 None
134 };
135
136 let python_result = if method == WhiteMethod::Python || method == WhiteMethod::Both {
137 let theoretical_df = (k * (k + 3) / 2) as f64;
138 let df_p = theoretical_df.min((n - 1) as f64);
139 let p_value_p = chi_squared_p_value(lm_stat, df_p);
140 let passed_p = p_value_p > alpha;
141 Some(WhiteSingleResult {
142 method: "Python (statsmodels)".to_string(),
143 statistic: lm_stat,
144 p_value: p_value_p,
145 passed: passed_p,
146 })
147 } else {
148 None
149 };
150
151 let (interp_text, guid_text) = match (&r_result, &python_result) {
153 (Some(r), None) => interpret_result(r.p_value, alpha),
154 (None, Some(p)) => interpret_result(p.p_value, alpha),
155 (Some(r), Some(p)) => {
156 if r.p_value >= p.p_value {
157 interpret_result(r.p_value, alpha)
158 } else {
159 interpret_result(p.p_value, alpha)
160 }
161 },
162 (None, None) => unreachable!(),
163 };
164
165 Ok(WhiteTestOutput {
166 test_name: "White Test for Heteroscedasticity".to_string(),
167 r_result,
168 python_result,
169 interpretation: interp_text,
170 guidance: guid_text.to_string(),
171 })
172}
173
174fn compute_r2_and_lm(e_squared: &[f64], pred_aux: &[f64], n: usize) -> (f64, f64) {
176 let residuals_aux: Vec<f64> = e_squared
177 .iter()
178 .zip(pred_aux.iter())
179 .map(|(&yi, &yi_hat)| yi - yi_hat)
180 .collect();
181
182 let rss_aux: f64 = residuals_aux.iter().map(|&r| r * r).sum();
183
184 let mean_e_squared = vec_mean(e_squared);
185 let tss_centered: f64 = e_squared
186 .iter()
187 .map(|&e| {
188 let diff = e - mean_e_squared;
189 diff * diff
190 })
191 .sum();
192
193 let r_squared_aux = if tss_centered > 1e-10 {
194 (1.0 - (rss_aux / tss_centered)).clamp(0.0, 1.0)
195 } else {
196 0.0
197 };
198
199 let lm_stat = (n as f64) * r_squared_aux;
200 (r_squared_aux, lm_stat)
201}
202
203fn build_auxiliary_matrix(n: usize, x_vars: &[Vec<f64>], method: WhiteMethod) -> (Vec<f64>, usize) {
205 let k = x_vars.len();
206
207 match method {
208 WhiteMethod::R => {
209 let z_cols = 1 + 2 * k;
210 let mut z_data = vec![0.0; n * z_cols];
211
212 for row in 0..n {
213 let mut col_idx = 0;
214 z_data[row * z_cols + col_idx] = 1.0;
215 col_idx += 1;
216
217 for x_var in x_vars.iter() {
218 z_data[row * z_cols + col_idx] = x_var[row];
219 col_idx += 1;
220 }
221
222 for x_var in x_vars.iter() {
223 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
224 col_idx += 1;
225 }
226 }
227
228 (z_data, z_cols)
229 },
230 WhiteMethod::Python => {
231 let num_cross = k * (k - 1) / 2;
232 let z_cols = 1 + 2 * k + num_cross;
233 let mut z_data = vec![0.0; n * z_cols];
234
235 for row in 0..n {
236 let mut col_idx = 0;
237
238 z_data[row * z_cols + col_idx] = 1.0;
239 col_idx += 1;
240
241 for x_var in x_vars.iter() {
242 z_data[row * z_cols + col_idx] = x_var[row];
243 col_idx += 1;
244 }
245
246 for x_var in x_vars.iter() {
247 z_data[row * z_cols + col_idx] = x_var[row] * x_var[row];
248 col_idx += 1;
249 }
250
251 for i in 0..k {
252 for j in (i + 1)..k {
253 z_data[row * z_cols + col_idx] = x_vars[i][row] * x_vars[j][row];
254 col_idx += 1;
255 }
256 }
257 }
258
259 (z_data, z_cols)
260 },
261 WhiteMethod::Both => build_auxiliary_matrix(n, x_vars, WhiteMethod::Python),
262 }
263}
264
265fn interpret_result(p_value: f64, alpha: f64) -> (String, &'static str) {
267 if p_value > alpha {
268 (
269 format!(
270 "p-value = {:.4} is greater than {:.2}. Cannot reject H0. No significant evidence of heteroscedasticity.",
271 p_value, alpha
272 ),
273 "The assumption of homoscedasticity (constant variance) appears to be met."
274 )
275 } else {
276 (
277 format!(
278 "p-value = {:.4} is less than or equal to {:.2}. Reject H0. Significant evidence of heteroscedasticity detected.",
279 p_value, alpha
280 ),
281 "Consider transforming the dependent variable (e.g., log transformation), using weighted least squares, or robust standard errors."
282 )
283 }
284}
285
286pub fn r_white_method(y: &[f64], x_vars: &[Vec<f64>]) -> Result<WhiteSingleResult> {
288 let result = white_test(y, x_vars, WhiteMethod::R)?;
289 result.r_result.ok_or(Error::SingularMatrix)
290}
291
292pub fn python_white_method(y: &[f64], x_vars: &[Vec<f64>]) -> Result<WhiteSingleResult> {
294 let result = white_test(y, x_vars, WhiteMethod::Python)?;
295 result.python_result.ok_or(Error::SingularMatrix)
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 fn test_data() -> (Vec<f64>, Vec<Vec<f64>>) {
303 let y = vec![
304 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8, 16.4, 17.3, 15.2,
305 10.4, 10.4, 14.7, 32.4, 30.4, 33.9, 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4,
306 15.8, 19.7, 15.0, 21.4,
307 ];
308 let x1 = vec![
309 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44, 3.44, 4.07, 3.73, 3.78,
310 5.25, 5.424, 5.345, 2.2, 1.615, 1.835, 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14,
311 1.513, 3.17, 2.77, 3.57, 2.78,
312 ];
313 let x2 = vec![
314 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0, 123.0, 180.0, 180.0,
315 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0, 97.0, 150.0, 150.0, 245.0, 175.0, 66.0,
316 91.0, 113.0, 264.0, 175.0, 335.0, 109.0,
317 ];
318 (y, vec![x1, x2])
319 }
320
321 #[test]
322 fn test_white_test_r_method() {
323 let (y, x_vars) = test_data();
324 let result = white_test(&y, &x_vars, WhiteMethod::R);
325 assert!(result.is_ok());
326 let output = result.unwrap();
327 assert!(output.r_result.is_some());
328 assert!(output.python_result.is_none());
329 }
330
331 #[test]
332 fn test_white_test_python_method() {
333 let (y, x_vars) = test_data();
334 let result = white_test(&y, &x_vars, WhiteMethod::Python);
335 assert!(result.is_ok());
336 let output = result.unwrap();
337 assert!(output.r_result.is_none());
338 assert!(output.python_result.is_some());
339 }
340
341 #[test]
342 fn test_white_test_both_methods() {
343 let (y, x_vars) = test_data();
344 let result = white_test(&y, &x_vars, WhiteMethod::Both);
345 assert!(result.is_ok());
346 let output = result.unwrap();
347 assert!(output.r_result.is_some());
348 assert!(output.python_result.is_some());
349 }
350
351 #[test]
352 fn test_white_test_insufficient_data() {
353 let y = vec![1.0, 2.0];
354 let x1 = vec![1.0, 2.0];
355 let x2 = vec![2.0, 3.0];
356 let result = white_test(&y, &[x1, x2], WhiteMethod::R);
357 assert!(result.is_err());
358 }
359
360 fn mtcars_data() -> (Vec<f64>, Vec<Vec<f64>>) {
361 let y = vec![
362 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8, 16.4, 17.3, 15.2,
363 10.4, 10.4, 14.7, 32.4, 30.4, 33.9, 21.5, 15.5, 15.2, 13.3, 19.2, 27.3, 26.0, 30.4,
364 15.8, 19.7, 15.0, 21.4,
365 ];
366
367 let cyl = vec![
368 6.0, 6.0, 4.0, 6.0, 8.0, 6.0, 8.0, 4.0, 4.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
369 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 4.0,
370 ];
371
372 let disp = vec![
373 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 167.6, 167.6, 275.8,
374 275.8, 275.8, 472.0, 460.0, 440.0, 78.7, 75.7, 71.1, 120.1, 318.0, 304.0, 350.0, 400.0,
375 79.0, 120.3, 95.1, 351.0, 145.0, 301.0, 121.0,
376 ];
377
378 let hp = vec![
379 110.0, 110.0, 93.0, 110.0, 175.0, 105.0, 245.0, 62.0, 95.0, 123.0, 123.0, 180.0, 180.0,
380 180.0, 205.0, 215.0, 230.0, 66.0, 52.0, 65.0, 97.0, 150.0, 150.0, 245.0, 175.0, 66.0,
381 91.0, 113.0, 264.0, 175.0, 335.0, 109.0,
382 ];
383
384 let drat = vec![
385 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92, 3.92, 3.07, 3.07, 3.07,
386 2.93, 3.00, 3.23, 4.08, 4.93, 4.22, 3.70, 2.76, 3.15, 3.73, 3.08, 4.08, 4.43, 3.77,
387 4.22, 3.62, 3.54, 4.11,
388 ];
389
390 let wt = vec![
391 2.62, 2.875, 2.32, 3.215, 3.44, 3.46, 3.57, 3.19, 3.15, 3.44, 3.44, 4.07, 3.73, 3.78,
392 5.25, 5.424, 5.345, 2.2, 1.615, 1.835, 2.465, 3.52, 3.435, 3.84, 3.845, 1.935, 2.14,
393 1.513, 3.17, 2.77, 3.57, 2.78,
394 ];
395
396 let qsec = vec![
397 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18.30, 18.90, 17.40,
398 17.60, 18.00, 17.98, 17.82, 17.42, 19.47, 18.52, 19.90, 20.01, 16.87, 17.30, 15.41,
399 17.05, 18.90, 16.70, 16.90, 14.50, 15.50, 14.60, 18.60,
400 ];
401
402 let vs = vec![
403 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
404 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0,
405 ];
406
407 let am = vec![
408 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
409 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
410 ];
411
412 let gear = vec![
413 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
414 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0,
415 ];
416
417 let carb = vec![
418 4.0, 4.0, 1.0, 1.0, 2.0, 1.0, 4.0, 2.0, 2.0, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0,
419 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 4.0, 2.0, 1.0, 2.0, 2.0, 4.0, 6.0, 8.0, 2.0,
420 ];
421
422 (y, vec![cyl, disp, hp, drat, wt, qsec, vs, am, gear, carb])
423 }
424
425 #[test]
426 fn test_white_r_validation() {
427 let (y, x_vars) = mtcars_data();
428 let result = white_test(&y, &x_vars, WhiteMethod::R).unwrap();
429
430 if let Some(r) = result.r_result {
431 println!("\n=== White Test R Method Validation ===");
434 println!("Reference: LM-statistic = 19.3975, p-value = 0.49614");
435 println!(
436 "Rust: LM-statistic = {}, p-value = {}",
437 r.statistic, r.p_value
438 );
439
440 assert!(r.p_value > 0.05);
442 assert!(r.passed);
443 }
444 }
445
446 #[test]
447 fn test_white_python_validation() {
448 let (y, x_vars) = mtcars_data();
449 let result = white_test(&y, &x_vars, WhiteMethod::Python).unwrap();
450
451 if let Some(p) = result.python_result {
452 println!("\n=== White Test Python Method Validation ===");
455 println!("Reference: LM-statistic = 32.0, p-value = 0.41674");
456 println!(
457 "Rust: LM-statistic = {}, p-value = {}",
458 p.statistic, p.p_value
459 );
460
461 let stat_diff = (p.statistic - 32.0).abs();
463 let pval_diff = (p.p_value - 0.41674).abs();
464 println!("Differences: stat={:.2}, pval={:.2}", stat_diff, pval_diff);
465
466 assert!(stat_diff < 10.0);
467 assert!(pval_diff < 0.3);
468 assert!(p.passed);
469 }
470 }
471
472 #[test]
473 fn test_r_white_method_direct() {
474 let (y, x_vars) = test_data();
475 let result = r_white_method(&y, &x_vars);
476 assert!(result.is_ok());
477 let output = result.unwrap();
478 assert_eq!(output.method, "R (skedastic::white)");
479 assert!(output.passed);
480 }
481
482 #[test]
483 fn test_python_white_method_direct() {
484 let (y, x_vars) = test_data();
485 let result = python_white_method(&y, &x_vars);
486 assert!(result.is_ok());
487 let output = result.unwrap();
488 assert_eq!(output.method, "Python (statsmodels)");
489 assert!(output.passed);
490 }
491}