use crfs::Attribute;
use crfs::train::Trainer;
#[test]
fn test_basic_training() {
let xseq = vec![
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("walk", 1.0), Attribute::new("clean", 0.5)],
vec![Attribute::new("shop", 0.5), Attribute::new("clean", 0.5)],
vec![Attribute::new("walk", 0.5), Attribute::new("clean", 1.0)],
vec![Attribute::new("clean", 1.0), Attribute::new("shop", 0.1)],
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![],
vec![Attribute::new("clean", 1.0)],
];
let yseq = vec![
"sunny", "sunny", "sunny", "rainy", "rainy", "rainy", "sunny", "sunny", "rainy",
];
let mut trainer = Trainer::lbfgs();
trainer.verbose(true).append(&xseq, &yseq).unwrap();
trainer.params_mut().set_c1(0.0).unwrap();
trainer.params_mut().set_c2(1.0).unwrap();
trainer.params_mut().set_max_iterations(50).unwrap();
let temp_file = tempfile::NamedTempFile::new().unwrap();
let model_path = temp_file.path();
let result = trainer.train(model_path);
match result {
Ok(_) => {
println!("Training completed successfully!");
assert!(temp_file.path().exists());
}
Err(e) => {
panic!("Training failed: {}", e);
}
}
}
#[test]
fn test_trainer_params() {
let mut trainer = Trainer::lbfgs();
trainer.params_mut().set_c1(0.5).unwrap();
trainer.params_mut().set_c2(2.0).unwrap();
trainer.params_mut().set_max_iterations(100).unwrap();
assert_eq!(trainer.params().c1(), 0.5);
assert!((trainer.params().c2() - 2.0).abs() < f64::EPSILON);
assert_eq!(trainer.params().max_iterations(), 100);
}
#[test]
fn test_trainer_validation() {
let mut trainer = Trainer::lbfgs();
let temp_file = tempfile::NamedTempFile::new().unwrap();
let model_path = temp_file.path();
let result = trainer.train(model_path);
assert!(result.is_err());
}
#[test]
fn test_lbfgs_with_l1_regularization() {
let xseq = vec![
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("walk", 1.0), Attribute::new("clean", 0.5)],
vec![Attribute::new("shop", 0.5), Attribute::new("clean", 0.5)],
vec![Attribute::new("walk", 0.5), Attribute::new("clean", 1.0)],
vec![Attribute::new("clean", 1.0), Attribute::new("shop", 0.1)],
];
let yseq = vec!["sunny", "sunny", "sunny", "rainy", "rainy", "rainy"];
let mut trainer = Trainer::lbfgs();
trainer.append(&xseq, &yseq).unwrap();
trainer.params_mut().set_c1(0.1).unwrap();
trainer.params_mut().set_c2(1.0).unwrap();
trainer.params_mut().set_max_iterations(50).unwrap();
let temp_file = tempfile::NamedTempFile::new().unwrap();
let result = trainer.train(temp_file.path());
assert!(result.is_ok(), "OWL-QN training failed: {:?}", result.err());
assert!(temp_file.path().exists());
}
#[test]
fn test_pruned_model_roundtrip() {
use crfs::Model;
let xseq = vec![
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("shop", 1.0)],
vec![Attribute::new("walk", 1.0)],
];
let yseq = vec!["sunny", "rainy", "sunny"];
let mut trainer = Trainer::lbfgs();
trainer.append(&xseq, &yseq).unwrap();
trainer.params_mut().set_max_iterations(20).unwrap();
let temp_file = tempfile::NamedTempFile::new().unwrap();
trainer.train(temp_file.path()).unwrap();
let model_data = std::fs::read(temp_file.path()).expect("Failed to read model file");
let model = Model::new(&model_data).expect("Failed to load pruned model");
let tagger = model.tagger().expect("Failed to create tagger");
let test_xseq = vec![
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("shop", 1.0)],
];
let result = tagger.tag(&test_xseq);
assert!(
result.is_ok(),
"Tagging with pruned model failed: {:?}",
result.err()
);
let tags = result.unwrap();
assert_eq!(tags.len(), 2);
for tag in &tags {
assert!(
*tag == "sunny" || *tag == "rainy",
"Unexpected tag: {}",
tag
);
}
}
#[test]
fn test_lbfgs_period_zero_disables_delta_test() {
let mut trainer = Trainer::lbfgs();
trainer.params_mut().set_period(0);
assert_eq!(trainer.params().period(), 0);
let xseq = vec![
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("shop", 1.0)],
];
let yseq = vec!["sunny", "rainy"];
trainer.append(&xseq, &yseq).unwrap();
trainer.params_mut().set_max_iterations(10).unwrap();
let temp_file = tempfile::NamedTempFile::new().unwrap();
let result = trainer.train(temp_file.path());
assert!(
result.is_ok(),
"Training with period=0 failed: {:?}",
result.err()
);
}