use linfa::prelude::*;
use linfa_svm::{error::Result, Svm};
fn main() -> Result<()> {
let (train, valid) = linfa_datasets::winequality()
.map_targets(|x| *x > 6)
.split_with_ratio(0.9);
println!(
"Fit SVM classifier with #{} training points",
train.nsamples()
);
let model = Svm::<_, bool>::params()
.pos_neg_weights(50000., 5000.)
.gaussian_kernel(80.0)
.fit(&train)?;
println!("{model}",);
fn tag_classes(x: &bool) -> String {
if *x {
"good".into()
} else {
"bad".into()
}
}
let valid = valid.map_targets(tag_classes);
let pred = model.predict(&valid).map(tag_classes);
let cm = pred.confusion_matrix(&valid)?;
println!("{cm:?}",);
println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
Ok(())
}