use linfa::Float;
use ndarray::{ArrayBase, Data, Ix1};
use std::fmt;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
mod classification;
pub mod error;
pub mod hyperparams;
mod permutable_kernel;
mod regression;
pub mod solver_smo;
pub use error::{Result, SvmError};
pub use hyperparams::{SvmParams, SvmValidParams};
use linfa_kernel::KernelMethod;
pub use solver_smo::{SeparatingHyperplane, SolverParams};
use std::ops::Mul;
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ExitReason {
ReachedThreshold,
ReachedIterations,
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq)]
pub struct Svm<F: Float, T> {
pub alpha: Vec<F>,
pub rho: F,
r: Option<F>,
exit_reason: ExitReason,
iterations: usize,
obj: F,
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "KernelMethod<F>: Serialize",
deserialize = "KernelMethod<F>: Deserialize<'de>"
))
)]
kernel_method: KernelMethod<F>,
sep_hyperplane: SeparatingHyperplane<F>,
probability_coeffs: Option<(F, F)>,
phantom: PhantomData<T>,
}
impl<F: Float, T> Svm<F, T> {
pub fn nsupport(&self) -> usize {
self.alpha
.iter()
.filter(|x| x.abs() > F::cast(100.) * F::epsilon())
.count()
}
pub(crate) fn with_phantom<S>(self) -> Svm<F, S> {
Svm {
alpha: self.alpha,
rho: self.rho,
r: self.r,
exit_reason: self.exit_reason,
obj: self.obj,
iterations: self.iterations,
sep_hyperplane: self.sep_hyperplane,
kernel_method: self.kernel_method,
probability_coeffs: self.probability_coeffs,
phantom: PhantomData,
}
}
pub fn weighted_sum<D: Data<Elem = F>>(&self, sample: &ArrayBase<D, Ix1>) -> F {
match self.sep_hyperplane {
SeparatingHyperplane::Linear(ref x) => x.mul(sample).sum(),
SeparatingHyperplane::WeightedCombination(ref supp_vecs) => supp_vecs
.outer_iter()
.zip(
self.alpha
.iter()
.filter(|a| a.abs() > F::cast(100.) * F::epsilon()),
)
.map(|(x, a)| self.kernel_method.distance(x, sample.view()) * *a)
.sum(),
}
}
}
impl<F: Float, T> fmt::Display for Svm<F, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.exit_reason {
ExitReason::ReachedThreshold => write!(
f,
"Exited after {} iterations with obj = {} and {} support vectors",
self.iterations,
self.obj,
self.nsupport()
),
ExitReason::ReachedIterations => write!(
f,
"Reached maximal iterations {} with obj = {} and {} support vectors",
self.iterations,
self.obj,
self.nsupport()
),
}
}
}
#[cfg(test)]
mod tests {
use crate::{Svm, SvmParams, SvmValidParams};
use linfa::prelude::*;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<Svm<f64, usize>>();
has_autotraits::<SvmParams<f64, usize>>();
has_autotraits::<SvmValidParams<f64, usize>>();
}
#[test]
fn test_iter_folding_for_classification() {
let mut dataset = linfa_datasets::winequality().map_targets(|x| *x > 6);
let params = Svm::<_, bool>::params()
.pos_neg_weights(7., 0.6)
.gaussian_kernel(80.0);
let avg_acc = dataset
.iter_fold(4, |training_set| params.fit(training_set).unwrap())
.map(|(model, valid)| {
model
.predict(valid.view())
.confusion_matrix(&valid)
.unwrap()
.accuracy()
})
.sum::<f32>()
/ 4_f32;
assert!(avg_acc >= 0.5)
}
}