use super::helpers::{compute_rss, f_p_value, fit_ols};
use super::types::{RainbowMethod, RainbowSingleResult, RainbowTestOutput};
use crate::error::{Error, Result};
use crate::linalg::Matrix;
fn raintest_subset_r(n: usize, fraction: f64, center: f64) -> (usize, usize) {
let n_f = n as f64;
let p = center - fraction / 2.0;
let idx_float = p * (n_f - 1.0) + 1.0;
let from_1based = idx_float.ceil() as usize;
let to_1based = from_1based + (fraction * n_f).floor() as usize - 1;
(from_1based - 1, to_1based - 1)
}
fn raintest_subset_python(n: usize, fraction: f64) -> (usize, usize) {
let n_f = n as f64;
let lowidx = (0.5 * (1.0 - fraction) * n_f).ceil() as usize;
let uppidx_excl = (lowidx as f64 + fraction * n_f).floor() as usize;
(lowidx, uppidx_excl - 1)
}
fn rainbow_test_internal(
y: &[f64],
x_data: &[f64],
x_full: &Matrix,
beta_full: &[f64],
indices: &[usize],
subset_range: (usize, usize),
_method: RainbowMethod,
) -> Result<(f64, f64)> {
let n = y.len();
let p = x_full.cols;
let rss_full = compute_rss(y, x_full, beta_full)?;
let (start, end) = subset_range;
let subset_indices = &indices[start..=end];
if subset_indices.len() < p {
return Err(Error::InsufficientData {
required: p,
available: subset_indices.len(),
});
}
let mut y_subset = Vec::with_capacity(subset_indices.len());
let mut x_subset_data = Vec::with_capacity(subset_indices.len() * p);
for &idx in subset_indices {
y_subset.push(y[idx]);
for j in 0..p {
x_subset_data.push(x_data[idx * p + j]);
}
}
let x_subset = Matrix::new(subset_indices.len(), p, x_subset_data);
let beta_subset = fit_ols(&y_subset, &x_subset)?;
let predictions_subset = x_subset.mul_vec(&beta_subset);
let residuals_subset: Vec<f64> = y_subset
.iter()
.zip(predictions_subset.iter())
.map(|(&yi, &yi_hat)| yi - yi_hat)
.collect();
let rss_subset = residuals_subset.iter().map(|&r| r * r).sum::<f64>();
#[cfg(test)]
{
}
let subset_size = subset_indices.len() as f64;
let df1 = (n - subset_indices.len()) as f64;
let df2 = subset_size - (p as f64);
let numerator = (rss_full - rss_subset).max(0.0) / df1;
let denominator = rss_subset / df2;
let f_stat = if denominator > 1e-10 {
numerator / denominator
} else {
return Err(Error::InvalidInput(
"Invalid denominator in Rainbow test".to_string(),
));
};
let p_value = f_p_value(f_stat, df1, df2);
Ok((f_stat, p_value))
}
pub fn rainbow_test(
y: &[f64],
x_vars: &[Vec<f64>],
fraction: f64,
method: RainbowMethod,
) -> Result<RainbowTestOutput> {
let n = y.len();
let k = x_vars.len();
let p = k + 1;
if n <= p + 1 {
return Err(Error::InsufficientData {
required: p + 2,
available: n,
});
}
super::helpers::validate_regression_data(y, x_vars)?;
let fraction = if fraction <= 0.0 || fraction > 1.0 {
0.5
} else {
fraction
};
let center = 0.5;
let mut x_data = vec![0.0; n * p];
for (row, _yi) in y.iter().enumerate() {
x_data[row * p] = 1.0; for (col, x_var) in x_vars.iter().enumerate() {
x_data[row * p + col + 1] = x_var[row];
}
}
let x_full = Matrix::new(n, p, x_data.clone());
let beta_full = fit_ols(y, &x_full)?;
let fitted_full = x_full.mul_vec(&beta_full);
let mut sorted_indices: Vec<usize> = (0..n).collect();
sorted_indices.sort_by(|&a, &b| {
fitted_full[a]
.partial_cmp(&fitted_full[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let r_result = match method {
RainbowMethod::R | RainbowMethod::Both => {
let original_indices: Vec<usize> = (0..n).collect();
let range = raintest_subset_r(n, fraction, center);
match rainbow_test_internal(
y,
&x_data,
&x_full,
&beta_full,
&original_indices,
range,
RainbowMethod::R,
) {
Ok((f_stat, p_value)) => {
let alpha = 0.05;
Some(RainbowSingleResult {
method: "R (lmtest::raintest)".to_string(),
statistic: f_stat,
p_value,
passed: p_value > alpha,
})
},
Err(_) => None,
}
},
_ => None,
};
let python_result = match method {
RainbowMethod::Python | RainbowMethod::Both => {
let original_indices: Vec<usize> = (0..n).collect();
let range = raintest_subset_python(n, fraction);
match rainbow_test_internal(
y,
&x_data,
&x_full,
&beta_full,
&original_indices,
range,
RainbowMethod::Python,
) {
Ok((f_stat, p_value)) => {
let alpha = 0.05;
Some(RainbowSingleResult {
method: "Python (statsmodels)".to_string(),
statistic: f_stat,
p_value,
passed: p_value > alpha,
})
},
Err(_) => None,
}
},
_ => None,
};
let primary_result = r_result.as_ref().or(python_result.as_ref());
let (interpretation, guidance) = if let Some(result) = primary_result {
let alpha = 0.05;
if result.passed {
(
format!("p-value = {:.4} is greater than {:.2}. Cannot reject H0. No significant evidence of non-linearity.", result.p_value, alpha),
"The linear model appears appropriate. Consider other diagnostic tests.".to_string()
)
} else {
(
format!("p-value = {:.4} is less than or equal to {:.2}. Reject H0. Significant evidence of non-linearity detected.", result.p_value, alpha),
"Consider adding polynomial terms, transforming variables, or using non-linear modeling.".to_string()
)
}
} else {
(
"Unable to compute Rainbow test.".to_string(),
"Check your data and try again.".to_string(),
)
};
Ok(RainbowTestOutput {
test_name: "Rainbow Test for Linearity".to_string(),
r_result,
python_result,
interpretation,
guidance,
})
}