use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_core::numeric::Float;
pub use self::linear::{linear_regression, linregress, multilinear_regression, odr};
pub use self::polynomial::polyfit;
pub use self::regularized::{elastic_net, group_lasso, lasso_regression, ridge_regression};
pub use self::robust::{
bisquare_regression, huber_regression, lts_regression, ransac, theilslopes, HuberT, LtsResult,
TheilSlopesResult,
};
pub use self::stepwise::{
stepwise_regression, StepwiseCriterion, StepwiseDirection, StepwiseResults,
};
pub type MultilinearRegressionResult<F> = StatsResult<(Array1<F>, Array1<F>, usize, Array1<F>)>;
pub struct RegressionResults<F>
where
F: Float + std::fmt::Debug + std::fmt::Display + 'static,
{
pub coefficients: Array1<F>,
pub std_errors: Array1<F>,
pub t_values: Array1<F>,
pub p_values: Array1<F>,
pub conf_intervals: Array2<F>,
pub r_squared: F,
pub adj_r_squared: F,
pub f_statistic: F,
pub f_p_value: F,
pub residual_std_error: F,
pub df_residuals: usize,
pub residuals: Array1<F>,
pub fitted_values: Array1<F>,
pub inlier_mask: Vec<bool>,
}
impl<F> RegressionResults<F>
where
F: Float + std::fmt::Debug + std::fmt::Display + 'static,
{
pub fn predict(&self, xnew: &ArrayView2<F>) -> StatsResult<Array1<F>>
where
F: std::ops::Mul<Output = F> + std::iter::Sum<F>,
{
if xnew.ncols() != self.coefficients.len() {
return Err(StatsError::DimensionMismatch(format!(
"Number of features in x_new ({}) must match the number of coefficients ({})",
xnew.ncols(),
self.coefficients.len()
)));
}
let predictions = xnew.dot(&self.coefficients);
Ok(predictions)
}
pub fn summary(&self) -> String {
let mut summary = String::new();
let method_type = if self.inlier_mask.iter().any(|&x| !x) {
"=== Robust Regression Results ===\n\n"
} else {
"=== Regression Results ===\n\n"
};
summary.push_str(method_type);
summary.push_str(&format!("R^2 = {:.6}\n", self.r_squared));
summary.push_str(&format!("Adjusted R^2 = {:.6}\n", self.adj_r_squared));
summary.push_str(&format!(
"Residual Std. Error = {:.6} (df = {})\n",
self.residual_std_error, self.df_residuals
));
summary.push_str(&format!(
"F-statistic = {:.6} (p-value = {:.6})\n\n",
self.f_statistic, self.f_p_value
));
summary.push_str("Coefficients:\n");
summary.push_str(
" Estimate Std. Error t value Pr(>|t|) [95% Conf. Interval]\n",
);
summary.push_str(
"------------------------------------------------------------------------------\n",
);
for i in 0..self.coefficients.len() {
let coef_name = if i == 0 {
"Intercept"
} else {
&format!("X{}", i)
};
summary.push_str(&format!(
"{:10} {:10.6} {:12.6} {:9.4} {:10.6} [{:.6}, {:.6}]\n",
coef_name,
self.coefficients[i],
self.std_errors[i],
self.t_values[i],
self.p_values[i],
self.conf_intervals[[i, 0]],
self.conf_intervals[[i, 1]]
));
}
if self.inlier_mask.iter().any(|&x| !x) {
let inlier_count = self.inlier_mask.iter().filter(|&&x| x).count();
let outlier_count = self.inlier_mask.len() - inlier_count;
let outlier_percentage = (outlier_count as f64 * 100.0) / self.inlier_mask.len() as f64;
summary.push_str("\nRobust Statistics:\n");
summary.push_str(&format!(
" Total observations: {}\n",
self.inlier_mask.len()
));
summary.push_str(&format!(
" Inliers: {} ({:.1}%)\n",
inlier_count,
100.0 - outlier_percentage
));
summary.push_str(&format!(
" Outliers: {} ({:.1}%)\n",
outlier_count, outlier_percentage
));
if outlier_count > 0 && outlier_count <= 10 {
let outlier_indices: Vec<_> = self
.inlier_mask
.iter()
.enumerate()
.filter_map(|(i, &is_inlier)| if !is_inlier { Some(i) } else { None })
.collect();
summary.push_str(&format!(" Outlier indices: {:?}\n", outlier_indices));
}
if outlier_count > 0 {
summary.push_str("\nNote: This model used a robust method that identified and handled outliers.\n");
summary.push_str(" The coefficients are less influenced by outliers than traditional OLS.\n");
}
}
summary
}
}
pub mod functional;
mod linear;
mod polynomial;
mod regularized;
mod robust;
mod stepwise;
mod utils;