#![allow(clippy::field_reassign_with_default)]
use std::convert::TryFrom;
use std::io::{Cursor, Read, Write};
fn read_i32_le<R: Read>(r: &mut R) -> i32 {
let mut buf = [0u8; 4];
r.read_exact(&mut buf).unwrap();
i32::from_le_bytes(buf)
}
fn read_f64_le<R: Read>(r: &mut R) -> f64 {
let mut buf = [0u8; 8];
r.read_exact(&mut buf).unwrap();
f64::from_le_bytes(buf)
}
fn write_i32_le<W: Write>(w: &mut W, val: i32) {
w.write_all(&val.to_le_bytes()).unwrap();
}
fn write_f64_le<W: Write>(w: &mut W, val: f64) {
w.write_all(&val.to_le_bytes()).unwrap();
}
use fasttext::args::{Args, LossName, MetricName, ModelName};
#[test]
fn test_args_defaults() {
let args = Args::default();
assert!(args.input.as_os_str().is_empty());
assert!(args.output.as_os_str().is_empty());
assert!((args.lr - 0.05).abs() < f64::EPSILON);
assert_eq!(args.lr_update_rate, 100);
assert_eq!(args.dim, 100);
assert_eq!(args.ws, 5);
assert_eq!(args.epoch, 5);
assert_eq!(args.min_count, 5);
assert_eq!(args.min_count_label, 0);
assert_eq!(args.neg, 5);
assert_eq!(args.word_ngrams, 1);
assert_eq!(args.loss, LossName::NegativeSampling);
assert_eq!(args.model, ModelName::SkipGram);
assert_eq!(args.bucket, 2_000_000);
assert_eq!(args.minn, 3);
assert_eq!(args.maxn, 6);
assert_eq!(args.thread, 12);
assert!((args.t - 1e-4).abs() < f64::EPSILON);
assert_eq!(args.label, "__label__");
assert_eq!(args.verbose, 2);
assert!(args.pretrained_vectors.as_os_str().is_empty());
assert!(!args.save_output);
assert_eq!(args.seed, 0);
assert!(!args.qout);
assert!(!args.retrain);
assert!(!args.qnorm);
assert_eq!(args.cutoff, 0);
assert_eq!(args.dsub, 2);
assert!(args.autotune_validation_file.as_os_str().is_empty());
assert_eq!(args.autotune_metric, "f1");
assert_eq!(args.autotune_predictions, 1);
assert_eq!(args.autotune_duration, 300);
assert_eq!(args.autotune_model_size, "");
}
#[test]
fn test_args_supervised_overrides() {
let mut args = Args::default();
args.apply_supervised_defaults();
assert_eq!(args.model, ModelName::Supervised);
assert_eq!(args.loss, LossName::Softmax);
assert_eq!(args.min_count, 1);
assert_eq!(args.minn, 0);
assert_eq!(args.maxn, 0);
assert!((args.lr - 0.1).abs() < f64::EPSILON);
}
#[test]
fn test_args_supervised_bucket_zero() {
let mut args = Args::default();
args.apply_supervised_defaults();
assert_eq!(args.bucket, 0);
}
#[test]
fn test_args_supervised_bucket_nonzero_with_word_ngrams() {
let mut args = Args::default();
args.word_ngrams = 2;
args.apply_supervised_defaults();
assert_eq!(args.bucket, 2_000_000);
}
#[test]
fn test_args_supervised_bucket_nonzero_with_autotune() {
let mut args = Args::default();
args.autotune_validation_file = std::path::PathBuf::from("valid.txt");
args.apply_supervised_defaults();
assert_eq!(args.bucket, 2_000_000);
}
#[test]
fn test_has_autotune() {
let args = Args::default();
assert!(!args.has_autotune());
let mut args = Args::default();
args.autotune_validation_file = std::path::PathBuf::from("validation.txt");
assert!(args.has_autotune());
}
#[test]
fn test_args_binary_serialization_layout() {
let args = Args::default();
let mut buf = Vec::new();
args.save(&mut buf).unwrap();
assert_eq!(buf.len(), 56, "Args binary block must be exactly 56 bytes");
let mut args2 = Args::default();
args2.dim = 999;
args2.ws = 999;
args2.epoch = 999;
let mut cursor = Cursor::new(&buf);
args2.load(&mut cursor).unwrap();
assert_eq!(args2.dim, args.dim);
assert_eq!(args2.ws, args.ws);
assert_eq!(args2.epoch, args.epoch);
assert_eq!(args2.min_count, args.min_count);
assert_eq!(args2.neg, args.neg);
assert_eq!(args2.word_ngrams, args.word_ngrams);
assert_eq!(args2.loss, args.loss);
assert_eq!(args2.model, args.model);
assert_eq!(args2.bucket, args.bucket);
assert_eq!(args2.minn, args.minn);
assert_eq!(args2.maxn, args.maxn);
assert_eq!(args2.lr_update_rate, args.lr_update_rate);
assert!((args2.t - args.t).abs() < f64::EPSILON);
}
#[test]
fn test_args_binary_serialization_nondefault() {
let mut args = Args::default();
args.dim = 300;
args.ws = 10;
args.epoch = 25;
args.min_count = 3;
args.neg = 10;
args.word_ngrams = 2;
args.loss = LossName::Softmax;
args.model = ModelName::Supervised;
args.bucket = 500_000;
args.minn = 2;
args.maxn = 5;
args.lr_update_rate = 50;
args.t = 1e-3;
let mut buf = Vec::new();
args.save(&mut buf).unwrap();
assert_eq!(buf.len(), 56);
let mut args2 = Args::default();
let mut cursor = Cursor::new(&buf);
args2.load(&mut cursor).unwrap();
assert_eq!(args2.dim, 300);
assert_eq!(args2.ws, 10);
assert_eq!(args2.epoch, 25);
assert_eq!(args2.min_count, 3);
assert_eq!(args2.neg, 10);
assert_eq!(args2.word_ngrams, 2);
assert_eq!(args2.loss, LossName::Softmax);
assert_eq!(args2.model, ModelName::Supervised);
assert_eq!(args2.bucket, 500_000);
assert_eq!(args2.minn, 2);
assert_eq!(args2.maxn, 5);
assert_eq!(args2.lr_update_rate, 50);
assert!((args2.t - 1e-3).abs() < f64::EPSILON);
}
#[test]
fn test_args_binary_serialization_field_order() {
let mut args = Args::default();
args.dim = 1;
args.ws = 2;
args.epoch = 3;
args.min_count = 4;
args.neg = 5;
args.word_ngrams = 6;
args.loss = LossName::Softmax; args.model = ModelName::Supervised; args.bucket = 9;
args.minn = 10;
args.maxn = 11;
args.lr_update_rate = 12;
args.t = 0.5;
let mut buf = Vec::new();
args.save(&mut buf).unwrap();
let mut cursor = Cursor::new(&buf);
assert_eq!(read_i32_le(&mut cursor), 1); assert_eq!(read_i32_le(&mut cursor), 2); assert_eq!(read_i32_le(&mut cursor), 3); assert_eq!(read_i32_le(&mut cursor), 4); assert_eq!(read_i32_le(&mut cursor), 5); assert_eq!(read_i32_le(&mut cursor), 6); assert_eq!(read_i32_le(&mut cursor), 3); assert_eq!(read_i32_le(&mut cursor), 3); assert_eq!(read_i32_le(&mut cursor), 9); assert_eq!(read_i32_le(&mut cursor), 10); assert_eq!(read_i32_le(&mut cursor), 11); assert_eq!(read_i32_le(&mut cursor), 12); let t_val = read_f64_le(&mut cursor);
assert!((t_val - 0.5).abs() < f64::EPSILON); }
#[test]
fn test_args_binary_load_invalid_loss() {
let mut buf = Vec::new();
write_i32_le(&mut buf, 100); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 1); write_i32_le(&mut buf, 99); write_i32_le(&mut buf, 1); write_i32_le(&mut buf, 2000000); write_i32_le(&mut buf, 3); write_i32_le(&mut buf, 6); write_i32_le(&mut buf, 100); write_f64_le(&mut buf, 1e-4);
let mut args = Args::default();
let mut cursor = Cursor::new(&buf);
let result = args.load(&mut cursor);
assert!(result.is_err());
}
#[test]
fn test_args_binary_load_invalid_model() {
let mut buf = Vec::new();
write_i32_le(&mut buf, 100); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 5); write_i32_le(&mut buf, 1); write_i32_le(&mut buf, 2); write_i32_le(&mut buf, 99); write_i32_le(&mut buf, 2000000); write_i32_le(&mut buf, 3); write_i32_le(&mut buf, 6); write_i32_le(&mut buf, 100); write_f64_le(&mut buf, 1e-4);
let mut args = Args::default();
let mut cursor = Cursor::new(&buf);
let result = args.load(&mut cursor);
assert!(result.is_err());
}
#[test]
fn test_args_binary_load_truncated() {
let buf = vec![0u8; 20];
let mut args = Args::default();
let mut cursor = Cursor::new(&buf);
let result = args.load(&mut cursor);
assert!(result.is_err());
}
#[test]
fn test_model_name_values() {
assert_eq!(ModelName::Cbow as i32, 1);
assert_eq!(ModelName::SkipGram as i32, 2);
assert_eq!(ModelName::Supervised as i32, 3);
}
#[test]
fn test_loss_name_values() {
assert_eq!(LossName::HierarchicalSoftmax as i32, 1);
assert_eq!(LossName::NegativeSampling as i32, 2);
assert_eq!(LossName::Softmax as i32, 3);
assert_eq!(LossName::OneVsAll as i32, 4);
}
#[test]
fn test_model_name_try_from_i32() {
assert_eq!(ModelName::try_from(1), Ok(ModelName::Cbow));
assert_eq!(ModelName::try_from(2), Ok(ModelName::SkipGram));
assert_eq!(ModelName::try_from(3), Ok(ModelName::Supervised));
assert!(ModelName::try_from(0).is_err());
assert!(ModelName::try_from(4).is_err());
assert!(ModelName::try_from(-1).is_err());
}
#[test]
fn test_loss_name_try_from_i32() {
assert_eq!(LossName::try_from(1), Ok(LossName::HierarchicalSoftmax));
assert_eq!(LossName::try_from(2), Ok(LossName::NegativeSampling));
assert_eq!(LossName::try_from(3), Ok(LossName::Softmax));
assert_eq!(LossName::try_from(4), Ok(LossName::OneVsAll));
assert!(LossName::try_from(0).is_err());
assert!(LossName::try_from(5).is_err());
assert!(LossName::try_from(-1).is_err());
}
#[test]
fn test_loss_display() {
assert_eq!(LossName::HierarchicalSoftmax.to_string(), "hs");
assert_eq!(LossName::NegativeSampling.to_string(), "ns");
assert_eq!(LossName::Softmax.to_string(), "softmax");
assert_eq!(LossName::OneVsAll.to_string(), "one-vs-all");
}
#[test]
fn test_model_display() {
assert_eq!(ModelName::Cbow.to_string(), "cbow");
assert_eq!(ModelName::SkipGram.to_string(), "sg");
assert_eq!(ModelName::Supervised.to_string(), "sup");
}
#[test]
fn test_autotune_metric_name_default() {
let args = Args::default();
assert_eq!(args.get_autotune_metric_name(), Some(MetricName::F1Score));
}
#[test]
fn test_autotune_metric_name_label_f1() {
let mut args = Args::default();
args.autotune_metric = "f1:cooking".to_string();
assert_eq!(
args.get_autotune_metric_name(),
Some(MetricName::LabelF1Score)
);
}
#[test]
fn test_autotune_metric_name_precision_at_recall() {
let mut args = Args::default();
args.autotune_metric = "precisionAtRecall:50".to_string();
assert_eq!(
args.get_autotune_metric_name(),
Some(MetricName::PrecisionAtRecall)
);
}
#[test]
fn test_autotune_metric_name_precision_at_recall_label() {
let mut args = Args::default();
args.autotune_metric = "precisionAtRecall:50:cooking".to_string();
assert_eq!(
args.get_autotune_metric_name(),
Some(MetricName::PrecisionAtRecallLabel)
);
}
#[test]
fn test_autotune_metric_name_recall_at_precision() {
let mut args = Args::default();
args.autotune_metric = "recallAtPrecision:50".to_string();
assert_eq!(
args.get_autotune_metric_name(),
Some(MetricName::RecallAtPrecision)
);
}
#[test]
fn test_autotune_metric_name_recall_at_precision_label() {
let mut args = Args::default();
args.autotune_metric = "recallAtPrecision:50:cooking".to_string();
assert_eq!(
args.get_autotune_metric_name(),
Some(MetricName::RecallAtPrecisionLabel)
);
}
#[test]
fn test_autotune_metric_name_unknown() {
let mut args = Args::default();
args.autotune_metric = "unknown_metric".to_string();
assert_eq!(args.get_autotune_metric_name(), None);
}
#[test]
fn test_args_binary_all_loss_types() {
for loss in &[
LossName::HierarchicalSoftmax,
LossName::NegativeSampling,
LossName::Softmax,
LossName::OneVsAll,
] {
let mut args = Args::default();
args.loss = *loss;
let mut buf = Vec::new();
args.save(&mut buf).unwrap();
let mut args2 = Args::default();
let mut cursor = Cursor::new(&buf);
args2.load(&mut cursor).unwrap();
assert_eq!(args2.loss, *loss);
}
}
#[test]
fn test_args_binary_all_model_types() {
for model in &[ModelName::Cbow, ModelName::SkipGram, ModelName::Supervised] {
let mut args = Args::default();
args.model = *model;
let mut buf = Vec::new();
args.save(&mut buf).unwrap();
let mut args2 = Args::default();
let mut cursor = Cursor::new(&buf);
args2.load(&mut cursor).unwrap();
assert_eq!(args2.model, *model);
}
}
#[test]
fn test_args_binary_does_not_save_non_serialized_fields() {
let mut args = Args::default();
args.lr = 0.2; args.verbose = 0; args.label = "custom".to_string();
let mut buf = Vec::new();
args.save(&mut buf).unwrap();
assert_eq!(buf.len(), 56);
let mut args2 = Args::default();
let mut cursor = Cursor::new(&buf);
args2.load(&mut cursor).unwrap();
assert!((args2.lr - 0.05).abs() < f64::EPSILON); assert_eq!(args2.verbose, 2); assert_eq!(args2.label, "__label__"); }
#[test]
fn test_args_binary_extreme_values() {
let mut args = Args::default();
args.dim = i32::MAX;
args.ws = i32::MIN;
args.epoch = 0;
args.min_count = -1;
args.neg = 0;
args.word_ngrams = 0;
args.bucket = i32::MAX;
args.minn = 0;
args.maxn = 0;
args.lr_update_rate = 0;
args.t = f64::MIN_POSITIVE;
let mut buf = Vec::new();
args.save(&mut buf).unwrap();
let mut args2 = Args::default();
let mut cursor = Cursor::new(&buf);
args2.load(&mut cursor).unwrap();
assert_eq!(args2.dim, i32::MAX);
assert_eq!(args2.ws, i32::MIN);
assert_eq!(args2.epoch, 0);
assert_eq!(args2.min_count, -1);
assert_eq!(args2.neg, 0);
assert_eq!(args2.word_ngrams, 0);
assert_eq!(args2.bucket, i32::MAX);
assert_eq!(args2.minn, 0);
assert_eq!(args2.maxn, 0);
assert_eq!(args2.lr_update_rate, 0);
assert_eq!(args2.t, f64::MIN_POSITIVE);
}