use linfa::prelude::*;
use linfa_logistic::MultiLogisticRegression;
use std::error::Error;
fn main() -> Result<(), Box<dyn Error>> {
let (train, valid) = linfa_datasets::winequality().split_with_ratio(0.9);
println!(
"Fit Multinomial Logistic Regression classifier with #{} training points",
train.nsamples()
);
let model = MultiLogisticRegression::default()
.max_iterations(50)
.fit(&train)
.unwrap();
let pred = model.predict(&valid);
let cm = pred.confusion_matrix(&valid).unwrap();
println!("{cm:?}");
println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());
Ok(())
}