use crate::{KernelMethod, SolverParams, Svm, SvmError};
use linfa::{platt_scaling::PlattParams, Float, ParamGuard, Platt};
use linfa_kernel::{Kernel, KernelParams};
use std::marker::PhantomData;
#[derive(Debug, Clone, PartialEq)]
pub struct SvmValidParams<F: Float, T> {
c: Option<(F, F)>,
nu: Option<(F, F)>,
solver_params: SolverParams<F>,
phantom: PhantomData<T>,
kernel: KernelParams<F>,
platt: PlattParams<F, ()>,
}
impl<F: Float, T> SvmValidParams<F, T> {
pub fn c(&self) -> Option<(F, F)> {
self.c
}
pub fn nu(&self) -> Option<(F, F)> {
self.nu
}
pub fn solver_params(&self) -> &SolverParams<F> {
&self.solver_params
}
pub fn kernel_params(&self) -> &KernelParams<F> {
&self.kernel
}
pub fn platt_params(&self) -> &PlattParams<F, ()> {
&self.platt
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmParams<F: Float, T>(SvmValidParams<F, T>);
impl<F: Float, T> SvmParams<F, T> {
pub fn new() -> Self {
Self(SvmValidParams {
c: Some((F::one(), F::one())),
nu: None,
solver_params: SolverParams {
eps: F::cast(1e-7),
shrinking: false,
},
phantom: PhantomData,
kernel: Kernel::params().method(KernelMethod::Linear),
platt: Platt::params(),
})
}
pub fn eps(mut self, new_eps: F) -> Self {
self.0.solver_params.eps = new_eps;
self
}
pub fn shrinking(mut self, shrinking: bool) -> Self {
self.0.solver_params.shrinking = shrinking;
self
}
pub fn with_kernel_params(mut self, kernel: KernelParams<F>) -> Self {
self.0.kernel = kernel;
self
}
pub fn with_platt_params(mut self, platt: PlattParams<F, ()>) -> Self {
self.0.platt = platt;
self
}
pub fn gaussian_kernel(mut self, eps: F) -> Self {
self.0.kernel = Kernel::params().method(KernelMethod::Gaussian(eps));
self
}
pub fn polynomial_kernel(mut self, constant: F, degree: F) -> Self {
self.0.kernel = Kernel::params().method(KernelMethod::Polynomial(constant, degree));
self
}
pub fn linear_kernel(mut self) -> Self {
self.0.kernel = Kernel::params().method(KernelMethod::Linear);
self
}
}
impl<F: Float, T> SvmParams<F, T> {
pub fn pos_neg_weights(mut self, c_pos: F, c_neg: F) -> Self {
self.0.c = Some((c_pos, c_neg));
self.0.nu = None;
self
}
pub fn nu_weight(mut self, nu: F) -> Self {
self.0.nu = Some((nu, nu));
self.0.c = None;
self
}
}
impl<F: Float> SvmParams<F, F> {
#[deprecated(since = "0.7.2", note = "Use .c_svr() and .eps()")]
pub fn c_eps(mut self, c: F, eps: F) -> Self {
self.0.c = Some((c, F::cast(0.1)));
self.0.nu = None;
self.0.solver_params.eps = eps;
self
}
#[deprecated(since = "0.7.2", note = "Use .nu_svr() and .eps()")]
pub fn nu_eps(mut self, nu: F, eps: F) -> Self {
self.0.nu = Some((nu, F::one()));
self.0.c = None;
self.0.solver_params.eps = eps;
self
}
pub fn c_svr(mut self, c: F, loss_eps: Option<F>) -> Self {
self.0.c = Some((c, loss_eps.unwrap_or(F::cast(0.1))));
self.0.nu = None;
self
}
pub fn nu_svr(mut self, nu: F, c: Option<F>) -> Self {
self.0.nu = Some((nu, c.unwrap_or(F::one())));
self.0.c = None;
self
}
}
impl<F: Float, L> Default for SvmParams<F, L> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float, L> Svm<F, L> {
pub fn params() -> SvmParams<F, L> {
SvmParams::new()
}
}
impl<F: Float, L> ParamGuard for SvmParams<F, L> {
type Checked = SvmValidParams<F, L>;
type Error = SvmError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
self.0.platt_params().check_ref()?;
if self.0.solver_params.eps.is_negative()
|| self.0.solver_params.eps.is_nan()
|| self.0.solver_params.eps.is_infinite()
{
return Err(SvmError::InvalidEps(
self.0.solver_params.eps.to_f32().unwrap(),
));
}
if let Some((c1, c2)) = self.0.c {
if c1 <= F::zero() || c2 <= F::zero() {
return Err(SvmError::InvalidC((
c1.to_f32().unwrap(),
c2.to_f32().unwrap(),
)));
}
}
if let Some((nu, _)) = self.0.nu {
if nu <= F::zero() || nu > F::one() {
return Err(SvmError::InvalidNu(nu.to_f32().unwrap()));
}
}
Ok(&self.0)
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}