use crate::error::{StatsError, StatsResult};
use crate::regression::utils::*;
use crate::regression::RegressionResults;
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::Float;
use scirs2_linalg::lstsq;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum StepwiseDirection {
Forward,
Backward,
Both,
}
#[derive(Debug, Clone, Copy)]
pub enum StepwiseCriterion {
AIC,
BIC,
AdjR2,
F,
T,
}
pub struct StepwiseResults<F>
where
F: Float + std::fmt::Debug + std::fmt::Display + 'static,
{
pub final_model: RegressionResults<F>,
pub selected_indices: Vec<usize>,
pub sequence: Vec<(usize, bool)>,
pub criteria_values: Vec<F>,
}
impl<F> StepwiseResults<F>
where
F: Float + std::fmt::Debug + std::fmt::Display + 'static,
{
pub fn summary(&self) -> String {
let mut summary = String::new();
summary.push_str("=== Stepwise Regression Results ===\n\n");
summary.push_str("Selected variables: ");
for (i, &idx) in self.selected_indices.iter().enumerate() {
if i > 0 {
summary.push_str(", ");
}
summary.push_str(&format!("X{}", idx));
}
summary.push_str("\n\n");
summary.push_str("Sequence of variable entry/exit:\n");
for (i, &(idx, is_entry)) in self.sequence.iter().enumerate() {
summary.push_str(&format!(
"Step {}: {} X{} (criterion value: {})\n",
i + 1,
if is_entry { "Added" } else { "Removed" },
idx,
self.criteria_values[i]
));
}
summary.push('\n');
summary.push_str("Final Model:\n");
summary.push_str(&self.final_model.summary());
summary
}
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn stepwise_regression<F>(
x: &ArrayView2<F>,
y: &ArrayView1<F>,
direction: StepwiseDirection,
criterion: StepwiseCriterion,
p_enter: Option<F>,
p_remove: Option<F>,
max_steps: Option<usize>,
include_intercept: bool,
) -> StatsResult<StepwiseResults<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ std::fmt::Debug
+ std::fmt::Display
+ 'static
+ scirs2_core::numeric::NumAssign
+ scirs2_core::numeric::One
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
if x.nrows() != y.len() {
return Err(StatsError::DimensionMismatch(format!(
"Input x has {} rows but y has length {}",
x.nrows(),
y.len()
)));
}
let n = x.nrows();
let p = x.ncols();
if n < 3 {
return Err(StatsError::InvalidArgument(
"At least 3 observations required for stepwise regression".to_string(),
));
}
let p_enter =
p_enter.unwrap_or_else(|| F::from(0.05).expect("Failed to convert constant to float"));
let p_remove =
p_remove.unwrap_or_else(|| F::from(0.1).expect("Failed to convert constant to float"));
let max_steps = max_steps.unwrap_or(p * 2);
let mut selected_indices = match direction {
StepwiseDirection::Forward => HashSet::new(),
StepwiseDirection::Backward | StepwiseDirection::Both => {
let mut indices = HashSet::new();
for i in 0..p {
indices.insert(i);
}
indices
}
};
let mut sequence = Vec::new();
let mut criteria_values = Vec::new();
let mut current_x = match direction {
StepwiseDirection::Forward => {
if include_intercept {
Array2::<F>::ones((n, 1))
} else {
Array2::<F>::zeros((n, 0))
}
}
StepwiseDirection::Backward | StepwiseDirection::Both => {
if include_intercept {
let mut x_full = Array2::<F>::zeros((n, p + 1));
x_full.slice_mut(s![.., 0]).fill(F::one());
for i in 0..p {
x_full.slice_mut(s![.., i + 1]).assign(&x.slice(s![.., i]));
}
x_full
} else {
x.to_owned()
}
}
};
let mut step = 0;
let mut criterion_improved = true;
while step < max_steps && criterion_improved {
criterion_improved = false;
if direction == StepwiseDirection::Forward || direction == StepwiseDirection::Both {
let mut best_var = None;
let mut best_criterion = F::infinity();
for i in 0..p {
if selected_indices.contains(&i) {
continue;
}
let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
let var_col = x.slice(s![.., i]).to_owned();
test_x
.push_column(var_col.view())
.expect("Failed to push column");
if let Ok(model) = linear_regression(&test_x.view(), y) {
let crit_value =
calculate_criterion(&model, n, model.coefficients.len(), criterion);
if is_criterion_better(crit_value, best_criterion, criterion) {
best_var = Some(i);
best_criterion = crit_value;
}
}
}
if let Some(var_idx) = best_var {
let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
let var_col = x.slice(s![.., var_idx]).to_owned();
test_x
.push_column(var_col.view())
.expect("Failed to push column");
if let Ok(model) = linear_regression(&test_x.view(), y) {
let var_pos = test_x.ncols() - 1;
let _t_value = model.t_values[var_pos];
let p_value = model.p_values[var_pos];
if p_value <= p_enter {
selected_indices.insert(var_idx);
current_x = test_x;
sequence.push((var_idx, true));
criteria_values.push(best_criterion);
criterion_improved = true;
}
}
}
}
if (direction == StepwiseDirection::Backward || direction == StepwiseDirection::Both)
&& !criterion_improved
&& !selected_indices.is_empty()
{
let mut worst_var = None;
let mut worst_criterion = F::infinity();
for &var_idx in &selected_indices {
let mut test_indices = selected_indices.clone();
test_indices.remove(&var_idx);
let test_x = create_model_matrix(x, &test_indices, include_intercept);
if let Ok(model) = linear_regression(&test_x.view(), y) {
let crit_value =
calculate_criterion(&model, n, model.coefficients.len(), criterion);
if is_criterion_better(crit_value, worst_criterion, criterion) {
worst_var = Some(var_idx);
worst_criterion = crit_value;
}
}
}
if let Some(var_idx) = worst_var {
let var_pos = find_var_position(¤t_x, x, var_idx, include_intercept);
if let Ok(model) = linear_regression(¤t_x.view(), y) {
let p_value = model.p_values[var_pos];
if p_value > p_remove {
selected_indices.remove(&var_idx);
current_x = create_model_matrix(x, &selected_indices, include_intercept);
sequence.push((var_idx, false));
criteria_values.push(worst_criterion);
criterion_improved = true;
}
}
}
}
step += 1;
}
let final_model = linear_regression(¤t_x.view(), y)?;
let selected_indices = selected_indices.into_iter().collect();
Ok(StepwiseResults {
final_model,
selected_indices,
sequence,
criteria_values,
})
}
#[allow(dead_code)]
fn create_model_matrix<F>(
x: &ArrayView2<F>,
indices: &HashSet<usize>,
include_intercept: bool,
) -> Array2<F>
where
F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
{
let n = x.nrows();
let p = indices.len();
let cols = if include_intercept { p + 1 } else { p };
let mut x_model = Array2::<F>::zeros((n, cols));
if include_intercept {
x_model.slice_mut(s![.., 0]).fill(F::one());
}
let offset = if include_intercept { 1 } else { 0 };
for (i, &idx) in indices.iter().enumerate() {
x_model
.slice_mut(s![.., i + offset])
.assign(&x.slice(s![.., idx]));
}
x_model
}
#[allow(dead_code)]
fn find_var_position<F>(
current_x: &Array2<F>,
x: &ArrayView2<F>,
var_idx: usize,
include_intercept: bool,
) -> usize
where
F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
{
let offset = if include_intercept { 1 } else { 0 };
for i in offset..current_x.ncols() {
let col = current_x.slice(s![.., i]);
let x_col = x.slice(s![.., var_idx]);
if col
.iter()
.zip(x_col.iter())
.all(|(&a, &b)| (a - b).abs() < F::epsilon())
{
return i;
}
}
current_x.ncols() - 1
}
#[allow(dead_code)]
fn calculate_criterion<F>(
model: &RegressionResults<F>,
n: usize,
p: usize,
criterion: StepwiseCriterion,
) -> F
where
F: Float + 'static + std::iter::Sum<F> + std::fmt::Debug + std::fmt::Display,
{
match criterion {
StepwiseCriterion::AIC => {
let rss: F = model
.residuals
.iter()
.map(|&r| scirs2_core::numeric::Float::powi(r, 2))
.sum();
let n_f = F::from(n).expect("Failed to convert to float");
let k_f = F::from(p).expect("Failed to convert to float");
n_f * scirs2_core::numeric::Float::ln(rss / n_f)
+ F::from(2.0).expect("Failed to convert constant to float") * k_f
}
StepwiseCriterion::BIC => {
let rss: F = model
.residuals
.iter()
.map(|&r| scirs2_core::numeric::Float::powi(r, 2))
.sum();
let n_f = F::from(n).expect("Failed to convert to float");
let k_f = F::from(p).expect("Failed to convert to float");
n_f * scirs2_core::numeric::Float::ln(rss / n_f)
+ k_f * scirs2_core::numeric::Float::ln(n_f)
}
StepwiseCriterion::AdjR2 => {
-model.adj_r_squared }
StepwiseCriterion::F => {
-model.f_statistic }
StepwiseCriterion::T => {
let min_t = model
.t_values
.iter()
.map(|&t| t.abs())
.fold(F::infinity(), |a, b| a.min(b));
-min_t }
}
}
#[allow(dead_code)]
fn is_criterion_better<F>(_new_value: F, oldvalue: F, criterion: StepwiseCriterion) -> bool
where
F: Float + std::fmt::Display,
{
match criterion {
StepwiseCriterion::AIC | StepwiseCriterion::BIC => _new_value < oldvalue,
StepwiseCriterion::AdjR2 | StepwiseCriterion::F | StepwiseCriterion::T => {
_new_value < oldvalue
}
}
}
#[allow(dead_code)]
fn linear_regression<F>(x: &ArrayView2<F>, y: &ArrayView1<F>) -> StatsResult<RegressionResults<F>>
where
F: Float
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ std::fmt::Debug
+ std::fmt::Display
+ 'static
+ scirs2_core::numeric::NumAssign
+ scirs2_core::numeric::One
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
let n = x.nrows();
let p = x.ncols();
if n <= p {
return Err(StatsError::InvalidArgument(format!(
"Number of observations ({}) must be greater than number of predictors ({})",
n, p
)));
}
let coefficients = match lstsq(x, y, None) {
Ok(result) => result.x,
Err(e) => {
return Err(StatsError::ComputationError(format!(
"Least squares computation failed: {:?}",
e
)))
}
};
let fitted_values = x.dot(&coefficients);
let residuals = y.to_owned() - &fitted_values;
let df_model = p - 1; let df_residuals = n - p;
let (_y_mean, ss_total, ss_residual, ss_explained) =
calculate_sum_of_squares(y, &residuals.view());
let r_squared = ss_explained / ss_total;
let adj_r_squared = F::one()
- (F::one() - r_squared) * F::from(n - 1).expect("Failed to convert to float")
/ F::from(df_residuals).expect("Failed to convert to float");
let mse = ss_residual / F::from(df_residuals).expect("Failed to convert to float");
let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
let std_errors = match calculate_std_errors(x, &residuals.view(), df_residuals) {
Ok(se) => se,
Err(_) => Array1::<F>::zeros(p),
};
let t_values = calculate_t_values(&coefficients, &std_errors);
let p_values = t_values.mapv(|t| {
let t_abs = scirs2_core::numeric::Float::abs(t);
let df_f = F::from(df_residuals).expect("Failed to convert to float");
F::from(2.0).expect("Failed to convert constant to float")
* (F::one() - t_abs / scirs2_core::numeric::Float::sqrt(df_f + t_abs * t_abs))
});
let mut conf_intervals = Array2::<F>::zeros((p, 2));
for i in 0..p {
let margin = std_errors[i] * F::from(1.96).expect("Failed to convert constant to float"); conf_intervals[[i, 0]] = coefficients[i] - margin;
conf_intervals[[i, 1]] = coefficients[i] + margin;
}
let f_statistic = if df_model > 0 && df_residuals > 0 {
(ss_explained / F::from(df_model).expect("Failed to convert to float"))
/ (ss_residual / F::from(df_residuals).expect("Failed to convert to float"))
} else {
F::infinity()
};
let f_p_value = F::zero();
Ok(RegressionResults {
coefficients,
std_errors,
t_values,
p_values,
conf_intervals,
r_squared,
adj_r_squared,
f_statistic,
f_p_value,
residual_std_error,
df_residuals,
residuals,
fitted_values,
inlier_mask: vec![true; n], })
}