use crate::predict::{predict, predict_probability};
use crate::train::svm_train;
use crate::types::{SvmModel, SvmNode, SvmParameter, SvmProblem, SvmType};
use crate::util::{group_classes, shuffle_range};
fn predict_cv_target(model: &SvmModel, param: &SvmParameter, x: &[SvmNode]) -> f64 {
if param.probability && matches!(param.svm_type, SvmType::CSvc | SvmType::NuSvc) {
predict_probability(model, x)
.map(|(label, _)| label)
.unwrap_or_else(|| predict(model, x))
} else {
predict(model, x)
}
}
fn fold_starts(fold_count: &[usize]) -> Vec<usize> {
let mut fold_start = vec![0usize; fold_count.len() + 1];
for i in 0..fold_count.len() {
fold_start[i + 1] = fold_start[i] + fold_count[i];
}
fold_start
}
fn is_stratified_cv(param: &SvmParameter, nr_fold: usize, l: usize) -> bool {
matches!(param.svm_type, SvmType::CSvc | SvmType::NuSvc) && nr_fold < l
}
pub fn svm_cross_validation(
prob: &SvmProblem,
param: &SvmParameter,
mut nr_fold: usize,
) -> Vec<f64> {
let l = prob.labels.len();
if l == 0 {
return Vec::new();
}
if nr_fold == 0 {
crate::info(
"WARNING: # folds (0) <= 0. Will use # folds = # data instead \
(i.e., leave-one-out cross validation)\n",
);
nr_fold = l;
}
if nr_fold > l {
crate::info(&format!(
"WARNING: # folds ({}) > # data ({}). Will use # folds = # data instead \
(i.e., leave-one-out cross validation)\n",
nr_fold, l
));
nr_fold = l;
}
let mut rng: u64 = 1;
let mut perm: Vec<usize> = (0..l).collect();
let (_fold_count, fold_start) = if is_stratified_cv(param, nr_fold, l) {
let grouped = group_classes(&prob.labels);
let nr_class = grouped.label.len();
let start = &grouped.start;
let count = &grouped.count;
let mut index = grouped.perm;
for c in 0..nr_class {
let s = start[c];
let n = count[c];
shuffle_range(&mut index, s, n, &mut rng);
}
let fold_count: Vec<usize> = (0..nr_fold)
.map(|i| {
count
.iter()
.map(|&cnt| ((i + 1) * cnt) / nr_fold - (i * cnt) / nr_fold)
.sum()
})
.collect();
let fold_start = fold_starts(&fold_count);
let mut offset = vec![0usize; nr_fold];
for c in 0..nr_class {
for i in 0..nr_fold {
let begin = start[c] + (i * count[c]) / nr_fold;
let end = start[c] + ((i + 1) * count[c]) / nr_fold;
for &idx in &index[begin..end] {
perm[fold_start[i] + offset[i]] = idx;
offset[i] += 1;
}
}
}
(fold_count, fold_start)
} else {
shuffle_range(&mut perm, 0, l, &mut rng);
let fold_count: Vec<usize> = (0..nr_fold)
.map(|i| (i + 1) * l / nr_fold - i * l / nr_fold)
.collect();
let fold_start = fold_starts(&fold_count);
(fold_count, fold_start)
};
let mut target = vec![0.0; l];
for i in 0..nr_fold {
let begin = fold_start[i];
let end = fold_start[i + 1];
let sub_l = l - (end - begin);
let mut sub_labels = Vec::with_capacity(sub_l);
let mut sub_instances = Vec::with_capacity(sub_l);
for &pi in &perm[..begin] {
sub_labels.push(prob.labels[pi]);
sub_instances.push(prob.instances[pi].clone());
}
for &pi in &perm[end..l] {
sub_labels.push(prob.labels[pi]);
sub_instances.push(prob.instances[pi].clone());
}
let subprob = SvmProblem {
labels: sub_labels,
instances: sub_instances,
};
let submodel = svm_train(&subprob, param);
for j in begin..end {
target[perm[j]] = predict_cv_target(&submodel, param, &prob.instances[perm[j]]);
}
}
target
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::load_problem;
use crate::types::{KernelType, SvmModel, SvmNode};
use std::path::PathBuf;
fn data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("data")
}
#[test]
fn cross_validation_basic() {
let labels = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let instances: Vec<Vec<SvmNode>> = (0..10)
.map(|i| {
vec![SvmNode {
index: 1,
value: i as f64 * 0.1,
}]
})
.collect();
let prob = SvmProblem { labels, instances };
let param = SvmParameter {
kernel_type: KernelType::Linear,
..Default::default()
};
let target = svm_cross_validation(&prob, ¶m, 5);
assert_eq!(target.len(), 10);
for &pred in &target {
assert!(pred == 1.0 || pred == -1.0);
}
}
#[test]
fn cross_validation_classification() {
let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
let param = SvmParameter {
svm_type: SvmType::CSvc,
kernel_type: KernelType::Rbf,
gamma: 1.0 / 13.0,
c: 1.0,
cache_size: 100.0,
eps: 0.001,
shrinking: true,
..Default::default()
};
let target = svm_cross_validation(&problem, ¶m, 5);
assert_eq!(target.len(), problem.labels.len());
let correct = target
.iter()
.zip(problem.labels.iter())
.filter(|(&pred, &label)| pred == label)
.count();
let accuracy = correct as f64 / problem.labels.len() as f64;
assert!(
accuracy > 0.70,
"5-fold CV accuracy {:.1}% too low (expected >70%)",
accuracy * 100.0
);
}
#[test]
fn cross_validation_regression() {
let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
let param = SvmParameter {
svm_type: SvmType::EpsilonSvr,
kernel_type: KernelType::Rbf,
gamma: 1.0 / 13.0,
c: 1.0,
p: 0.1,
cache_size: 100.0,
eps: 0.001,
shrinking: true,
..Default::default()
};
let target = svm_cross_validation(&problem, ¶m, 5);
assert_eq!(target.len(), problem.labels.len());
let mse: f64 = target
.iter()
.zip(problem.labels.iter())
.map(|(&pred, &label)| (pred - label).powi(2))
.sum::<f64>()
/ problem.labels.len() as f64;
assert!(mse.is_finite(), "MSE is not finite");
assert!(mse < 500.0, "MSE {} too high", mse);
}
#[test]
fn cross_validation_zero_folds_clamps_to_leave_one_out() {
let labels = vec![1.0, -1.0, 1.0, -1.0];
let instances: Vec<Vec<SvmNode>> = vec![
vec![SvmNode {
index: 1,
value: 1.0,
}],
vec![SvmNode {
index: 1,
value: -1.0,
}],
vec![SvmNode {
index: 1,
value: 0.8,
}],
vec![SvmNode {
index: 1,
value: -0.9,
}],
];
let prob = SvmProblem { labels, instances };
let param = SvmParameter {
svm_type: SvmType::CSvc,
kernel_type: KernelType::Linear,
c: 1.0,
eps: 0.001,
..Default::default()
};
let target = svm_cross_validation(&prob, ¶m, 0);
assert_eq!(target.len(), prob.labels.len());
for &pred in &target {
assert!(pred == 1.0 || pred == -1.0);
}
}
#[test]
fn cross_validation_empty_problem_returns_empty() {
let prob = SvmProblem {
labels: Vec::new(),
instances: Vec::new(),
};
let target = svm_cross_validation(&prob, &SvmParameter::default(), 5);
assert!(target.is_empty());
}
#[test]
fn predict_cv_target_uses_probability_label_for_classification() {
let param = SvmParameter {
svm_type: SvmType::CSvc,
kernel_type: KernelType::Linear,
probability: true,
..Default::default()
};
let model = SvmModel {
param: param.clone(),
nr_class: 2,
sv: vec![
vec![SvmNode {
index: 1,
value: 1.0,
}],
vec![SvmNode {
index: 1,
value: 1.0,
}],
],
sv_coef: vec![vec![1.0, -1.0]],
rho: vec![-1.0],
prob_a: vec![1.0],
prob_b: vec![0.0],
prob_density_marks: Vec::new(),
sv_indices: vec![1, 2],
label: vec![1, -1],
n_sv: vec![1, 1],
};
let x = vec![SvmNode {
index: 1,
value: 1.0,
}];
let vote_label = predict(&model, &x);
let (prob_label, _) = predict_probability(&model, &x).unwrap();
assert_eq!(vote_label, 1.0);
assert_eq!(prob_label, -1.0);
assert_eq!(predict_cv_target(&model, ¶m, &x), prob_label);
}
}