use crfs::Attribute;
use crfs::train::Trainer;
use std::path::Path;
#[test]
fn test_ap_basic_training() {
let xseq = vec![
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("walk", 1.0), Attribute::new("clean", 0.5)],
vec![Attribute::new("shop", 0.5), Attribute::new("clean", 0.5)],
vec![Attribute::new("walk", 0.5), Attribute::new("clean", 1.0)],
vec![Attribute::new("clean", 1.0), Attribute::new("shop", 0.1)],
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![],
vec![Attribute::new("clean", 1.0)],
];
let yseq = [
"sunny", "sunny", "sunny", "rainy", "rainy", "rainy", "sunny", "sunny", "rainy",
];
let mut trainer = Trainer::averaged_perceptron();
trainer.verbose(true);
trainer.params_mut().set_max_iterations(50).unwrap();
trainer.params_mut().set_epsilon(0.01).unwrap();
trainer.params_mut().set_shuffle_seed(Some(1));
trainer.append(&xseq, &yseq).unwrap();
let model_path = Path::new("/tmp/test_ap.crfsuite");
trainer.train(model_path).unwrap();
assert!(model_path.exists());
let model_data = std::fs::read(model_path).unwrap();
let model = crfs::Model::new(&model_data).unwrap();
let tagger = model.tagger().unwrap();
let predicted = tagger.tag(&xseq).unwrap();
let mut correct = 0;
for (pred, true_label) in predicted.iter().zip(yseq.iter()) {
if pred == true_label {
correct += 1;
}
}
let accuracy = correct as f64 / yseq.len() as f64;
println!("AP Accuracy: {:.2}%", accuracy * 100.0);
assert!(
accuracy > 0.7,
"AP accuracy too low: {:.2}%",
accuracy * 100.0
);
}
#[test]
fn test_ap_no_verbose() {
let xseq = vec![
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("clean", 1.0)],
];
let yseq = ["sunny", "rainy"];
let mut trainer = Trainer::averaged_perceptron();
trainer.verbose(false);
trainer.params_mut().set_max_iterations(10).unwrap();
trainer.params_mut().set_shuffle_seed(Some(1));
trainer.append(&xseq, &yseq).unwrap();
let model_path = Path::new("/tmp/test_ap_quiet.crfsuite");
trainer.train(model_path).unwrap();
assert!(model_path.exists());
}
#[test]
fn test_ap_convergence() {
let xseq = vec![
vec![Attribute::new("a", 1.0)],
vec![Attribute::new("b", 1.0)],
vec![Attribute::new("a", 1.0)],
vec![Attribute::new("b", 1.0)],
];
let yseq = ["X", "Y", "X", "Y"];
let mut trainer = Trainer::averaged_perceptron();
trainer.verbose(true);
trainer.params_mut().set_max_iterations(100).unwrap();
trainer.params_mut().set_epsilon(0.000001).unwrap(); trainer.params_mut().set_shuffle_seed(Some(1));
trainer.append(&xseq, &yseq).unwrap();
let model_path = Path::new("/tmp/test_ap_converge.crfsuite");
trainer.train(model_path).unwrap();
let model_data = std::fs::read(model_path).unwrap();
let model = crfs::Model::new(&model_data).unwrap();
let tagger = model.tagger().unwrap();
let predicted = tagger.tag(&xseq).unwrap();
assert_eq!(predicted, yseq);
}
#[test]
fn test_ap_vs_lbfgs() {
let xseq = vec![
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("walk", 1.0), Attribute::new("clean", 0.5)],
vec![Attribute::new("shop", 0.5), Attribute::new("clean", 0.5)],
vec![Attribute::new("walk", 0.5), Attribute::new("clean", 1.0)],
vec![Attribute::new("clean", 1.0), Attribute::new("shop", 0.1)],
];
let yseq = ["sunny", "sunny", "sunny", "rainy", "rainy", "rainy"];
let mut ap_trainer = Trainer::averaged_perceptron();
ap_trainer.verbose(true);
ap_trainer.params_mut().set_max_iterations(100).unwrap();
ap_trainer.params_mut().set_epsilon(0.001).unwrap();
ap_trainer.params_mut().set_shuffle_seed(Some(1));
ap_trainer.append(&xseq, &yseq).unwrap();
let ap_model_path = Path::new("/tmp/test_ap_compare.crfsuite");
ap_trainer.train(ap_model_path).unwrap();
let mut lbfgs_trainer = Trainer::lbfgs();
lbfgs_trainer.verbose(false);
lbfgs_trainer.params_mut().set_max_iterations(50).unwrap();
lbfgs_trainer.append(&xseq, &yseq).unwrap();
let lbfgs_model_path = Path::new("/tmp/test_lbfgs_compare.crfsuite");
lbfgs_trainer.train(lbfgs_model_path).unwrap();
let ap_model_data = std::fs::read(ap_model_path).unwrap();
let ap_model = crfs::Model::new(&ap_model_data).unwrap();
let ap_tagger = ap_model.tagger().unwrap();
let ap_predicted = ap_tagger.tag(&xseq).unwrap();
let lbfgs_model_data = std::fs::read(lbfgs_model_path).unwrap();
let lbfgs_model = crfs::Model::new(&lbfgs_model_data).unwrap();
let lbfgs_tagger = lbfgs_model.tagger().unwrap();
let lbfgs_predicted = lbfgs_tagger.tag(&xseq).unwrap();
let ap_correct = ap_predicted
.iter()
.zip(yseq.iter())
.filter(|(p, t)| p == t)
.count();
let lbfgs_correct = lbfgs_predicted
.iter()
.zip(yseq.iter())
.filter(|(p, t)| p == t)
.count();
let ap_accuracy = ap_correct as f64 / yseq.len() as f64;
let lbfgs_accuracy = lbfgs_correct as f64 / yseq.len() as f64;
println!("AP Accuracy: {:.2}%", ap_accuracy * 100.0);
println!("LBFGS Accuracy: {:.2}%", lbfgs_accuracy * 100.0);
assert!(
ap_accuracy > 0.5,
"AP accuracy too low: {:.2}%",
ap_accuracy * 100.0
);
assert!(
lbfgs_accuracy > 0.7,
"LBFGS accuracy too low: {:.2}%",
lbfgs_accuracy * 100.0
);
}