use crfs::{Attribute, Trainer};
use std::path::Path;
#[test]
fn test_l2sgd_basic_training() {
let xseq = vec![
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("walk", 1.0), Attribute::new("clean", 0.5)],
vec![Attribute::new("shop", 0.5), Attribute::new("clean", 0.5)],
vec![Attribute::new("walk", 0.5), Attribute::new("clean", 1.0)],
vec![Attribute::new("clean", 1.0), Attribute::new("shop", 0.1)],
];
let yseq = ["sunny", "sunny", "sunny", "rainy", "rainy", "rainy"];
let mut trainer = Trainer::l2sgd();
trainer.verbose(true);
trainer.params_mut().set_c2(1.0).unwrap();
trainer.params_mut().set_max_iterations(50).unwrap();
trainer.params_mut().set_period(10).unwrap();
trainer.append(&xseq, &yseq).unwrap();
let model_path = Path::new("/tmp/test_l2sgd.crfsuite");
trainer.train(model_path).unwrap();
assert!(model_path.exists());
let model_data = std::fs::read(model_path).unwrap();
let model = crfs::Model::new(&model_data).unwrap();
let tagger = model.tagger().unwrap();
let predicted = tagger.tag(&xseq).unwrap();
let mut correct = 0;
for (p, t) in predicted.iter().zip(yseq.iter()) {
if p == t {
correct += 1;
}
}
let accuracy = correct as f64 / yseq.len() as f64;
println!("L2SGD Accuracy: {:.2}%", accuracy * 100.0);
assert!(accuracy > 0.5, "L2SGD accuracy too low");
}
#[test]
fn test_l2sgd_calibration() {
let xseq = vec![
vec![Attribute::new("a", 1.0)],
vec![Attribute::new("b", 1.0)],
vec![Attribute::new("a", 1.0)],
vec![Attribute::new("b", 1.0)],
];
let yseq = ["X", "Y", "X", "Y"];
let mut trainer = Trainer::l2sgd();
trainer.verbose(true);
trainer.params_mut().set_c2(1.0).unwrap();
trainer.params_mut().set_max_iterations(20).unwrap();
trainer.params_mut().set_calibration_samples(4).unwrap();
trainer.params_mut().set_calibration_candidates(5).unwrap();
trainer.append(&xseq, &yseq).unwrap();
let model_path = Path::new("/tmp/test_l2sgd_calibration.crfsuite");
trainer.train(model_path).unwrap();
assert!(model_path.exists());
}
#[test]
fn test_l2sgd_vs_lbfgs() {
let xseq = vec![
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("walk", 1.0), Attribute::new("clean", 0.5)],
vec![Attribute::new("shop", 0.5), Attribute::new("clean", 0.5)],
vec![Attribute::new("walk", 0.5), Attribute::new("clean", 1.0)],
vec![Attribute::new("clean", 1.0), Attribute::new("shop", 0.1)],
];
let yseq = ["sunny", "sunny", "sunny", "rainy", "rainy", "rainy"];
let mut l2sgd_trainer = Trainer::l2sgd();
l2sgd_trainer.verbose(false);
l2sgd_trainer.params_mut().set_c2(1.0).unwrap();
l2sgd_trainer.params_mut().set_max_iterations(100).unwrap();
l2sgd_trainer.params_mut().set_period(10).unwrap();
l2sgd_trainer.append(&xseq, &yseq).unwrap();
let l2sgd_model_path = Path::new("/tmp/test_l2sgd_compare.crfsuite");
l2sgd_trainer.train(l2sgd_model_path).unwrap();
let mut lbfgs_trainer = Trainer::lbfgs();
lbfgs_trainer.verbose(false);
lbfgs_trainer.params_mut().set_c1(0.0).unwrap();
lbfgs_trainer.params_mut().set_c2(1.0).unwrap();
lbfgs_trainer.params_mut().set_max_iterations(100).unwrap();
lbfgs_trainer.append(&xseq, &yseq).unwrap();
let lbfgs_model_path = Path::new("/tmp/test_lbfgs_compare_l2sgd.crfsuite");
lbfgs_trainer.train(lbfgs_model_path).unwrap();
let l2sgd_model_data = std::fs::read(l2sgd_model_path).unwrap();
let l2sgd_model = crfs::Model::new(&l2sgd_model_data).unwrap();
let l2sgd_tagger = l2sgd_model.tagger().unwrap();
let l2sgd_predicted = l2sgd_tagger.tag(&xseq).unwrap();
let lbfgs_model_data = std::fs::read(lbfgs_model_path).unwrap();
let lbfgs_model = crfs::Model::new(&lbfgs_model_data).unwrap();
let lbfgs_tagger = lbfgs_model.tagger().unwrap();
let lbfgs_predicted = lbfgs_tagger.tag(&xseq).unwrap();
let l2sgd_correct = l2sgd_predicted
.iter()
.zip(yseq.iter())
.filter(|(p, t)| p == t)
.count();
let lbfgs_correct = lbfgs_predicted
.iter()
.zip(yseq.iter())
.filter(|(p, t)| p == t)
.count();
let l2sgd_accuracy = l2sgd_correct as f64 / yseq.len() as f64;
let lbfgs_accuracy = lbfgs_correct as f64 / yseq.len() as f64;
println!("L2SGD Accuracy: {:.2}%", l2sgd_accuracy * 100.0);
println!("LBFGS Accuracy: {:.2}%", lbfgs_accuracy * 100.0);
assert!(
l2sgd_accuracy > 0.5,
"L2SGD accuracy too low: {:.2}%",
l2sgd_accuracy * 100.0
);
assert!(
lbfgs_accuracy > 0.7,
"LBFGS accuracy too low: {:.2}%",
lbfgs_accuracy * 100.0
);
}
#[test]
fn test_l2sgd_parameter_validation() {
let mut trainer = Trainer::l2sgd();
assert!(trainer.params_mut().set_c2(1.0).is_ok());
assert!(trainer.params_mut().set_period(10).is_ok());
assert!(trainer.params_mut().set_delta(1e-5).is_ok());
assert!(trainer.params_mut().set_calibration_eta(0.1).is_ok());
assert!(trainer.params_mut().set_calibration_rate(2.0).is_ok());
assert!(trainer.params_mut().set_period(0).is_err()); assert!(trainer.params_mut().set_delta(0.0).is_err()); assert!(trainer.params_mut().set_calibration_eta(0.0).is_err()); assert!(trainer.params_mut().set_calibration_rate(1.0).is_err()); }
#[test]
fn test_l2sgd_convergence() {
let xseq = vec![
vec![Attribute::new("a", 1.0)],
vec![Attribute::new("b", 1.0)],
vec![Attribute::new("a", 1.0)],
vec![Attribute::new("b", 1.0)],
vec![Attribute::new("a", 1.0)],
vec![Attribute::new("b", 1.0)],
];
let yseq = ["X", "Y", "X", "Y", "X", "Y"];
let mut trainer = Trainer::l2sgd();
trainer.verbose(true);
trainer.params_mut().set_c2(1.0).unwrap();
trainer.params_mut().set_max_iterations(100).unwrap();
trainer.params_mut().set_period(5).unwrap();
trainer.params_mut().set_delta(1e-4).unwrap();
trainer.append(&xseq, &yseq).unwrap();
let model_path = Path::new("/tmp/test_l2sgd_converge.crfsuite");
trainer.train(model_path).unwrap();
let model_data = std::fs::read(model_path).unwrap();
let model = crfs::Model::new(&model_data).unwrap();
let tagger = model.tagger().unwrap();
let predicted = tagger.tag(&xseq).unwrap();
let correct = predicted
.iter()
.zip(yseq.iter())
.filter(|(p, t)| p == t)
.count();
let accuracy = correct as f64 / yseq.len() as f64;
assert!(accuracy > 0.5, "Accuracy too low: {:.2}%", accuracy * 100.0);
}