pub mod svc;
pub mod svr;
pub mod search;
use core::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
use typetag;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, ArrayView1};
#[cfg_attr(
all(feature = "serde", not(target_arch = "wasm32")),
typetag::serde(tag = "type")
)]
pub trait Kernel: Debug {
#[allow(clippy::ptr_arg)]
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)]
pub enum Kernels {
Linear,
RBF {
gamma: Option<f64>,
},
Polynomial {
degree: Option<f64>,
gamma: Option<f64>,
coef0: Option<f64>,
},
Sigmoid {
gamma: Option<f64>,
coef0: Option<f64>,
},
}
impl Kernels {
pub fn linear() -> Self {
Kernels::Linear
}
pub fn rbf() -> Self {
Kernels::RBF { gamma: None }
}
pub fn polynomial() -> Self {
Kernels::Polynomial {
gamma: None,
degree: None,
coef0: Some(1.0),
}
}
pub fn sigmoid() -> Self {
Kernels::Sigmoid {
gamma: None,
coef0: Some(1.0),
}
}
pub fn with_gamma(self, gamma: f64) -> Self {
match self {
Kernels::RBF { .. } => Kernels::RBF { gamma: Some(gamma) },
Kernels::Polynomial { degree, coef0, .. } => Kernels::Polynomial {
gamma: Some(gamma),
degree,
coef0,
},
Kernels::Sigmoid { coef0, .. } => Kernels::Sigmoid {
gamma: Some(gamma),
coef0,
},
other => other,
}
}
pub fn with_degree(self, degree: f64) -> Self {
match self {
Kernels::Polynomial { gamma, coef0, .. } => Kernels::Polynomial {
degree: Some(degree),
gamma,
coef0,
},
other => other,
}
}
pub fn with_coef0(self, coef0: f64) -> Self {
match self {
Kernels::Polynomial { degree, gamma, .. } => Kernels::Polynomial {
degree,
gamma,
coef0: Some(coef0),
},
Kernels::Sigmoid { gamma, .. } => Kernels::Sigmoid {
gamma,
coef0: Some(coef0),
},
other => other,
}
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for Kernels {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
match self {
Kernels::Linear => Ok(x_i.dot(x_j)),
Kernels::RBF { gamma } => {
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let v_diff = x_i.sub(x_j);
Ok((-gamma * v_diff.mul(&v_diff).sum()).exp())
}
Kernels::Polynomial {
degree,
gamma,
coef0,
} => {
let degree = degree.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "degree not set")
})?;
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let coef0 = coef0.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "coef0 not set")
})?;
let dot = x_i.dot(x_j);
Ok((gamma * dot + coef0).powf(degree))
}
Kernels::Sigmoid { gamma, coef0 } => {
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let coef0 = coef0.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "coef0 not set")
})?;
let dot = x_i.dot(x_j);
Ok((gamma * dot + coef0).tanh())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::svm::Kernels;
#[test]
fn rbf_kernel() {
let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.];
let result = Kernels::rbf()
.with_gamma(0.055)
.apply(&v1, &v2)
.unwrap()
.abs();
assert!((0.2265f64 - result) < 1e-4);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn linear_kernel() {
let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.];
assert_eq!(32f64, Kernels::linear().apply(&v1, &v2).unwrap());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_rbf_kernel() {
let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.];
let result = Kernels::rbf()
.with_gamma(0.055)
.apply(&v1, &v2)
.unwrap()
.abs();
assert!((0.2265f64 - result) < 1e-4);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn polynomial_kernel() {
let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.];
let result = Kernels::polynomial()
.with_gamma(0.5)
.with_degree(3.0)
.with_coef0(1.0)
.apply(&v1, &v2)
.unwrap()
.abs();
assert!((4913f64 - result).abs() < f64::EPSILON);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn sigmoid_kernel() {
let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.];
let result = Kernels::sigmoid()
.with_gamma(0.01)
.with_coef0(0.1)
.apply(&v1, &v2)
.unwrap()
.abs();
assert!((0.3969f64 - result) < 1e-4);
}
}