extern crate crfsuite;
use crfsuite::{Algorithm, Attribute, CrfError, GraphicalModel, Model, Trainer};
#[test]
fn test_trainer_default_impl() {
let mut trainer = Trainer::default();
let ret = trainer.train("tests/test.crfsuite", 1i32);
assert_eq!(ret.err(), Some(CrfError::AlgorithmNotSelected));
}
#[test]
fn test_trainer_train_uninitialized() {
let mut trainer = Trainer::default();
let ret = trainer.train("tests/test.crfsuite", 1i32);
assert_eq!(ret.err(), Some(CrfError::AlgorithmNotSelected));
}
#[test]
fn test_trainer_train_empty_data() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::LBFGS, GraphicalModel::CRF1D)
.unwrap();
let ret = trainer.train("tests/test.crfsuite", -1i32);
assert_eq!(ret.err(), Some(CrfError::EmptyData));
}
#[test]
fn test_train_and_tag() {
let xseq = vec![
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![Attribute::new("walk", 1.0)],
vec![Attribute::new("walk", 1.0), Attribute::new("clean", 0.5)],
vec![Attribute::new("shop", 0.5), Attribute::new("clean", 0.5)],
vec![Attribute::new("walk", 0.5), Attribute::new("clean", 1.0)],
vec![Attribute::new("clean", 1.0), Attribute::new("shop", 0.1)],
vec![Attribute::new("walk", 1.0), Attribute::new("shop", 0.5)],
vec![],
vec![Attribute::new("clean", 1.0)],
];
let yseq = [
"sunny", "sunny", "sunny", "rainy", "rainy", "rainy", "sunny", "sunny", "rainy",
];
let mut trainer = Trainer::new(true);
trainer
.select(Algorithm::LBFGS, GraphicalModel::CRF1D)
.unwrap();
trainer.append(&xseq, &yseq, 0i32).unwrap();
trainer.train("tests/test.crfsuite", -1i32).unwrap();
drop(trainer);
let model = Model::from_file("tests/test.crfsuite").unwrap();
let mut tagger = model.tagger().unwrap();
let res = tagger.tag(&xseq).unwrap();
assert_eq!(res, yseq);
}
#[test]
fn test_clear_empty() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::LBFGS, GraphicalModel::CRF1D)
.unwrap();
trainer.clear().unwrap();
}
#[test]
fn test_clear_not_empty() {
let xseq = vec![vec![
Attribute::new("walk", 1.0),
Attribute::new("shop", 0.5),
]];
let yseq = ["sunny"];
let mut trainer = Trainer::default();
trainer
.select(Algorithm::LBFGS, GraphicalModel::CRF1D)
.unwrap();
trainer.append(&xseq, &yseq, 0i32).unwrap();
trainer.clear().unwrap();
}
#[test]
fn test_params() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::LBFGS, GraphicalModel::CRF1D)
.unwrap();
let params = trainer.params();
assert!(params.contains(&"c1".to_string()));
assert!(params.contains(&"c2".to_string()));
assert!(params.contains(&"num_memories".to_string()));
trainer
.select(Algorithm::L2SGD, GraphicalModel::CRF1D)
.unwrap();
let params = trainer.params();
assert!(!params.contains(&"c1".to_string()));
assert!(params.contains(&"c2".to_string()));
}
#[test]
fn test_help() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::LBFGS, GraphicalModel::CRF1D)
.unwrap();
let msg = trainer.help("c1").unwrap();
assert!(msg.contains("L1"));
trainer
.select(Algorithm::L2SGD, GraphicalModel::CRF1D)
.unwrap();
let msg = trainer.help("c2").unwrap();
assert!(msg.contains("L2"));
}
#[test]
fn test_help_invalid_argument() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::LBFGS, GraphicalModel::CRF1D)
.unwrap();
let ret = trainer.help("foo");
match ret.err().unwrap() {
CrfError::ParamNotFound(_) => {}
_ => panic!("test fail"),
}
}
#[test]
fn test_get() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::L2SGD, GraphicalModel::CRF1D)
.unwrap();
trainer.get("c2").unwrap();
}
#[test]
fn test_get_invalid_argument() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::L2SGD, GraphicalModel::CRF1D)
.unwrap();
let ret = trainer.get("foo");
match ret.err().unwrap() {
CrfError::ParamNotFound(_) => {}
_ => panic!("test fail"),
}
}
#[test]
fn test_set() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::L2SGD, GraphicalModel::CRF1D)
.unwrap();
let before = trainer.get("c2").unwrap();
trainer.set("c2", "0.5").unwrap();
let after = trainer.get("c2").unwrap();
assert!(before != after);
}
#[test]
fn test_set_invalid_argument() {
let mut trainer = Trainer::default();
trainer
.select(Algorithm::L2SGD, GraphicalModel::CRF1D)
.unwrap();
let ret = trainer.set("foo", "1.0");
match ret.err().unwrap() {
CrfError::ParamNotFound(_) => {}
_ => panic!("test fail"),
}
}