#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(i32)]
pub enum SvmType {
CSvc = 0,
NuSvc = 1,
OneClass = 2,
EpsilonSvr = 3,
NuSvr = 4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(i32)]
pub enum KernelType {
Linear = 0,
Polynomial = 1,
Rbf = 2,
Sigmoid = 3,
Precomputed = 4,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SvmNode {
pub index: i32,
pub value: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmProblem {
pub labels: Vec<f64>,
pub instances: Vec<Vec<SvmNode>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmParameter {
pub svm_type: SvmType,
pub kernel_type: KernelType,
pub degree: i32,
pub gamma: f64,
pub coef0: f64,
pub cache_size: f64,
pub eps: f64,
pub c: f64,
pub weight: Vec<(i32, f64)>,
pub nu: f64,
pub p: f64,
pub shrinking: bool,
pub probability: bool,
}
impl Default for SvmParameter {
fn default() -> Self {
Self {
svm_type: SvmType::CSvc,
kernel_type: KernelType::Rbf,
degree: 3,
gamma: 0.0, coef0: 0.0,
cache_size: 100.0,
eps: 0.001,
c: 1.0,
weight: Vec::new(),
nu: 0.5,
p: 0.1,
shrinking: true,
probability: false,
}
}
}
impl SvmParameter {
pub fn validate(&self) -> Result<(), crate::error::SvmError> {
use crate::error::SvmError;
if matches!(
self.kernel_type,
KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
) && self.gamma < 0.0
{
return Err(SvmError::InvalidParameter("gamma < 0".into()));
}
if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
return Err(SvmError::InvalidParameter(
"degree of polynomial kernel < 0".into(),
));
}
if self.cache_size <= 0.0 {
return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
}
if self.eps <= 0.0 {
return Err(SvmError::InvalidParameter("eps <= 0".into()));
}
if matches!(
self.svm_type,
SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
) && self.c <= 0.0
{
return Err(SvmError::InvalidParameter("C <= 0".into()));
}
if matches!(
self.svm_type,
SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
) && (self.nu <= 0.0 || self.nu > 1.0)
{
return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
}
if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
return Err(SvmError::InvalidParameter("p < 0".into()));
}
Ok(())
}
}
pub fn check_parameter(
problem: &SvmProblem,
param: &SvmParameter,
) -> Result<(), crate::error::SvmError> {
use crate::error::SvmError;
param.validate()?;
if problem.labels.len() != problem.instances.len() {
return Err(SvmError::InvalidParameter(format!(
"labels length ({}) does not match instance length ({})",
problem.labels.len(),
problem.instances.len()
)));
}
if problem.labels.is_empty() {
return Err(SvmError::InvalidParameter(
"problem has no instances".into(),
));
}
if param.kernel_type == KernelType::Precomputed {
let upper = problem.instances.len() as f64;
for (row, instance) in problem.instances.iter().enumerate() {
let first = instance.first().ok_or_else(|| {
SvmError::InvalidParameter(format!(
"precomputed kernel row {} is missing 0:sample_serial_number",
row + 1
))
})?;
if first.index != 0
|| !first.value.is_finite()
|| first.value < 1.0
|| first.value > upper
|| first.value.fract() != 0.0
{
return Err(SvmError::InvalidParameter(format!(
"precomputed kernel row {} must start with 0:sample_serial_number in [1, {}]",
row + 1,
problem.instances.len()
)));
}
}
}
if param.svm_type == SvmType::NuSvc {
let mut class_counts: Vec<(i32, usize)> = Vec::new();
for &y in &problem.labels {
let label = y as i32;
if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
entry.1 += 1;
} else {
class_counts.push((label, 1));
}
}
for (i, &(_, n1)) in class_counts.iter().enumerate() {
for &(_, n2) in &class_counts[i + 1..] {
if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
return Err(SvmError::InvalidParameter(
"specified nu is infeasible".into(),
));
}
}
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
pub struct SvmModel {
pub param: SvmParameter,
pub nr_class: usize,
pub sv: Vec<Vec<SvmNode>>,
pub sv_coef: Vec<Vec<f64>>,
pub rho: Vec<f64>,
pub prob_a: Vec<f64>,
pub prob_b: Vec<f64>,
pub prob_density_marks: Vec<f64>,
pub sv_indices: Vec<usize>,
pub label: Vec<i32>,
pub n_sv: Vec<usize>,
}
impl SvmModel {
pub fn svm_type(&self) -> SvmType {
self.param.svm_type
}
pub fn class_count(&self) -> usize {
self.nr_class
}
pub fn labels(&self) -> &[i32] {
&self.label
}
pub fn support_vector_indices(&self) -> &[usize] {
&self.sv_indices
}
pub fn support_vector_count(&self) -> usize {
self.sv.len()
}
pub fn svr_probability(&self) -> Option<f64> {
match self.param.svm_type {
SvmType::EpsilonSvr | SvmType::NuSvr => self.prob_a.first().copied(),
_ => None,
}
}
pub fn has_probability_model(&self) -> bool {
match self.param.svm_type {
SvmType::CSvc | SvmType::NuSvc => !self.prob_a.is_empty() && !self.prob_b.is_empty(),
SvmType::EpsilonSvr | SvmType::NuSvr => !self.prob_a.is_empty(),
SvmType::OneClass => !self.prob_density_marks.is_empty(),
}
}
}
pub fn svm_get_svm_type(model: &SvmModel) -> SvmType {
model.svm_type()
}
pub fn svm_get_nr_class(model: &SvmModel) -> usize {
model.class_count()
}
pub fn svm_get_labels(model: &SvmModel) -> &[i32] {
model.labels()
}
pub fn svm_get_sv_indices(model: &SvmModel) -> &[usize] {
model.support_vector_indices()
}
pub fn svm_get_nr_sv(model: &SvmModel) -> usize {
model.support_vector_count()
}
pub fn svm_get_svr_probability(model: &SvmModel) -> Option<f64> {
model.svr_probability()
}
pub fn svm_check_probability_model(model: &SvmModel) -> bool {
model.has_probability_model()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::train::svm_train;
use std::path::PathBuf;
fn data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("data")
}
#[test]
fn default_params_are_valid() {
SvmParameter::default().validate().unwrap();
}
#[test]
fn negative_gamma_rejected() {
let p = SvmParameter {
gamma: -1.0,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn zero_cache_rejected() {
let p = SvmParameter {
cache_size: 0.0,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn zero_c_rejected() {
let p = SvmParameter {
c: 0.0,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn nu_out_of_range_rejected() {
let p = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 1.5,
..Default::default()
};
assert!(p.validate().is_err());
let p2 = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.0,
..Default::default()
};
assert!(p2.validate().is_err());
}
#[test]
fn negative_p_rejected_for_svr() {
let p = SvmParameter {
svm_type: SvmType::EpsilonSvr,
p: -0.1,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn negative_poly_degree_rejected() {
let p = SvmParameter {
kernel_type: KernelType::Polynomial,
degree: -1,
..Default::default()
};
assert!(p.validate().is_err());
}
#[test]
fn check_parameter_rejects_empty_problem() {
let problem = SvmProblem {
labels: Vec::new(),
instances: Vec::new(),
};
let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
assert!(format!("{}", err).contains("problem has no instances"));
}
#[test]
fn check_parameter_rejects_label_instance_length_mismatch() {
let problem = SvmProblem {
labels: vec![1.0],
instances: Vec::new(),
};
let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
assert!(format!("{}", err).contains("does not match instance length"));
}
#[test]
fn check_parameter_rejects_precomputed_rows_without_sample_serial_number() {
let problem = SvmProblem {
labels: vec![1.0, -1.0],
instances: vec![
vec![],
vec![SvmNode {
index: 0,
value: 2.0,
}],
],
};
let param = SvmParameter {
kernel_type: KernelType::Precomputed,
..Default::default()
};
let err = check_parameter(&problem, ¶m).unwrap_err();
assert!(format!("{}", err).contains("missing 0:sample_serial_number"));
}
#[test]
fn nu_svc_feasibility_check() {
let problem = SvmProblem {
labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
instances: vec![vec![]; 6],
};
let ok_param = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.5,
..Default::default()
};
check_parameter(&problem, &ok_param).unwrap();
let borderline = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.9,
..Default::default()
};
check_parameter(&problem, &borderline).unwrap();
}
#[test]
fn nu_svc_infeasible() {
let problem = SvmProblem {
labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
instances: vec![vec![]; 6],
};
let param = SvmParameter {
svm_type: SvmType::NuSvc,
nu: 0.5, ..Default::default()
};
let err = check_parameter(&problem, ¶m);
assert!(err.is_err());
assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
}
#[test]
fn c_api_style_model_helpers() {
let problem = crate::io::load_problem(&data_dir().join("heart_scale")).unwrap();
let param = SvmParameter {
gamma: 1.0 / 13.0,
..Default::default()
};
let model = svm_train(&problem, ¶m);
assert_eq!(svm_get_svm_type(&model), SvmType::CSvc);
assert_eq!(svm_get_nr_class(&model), 2);
assert_eq!(svm_get_nr_sv(&model), model.sv.len());
assert_eq!(svm_get_labels(&model), model.label.as_slice());
assert_eq!(svm_get_sv_indices(&model), model.sv_indices.as_slice());
assert!(!svm_check_probability_model(&model));
assert_eq!(svm_get_svr_probability(&model), None);
}
#[test]
fn probability_helpers_by_svm_type() {
let svm = vec![SvmNode {
index: 1,
value: 1.0,
}];
let csvc_model = SvmModel {
param: SvmParameter {
svm_type: SvmType::CSvc,
..Default::default()
},
nr_class: 2,
sv: vec![svm.clone()],
sv_coef: vec![vec![1.0]],
rho: vec![0.0],
prob_a: vec![1.0],
prob_b: vec![-0.5],
prob_density_marks: vec![],
sv_indices: vec![1],
label: vec![1, -1],
n_sv: vec![1, 0],
};
assert!(csvc_model.has_probability_model());
assert!(svm_check_probability_model(&csvc_model));
assert_eq!(svm_get_svr_probability(&csvc_model), None);
let eps_svr_model = SvmModel {
param: SvmParameter {
svm_type: SvmType::EpsilonSvr,
..Default::default()
},
nr_class: 2,
sv: vec![svm.clone()],
sv_coef: vec![vec![0.8]],
rho: vec![0.0],
prob_a: vec![0.123],
prob_b: vec![],
prob_density_marks: vec![],
sv_indices: vec![1],
label: vec![],
n_sv: vec![],
};
assert!(eps_svr_model.has_probability_model());
assert_eq!(svm_get_svr_probability(&eps_svr_model), Some(0.123));
let one_class_model = SvmModel {
param: SvmParameter {
svm_type: SvmType::OneClass,
..Default::default()
},
nr_class: 2,
sv: vec![svm],
sv_coef: vec![vec![1.0]],
rho: vec![0.0],
prob_a: vec![],
prob_b: vec![],
prob_density_marks: vec![0.1; 10],
sv_indices: vec![1],
label: vec![],
n_sv: vec![],
};
assert!(one_class_model.has_probability_model());
assert_eq!(svm_get_svr_probability(&one_class_model), None);
}
}