pub use super::RegularizationType;
use super::helper_function::{
preliminary_check, validate_learning_rate, validate_max_iterations, validate_regulation_type,
validate_tolerance,
};
use crate::error::ModelError;
use crate::math::{logistic_loss, sigmoid};
use crate::{Deserialize, Serialize};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, Ix1, Ix2, s};
use rayon::prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
const LOGISTIC_REGRESSION_PARALLEL_THRESHOLD: usize = 1000;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogisticRegression {
weights: Option<Array1<f64>>,
fit_intercept: bool,
learning_rate: f64,
max_iter: usize,
tol: f64,
n_iter: Option<usize>,
regularization_type: Option<RegularizationType>,
}
impl Default for LogisticRegression {
fn default() -> Self {
LogisticRegression {
weights: None,
fit_intercept: true,
learning_rate: 0.01,
max_iter: 100,
tol: 1e-4,
n_iter: None,
regularization_type: None,
}
}
}
impl LogisticRegression {
pub fn new(
fit_intercept: bool,
learning_rate: f64,
max_iterations: usize,
tolerance: f64,
regularization_type: Option<RegularizationType>,
) -> Result<Self, ModelError> {
validate_learning_rate(learning_rate)?;
validate_max_iterations(max_iterations)?;
validate_tolerance(tolerance)?;
validate_regulation_type(regularization_type)?;
Ok(LogisticRegression {
weights: None,
fit_intercept,
learning_rate,
max_iter: max_iterations,
tol: tolerance,
n_iter: None,
regularization_type,
})
}
get_field!(get_fit_intercept, fit_intercept, bool);
get_field!(get_learning_rate, learning_rate, f64);
get_field!(get_max_iterations, max_iter, usize);
get_field!(get_tolerance, tol, f64);
get_field!(get_actual_iterations, n_iter, Option<usize>);
get_field!(
get_regularization_type,
regularization_type,
Option<RegularizationType>
);
get_field_as_ref!(get_weights, weights, Option<&Array1<f64>>);
pub fn fit<S>(
&mut self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64>,
{
preliminary_check(x, Some(y))?;
for &val in y.iter() {
if val != 0.0 && val != 1.0 {
return Err(ModelError::InputValidationError(
"Target vector must contain only 0 or 1".to_string(),
));
}
}
let (n_samples, mut n_features) = x.dim();
let x_train_view: ArrayView2<f64>;
let _x_train_owned: Option<Array2<f64>>;
if self.fit_intercept {
n_features += 1;
let mut x_with_bias = Array2::ones((n_samples, n_features));
x_with_bias.slice_mut(s![.., 1..]).assign(&x);
_x_train_owned = Some(x_with_bias);
x_train_view = _x_train_owned.as_ref().unwrap().view();
} else {
_x_train_owned = None;
x_train_view = x.view();
}
let mut weights = Array1::zeros(n_features);
let mut prev_cost = f64::INFINITY;
let mut final_cost = prev_cost;
let mut n_iter = 0;
let progress_bar = ProgressBar::new(self.max_iter as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Loss: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
progress_bar.set_message(format!("{:.6}", f64::INFINITY));
while n_iter < self.max_iter {
n_iter += 1;
let predictions = x_train_view.dot(&weights);
let sigmoid_preds = if n_samples >= LOGISTIC_REGRESSION_PARALLEL_THRESHOLD {
let sigmoid_vec = (0..n_samples)
.into_par_iter()
.map(|i| sigmoid(predictions[i]))
.collect::<Vec<f64>>();
Array1::from(sigmoid_vec)
} else {
predictions.mapv(|x| sigmoid(x))
};
let errors = &sigmoid_preds - y;
let mut gradients = x_train_view.t().dot(&errors) / n_samples as f64;
if gradients.iter().any(|&val| !val.is_finite()) {
progress_bar.finish_with_message("Error: NaN or infinite gradients");
return Err(ModelError::ProcessingError(
"Gradient calculation resulted in NaN or infinite values".to_string(),
));
}
if let Some(reg_type) = &self.regularization_type {
let start_idx = if self.fit_intercept { 1 } else { 0 };
match reg_type {
RegularizationType::L1(regularization_strength) => {
for i in start_idx..n_features {
let sign = if weights[i] > 0.0 {
1.0
} else if weights[i] < 0.0 {
-1.0
} else {
0.0
};
gradients[i] += regularization_strength * sign / n_samples as f64;
}
}
RegularizationType::L2(regularization_strength) => {
for i in start_idx..n_features {
gradients[i] += regularization_strength * weights[i] / n_samples as f64;
}
}
}
}
weights = &weights - self.learning_rate * &gradients;
if weights.iter().any(|&val| !val.is_finite()) {
progress_bar.finish_with_message("Error: NaN or infinite weights");
return Err(ModelError::ProcessingError(
"Weight update resulted in NaN or infinite values".to_string(),
));
}
let mut cost = logistic_loss(&predictions, &y);
if let Some(reg_type) = &self.regularization_type {
let start_idx = if self.fit_intercept { 1 } else { 0 };
match reg_type {
RegularizationType::L1(regularization_strength) => {
let l1_penalty: f64 =
weights.slice(s![start_idx..]).mapv(|w| w.abs()).sum();
cost += regularization_strength * l1_penalty / n_samples as f64;
}
RegularizationType::L2(regularization_strength) => {
let l2_penalty: f64 = weights.slice(s![start_idx..]).mapv(|w| w * w).sum();
cost += regularization_strength * l2_penalty / (2.0 * n_samples as f64);
}
}
}
final_cost = cost;
if !cost.is_finite() {
progress_bar.finish_with_message("Error: NaN or infinite cost");
return Err(ModelError::ProcessingError(
"Cost calculation resulted in NaN or infinite value".to_string(),
));
}
progress_bar.set_message(format!("{:.6}", cost));
progress_bar.inc(1);
if (prev_cost - cost).abs() < self.tol {
break;
}
prev_cost = cost;
}
let convergence_status = if n_iter < self.max_iter {
"Converged"
} else {
"Max iterations"
};
progress_bar.finish_with_message(format!(
"{:.6} | {} | Iterations: {}",
final_cost, convergence_status, n_iter
));
self.weights = Some(weights);
self.n_iter = Some(n_iter);
println!(
"\nLogistic Regression training completed: {} samples, {} features, {} iterations, final loss: {:.6}",
n_samples, n_features, n_iter, final_cost
);
Ok(self)
}
pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<i32>, ModelError>
where
S: Data<Elem = f64>,
{
if self.weights.is_none() {
return Err(ModelError::NotFitted);
}
let (n_samples, n_features) = x.dim();
if n_samples == 0 {
return Err(ModelError::InputValidationError(
"Cannot predict on empty dataset".to_string(),
));
}
let expected_features = if self.fit_intercept {
self.weights.as_ref().unwrap().len() - 1
} else {
self.weights.as_ref().unwrap().len()
};
if n_features != expected_features {
return Err(ModelError::InputValidationError(format!(
"Number of features does not match training data, x columns: {}, expected: {}",
n_features, expected_features
)));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input data contains NaN or infinite values".to_string(),
));
}
let probs = if self.fit_intercept {
let mut x_with_bias = Array2::ones((n_samples, n_features + 1));
x_with_bias.slice_mut(s![.., 1..]).assign(&x);
self.predict_proba(&x_with_bias.view())?
} else {
self.predict_proba(&x)?
};
if probs.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::ProcessingError(
"Probability calculation resulted in NaN or infinite values".to_string(),
));
}
Ok(probs.mapv(|prob| if prob >= 0.5 { 1 } else { 0 }))
}
fn predict_proba<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64>,
{
if let Some(weights) = &self.weights {
let mut predictions = x.dot(weights);
let n_samples = predictions.len();
if n_samples >= LOGISTIC_REGRESSION_PARALLEL_THRESHOLD {
predictions.par_mapv_inplace(|x| sigmoid(x));
} else {
predictions.mapv_inplace(|x| sigmoid(x));
}
Ok(predictions)
} else {
Err(ModelError::NotFitted)
}
}
pub fn fit_predict<S>(
&mut self,
train_x: &ArrayBase<S, Ix2>,
train_y: &ArrayBase<S, Ix1>,
) -> Result<Array1<i32>, ModelError>
where
S: Data<Elem = f64>,
{
self.fit(train_x, train_y)?;
Ok(self.predict(train_x)?)
}
model_save_and_load_methods!(LogisticRegression);
}
pub fn generate_polynomial_features<S>(x: &ArrayBase<S, Ix2>, degree: usize) -> Array2<f64>
where
S: Data<Elem = f64> + Send + Sync,
{
let (n_samples, n_features) = x.dim();
let n_output_features = {
let mut count = 0; for d in 1..=degree {
let mut term = 1;
for i in 0..d {
term = term * (n_features + i) / (i + 1);
}
count += term;
}
count
};
let mut result = Array2::<f64>::zeros((n_samples, n_output_features));
result
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
for j in 0..n_features {
row[j] = x[[i, j]]; }
});
if degree >= 2 {
let mut col_idx = n_features;
fn add_combinations<S>(
x: &ArrayBase<S, Ix2>,
result: &mut Array2<f64>,
col_idx: &mut usize,
n_samples: usize,
n_features: usize,
degree: usize,
current_degree: usize,
start_feature: usize,
combination: &mut Vec<usize>,
) where
S: Data<Elem = f64> + Send + Sync,
{
if current_degree == degree {
let current_col = *col_idx;
*col_idx += 1;
result
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
let mut value = 1.0;
for &feat_idx in combination.iter() {
value *= x[[i, feat_idx]];
}
row[current_col] = value;
});
return;
}
for j in start_feature..n_features {
combination.push(j);
add_combinations(
x,
result,
col_idx,
n_samples,
n_features,
degree,
current_degree + 1,
j,
combination,
);
combination.pop();
}
}
for d in 2..=degree {
add_combinations(
x,
&mut result,
&mut col_idx,
n_samples,
n_features,
d,
0,
0,
&mut vec![],
);
}
}
result
}