use crate::kernel::k_function;
use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
use crate::types::{SvmModel, SvmNode, SvmType};
pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
match model.param.svm_type {
SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
let sv_coef = &model.sv_coef[0];
let mut sum = 0.0;
for (i, sv) in model.sv.iter().enumerate() {
sum += sv_coef[i] * k_function(x, sv, &model.param);
}
sum -= model.rho[0];
dec_values[0] = sum;
if model.param.svm_type == SvmType::OneClass {
if sum > 0.0 {
1.0
} else {
-1.0
}
} else {
sum
}
}
SvmType::CSvc | SvmType::NuSvc => {
let nr_class = model.nr_class;
let l = model.sv.len();
let kvalue: Vec<f64> = model
.sv
.iter()
.map(|sv| k_function(x, sv, &model.param))
.collect();
let mut start = vec![0usize; nr_class];
for i in 1..nr_class {
start[i] = start[i - 1] + model.n_sv[i - 1];
}
let mut vote = vec![0usize; nr_class];
let mut p = 0;
for i in 0..nr_class {
for j in (i + 1)..nr_class {
let mut sum = 0.0;
let si = start[i];
let sj = start[j];
let ci = model.n_sv[i];
let cj = model.n_sv[j];
let coef1 = &model.sv_coef[j - 1];
let coef2 = &model.sv_coef[i];
for k in 0..ci {
sum += coef1[si + k] * kvalue[si + k];
}
for k in 0..cj {
sum += coef2[sj + k] * kvalue[sj + k];
}
sum -= model.rho[p];
dec_values[p] = sum;
if sum > 0.0 {
vote[i] += 1;
} else {
vote[j] += 1;
}
p += 1;
}
}
let vote_max_idx = vote
.iter()
.enumerate()
.max_by_key(|&(_, &v)| v)
.map(|(i, _)| i)
.unwrap_or(0);
let _ = l; model.label[vote_max_idx] as f64
}
}
}
pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
let n = match model.param.svm_type {
SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
SvmType::CSvc | SvmType::NuSvc => model.nr_class * (model.nr_class - 1) / 2,
};
let mut dec_values = vec![0.0; n];
predict_values(model, x, &mut dec_values)
}
#[allow(clippy::needless_range_loop)]
pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
match model.param.svm_type {
SvmType::CSvc | SvmType::NuSvc if !model.prob_a.is_empty() && !model.prob_b.is_empty() => {
let nr_class = model.nr_class;
let n_pairs = nr_class * (nr_class - 1) / 2;
let mut dec_values = vec![0.0; n_pairs];
predict_values(model, x, &mut dec_values);
let min_prob = 1e-7;
let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
let mut k = 0;
for i in 0..nr_class {
for j in (i + 1)..nr_class {
let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
.max(min_prob)
.min(1.0 - min_prob);
pairwise[i][j] = p;
pairwise[j][i] = 1.0 - p;
k += 1;
}
}
let mut prob_estimates = vec![0.0; nr_class];
if nr_class == 2 {
prob_estimates[0] = pairwise[0][1];
prob_estimates[1] = pairwise[1][0];
} else {
multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
}
let mut best = 0;
for i in 1..nr_class {
if prob_estimates[i] > prob_estimates[best] {
best = i;
}
}
Some((model.label[best] as f64, prob_estimates))
}
SvmType::OneClass if !model.prob_density_marks.is_empty() => {
let mut dec_value = [0.0];
let pred_result = predict_values(model, x, &mut dec_value);
let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
Some((pred_result, vec![p, 1.0 - p]))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::load_model;
use crate::io::load_problem;
use std::path::PathBuf;
fn data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("data")
}
#[test]
fn predict_heart_scale() {
let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
let mut correct = 0;
for (i, instance) in problem.instances.iter().enumerate() {
let pred = predict(&model, instance);
if pred == problem.labels[i] {
correct += 1;
}
}
let accuracy = correct as f64 / problem.labels.len() as f64;
assert!(
accuracy > 0.85,
"accuracy {:.2}% too low (expected >85%)",
accuracy * 100.0
);
}
#[test]
fn predict_values_binary() {
let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
let mut dec_values = vec![0.0; 1];
let label = predict_values(&model, &problem.instances[0], &mut dec_values);
assert!(dec_values[0].abs() > 1e-10);
assert_eq!(label, predict(&model, &problem.instances[0]));
}
#[test]
fn predict_probability_binary() {
let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
let param = crate::types::SvmParameter {
svm_type: SvmType::CSvc,
kernel_type: crate::types::KernelType::Rbf,
gamma: 1.0 / 13.0,
c: 1.0,
cache_size: 100.0,
eps: 0.001,
shrinking: true,
probability: true,
..Default::default()
};
let model = crate::train::svm_train(&problem, ¶m);
assert!(!model.prob_a.is_empty());
for instance in &problem.instances {
let result = predict_probability(&model, instance);
assert!(result.is_some(), "should return probability");
let (label, probs) = result.unwrap();
assert!(label == 1.0 || label == -1.0);
assert_eq!(probs.len(), 2);
let sum: f64 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"probs sum to {}, expected 1.0",
sum
);
for &p in &probs {
assert!((0.0..=1.0).contains(&p), "prob {} out of [0,1]", p);
}
}
}
#[test]
fn predict_probability_multiclass() {
let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
let param = crate::types::SvmParameter {
svm_type: SvmType::CSvc,
kernel_type: crate::types::KernelType::Rbf,
gamma: 0.25,
c: 1.0,
cache_size: 100.0,
eps: 0.001,
shrinking: true,
probability: true,
..Default::default()
};
let model = crate::train::svm_train(&problem, ¶m);
assert_eq!(model.nr_class, 3);
assert_eq!(model.prob_a.len(), 3);
for instance in problem.instances.iter().take(10) {
let result = predict_probability(&model, instance);
assert!(result.is_some());
let (_label, probs) = result.unwrap();
assert_eq!(probs.len(), 3);
let sum: f64 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"probs sum to {}, expected 1.0",
sum
);
}
}
#[test]
fn predict_matches_c_svm_predict() {
let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("vendor")
.join("libsvm")
.join("svm-predict");
if !c_predict.exists() {
return;
}
let output_path = data_dir().join("heart_scale.predict_test");
let status = std::process::Command::new(&c_predict)
.args([
data_dir().join("heart_scale").to_str().unwrap(),
data_dir().join("heart_scale.model").to_str().unwrap(),
output_path.to_str().unwrap(),
])
.output();
if let Ok(output) = status {
if output.status.success() {
let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
.unwrap()
.lines()
.filter(|l| !l.is_empty())
.map(|l| l.trim().parse().unwrap())
.collect();
assert_eq!(c_preds.len(), problem.labels.len());
let mut mismatches = 0;
for (i, instance) in problem.instances.iter().enumerate() {
let rust_pred = predict(&model, instance);
if rust_pred != c_preds[i] {
mismatches += 1;
}
}
assert_eq!(
mismatches, 0,
"{} predictions differ between Rust and C",
mismatches
);
let _ = std::fs::remove_file(&output_path);
}
}
}
}