use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::marker::PhantomData;
use num::Bounded;
use rand::seq::SliceRandom;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::rand_custom::get_rng_impl;
use crate::svm::Kernel;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
pub epoch: usize,
pub c: TX,
pub tol: TX,
#[cfg_attr(
all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing)
)]
pub kernel: Option<Box<dyn Kernel>>,
m: PhantomData<(X, Y, TY)>,
seed: Option<u64>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
#[cfg_attr(
feature = "serde",
serde(bound(
serialize = "TX: Serialize, TY: Serialize, X: Serialize, Y: Serialize",
deserialize = "TX: Deserialize<'de>, TY: Deserialize<'de>, X: Deserialize<'de>, Y: Deserialize<'de>",
))
)]
pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
classes: Option<Vec<TY>>,
instances: Option<Vec<Vec<TX>>>,
#[cfg_attr(feature = "serde", serde(skip))]
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
w: Option<Vec<TX>>,
b: Option<TX>,
phantomdata: PhantomData<(X, Y)>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct SupportVector<TX: Number + RealNumber> {
index: usize,
x: Vec<TX>,
alpha: f64,
grad: f64,
cmin: f64,
cmax: f64,
k: f64,
}
struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
data: HashMap<(usize, usize), f64>,
phantom: PhantomData<(X, Y, TY, TX)>,
}
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
svmin: usize,
svmax: usize,
gmin: TX,
gmax: TX,
tau: TX,
sv: Vec<SupportVector<TX>>,
recalculate_minmax_grad: bool,
}
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SVCParameters<TX, TY, X, Y>
{
pub fn with_epoch(mut self, epoch: usize) -> Self {
self.epoch = epoch;
self
}
pub fn with_c(mut self, c: TX) -> Self {
self.c = c;
self
}
pub fn with_tol(mut self, tol: TX) -> Self {
self.tol = tol;
self
}
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(Box::new(kernel));
self
}
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
self.seed = seed;
self
}
}
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
for SVCParameters<TX, TY, X, Y>
{
fn default() -> Self {
SVCParameters {
epoch: 2,
c: TX::one(),
tol: TX::from_f64(1e-3).unwrap(),
kernel: Option::None,
m: PhantomData,
seed: Option::None,
}
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
{
fn new() -> Self {
Self {
classes: Option::None,
instances: Option::None,
parameters: Option::None,
w: Option::None,
b: Option::None,
phantomdata: PhantomData,
}
}
fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<Self, Failed> {
SVC::fit(x, y, parameters)
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
PredictorBorrow<'a, X, TX> for SVC<'a, TX, TY, X, Y>
{
fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
Ok(self.predict(x).unwrap())
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a>
SVC<'a, TX, TY, X, Y>
{
pub fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let (n, _) = x.shape();
if parameters.kernel.is_none() {
return Err(Failed::because(
FailedError::ParametersError,
"kernel should be defined at this point, please use `with_kernel()`",
));
}
if n != y.shape() {
return Err(Failed::fit(
"Number of rows of X doesn\'t match number of rows of Y",
));
}
let classes = y.unique();
if classes.len() != 2 {
return Err(Failed::fit(&format!(
"Incorrect number of classes: {}",
classes.len()
)));
}
for e in y.iterator(0) {
let y_v = e.to_i32().unwrap();
if y_v != -1 && y_v != 1 {
return Err(Failed::because(
FailedError::ParametersError,
"Class labels must be 1 or -1",
));
}
}
let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, parameters);
let (support_vectors, weight, b) = optimizer.optimize();
Ok(SVC::<'a> {
classes: Some(classes),
instances: Some(support_vectors),
parameters: Some(parameters),
w: Some(weight),
b: Some(b),
phantomdata: PhantomData,
})
}
pub fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
let mut y_hat: Vec<TX> = self.decision_function(x)?;
for i in 0..y_hat.len() {
let cls_idx = match *y_hat.get(i).unwrap() > TX::zero() {
false => TX::from(self.classes.as_ref().unwrap()[0]).unwrap(),
true => TX::from(self.classes.as_ref().unwrap()[1]).unwrap(),
};
y_hat.set(i, cls_idx);
}
Ok(y_hat)
}
pub fn decision_function(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
let (n, _) = x.shape();
let mut y_hat: Vec<TX> = Array1::zeros(n);
for i in 0..n {
let row_pred: TX =
self.predict_for_row(Vec::from_iterator(x.get_row(i).iterator(0).copied(), n));
y_hat.set(i, row_pred);
}
Ok(y_hat)
}
fn predict_for_row(&self, x: Vec<TX>) -> TX {
let mut f = self.b.unwrap();
for i in 0..self.instances.as_ref().unwrap().len() {
f += self.w.as_ref().unwrap()[i]
* TX::from(
self.parameters
.as_ref()
.unwrap()
.kernel
.as_ref()
.unwrap()
.apply(
&x.iter().map(|e| e.to_f64().unwrap()).collect(),
&self.instances.as_ref().unwrap()[i]
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
)
.unwrap(),
)
.unwrap();
}
f
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
for SVC<'a, TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two()
|| self.w.as_ref().unwrap().len() != other.w.as_ref().unwrap().len()
|| self.instances.as_ref().unwrap().len() != other.instances.as_ref().unwrap().len()
{
false
} else {
if !self
.w
.as_ref()
.unwrap()
.approximate_eq(other.w.as_ref().unwrap(), TX::epsilon())
{
return false;
}
for i in 0..self.w.as_ref().unwrap().len() {
if (self.w.as_ref().unwrap()[i].sub(other.w.as_ref().unwrap()[i])).abs()
> TX::epsilon()
{
return false;
}
}
for i in 0..self.instances.as_ref().unwrap().len() {
if !(self.instances.as_ref().unwrap()[i] == other.instances.as_ref().unwrap()[i]) {
return false;
}
}
true
}
}
}
impl<TX: Number + RealNumber> SupportVector<TX> {
fn new(i: usize, x: Vec<TX>, y: TX, g: f64, c: f64, k_v: f64) -> SupportVector<TX> {
let (cmin, cmax) = if y > TX::zero() {
(0f64, c)
} else {
(-c, 0f64)
};
SupportVector {
index: i,
x,
grad: g,
k: k_v,
alpha: 0f64,
cmin,
cmax,
}
}
}
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Cache<TX, TY, X, Y> {
fn new() -> Cache<TX, TY, X, Y> {
Cache {
data: HashMap::new(),
phantom: PhantomData,
}
}
fn get(&mut self, i: &SupportVector<TX>, j: &SupportVector<TX>, or_insert: f64) -> f64 {
let idx_i = i.index;
let idx_j = j.index;
#[allow(clippy::or_fun_call)]
let entry = self.data.entry((idx_i, idx_j)).or_insert(or_insert);
*entry
}
fn insert(&mut self, key: (usize, usize), value: f64) {
self.data.insert(key, value);
}
fn drop(&mut self, idxs_to_drop: HashSet<usize>) {
self.data.retain(|k, _| !idxs_to_drop.contains(&k.0));
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
Optimizer<'a, TX, TY, X, Y>
{
fn new(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Optimizer<'a, TX, TY, X, Y> {
let (n, _) = x.shape();
Optimizer {
x,
y,
parameters,
svmin: 0,
svmax: 0,
gmin: <TX as Bounded>::max_value(),
gmax: <TX as Bounded>::min_value(),
tau: TX::from_f64(1e-12).unwrap(),
sv: Vec::with_capacity(n),
recalculate_minmax_grad: true,
}
}
fn optimize(mut self) -> (Vec<Vec<TX>>, Vec<TX>, TX) {
let (n, _) = self.x.shape();
let mut cache: Cache<TX, TY, X, Y> = Cache::new();
self.initialize(&mut cache);
let tol = self.parameters.tol;
let good_enough = TX::from_i32(1000).unwrap();
for _ in 0..self.parameters.epoch {
for i in self.permutate(n) {
self.process(
i,
Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
*self.y.get(i),
&mut cache,
);
loop {
self.reprocess(tol, &mut cache);
self.find_min_max_gradient();
if self.gmax - self.gmin < good_enough {
break;
}
}
}
}
self.finish(&mut cache);
let mut support_vectors: Vec<Vec<TX>> = Vec::new();
let mut w: Vec<TX> = Vec::new();
let b = (self.gmax + self.gmin) / TX::two();
for v in self.sv {
support_vectors.push(v.x);
w.push(TX::from(v.alpha).unwrap());
}
(support_vectors, w, b)
}
fn initialize(&mut self, cache: &mut Cache<TX, TY, X, Y>) {
let (n, _) = self.x.shape();
let few = 5;
let mut cp = 0;
let mut cn = 0;
for i in self.permutate(n) {
if *self.y.get(i) == TY::one() && cp < few {
if self.process(
i,
Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
*self.y.get(i),
cache,
) {
cp += 1;
}
} else if *self.y.get(i) == TY::from(-1).unwrap()
&& cn < few
&& self.process(
i,
Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
*self.y.get(i),
cache,
)
{
cn += 1;
}
if cp >= few && cn >= few {
break;
}
}
}
fn process(&mut self, i: usize, x: Vec<TX>, y: TY, cache: &mut Cache<TX, TY, X, Y>) -> bool {
for j in 0..self.sv.len() {
if self.sv[j].index == i {
return true;
}
}
let mut g: f64 = y.to_f64().unwrap();
let mut cache_values: Vec<((usize, usize), TX)> = Vec::new();
for v in self.sv.iter() {
let k = self
.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&v.x.iter().map(|e| e.to_f64().unwrap()).collect(),
&x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap();
cache_values.push(((i, v.index), TX::from(k).unwrap()));
g -= v.alpha * k;
}
self.find_min_max_gradient();
if self.gmin < self.gmax
&& ((y > TY::zero() && g < self.gmin.to_f64().unwrap())
|| (y < TY::zero() && g > self.gmax.to_f64().unwrap()))
{
return false;
}
for v in cache_values {
cache.insert(v.0, v.1.to_f64().unwrap());
}
let x_f64 = x.iter().map(|e| e.to_f64().unwrap()).collect();
let k_v = self
.parameters
.kernel
.as_ref()
.expect("Kernel should be defined at this point, use with_kernel() on parameters")
.apply(&x_f64, &x_f64)
.unwrap();
self.sv.insert(
0,
SupportVector::<TX>::new(
i,
x.to_vec(),
TX::from(y).unwrap(),
g,
self.parameters.c.to_f64().unwrap(),
k_v,
),
);
if y > TY::zero() {
self.smo(None, Some(0), TX::zero(), cache);
} else {
self.smo(Some(0), None, TX::zero(), cache);
}
true
}
fn reprocess(&mut self, tol: TX, cache: &mut Cache<TX, TY, X, Y>) -> bool {
let status = self.smo(None, None, tol, cache);
self.clean(cache);
status
}
fn finish(&mut self, cache: &mut Cache<TX, TY, X, Y>) {
let mut max_iter = self.sv.len();
while self.smo(None, None, self.parameters.tol, cache) && max_iter > 0 {
max_iter -= 1;
}
self.clean(cache);
}
fn find_min_max_gradient(&mut self) {
if !self.recalculate_minmax_grad {
return;
}
self.gmin = <TX as Bounded>::max_value();
self.gmax = <TX as Bounded>::min_value();
for i in 0..self.sv.len() {
let v = &self.sv[i];
let g = v.grad;
let a = v.alpha;
if g < self.gmin.to_f64().unwrap() && a > v.cmin {
self.gmin = TX::from(g).unwrap();
self.svmin = i;
}
if g > self.gmax.to_f64().unwrap() && a < v.cmax {
self.gmax = TX::from(g).unwrap();
self.svmax = i;
}
}
self.recalculate_minmax_grad = false
}
fn clean(&mut self, cache: &mut Cache<TX, TY, X, Y>) {
self.find_min_max_gradient();
let gmax = self.gmax;
let gmin = self.gmin;
let mut idxs_to_drop: HashSet<usize> = HashSet::new();
self.sv.retain(|v| {
if v.alpha == 0f64
&& ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap())
|| (TX::from(v.grad).unwrap() <= gmin
&& TX::zero() <= TX::from(v.cmin).unwrap()))
{
idxs_to_drop.insert(v.index);
return false;
};
true
});
cache.drop(idxs_to_drop);
self.recalculate_minmax_grad = true;
}
fn permutate(&self, n: usize) -> Vec<usize> {
let mut rng = get_rng_impl(self.parameters.seed);
let mut range: Vec<usize> = (0..n).collect();
range.shuffle(&mut rng);
range
}
fn select_pair(
&mut self,
idx_1: Option<usize>,
idx_2: Option<usize>,
cache: &mut Cache<TX, TY, X, Y>,
) -> Option<(usize, usize, f64)> {
match (idx_1, idx_2) {
(None, None) => {
if self.gmax > -self.gmin {
self.select_pair(None, Some(self.svmax), cache)
} else {
self.select_pair(Some(self.svmin), None, cache)
}
}
(Some(idx_1), None) => {
let sv1 = &self.sv[idx_1];
let mut idx_2 = None;
let mut k_v_12 = None;
let km = sv1.k;
let gm = sv1.grad;
let mut best = 0f64;
for i in 0..self.sv.len() {
let v = &self.sv[i];
let z = v.grad - gm;
let k = cache.get(
sv1,
v,
self.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&sv1.x.iter().map(|e| e.to_f64().unwrap()).collect(),
&v.x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(),
);
let mut curv = km + v.k - 2f64 * k;
if curv <= 0f64 {
curv = self.tau.to_f64().unwrap();
}
let mu = z / curv;
if (mu > 0f64 && v.alpha < v.cmax) || (mu < 0f64 && v.alpha > v.cmin) {
let gain = z * mu;
if gain > best {
best = gain;
idx_2 = Some(i);
k_v_12 = Some(k);
}
}
}
idx_2.map(|idx_2| {
(
idx_1,
idx_2,
k_v_12.unwrap_or_else(|| {
self.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&self.sv[idx_1]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
&self.sv[idx_2]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
)
.unwrap()
}),
)
})
}
(None, Some(idx_2)) => {
let mut idx_1 = None;
let sv2 = &self.sv[idx_2];
let mut k_v_12 = None;
let km = sv2.k;
let gm = sv2.grad;
let mut best = 0f64;
for i in 0..self.sv.len() {
let v = &self.sv[i];
let z = gm - v.grad;
let k = cache.get(
sv2,
v,
self.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&sv2.x.iter().map(|e| e.to_f64().unwrap()).collect(),
&v.x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(),
);
let mut curv = km + v.k - 2f64 * k;
if curv <= 0f64 {
curv = self.tau.to_f64().unwrap();
}
let mu = z / curv;
if (mu > 0f64 && v.alpha > v.cmin) || (mu < 0f64 && v.alpha < v.cmax) {
let gain = z * mu;
if gain > best {
best = gain;
idx_1 = Some(i);
k_v_12 = Some(k);
}
}
}
idx_1.map(|idx_1| {
(
idx_1,
idx_2,
k_v_12.unwrap_or_else(|| {
self.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&self.sv[idx_1]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
&self.sv[idx_2]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
)
.unwrap()
}),
)
})
}
(Some(idx_1), Some(idx_2)) => Some((
idx_1,
idx_2,
self.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&self.sv[idx_1]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
&self.sv[idx_2]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
)
.unwrap(),
)),
}
}
fn smo(
&mut self,
idx_1: Option<usize>,
idx_2: Option<usize>,
tol: TX,
cache: &mut Cache<TX, TY, X, Y>,
) -> bool {
match self.select_pair(idx_1, idx_2, cache) {
Some((idx_1, idx_2, k_v_12)) => {
let mut curv = self.sv[idx_1].k + self.sv[idx_2].k - 2f64 * k_v_12;
if curv <= 0f64 {
curv = self.tau.to_f64().unwrap();
}
let mut step = (self.sv[idx_2].grad - self.sv[idx_1].grad) / curv;
if step >= 0f64 {
let mut ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmin;
if ostep < step {
step = ostep;
}
ostep = self.sv[idx_2].cmax - self.sv[idx_2].alpha;
if ostep < step {
step = ostep;
}
} else {
let mut ostep = self.sv[idx_2].cmin - self.sv[idx_2].alpha;
if ostep > step {
step = ostep;
}
ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmax;
if ostep > step {
step = ostep;
}
}
self.update(idx_1, idx_2, TX::from(step).unwrap(), cache);
self.gmax - self.gmin > tol
}
None => false,
}
}
fn update(&mut self, v1: usize, v2: usize, step: TX, cache: &mut Cache<TX, TY, X, Y>) {
self.sv[v1].alpha -= step.to_f64().unwrap();
self.sv[v2].alpha += step.to_f64().unwrap();
for i in 0..self.sv.len() {
let k2 = cache.get(
&self.sv[v2],
&self.sv[i],
self.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&self.sv[v2].x.iter().map(|e| e.to_f64().unwrap()).collect(),
&self.sv[i].x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(),
);
let k1 = cache.get(
&self.sv[v1],
&self.sv[i],
self.parameters
.kernel
.as_ref()
.unwrap()
.apply(
&self.sv[v1].x.iter().map(|e| e.to_f64().unwrap()).collect(),
&self.sv[i].x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(),
);
self.sv[i].grad -= step.to_f64().unwrap() * (k2 - k1);
}
self.recalculate_minmax_grad = true;
self.find_min_max_gradient();
}
}
#[cfg(test)]
mod tests {
use num::ToPrimitive;
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::accuracy;
use crate::svm::Kernels;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn svc_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y: Vec<i32> = vec![
-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
];
let knl = Kernels::linear();
let params = SVCParameters::default()
.with_c(200.0)
.with_kernel(knl)
.with_seed(Some(100));
let y_hat = SVC::fit(&x, &y, ¶ms)
.and_then(|lr| lr.predict(&x))
.unwrap();
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
assert!(
acc >= 0.9,
"accuracy ({}) is not larger or equal to 0.9",
acc
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn svc_fit_decision_function() {
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
let x2 = DenseMatrix::from_2d_array(&[
&[3.0, 3.0],
&[4.0, 4.0],
&[6.0, 6.0],
&[10.0, 10.0],
&[1.0, 1.0],
&[0.0, 0.0],
]);
let y: Vec<i32> = vec![-1, -1, 1, 1];
let y_hat = SVC::fit(
&x,
&y,
&SVCParameters::default()
.with_c(200.0)
.with_kernel(Kernels::linear()),
)
.and_then(|lr| lr.decision_function(&x2))
.unwrap();
assert!(y_hat[1] < y_hat[2]);
assert!(y_hat[2] < y_hat[3]);
assert!(y_hat[4] > y_hat[5]);
assert!(num::Float::abs(y_hat[0]) <= 0.1);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn svc_fit_predict_rbf() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y: Vec<i32> = vec![
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
];
let y_hat = SVC::fit(
&x,
&y,
&SVCParameters::default()
.with_c(1.0)
.with_kernel(Kernels::rbf().with_gamma(0.7)),
)
.and_then(|lr| lr.predict(&x))
.unwrap();
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
assert!(
acc >= 0.9,
"accuracy ({}) is not larger or equal to 0.9",
acc
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
fn svc_serde() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y: Vec<i32> = vec![
-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
];
let knl = Kernels::linear();
let params = SVCParameters::default().with_kernel(knl);
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
let deserialized_svc: SVC<f64, i32, _, _> =
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
assert_eq!(svc, deserialized_svc);
}
}