#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(i32)]
pub enum SvmType {
CSvc = 0,
NuSvc = 1,
OneClass = 2,
EpsilonSvr = 3,
NuSvr = 4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(i32)]
pub enum KernelType {
Linear = 0,
Polynomial = 1,
Rbf = 2,
Sigmoid = 3,
Precomputed = 4,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SvmNode {
pub index: i32,
pub value: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmProblem {
pub labels: Vec<f64>,
pub instances: Vec<Vec<SvmNode>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmParameter {
pub svm_type: SvmType,
pub kernel_type: KernelType,
pub degree: i32,
pub gamma: f64,
pub coef0: f64,
pub cache_size: f64,
pub eps: f64,
pub c: f64,
pub weight: Vec<(i32, f64)>,
pub nu: f64,
pub p: f64,
pub shrinking: bool,
pub probability: bool,
}
impl Default for SvmParameter {
fn default() -> Self {
Self {
svm_type: SvmType::CSvc,
kernel_type: KernelType::Rbf,
degree: 3,
gamma: 0.0, coef0: 0.0,
cache_size: 100.0,
eps: 0.001,
c: 1.0,
weight: Vec::new(),
nu: 0.5,
p: 0.1,
shrinking: true,
probability: false,
}
}
}
impl SvmParameter {
pub fn validate(&self) -> Result<(), crate::error::SvmError> {
use crate::error::SvmError;
if matches!(
self.kernel_type,
KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
) && self.gamma < 0.0
{
return Err(SvmError::InvalidParameter("gamma < 0".into()));
}
if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
return Err(SvmError::InvalidParameter(
"degree of polynomial kernel < 0".into(),
));
}
if self.cache_size <= 0.0 {
return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
}
if self.eps <= 0.0 {
return Err(SvmError::InvalidParameter("eps <= 0".into()));
}
if matches!(
self.svm_type,
SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
) && self.c <= 0.0
{
return Err(SvmError::InvalidParameter("C <= 0".into()));
}
if matches!(
self.svm_type,
SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
) && (self.nu <= 0.0 || self.nu > 1.0)
{
return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
}
if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
return Err(SvmError::InvalidParameter("p < 0".into()));
}
Ok(())
}
}
pub fn check_parameter(
problem: &SvmProblem,
param: &SvmParameter,
) -> Result<(), crate::error::SvmError> {
use crate::error::SvmError;
param.validate()?;
if param.svm_type == SvmType::NuSvc {
let mut class_counts: Vec<(i32, usize)> = Vec::new();
for &y in &problem.labels {
let label = y as i32;
if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
entry.1 += 1;
} else {
class_counts.push((label, 1));
}
}
for (i, &(_, n1)) in class_counts.iter().enumerate() {
for &(_, n2) in &class_counts[i + 1..] {
if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
return Err(SvmError::InvalidParameter(
"specified nu is infeasible".into(),
));
}
}
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmModel {
pub param: SvmParameter,
pub nr_class: usize,
pub sv: Vec<Vec<SvmNode>>,
pub sv_coef: Vec<Vec<f64>>,
pub rho: Vec<f64>,
pub prob_a: Vec<f64>,
pub prob_b: Vec<f64>,
pub prob_density_marks: Vec<f64>,
pub sv_indices: Vec<usize>,
pub label: Vec<i32>,
pub n_sv: Vec<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_params_are_valid() {
SvmParameter::default().validate().unwrap();
}
#[test]
fn negative_gamma_rejected() {
let p = SvmParameter {
gamma: -1.0,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn zero_cache_rejected() {
let p = SvmParameter {
cache_size: 0.0,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn zero_c_rejected() {
let p = SvmParameter {
c: 0.0,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn nu_out_of_range_rejected() {
let p = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 1.5,
..Default::default()
};
assert!(p.validate().is_err());
let p2 = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.0,
..Default::default()
};
assert!(p2.validate().is_err());
}
#[test]
fn negative_p_rejected_for_svr() {
let p = SvmParameter {
svm_type: SvmType::EpsilonSvr,
p: -0.1,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn negative_poly_degree_rejected() {
let p = SvmParameter {
kernel_type: KernelType::Polynomial,
degree: -1,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn nu_svc_feasibility_check() {
let problem = SvmProblem {
labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
instances: vec![vec![]; 6],
};
let ok_param = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.5,
..Default::default()
};
check_parameter(&problem, &ok_param).unwrap();
let borderline = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.9,
..Default::default()
};
check_parameter(&problem, &borderline).unwrap();
}
#[test]
fn nu_svc_infeasible() {
let problem = SvmProblem {
labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
instances: vec![vec![]; 6],
};
let param = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.5, ..Default::default()
};
let err = check_parameter(&problem, ¶m);
assert!(err.is_err());
assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
}
}