use super::types::{VifDetail, VifDiagnosticResult};
use crate::core::calculate_vif;
use crate::error::{Error, Result};
pub fn vif_test(y: &[f64], x_vars: &[Vec<f64>]) -> Result<VifDiagnosticResult> {
let n = y.len();
let k = x_vars.len();
if k < 2 {
return Err(Error::InsufficientData {
required: 2,
available: k,
});
}
for (i, var) in x_vars.iter().enumerate() {
if var.len() != n {
return Err(Error::InvalidInput(format!(
"x_vars[{}] has {} elements, expected {}",
i,
var.len(),
n
)));
}
}
for (i, var) in x_vars.iter().enumerate() {
for (j, &val) in var.iter().enumerate() {
if !val.is_finite() {
return Err(Error::InvalidInput(format!(
"x_vars[{}] contains non-finite value at index {}",
i, j
)));
}
}
}
let names: Vec<String> = (0..=k)
.map(|i| {
if i == 0 {
"Intercept".to_string()
} else {
format!("X{}", i)
}
})
.collect();
let vif_results = calculate_vif(x_vars, &names, n);
let details: Vec<VifDetail> = vif_results
.iter()
.map(|v| VifDetail {
variable: v.variable.clone(),
vif: v.vif,
rsquared: v.rsquared,
interpretation: v.interpretation.clone(),
})
.collect();
let max_vif = details
.iter()
.map(|d| d.vif)
.fold(0.0_f64, |acc, v| {
if v.is_infinite() && v > 0.0 {
f64::INFINITY
} else {
acc.max(v)
}
});
let high_vif_count = details.iter().filter(|d| d.vif > 10.0 || d.vif.is_infinite()).count();
let moderate_vif_count = details.iter().filter(|d| d.vif > 5.0 && d.vif <= 10.0).count();
let (interpretation, guidance) = if high_vif_count > 0 {
(
format!(
"Found {} variable(s) with VIF > 10 (severe multicollinearity). Maximum VIF = {:.2}.",
high_vif_count, max_vif
),
"Consider removing or combining highly correlated predictors. High multicollinearity makes coefficient estimates unstable and difficult to interpret.",
)
} else if moderate_vif_count > 0 {
(
format!(
"Found {} variable(s) with VIF > 5 (moderate multicollinearity). Maximum VIF = {:.2}.",
moderate_vif_count, max_vif
),
"Monitor these variables. Moderate multicollinearity may indicate redundant predictors. Consider dimensionality reduction if interpretation becomes problematic.",
)
} else {
(
format!(
"All VIF values are within acceptable range (VIF ≤ 5). Maximum VIF = {:.2}.",
max_vif
),
"No concerning multicollinearity detected. Coefficient estimates should be stable.",
)
};
Ok(VifDiagnosticResult {
test_name: "Variance Inflation Factor (VIF)".to_string(),
max_vif,
vif_results: details,
interpretation,
guidance: guidance.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vif_low_correlation() {
let y = vec![2.5, 3.7, 4.2, 5.1, 6.3, 7.0, 8.2, 9.1, 10.5, 11.2];
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let x2 = vec![2.0, 4.0, 5.0, 4.0, 3.0, 2.0, 4.0, 5.0, 6.0, 7.0];
let result = vif_test(&y, &[x1, x2]).unwrap();
assert_eq!(result.test_name, "Variance Inflation Factor (VIF)");
assert!(result.max_vif < 5.0, "Max VIF should be low for uncorrelated predictors");
assert_eq!(result.vif_results.len(), 2);
assert!(result.passed_interpretation());
}
#[test]
fn test_vif_high_correlation() {
let y = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let x2 = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
let result = vif_test(&y, &[x1, x2]).unwrap();
assert!(result.max_vif > 10.0, "VIF should be high for perfectly correlated predictors");
assert!(!result.passed_interpretation());
}
#[test]
fn test_vif_insufficient_predictors() {
let y = vec![2.0, 4.0, 6.0, 8.0];
let x1 = vec![1.0, 2.0, 3.0, 4.0];
let result = vif_test(&y, &[x1]);
assert!(result.is_err());
}
#[test]
fn test_vif_mismatched_dimensions() {
let y = vec![2.0, 4.0, 6.0, 8.0];
let x1 = vec![1.0, 2.0, 3.0, 4.0];
let x2 = vec![1.0, 2.0, 3.0];
let result = vif_test(&y, &[x1, x2]);
assert!(result.is_err());
}
#[test]
fn test_vif_detail_structure() {
let y = vec![2.5, 3.7, 4.2, 5.1, 6.3];
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x2 = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let result = vif_test(&y, &[x1, x2]).unwrap();
assert_eq!(result.vif_results.len(), 2);
for detail in &result.vif_results {
assert!(!detail.variable.is_empty());
assert!(detail.vif >= 1.0);
assert!(detail.rsquared >= 0.0 && detail.rsquared <= 1.0);
assert!(!detail.interpretation.is_empty());
}
}
}
impl VifDiagnosticResult {
#[allow(dead_code)]
fn passed_interpretation(&self) -> bool {
self.max_vif < 10.0 && self.guidance.contains("No concerning multicollinearity")
}
}