use super::helper_function::{preliminary_check, validate_max_iterations, validate_tolerance};
pub use crate::KernelType;
use crate::error::ModelError;
use crate::{Deserialize, Serialize};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array1, Array2, ArrayBase, ArrayView1, ArrayViewMut0, Axis, Data, Ix1, Ix2};
use ndarray_rand::rand::random_range;
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelBridge,
ParallelIterator,
};
const SVC_PARALLEL_THRESHOLD: usize = 100;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SVC {
kernel: KernelType,
regularization_param: f64,
alphas: Option<Array1<f64>>,
support_vectors: Option<Array2<f64>>,
support_vector_labels: Option<Array1<f64>>,
bias: Option<f64>,
tol: f64,
max_iter: usize,
eps: f64,
n_iter: Option<usize>,
}
impl Default for SVC {
fn default() -> Self {
SVC {
kernel: KernelType::RBF { gamma: 0.1 },
regularization_param: 1.0,
alphas: None,
support_vectors: None,
support_vector_labels: None,
bias: None,
tol: 0.001,
max_iter: 1000,
eps: 1e-8,
n_iter: None,
}
}
}
impl SVC {
pub fn new(
kernel: KernelType,
regularization_param: f64,
tol: f64,
max_iter: usize,
) -> Result<Self, ModelError> {
if regularization_param <= 0.0 || !regularization_param.is_finite() {
return Err(ModelError::InputValidationError(format!(
"Regularization parameter must be positive and finite, got {}",
regularization_param
)));
}
validate_tolerance(tol)?;
validate_max_iterations(max_iter)?;
Ok(SVC {
kernel,
regularization_param,
alphas: None,
support_vectors: None,
support_vector_labels: None,
bias: None,
tol,
max_iter,
eps: 1e-8,
n_iter: None,
})
}
get_field!(get_kernel, kernel, KernelType);
get_field!(get_regularization_parameter, regularization_param, f64);
get_field!(get_tolerance, tol, f64);
get_field!(get_max_iterations, max_iter, usize);
get_field!(get_epsilon, eps, f64);
get_field!(get_actual_iterations, n_iter, Option<usize>);
get_field_as_ref!(get_alphas, alphas, Option<&Array1<f64>>);
get_field_as_ref!(get_support_vectors, support_vectors, Option<&Array2<f64>>);
get_field_as_ref!(
get_support_vector_labels,
support_vector_labels,
Option<&Array1<f64>>
);
get_field!(get_bias, bias, Option<f64>);
fn kernel_function(&self, x1: ArrayView1<f64>, x2: ArrayView1<f64>) -> f64 {
match self.kernel {
KernelType::Linear => {
x1.dot(&x2)
}
KernelType::Poly {
degree,
gamma,
coef0,
} => {
(gamma * x1.dot(&x2) + coef0).powf(degree as f64)
}
KernelType::RBF { gamma } => {
let diff = &x1 - &x2;
let squared_norm = diff.dot(&diff);
(-gamma * squared_norm).exp()
}
KernelType::Sigmoid { gamma, coef0 } => {
(gamma * x1.dot(&x2) + coef0).tanh()
}
KernelType::Cosine => {
let norm_product = (x1.dot(&x1) * x2.dot(&x2)).sqrt();
if norm_product <= f64::EPSILON {
0.0
} else {
x1.dot(&x2) / norm_product
}
}
}
}
fn compute_kernel_matrix<S>(&self, x: &ArrayBase<S, Ix2>) -> Array2<f64>
where
S: Data<Elem = f64> + Send + Sync,
{
let n_samples = x.nrows();
let mut kernel_matrix = Array2::<f64>::zeros((n_samples, n_samples));
let pairs: Vec<(usize, usize)> = (0..n_samples)
.flat_map(|i| (i..n_samples).map(move |j| (i, j)))
.collect();
let kernel_values: Vec<((usize, usize), f64)> = if n_samples >= SVC_PARALLEL_THRESHOLD {
pairs
.par_iter()
.map(|&(i, j)| {
let k_val = self.kernel_function(x.row(i), x.row(j));
((i, j), k_val)
})
.collect()
} else {
pairs
.iter()
.map(|&(i, j)| {
let k_val = self.kernel_function(x.row(i), x.row(j));
((i, j), k_val)
})
.collect()
};
for ((i, j), val) in kernel_values {
kernel_matrix[[i, j]] = val;
if i != j {
kernel_matrix[[j, i]] = val; }
}
kernel_matrix
}
fn compute_quadratic_term(
support_indices: &[usize],
support_vector_alphas: &Array1<f64>,
support_vector_labels: &Array1<f64>,
kernel_matrix: &Array2<f64>,
use_parallel: bool,
) -> f64 {
let compute_fn = |(i, &idx_i): (usize, &usize)| {
support_indices
.iter()
.enumerate()
.map(|(j, &idx_j)| {
let kernel_val = kernel_matrix[[idx_i, idx_j]];
support_vector_alphas[i]
* support_vector_alphas[j]
* support_vector_labels[i]
* support_vector_labels[j]
* kernel_val
})
.sum::<f64>()
};
if use_parallel {
support_indices.par_iter().enumerate().map(compute_fn).sum()
} else {
support_indices.iter().enumerate().map(compute_fn).sum()
}
}
fn compute_decision_value<F>(
x_row: ArrayView1<f64>,
support_vectors: &Array2<f64>,
alphas: &Array1<f64>,
support_vector_labels: &Array1<f64>,
bias: f64,
kernel_fn: F,
) -> f64
where
F: Fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
{
(0..support_vectors.nrows())
.map(|j| {
let kernel_val = kernel_fn(x_row, support_vectors.row(j));
alphas[j] * support_vector_labels[j] * kernel_val
})
.sum::<f64>()
+ bias
}
pub fn fit<S>(
&mut self,
x: &ArrayBase<S, Ix2>,
y: &ArrayBase<S, Ix1>,
) -> Result<&mut Self, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
preliminary_check(x, Some(y))?;
let (n_samples, n_features) = (x.nrows(), x.ncols());
let y_vec: Vec<f64> = y.to_vec();
let label_check = if n_samples >= SVC_PARALLEL_THRESHOLD {
y_vec.par_iter().all(|&yi| yi == 1.0 || yi == -1.0)
} else {
y_vec.iter().all(|&yi| yi == 1.0 || yi == -1.0)
};
if !label_check {
return Err(ModelError::InputValidationError(
"All labels must be either 1.0 or -1.0".to_string(),
));
}
let mut alphas = Array1::<f64>::zeros(n_samples);
let mut b = 0.0;
let kernel_matrix = self.compute_kernel_matrix(x);
let kernel_vec: Vec<f64> = kernel_matrix.iter().cloned().collect();
let kernel_invalid = if n_samples >= SVC_PARALLEL_THRESHOLD {
kernel_vec.par_iter().any(|&val| !val.is_finite())
} else {
kernel_vec.iter().any(|&val| !val.is_finite())
};
if kernel_invalid {
return Err(ModelError::ProcessingError(
"Kernel matrix contains invalid values - check kernel parameters".to_string(),
));
}
let error_cache: Vec<f64> = if n_samples >= SVC_PARALLEL_THRESHOLD {
(0..n_samples)
.into_par_iter()
.map(|i| self.decision_function_internal(i, &alphas, &kernel_matrix, y, b))
.collect()
} else {
(0..n_samples)
.map(|i| self.decision_function_internal(i, &alphas, &kernel_matrix, y, b))
.collect()
};
let mut error_cache = Array1::from(error_cache);
let mut num_changed_alphas;
let mut examine_all = true;
let mut iteration_count = 0;
let progress_bar = ProgressBar::new(self.max_iter as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
progress_bar.set_message("Alpha changes: 0 | Examine: All");
loop {
if iteration_count >= self.max_iter {
progress_bar.finish_with_message(
"Warning: Max iterations reached without full convergence",
);
eprintln!(
"Warning: SVC reached maximum iterations ({}) without full convergence",
self.max_iter
);
break;
}
num_changed_alphas = 0;
iteration_count += 1;
progress_bar.inc(1);
let sample_range: Vec<usize> = if examine_all {
(0..n_samples).collect()
} else {
(0..n_samples)
.filter(|&i| alphas[i] > 0.0 && alphas[i] < self.regularization_param)
.collect()
};
for &i in &sample_range {
num_changed_alphas += self.examine_example(
i,
&mut alphas,
&kernel_matrix,
y,
&mut b,
&mut error_cache,
);
}
let examine_mode = if examine_all { "All" } else { "Non-bound" };
progress_bar.set_message(format!(
"Alpha changes: {} | Examine: {}",
num_changed_alphas, examine_mode
));
if examine_all {
examine_all = false;
} else if num_changed_alphas == 0 {
examine_all = true;
}
if !examine_all && num_changed_alphas == 0 {
break;
}
}
progress_bar.finish_with_message(format!("Converged at iteration {}", iteration_count));
let support_indices: Vec<usize> = if n_samples >= SVC_PARALLEL_THRESHOLD {
(0..n_samples)
.into_par_iter()
.filter_map(|i| if alphas[i] > self.eps { Some(i) } else { None })
.collect()
} else {
(0..n_samples)
.filter_map(|i| if alphas[i] > self.eps { Some(i) } else { None })
.collect()
};
if support_indices.is_empty() {
return Err(ModelError::ProcessingError(
"No support vectors found - model failed to converge. Try adjusting parameters."
.to_string(),
));
}
if !b.is_finite() {
return Err(ModelError::ProcessingError(
"Bias term is invalid - numerical instability detected".to_string(),
));
}
let n_support_vectors = support_indices.len();
let mut support_vectors = Array2::<f64>::zeros((n_support_vectors, n_features));
let mut support_vector_labels = Array1::<f64>::zeros(n_support_vectors);
let mut support_vector_alphas = Array1::<f64>::zeros(n_support_vectors);
for (i, &idx) in support_indices.iter().enumerate() {
support_vectors.row_mut(i).assign(&x.row(idx));
support_vector_labels[i] = y[idx];
support_vector_alphas[i] = alphas[idx];
}
let alphas_vec: Vec<f64> = support_vector_alphas.to_vec();
let alphas_invalid = if support_indices.len() >= SVC_PARALLEL_THRESHOLD {
alphas_vec.par_iter().any(|&val| !val.is_finite())
} else {
alphas_vec.iter().any(|&val| !val.is_finite())
};
if alphas_invalid {
return Err(ModelError::ProcessingError(
"Support vector alphas contain invalid values".to_string(),
));
}
let cost = {
let mut dual_objective = 0.0;
let quadratic_term: f64 = Self::compute_quadratic_term(
&support_indices,
&support_vector_alphas,
&support_vector_labels,
&kernel_matrix,
support_indices.len() >= SVC_PARALLEL_THRESHOLD,
);
dual_objective += 0.5 * quadratic_term;
let linear_term: f64 = support_vector_alphas.sum();
dual_objective -= linear_term;
-dual_objective
};
println!(
"\nSVC training completed: {} samples, {} features, {} iterations, {} support vectors, final cost: {:.6}",
n_samples, n_features, iteration_count, n_support_vectors, cost
);
self.alphas = Some(support_vector_alphas);
self.support_vectors = Some(support_vectors);
self.support_vector_labels = Some(support_vector_labels);
self.bias = Some(b);
self.n_iter = Some(iteration_count);
Ok(self)
}
pub fn predict<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
let (support_vectors, support_vector_labels, alphas, bias) = match (
&self.support_vectors,
&self.support_vector_labels,
&self.alphas,
self.bias,
) {
(Some(sv), Some(svl), Some(a), Some(b)) => (sv, svl, a, b),
_ => return Err(ModelError::NotFitted),
};
preliminary_check(x, None)?;
let n_features = x.ncols();
if n_features != support_vectors.ncols() {
return Err(ModelError::InputValidationError(format!(
"Input has {} features but model was trained on {} features",
n_features,
support_vectors.ncols()
)));
}
let n_samples = x.nrows();
let compute_prediction = |i: usize| {
let decision_value = Self::compute_decision_value(
x.row(i),
support_vectors,
alphas,
support_vector_labels,
bias,
|x1, x2| self.kernel_function(x1, x2),
);
if !decision_value.is_finite() {
Err(ModelError::ProcessingError(
"Decision function produced invalid value during prediction".to_string(),
))
} else {
Ok(if decision_value >= 0.0 { 1.0 } else { -1.0 })
}
};
let prediction_results: Vec<Result<f64, ModelError>> =
if n_samples >= SVC_PARALLEL_THRESHOLD {
(0..n_samples)
.into_par_iter()
.map(compute_prediction)
.collect()
} else {
(0..n_samples).map(compute_prediction).collect()
};
let mut predictions = Vec::with_capacity(n_samples);
for result in prediction_results {
match result {
Ok(pred) => predictions.push(pred),
Err(e) => return Err(e),
}
}
Ok(Array1::from(predictions))
}
pub fn decision_function<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array1<f64>, ModelError>
where
S: Data<Elem = f64> + Send + Sync,
{
let (support_vectors, support_vector_labels, alphas, bias) = match (
&self.support_vectors,
&self.support_vector_labels,
&self.alphas,
self.bias,
) {
(Some(sv), Some(svl), Some(a), Some(b)) => (sv, svl, a, b),
_ => return Err(ModelError::NotFitted),
};
preliminary_check(x, None)?;
let n_features = x.ncols();
if n_features != support_vectors.ncols() {
return Err(ModelError::InputValidationError(format!(
"Input has {} features but model was trained on {} features",
n_features,
support_vectors.ncols()
)));
}
let n_samples = x.nrows();
let mut decision_values = Array1::<f64>::zeros(n_samples);
let compute_fn = |(i, mut val): (usize, ArrayViewMut0<f64>)| {
let decision_val = Self::compute_decision_value(
x.row(i),
support_vectors,
alphas,
support_vector_labels,
bias,
|x1, x2| self.kernel_function(x1, x2),
);
val.fill(decision_val);
};
if n_samples >= SVC_PARALLEL_THRESHOLD {
decision_values
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(compute_fn);
} else {
decision_values
.axis_iter_mut(Axis(0))
.enumerate()
.for_each(compute_fn);
}
Ok(decision_values)
}
fn examine_example<S>(
&self,
i2: usize,
alphas: &mut Array1<f64>,
kernel_matrix: &Array2<f64>,
y: &ArrayBase<S, Ix1>,
b: &mut f64,
error_cache: &mut Array1<f64>,
) -> usize
where
S: Data<Elem = f64> + Send + Sync,
{
let y2 = y[i2];
let alpha2 = alphas[i2];
let e2 = error_cache[i2];
let r2 = e2 * y2;
if (r2 < -self.tol && alpha2 < self.regularization_param) || (r2 > self.tol && alpha2 > 0.0)
{
let mut i1 = self.select_second_alpha(i2, e2, alphas, error_cache);
if i1 != i2 && self.take_step(i1, i2, alphas, kernel_matrix, y, b, error_cache) {
return 1;
}
let n_samples = alphas.len();
let mut start = random_range(0..n_samples);
for _ in 0..n_samples {
i1 = start;
if alphas[i1] > 0.0 && alphas[i1] < self.regularization_param && i1 != i2 {
if self.take_step(i1, i2, alphas, kernel_matrix, y, b, error_cache) {
return 1;
}
}
start = (start + 1) % n_samples;
}
start = random_range(0..n_samples);
for _ in 0..n_samples {
i1 = start;
if i1 != i2 {
if self.take_step(i1, i2, alphas, kernel_matrix, y, b, error_cache) {
return 1;
}
}
start = (start + 1) % n_samples;
}
}
0
}
fn select_second_alpha(
&self,
i2: usize,
e2: f64,
alphas: &Array1<f64>,
error_cache: &Array1<f64>,
) -> usize {
let n_samples = alphas.len();
let result = if n_samples >= SVC_PARALLEL_THRESHOLD {
(0..n_samples)
.into_par_iter()
.filter(|&i| alphas[i] > 0.0 && alphas[i] < self.regularization_param)
.map(|i| {
let e1 = error_cache[i];
let delta_e = (e1 - e2).abs();
(i, delta_e)
})
.reduce(
|| (i2, 0.0), |a, b| if b.1 > a.1 { b } else { a },
)
} else {
(0..n_samples)
.filter(|&i| alphas[i] > 0.0 && alphas[i] < self.regularization_param)
.map(|i| {
let e1 = error_cache[i];
let delta_e = (e1 - e2).abs();
(i, delta_e)
})
.fold((i2, 0.0), |a, b| if b.1 > a.1 { b } else { a })
};
result.0
}
fn take_step<S>(
&self,
i1: usize,
i2: usize,
alphas: &mut Array1<f64>,
kernel_matrix: &Array2<f64>,
y: &ArrayBase<S, Ix1>,
b: &mut f64,
error_cache: &mut Array1<f64>,
) -> bool
where
S: Data<Elem = f64> + Send + Sync,
{
if i1 == i2 {
return false;
}
let alpha1_old = alphas[i1];
let alpha2_old = alphas[i2];
let y1 = y[i1];
let y2 = y[i2];
let e1 = error_cache[i1];
let e2 = error_cache[i2];
let s = y1 * y2;
let (l, h) = if y1 != y2 {
(
0.0f64.max(alpha2_old - alpha1_old),
self.regularization_param
.min(self.regularization_param + alpha2_old - alpha1_old),
)
} else {
(
0.0f64.max(alpha1_old + alpha2_old - self.regularization_param),
self.regularization_param.min(alpha1_old + alpha2_old),
)
};
if l == h {
return false;
}
let k11 = kernel_matrix[[i1, i1]];
let k12 = kernel_matrix[[i1, i2]];
let k22 = kernel_matrix[[i2, i2]];
let eta = k11 + k22 - 2.0 * k12;
let mut alpha2_new;
if eta > 0.0 {
alpha2_new = alpha2_old + y2 * (e1 - e2) / eta;
if alpha2_new < l {
alpha2_new = l;
} else if alpha2_new > h {
alpha2_new = h;
}
} else {
let f1 = y1 * (e1 + *b) - alpha1_old * k11 - s * alpha2_old * k12;
let f2 = y2 * (e2 + *b) - s * alpha1_old * k12 - alpha2_old * k22;
let l1 = alpha1_old + s * (alpha2_old - l);
let h1 = alpha1_old + s * (alpha2_old - h);
let obj_l =
l1 * f1 + l * f2 + 0.5 * l1 * l1 * k11 + 0.5 * l * l * k22 + s * l * l1 * k12;
let obj_h =
h1 * f1 + h * f2 + 0.5 * h1 * h1 * k11 + 0.5 * h * h * k22 + s * h * h1 * k12;
if obj_l < obj_h - self.eps {
alpha2_new = l;
} else if obj_l > obj_h + self.eps {
alpha2_new = h;
} else {
alpha2_new = alpha2_old;
}
}
if (alpha2_new - alpha2_old).abs() < self.eps * (alpha2_new + alpha2_old + self.eps) {
return false;
}
let alpha1_new = alpha1_old + s * (alpha2_old - alpha2_new);
let b1 =
*b + e1 + y1 * (alpha1_new - alpha1_old) * k11 + y2 * (alpha2_new - alpha2_old) * k12;
let b2 =
*b + e2 + y1 * (alpha1_new - alpha1_old) * k12 + y2 * (alpha2_new - alpha2_old) * k22;
if alpha1_new > 0.0 && alpha1_new < self.regularization_param {
*b = b1;
} else if alpha2_new > 0.0 && alpha2_new < self.regularization_param {
*b = b2;
} else {
*b = (b1 + b2) / 2.0;
}
alphas[i1] = alpha1_new;
alphas[i2] = alpha2_new;
self.update_error_cache(alphas, kernel_matrix, &y, *b, error_cache);
true
}
fn update_error_cache<S>(
&self,
alphas: &Array1<f64>,
kernel_matrix: &Array2<f64>,
y: &ArrayBase<S, Ix1>,
b: f64,
error_cache: &mut Array1<f64>,
) where
S: Data<Elem = f64> + Send + Sync,
{
let n_samples = alphas.len();
if n_samples >= SVC_PARALLEL_THRESHOLD {
error_cache
.indexed_iter_mut()
.par_bridge()
.for_each(|(i, error)| {
*error = self.decision_function_internal(i, alphas, kernel_matrix, y, b);
});
} else {
error_cache.indexed_iter_mut().for_each(|(i, error)| {
*error = self.decision_function_internal(i, alphas, kernel_matrix, y, b);
});
}
}
fn decision_function_internal<S>(
&self,
i: usize,
alphas: &Array1<f64>,
kernel_matrix: &Array2<f64>,
y: &ArrayBase<S, Ix1>,
b: f64,
) -> f64
where
S: Data<Elem = f64> + Send + Sync,
{
let n_samples = alphas.len();
let sum: f64 = if n_samples >= SVC_PARALLEL_THRESHOLD {
let indices: Vec<usize> = (0..n_samples).collect();
indices
.par_iter()
.filter(|&&j| alphas[j] > 0.0) .map(|&j| alphas[j] * y[j] * kernel_matrix[[i, j]])
.sum()
} else {
(0..n_samples)
.filter(|&j| alphas[j] > 0.0) .map(|j| alphas[j] * y[j] * kernel_matrix[[i, j]])
.sum()
};
sum - y[i] + b
}
model_save_and_load_methods!(SVC);
}