pub use super::RegularizationType;
use super::helper_function::{
preliminary_check, validate_learning_rate, validate_max_iterations, validate_tolerance,
};
use crate::error::ModelError;
use crate::{Deserialize, Serialize};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, ArrayBase, Data, Ix1, Ix2, s};
use ndarray_rand::rand::{rng, seq::SliceRandom};
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
const LINEAR_SVC_PARALLEL_THRESHOLD: usize = 200;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LinearSVC {
weights: Option<Array1<f64>>,
bias: Option<f64>,
max_iter: usize,
learning_rate: f64,
penalty: RegularizationType,
fit_intercept: bool,
tol: f64,
n_iter: Option<usize>,
}
impl Default for LinearSVC {
fn default() -> Self {
LinearSVC {
weights: None,
bias: None,
max_iter: 1000,
learning_rate: 0.001,
penalty: RegularizationType::L2(1.0),
fit_intercept: true,
tol: 1e-4,
n_iter: None,
}
}
}
impl LinearSVC {
pub fn new(
max_iter: usize,
learning_rate: f64,
penalty: RegularizationType,
fit_intercept: bool,
tol: f64,
) -> Result<Self, ModelError> {
validate_max_iterations(max_iter)?;
validate_learning_rate(learning_rate)?;
validate_tolerance(tol)?;
let reg_param = match penalty {
RegularizationType::L1(lambda) | RegularizationType::L2(lambda) => lambda,
};
if reg_param < 0.0 || !reg_param.is_finite() {
return Err(ModelError::InputValidationError(format!(
"Regularization parameter must be non-negative and finite, got {}",
reg_param
)));
}
Ok(LinearSVC {
weights: None,
bias: None,
max_iter,
learning_rate,
penalty,
fit_intercept,
tol,
n_iter: None,
})
}
fn validate_input_data<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<(), ModelError>
where
S: Data<Elem = f64>,
{
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Input data cannot be empty".to_string(),
));
}
if x.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::InputValidationError(
"Input features contain NaN or infinite values".to_string(),
));
}
if let Some(ref weights) = self.weights {
if x.ncols() != weights.len() {
return Err(ModelError::InputValidationError(format!(
"Feature dimension mismatch: expected {}, got {}",
weights.len(),
x.ncols()
)));
}
}
Ok(())
}
fn check_weights_validity(weights: &Array1<f64>, bias: f64) -> Result<(), ModelError> {
if weights.iter().any(|&w| !w.is_finite()) || !bias.is_finite() {
return Err(ModelError::ProcessingError(
"Weights became NaN or infinite during training. Try reducing learning_rate or regularization_param".to_string()
));
}
Ok(())
}
fn calculate_batch_size(n_samples: usize) -> usize {
const MIN_BATCH_SIZE: usize = 32;
const MAX_BATCH_SIZE: usize = 512;
std::cmp::max(
MIN_BATCH_SIZE,
std::cmp::min(MAX_BATCH_SIZE, n_samples / 10),
)
}
get_field!(get_fit_intercept, fit_intercept, bool);
get_field!(get_learning_rate, learning_rate, f64);
get_field!(get_max_iter, max_iter, usize);
get_field!(get_tolerance, tol, f64);
get_field!(get_max_iterations, max_iter, usize);
get_field!(get_actual_iterations, n_iter, Option<usize>);
get_field_as_ref!(get_weights, weights, Option<&Array1<f64>>);
get_field!(get_bias, bias, Option<f64>);
get_field!(get_penalty, penalty, RegularizationType);
pub fn fit<S>(
&mut self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
preliminary_check(x, Some(y))?;
let n_samples = x.nrows();
let n_features = x.ncols();
let mut weights = Array1::zeros(n_features);
let mut bias = 0.0;
let y_binary = y.mapv(|v| if v <= 0.0 { -1.0 } else { 1.0 });
let mut indices: Vec<usize> = (0..n_samples).collect();
let mut rng = rng();
let mut prev_weights = weights.clone();
let mut prev_bias = bias;
let mut n_iter = 0;
let batch_size = Self::calculate_batch_size(n_samples);
let calculate_cost = |x: &ArrayBase<S, Ix2>,
y: &Array1<f64>,
weights: &Array1<f64>,
bias: f64,
penalty: &RegularizationType|
-> f64 {
let n_samples = x.nrows() as f64;
let hinge_loss: f64 = x
.outer_iter()
.zip(y.iter())
.map(|(xi, &yi)| {
let margin = xi.dot(weights) + bias;
(1.0 - yi * margin).max(0.0)
})
.sum::<f64>()
/ n_samples;
let regularization_term = match penalty {
RegularizationType::L2(lambda) => {
lambda * weights.iter().map(|&w| w * w).sum::<f64>() / 2.0
}
RegularizationType::L1(lambda) => {
lambda * weights.iter().map(|&w| w.abs()).sum::<f64>()
}
};
hinge_loss + regularization_term
};
let progress_bar = ProgressBar::new(self.max_iter as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Cost: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
progress_bar.set_message("Initializing...");
while n_iter < self.max_iter {
n_iter += 1;
progress_bar.inc(1);
indices.shuffle(&mut rng);
for batch_indices in indices.chunks(batch_size) {
let batch_len = batch_indices.len() as f64;
let compute_gradient = |&idx: &usize| {
let xi = x.slice(s![idx, ..]);
let yi = y_binary[idx];
let margin = xi.dot(&weights) + bias;
if yi * margin < 1.0 {
let weight_grad = xi.to_owned() * yi;
let bias_grad = yi;
(weight_grad, bias_grad)
} else {
(Array1::zeros(n_features), 0.0)
}
};
let (weight_grad_sum, bias_grad_sum) =
if batch_indices.len() >= LINEAR_SVC_PARALLEL_THRESHOLD {
batch_indices.par_iter().map(compute_gradient).reduce(
|| (Array1::zeros(n_features), 0.0),
|mut acc, (w_grad, b_grad)| {
acc.0 = &acc.0 + &w_grad;
acc.1 += b_grad;
acc
},
)
} else {
batch_indices.iter().map(compute_gradient).fold(
(Array1::zeros(n_features), 0.0),
|mut acc, (w_grad, b_grad)| {
acc.0 = &acc.0 + &w_grad;
acc.1 += b_grad;
acc
},
)
};
weights = &weights + &(weight_grad_sum * (self.learning_rate / batch_len));
match self.penalty {
RegularizationType::L2(lambda) => {
weights = &weights * (1.0 - self.learning_rate * lambda);
}
RegularizationType::L1(lambda) => {
let l1_grad = weights.mapv(|w| {
if w > 0.0 {
1.0
} else if w < 0.0 {
-1.0
} else {
0.0
}
});
weights = &weights - &(l1_grad * (self.learning_rate * lambda));
}
}
if self.fit_intercept {
bias += self.learning_rate * bias_grad_sum / batch_len;
}
Self::check_weights_validity(&weights, bias)?;
}
let weight_diff = (&weights - &prev_weights)
.iter()
.map(|&x| x * x)
.sum::<f64>()
/ n_features as f64;
let bias_diff = if self.fit_intercept {
(bias - prev_bias).powi(2)
} else {
0.0
};
let total_diff = (weight_diff + bias_diff).sqrt();
if n_iter % 10 == 0 || total_diff < self.tol {
let current_cost = calculate_cost(x, &y_binary, &weights, bias, &self.penalty);
progress_bar.set_message(format!("{:.6}", current_cost));
}
if total_diff < self.tol {
break;
}
prev_weights.assign(&weights);
prev_bias = bias;
}
let final_cost = calculate_cost(x, &y_binary, &weights, bias, &self.penalty);
let convergence_status = if n_iter < self.max_iter {
"Converged"
} else {
"Max iterations"
};
progress_bar.finish_with_message(format!(
"{:.6} | {} | Iterations: {}",
final_cost, convergence_status, n_iter
));
println!(
"\nLinear SVC training completed: {} samples, {} features, {} iterations, final cost: {:.6}",
n_samples, n_features, n_iter, final_cost
);
self.weights = Some(weights);
self.bias = Some(bias);
self.n_iter = Some(n_iter);
Ok(self)
}
pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64>,
{
if self.weights.is_none() {
return Err(ModelError::NotFitted);
}
self.validate_input_data(x)?;
let decision = self.decision_function(x)?;
Ok(decision.mapv(|v| if v > 0.0 { 1.0 } else { 0.0 }))
}
pub fn decision_function<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64>,
{
let weights = match self.get_weights() {
Some(weights) => weights,
None => {
return Err(ModelError::NotFitted);
}
};
let bias = self.bias.unwrap_or(0.0);
self.validate_input_data(x)?;
let decision = x.dot(weights) + bias;
if decision.iter().any(|&val| !val.is_finite()) {
return Err(ModelError::ProcessingError(
"Decision function produced NaN or infinite values".to_string(),
));
}
Ok(decision)
}
model_save_and_load_methods!(LinearSVC);
}