use crate::predict::predict;
use crate::train::svm_train;
use crate::types::{SvmParameter, SvmProblem, SvmType};
fn rng_next(state: &mut u64) -> usize {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(*state >> 33) as usize
}
pub fn svm_cross_validation(
prob: &SvmProblem,
param: &SvmParameter,
mut nr_fold: usize,
) -> Vec<f64> {
let l = prob.labels.len();
if nr_fold > l {
eprintln!(
"WARNING: # folds ({}) > # data ({}). Will use # folds = # data instead \
(i.e., leave-one-out cross validation)",
nr_fold, l
);
nr_fold = l;
}
let mut rng: u64 = 1;
let mut perm: Vec<usize> = (0..l).collect();
let mut fold_start = vec![0usize; nr_fold + 1];
if matches!(param.svm_type, SvmType::CSvc | SvmType::NuSvc) && nr_fold < l {
let (label_list, start, count, group_perm) = group_classes(&prob.labels);
let nr_class = label_list.len();
let mut index = group_perm;
for c in 0..nr_class {
let s = start[c];
let n = count[c];
for i in 0..n {
let j = i + rng_next(&mut rng) % (n - i);
index.swap(s + i, s + j);
}
}
let mut fold_count = vec![0usize; nr_fold];
for i in 0..nr_fold {
for c in 0..nr_class {
fold_count[i] += ((i + 1) * count[c]) / nr_fold - (i * count[c]) / nr_fold;
}
}
fold_start[0] = 0;
for i in 1..=nr_fold {
fold_start[i] = fold_start[i - 1] + fold_count[i - 1];
}
let mut offset = vec![0usize; nr_fold];
for c in 0..nr_class {
for i in 0..nr_fold {
let begin = start[c] + (i * count[c]) / nr_fold;
let end = start[c] + ((i + 1) * count[c]) / nr_fold;
for j in begin..end {
perm[fold_start[i] + offset[i]] = index[j];
offset[i] += 1;
}
}
}
fold_start[0] = 0;
for i in 1..=nr_fold {
fold_start[i] = fold_start[i - 1] + fold_count[i - 1];
}
} else {
for i in 0..l {
let j = i + rng_next(&mut rng) % (l - i);
perm.swap(i, j);
}
for i in 0..=nr_fold {
fold_start[i] = i * l / nr_fold;
}
}
let mut target = vec![0.0; l];
for i in 0..nr_fold {
let begin = fold_start[i];
let end = fold_start[i + 1];
let sub_l = l - (end - begin);
let mut sub_labels = Vec::with_capacity(sub_l);
let mut sub_instances = Vec::with_capacity(sub_l);
for j in 0..begin {
sub_labels.push(prob.labels[perm[j]]);
sub_instances.push(prob.instances[perm[j]].clone());
}
for j in end..l {
sub_labels.push(prob.labels[perm[j]]);
sub_instances.push(prob.instances[perm[j]].clone());
}
let subprob = SvmProblem {
labels: sub_labels,
instances: sub_instances,
};
let submodel = svm_train(&subprob, param);
for j in begin..end {
target[perm[j]] = predict(&submodel, &prob.instances[perm[j]]);
}
}
target
}
fn group_classes(labels: &[f64]) -> (Vec<i32>, Vec<usize>, Vec<usize>, Vec<usize>) {
let l = labels.len();
let mut label_list: Vec<i32> = Vec::new();
let mut count: Vec<usize> = Vec::new();
let mut data_label = vec![0usize; l];
for i in 0..l {
let this_label = labels[i] as i32;
if let Some(pos) = label_list.iter().position(|&lab| lab == this_label) {
count[pos] += 1;
data_label[i] = pos;
} else {
data_label[i] = label_list.len();
label_list.push(this_label);
count.push(1);
}
}
let nr_class = label_list.len();
if nr_class == 2 && label_list[0] == -1 && label_list[1] == 1 {
label_list.swap(0, 1);
count.swap(0, 1);
for dl in data_label.iter_mut() {
*dl = if *dl == 0 { 1 } else { 0 };
}
}
let mut start = vec![0usize; nr_class];
for i in 1..nr_class {
start[i] = start[i - 1] + count[i - 1];
}
let mut perm = vec![0usize; l];
let mut start_copy = start.clone();
for i in 0..l {
let cls = data_label[i];
perm[start_copy[cls]] = i;
start_copy[cls] += 1;
}
(label_list, start, count, perm)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::load_problem;
use crate::types::{KernelType, SvmNode};
use std::path::PathBuf;
fn data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("data")
}
#[test]
fn cross_validation_basic() {
let labels = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let instances: Vec<Vec<SvmNode>> = (0..10)
.map(|i| vec![SvmNode { index: 1, value: i as f64 * 0.1 }])
.collect();
let prob = SvmProblem { labels, instances };
let param = SvmParameter {
kernel_type: KernelType::Linear,
..Default::default()
};
let target = svm_cross_validation(&prob, ¶m, 5);
assert_eq!(target.len(), 10);
for &pred in &target {
assert!(pred == 1.0 || pred == -1.0);
}
}
#[test]
fn cross_validation_classification() {
let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
let param = SvmParameter {
svm_type: SvmType::CSvc,
kernel_type: KernelType::Rbf,
gamma: 1.0 / 13.0,
c: 1.0,
cache_size: 100.0,
eps: 0.001,
shrinking: true,
..Default::default()
};
let target = svm_cross_validation(&problem, ¶m, 5);
assert_eq!(target.len(), problem.labels.len());
let correct = target
.iter()
.zip(problem.labels.iter())
.filter(|(&pred, &label)| pred == label)
.count();
let accuracy = correct as f64 / problem.labels.len() as f64;
assert!(
accuracy > 0.70,
"5-fold CV accuracy {:.1}% too low (expected >70%)",
accuracy * 100.0
);
}
#[test]
fn cross_validation_regression() {
let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
let param = SvmParameter {
svm_type: SvmType::EpsilonSvr,
kernel_type: KernelType::Rbf,
gamma: 1.0 / 13.0,
c: 1.0,
p: 0.1,
cache_size: 100.0,
eps: 0.001,
shrinking: true,
..Default::default()
};
let target = svm_cross_validation(&problem, ¶m, 5);
assert_eq!(target.len(), problem.labels.len());
let mse: f64 = target
.iter()
.zip(problem.labels.iter())
.map(|(&pred, &label)| (pred - label).powi(2))
.sum::<f64>()
/ problem.labels.len() as f64;
assert!(mse.is_finite(), "MSE is not finite");
assert!(mse < 500.0, "MSE {} too high", mse);
}
}