use libsvm_rs::cross_validation::svm_cross_validation;
use libsvm_rs::io::{
load_model_from_reader, load_problem, load_problem_from_reader, save_model_to_writer,
};
use libsvm_rs::predict::predict;
use libsvm_rs::train::svm_train;
use libsvm_rs::types::{KernelType, SvmModel, SvmNode, SvmParameter, SvmProblem, SvmType};
use proptest::prelude::*;
use std::path::Path;
fn load_heart_scale() -> SvmProblem {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/../../data/heart_scale");
load_problem(Path::new(path)).expect("Failed to load heart_scale dataset")
}
fn unique_labels(prob: &SvmProblem) -> Vec<f64> {
let mut labels: Vec<f64> = prob.labels.to_vec();
labels.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
labels.dedup();
labels
}
fn finite_value_strategy() -> impl Strategy<Value = f64> {
-1000.0f64..1000.0
}
fn sparse_instance_strategy() -> impl Strategy<Value = Vec<SvmNode>> {
prop::collection::vec((1i32..64, finite_value_strategy()), 0..12).prop_map(|mut pairs| {
pairs.sort_by_key(|(index, _)| *index);
pairs.dedup_by_key(|(index, _)| *index);
pairs
.into_iter()
.map(|(index, value)| SvmNode { index, value })
.collect()
})
}
fn problem_strategy() -> impl Strategy<Value = SvmProblem> {
(1usize..16).prop_flat_map(|len| {
(
prop::collection::vec(finite_value_strategy(), len),
prop::collection::vec(sparse_instance_strategy(), len),
)
.prop_map(|(labels, instances)| SvmProblem { labels, instances })
})
}
fn binary_model_strategy() -> impl Strategy<Value = SvmModel> {
(1usize..12).prop_flat_map(|total_sv| {
(
0usize..=total_sv,
prop::collection::vec(sparse_instance_strategy(), total_sv),
prop::collection::vec(finite_value_strategy(), total_sv),
finite_value_strategy(),
)
.prop_map(move |(split, sv, coef, rho)| SvmModel {
param: SvmParameter {
svm_type: SvmType::CSvc,
kernel_type: KernelType::Linear,
gamma: 0.0,
..Default::default()
},
nr_class: 2,
sv,
sv_coef: vec![coef],
rho: vec![rho],
prob_a: Vec::new(),
prob_b: Vec::new(),
prob_density_marks: Vec::new(),
sv_indices: (1..=total_sv).collect(),
label: vec![1, -1],
n_sv: vec![split, total_sv - split],
})
})
}
fn problem_to_text(prob: &SvmProblem) -> String {
let mut out = String::new();
for (label, instance) in prob.labels.iter().zip(prob.instances.iter()) {
out.push_str(&label.to_string());
for node in instance {
out.push(' ');
out.push_str(&node.index.to_string());
out.push(':');
out.push_str(&node.value.to_string());
}
out.push('\n');
}
out
}
#[test]
fn kernel_deterministic() {
libsvm_rs::set_quiet(true);
let instances = vec![
vec![
SvmNode {
index: 1,
value: 2.5,
},
SvmNode {
index: 5,
value: -1.3,
},
SvmNode {
index: 18,
value: 0.7,
},
],
vec![
SvmNode {
index: 2,
value: 1.1,
},
SvmNode {
index: 8,
value: 3.2,
},
SvmNode {
index: 15,
value: -2.1,
},
],
vec![
SvmNode {
index: 3,
value: -0.5,
},
SvmNode {
index: 10,
value: 1.9,
},
SvmNode {
index: 20,
value: 2.8,
},
],
vec![
SvmNode {
index: 1,
value: 1.2,
},
SvmNode {
index: 4,
value: -1.5,
},
SvmNode {
index: 12,
value: 0.3,
},
],
];
let prob = SvmProblem {
labels: vec![1.0, -1.0, 1.0, -1.0],
instances,
};
let param = SvmParameter {
gamma: 1.0 / 20.0, shrinking: false,
eps: 0.01,
..Default::default()
};
let model = svm_train(&prob, ¶m);
let test_instance = &prob.instances[0];
let pred1 = predict(&model, test_instance);
let pred2 = predict(&model, test_instance);
assert_eq!(
pred1, pred2,
"Predictions should be deterministic; got {} and {}",
pred1, pred2
);
}
#[test]
fn predict_deterministic() {
libsvm_rs::set_quiet(true);
let prob = load_heart_scale();
let param = SvmParameter {
gamma: 1.0 / 13.0,
..Default::default()
};
let model = svm_train(&prob, ¶m);
let test_indices = vec![0, 1, 2, 3, 4];
for &idx in &test_indices {
let test_instance = &prob.instances[idx];
let pred1 = predict(&model, test_instance);
let pred2 = predict(&model, test_instance);
assert_eq!(
pred1, pred2,
"Prediction for instance {} should be deterministic; got {} and {}",
idx, pred1, pred2
);
}
}
#[test]
fn train_predict_labels_in_range() {
libsvm_rs::set_quiet(true);
let prob = load_heart_scale();
let valid_labels = unique_labels(&prob);
let param = SvmParameter {
gamma: 1.0 / 13.0,
..Default::default()
};
let model = svm_train(&prob, ¶m);
for (idx, instance) in prob.instances.iter().enumerate() {
let pred = predict(&model, instance);
assert!(
valid_labels.contains(&pred),
"Instance {} prediction {} is not in training labels {:?}",
idx,
pred,
valid_labels
);
}
}
#[test]
fn cross_validation_results_valid() {
libsvm_rs::set_quiet(true);
let prob = load_heart_scale();
let valid_labels = unique_labels(&prob);
let param = SvmParameter {
gamma: 1.0 / 13.0,
..Default::default()
};
let cv_targets = svm_cross_validation(&prob, ¶m, 5);
assert_eq!(
cv_targets.len(),
prob.labels.len(),
"CV output length should match problem size"
);
for (idx, &target) in cv_targets.iter().enumerate() {
assert!(
target.is_finite(),
"CV target[{}] = {} is not finite",
idx,
target
);
assert!(
valid_labels.contains(&target),
"CV target[{}] = {} is not in training labels {:?}",
idx,
target,
valid_labels
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(96))]
#[test]
fn problem_loader_roundtrip_stability(prob in problem_strategy()) {
let text = problem_to_text(&prob);
let loaded = load_problem_from_reader(text.as_bytes())
.expect("generated problem should parse");
let loaded_again = load_problem_from_reader(text.as_bytes())
.expect("generated problem should parse repeatedly");
prop_assert_eq!(&loaded, &prob);
prop_assert_eq!(&loaded_again, &loaded);
}
#[test]
fn model_save_load_save_is_byte_stable(model in binary_model_strategy()) {
let mut first = Vec::new();
save_model_to_writer(&mut first, &model)
.expect("generated model should serialize");
let loaded = load_model_from_reader(first.as_slice())
.expect("serialized model should parse");
let mut second = Vec::new();
save_model_to_writer(&mut second, &loaded)
.expect("loaded model should serialize");
prop_assert_eq!(first, second);
prop_assert_eq!(loaded.param.svm_type, SvmType::CSvc);
prop_assert_eq!(loaded.param.kernel_type, KernelType::Linear);
prop_assert_eq!(loaded.nr_class, 2);
prop_assert_eq!(loaded.sv.len(), model.sv.len());
prop_assert_eq!(loaded.sv_coef.len(), 1);
prop_assert_eq!(loaded.n_sv.iter().sum::<usize>(), loaded.sv.len());
}
}