pub use super::RegularizationType;
use super::helper_function::{
preliminary_check, validate_learning_rate, validate_max_iterations, validate_regulation_type,
validate_tolerance,
};
use crate::error::ModelError;
use crate::math::sum_of_squared_errors;
use crate::{Deserialize, Serialize};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2};
use rayon::prelude::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelBridge,
ParallelIterator,
};
const LINEAR_REGRESSION_PARALLEL_THRESHOLD: usize = 200;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LinearRegression {
coefficients: Option<Array1<f64>>,
intercept: Option<f64>,
fit_intercept: bool,
learning_rate: f64,
max_iter: usize,
tol: f64,
n_iter: Option<usize>,
regularization_type: Option<RegularizationType>,
}
impl Default for LinearRegression {
fn default() -> Self {
Self {
coefficients: None,
intercept: None,
fit_intercept: true,
learning_rate: 0.01,
max_iter: 1000,
tol: 1e-5,
n_iter: None,
regularization_type: None,
}
}
}
impl LinearRegression {
pub fn new(
fit_intercept: bool,
learning_rate: f64,
max_iterations: usize,
tolerance: f64,
regularization_type: Option<RegularizationType>,
) -> Result<Self, ModelError> {
validate_learning_rate(learning_rate)?;
validate_max_iterations(max_iterations)?;
validate_tolerance(tolerance)?;
validate_regulation_type(regularization_type)?;
Ok(LinearRegression {
coefficients: None,
intercept: None,
fit_intercept,
learning_rate,
max_iter: max_iterations,
tol: tolerance,
n_iter: None,
regularization_type,
})
}
get_field!(get_fit_intercept, fit_intercept, bool);
get_field!(get_learning_rate, learning_rate, f64);
get_field!(get_max_iter, max_iter, usize);
get_field!(get_tolerance, tol, f64);
get_field!(get_max_iterations, max_iter, usize);
get_field!(get_actual_iterations, n_iter, Option<usize>);
get_field!(
get_regularization_type,
regularization_type,
Option<RegularizationType>
);
get_field_as_ref!(get_coefficients, coefficients, Option<&Array1<f64>>);
get_field!(get_intercept, intercept, Option<f64>);
pub fn fit<S>(
&mut self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64>,
{
preliminary_check(x, Some(y))?;
let n_samples = x.nrows();
let n_features = x.ncols();
let mut weights = Array1::<f64>::zeros(n_features); let mut intercept = 0.0;
let mut prev_cost = f64::INFINITY;
let mut convergence_count = 0; const CONVERGENCE_THRESHOLD: usize = 3;
let mut n_iter = 0;
let mut predictions = Array1::<f64>::zeros(n_samples);
let mut error_vec = Array1::<f64>::zeros(n_samples);
let progress_bar = ProgressBar::new(self.max_iter as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Cost: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
progress_bar.set_message(format!(
"{:.6} | Convergence: 0/{}",
f64::INFINITY,
CONVERGENCE_THRESHOLD
));
while n_iter < self.max_iter {
n_iter += 1;
predictions.assign(&x.dot(&weights));
if self.fit_intercept {
predictions += intercept;
}
error_vec.assign(&(&predictions - y));
let sse = sum_of_squared_errors(&predictions, &y);
let regularization_term = match &self.regularization_type {
None => 0.0,
Some(RegularizationType::L1(alpha)) => {
if n_features >= LINEAR_REGRESSION_PARALLEL_THRESHOLD {
alpha * weights.iter().par_bridge().map(|w| w.abs()).sum::<f64>()
} else {
alpha * weights.iter().map(|w| w.abs()).sum::<f64>()
}
}
Some(RegularizationType::L2(alpha)) => alpha * weights.dot(&weights),
};
let cost = sse / (2.0 * n_samples as f64) + regularization_term;
progress_bar.set_message(format!(
"{:.6} | Convergence: {}/{}",
cost, convergence_count, CONVERGENCE_THRESHOLD
));
progress_bar.inc(1);
if !cost.is_finite() {
progress_bar.finish_with_message("Error: NaN or infinite cost");
return Err(ModelError::ProcessingError(
"Cost calculation resulted in NaN or infinite value".to_string(),
));
}
let mut weight_gradients = x.t().dot(&error_vec) / (n_samples as f64);
let intercept_gradient = if self.fit_intercept {
error_vec.sum() / (n_samples as f64)
} else {
0.0
};
if weight_gradients.iter().any(|&val| !val.is_finite())
|| !intercept_gradient.is_finite()
{
progress_bar.finish_with_message("Error: NaN or infinite gradients");
return Err(ModelError::ProcessingError(
"Gradient calculation resulted in NaN or infinite values".to_string(),
));
}
match &self.regularization_type {
None => {}
Some(RegularizationType::L1(alpha)) => {
let alpha_val = *alpha;
if n_features >= LINEAR_REGRESSION_PARALLEL_THRESHOLD {
let weights_slice = weights.as_slice().unwrap();
let gradients_slice = weight_gradients.as_slice_mut().unwrap();
gradients_slice
.par_iter_mut()
.zip(weights_slice.par_iter())
.for_each(|(grad, w)| {
*grad += alpha_val * w.signum();
});
} else {
weight_gradients
.iter_mut()
.zip(weights.iter())
.for_each(|(grad, w)| {
*grad += alpha_val * w.signum();
});
}
}
Some(RegularizationType::L2(alpha)) => {
weight_gradients.scaled_add(*alpha, &weights);
}
}
weights.scaled_add(-self.learning_rate, &weight_gradients);
if self.fit_intercept {
intercept -= self.learning_rate * intercept_gradient;
}
if weights.iter().any(|&val| !val.is_finite()) || !intercept.is_finite() {
progress_bar.finish_with_message("Error: NaN or infinite parameters");
return Err(ModelError::ProcessingError(
"Parameter update resulted in NaN or infinite values".to_string(),
));
}
let cost_change = (prev_cost - cost).abs();
if cost_change < self.tol {
convergence_count += 1;
if convergence_count >= CONVERGENCE_THRESHOLD {
break;
}
} else {
convergence_count = 0; }
prev_cost = cost;
}
let convergence_status = if n_iter < self.max_iter {
"Converged"
} else {
"Max iterations"
};
progress_bar.finish_with_message(format!(
"{:.6} | {} | Iterations: {}",
prev_cost, convergence_status, n_iter
));
self.coefficients = Some(weights);
self.intercept = Some(if self.fit_intercept { intercept } else { 0.0 });
self.n_iter = Some(n_iter);
println!(
"\nLinear Regression training completed: {} samples, {} features, {} iterations, final cost: {:.6}",
n_samples, n_features, n_iter, prev_cost
);
Ok(self)
}
pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64>,
{
if self.coefficients.is_none() {
return Err(ModelError::NotFitted);
}
let coeffs = self.coefficients.as_ref().unwrap();
let intercept = self.intercept.unwrap_or(0.0);
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Cannot predict on empty dataset".to_string(),
));
}
if x.ncols() != coeffs.len() {
return Err(ModelError::InputValidationError(format!(
"Number of features does not match training data, x columns: {}, coefficients: {}",
x.ncols(),
coeffs.len()
)));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let mut predictions = x.dot(coeffs);
if self.fit_intercept {
predictions += intercept;
}
if predictions.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::ProcessingError(
"Prediction calculation resulted in NaN or infinite values".to_string(),
));
}
Ok(predictions)
}
pub fn fit_predict<S>(
&mut self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64>,
{
self.fit(x, y)?;
Ok(self.predict(x)?)
}
model_save_and_load_methods!(LinearRegression);
}