#![allow(clippy::field_reassign_with_default)]
use std::sync::Arc;
fn fnv1a_hash(s: &[u8]) -> u32 {
let mut h: u32 = 2166136261;
for &byte in s {
h ^= byte as i8 as i32 as u32;
h = h.wrapping_mul(16777619);
}
h
}
use fasttext::args::{Args, ModelName};
use fasttext::dictionary::{Dictionary, EntryType, BOW, EOS, EOW, MAX_LINE_SIZE};
fn make_args() -> Arc<Args> {
Arc::new(Args::default())
}
fn make_dict() -> Dictionary {
Dictionary::new_with_capacity(make_args(), 1024)
}
#[test]
fn test_tokenize_whitespace() {
assert_eq!(Dictionary::tokenize("hello world"), vec!["hello", "world"]);
assert_eq!(Dictionary::tokenize("hello\tworld"), vec!["hello", "world"]);
assert_eq!(
Dictionary::tokenize("hello\x0bworld"),
vec!["hello", "world"]
);
assert_eq!(
Dictionary::tokenize("hello\x0cworld"),
vec!["hello", "world"]
);
assert_eq!(Dictionary::tokenize("hello\0world"), vec!["hello", "world"]);
assert_eq!(Dictionary::tokenize("hello\rworld"), vec!["hello", "world"]);
assert_eq!(
Dictionary::tokenize("hello world"),
vec!["hello", "world"]
);
assert_eq!(
Dictionary::tokenize(" hello world "),
vec!["hello", "world"]
);
assert_eq!(
Dictionary::tokenize("hello \t world"),
vec!["hello", "world"]
);
}
#[test]
fn test_tokenize_eos() {
assert_eq!(
Dictionary::tokenize("hello\nworld"),
vec!["hello", EOS, "world"]
);
assert_eq!(Dictionary::tokenize("\nhello"), vec![EOS, "hello"]);
assert_eq!(Dictionary::tokenize("hello\n"), vec!["hello", EOS]);
assert_eq!(
Dictionary::tokenize("hello\n\nworld"),
vec!["hello", EOS, EOS, "world"]
);
assert_eq!(
Dictionary::tokenize("hello\r\nworld"),
vec!["hello", EOS, "world"]
);
assert_eq!(Dictionary::tokenize("a\nb"), vec!["a", EOS, "b"]);
assert_eq!(
Dictionary::tokenize("hello world\nfoo bar"),
vec!["hello", "world", EOS, "foo", "bar"]
);
}
#[test]
fn test_tokenize_max_tokens() {
let words: Vec<String> = (0..1025).map(|i| format!("word{}", i)).collect();
let text = words.join(" ");
let result = Dictionary::tokenize(&text);
assert_eq!(
result.len(),
MAX_LINE_SIZE,
"Expected {} tokens, got {}",
MAX_LINE_SIZE,
result.len()
);
for (i, token) in result.iter().enumerate() {
assert_eq!(token, &format!("word{}", i));
}
let text2 = format!("{}\n{}", text, text);
let result2 = Dictionary::tokenize(&text2);
assert_eq!(result2.len(), 2 * MAX_LINE_SIZE + 1);
assert_eq!(result2[MAX_LINE_SIZE], EOS);
}
#[test]
fn test_tokenize_utf8() {
let result = Dictionary::tokenize("hello 日本語 world");
assert_eq!(result, vec!["hello", "日本語", "world"]);
let result = Dictionary::tokenize("café naïve");
assert_eq!(result, vec!["café", "naïve"]);
let result = Dictionary::tokenize("hello\t日本語\ncafé");
assert_eq!(result, vec!["hello", "日本語", EOS, "café"]);
let result = Dictionary::tokenize("hello 🎉 world");
assert_eq!(result, vec!["hello", "🎉", "world"]);
let result = Dictionary::tokenize("你好 世界");
assert_eq!(result, vec!["你好", "世界"]);
}
#[test]
fn test_label_detection_default() {
let dict = make_dict();
assert_eq!(
dict.get_type_from_str("__label__cat"),
EntryType::Label,
"__label__cat should be a label"
);
assert_eq!(
dict.get_type_from_str("__label__"),
EntryType::Label,
"__label__ alone should be a label"
);
assert_eq!(
dict.get_type_from_str("__label__very_long_label_name"),
EntryType::Label,
"longer label should be detected"
);
assert_eq!(
dict.get_type_from_str("hello"),
EntryType::Word,
"regular word should be Word"
);
assert_eq!(
dict.get_type_from_str("label__"),
EntryType::Word,
"label__ without __ prefix should be Word"
);
assert_eq!(
dict.get_type_from_str("_label_cat"),
EntryType::Word,
"single underscore prefix should be Word"
);
assert_eq!(
dict.get_type_from_str(""),
EntryType::Word,
"empty string should be Word"
);
assert_eq!(
dict.get_type_from_str(EOS),
EntryType::Word,
"EOS token should be Word"
);
}
#[test]
fn test_label_detection_custom_prefix() {
let mut args = Args::default();
args.label = "#".to_string();
let dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
assert_eq!(
dict.get_type_from_str("#cat"),
EntryType::Label,
"#cat should be a label with # prefix"
);
assert_eq!(
dict.get_type_from_str("#"),
EntryType::Label,
"# alone should be a label"
);
assert_eq!(
dict.get_type_from_str("cat"),
EntryType::Word,
"cat should be a word"
);
assert_eq!(
dict.get_type_from_str("__label__cat"),
EntryType::Word,
"__label__cat should be Word with # prefix"
);
}
#[test]
fn test_vocab_lookup() {
let mut dict = make_dict();
dict.add("hello");
dict.add("world");
dict.add("hello");
let hello_id = dict.get_id("hello");
let world_id = dict.get_id("world");
let oov_id = dict.get_id("foo");
assert!(hello_id.is_some(), "hello should be in vocab");
assert!(world_id.is_some(), "world should be in vocab");
assert_ne!(
hello_id, world_id,
"hello and world should have different IDs"
);
assert_eq!(oov_id, None, "foo should be OOV");
let hello_id = hello_id.unwrap();
let world_id = world_id.unwrap();
assert_eq!(dict.get_word(hello_id), "hello");
assert_eq!(dict.get_word(world_id), "world");
}
#[test]
fn test_vocab_counts() {
let mut dict = make_dict();
dict.add("hello");
dict.add("world");
dict.add("hello");
dict.add("hello");
assert_eq!(dict.ntokens(), 4);
assert_eq!(dict.size(), 2);
dict.threshold(1, 1);
assert_eq!(dict.nwords(), 2);
assert_eq!(dict.nlabels(), 0);
assert_eq!(dict.size(), 2);
}
#[test]
fn test_vocab_label_counts() {
let mut dict = make_dict();
dict.add("word1");
dict.add("__label__cat");
dict.add("word2");
dict.add("__label__dog");
dict.add("__label__cat");
assert_eq!(dict.ntokens(), 5);
dict.threshold(1, 1);
assert_eq!(dict.nwords(), 2);
assert_eq!(dict.nlabels(), 2);
assert_eq!(dict.size(), 4);
assert_eq!(dict.get_label(0).unwrap(), "__label__cat"); assert_eq!(dict.get_label(1).unwrap(), "__label__dog"); }
#[test]
fn test_vocab_sorted_order() {
let mut dict = make_dict();
for _ in 0..5 {
dict.add("rare");
} for _ in 0..10 {
dict.add("common");
} for _ in 0..3 {
dict.add("__label__cat");
} for _ in 0..7 {
dict.add("__label__dog");
}
dict.threshold(1, 1);
assert_eq!(dict.nwords(), 2, "Should have 2 words");
assert_eq!(dict.nlabels(), 2, "Should have 2 labels");
assert_eq!(dict.words()[0].word, "common");
assert_eq!(dict.words()[0].count, 10);
assert_eq!(dict.words()[1].word, "rare");
assert_eq!(dict.words()[1].count, 5);
assert_eq!(dict.words()[2].word, "__label__dog");
assert_eq!(dict.words()[2].count, 7);
assert_eq!(dict.words()[3].word, "__label__cat");
assert_eq!(dict.words()[3].count, 3);
assert_eq!(dict.get_id("common"), Some(0));
assert_eq!(dict.get_id("rare"), Some(1));
assert_eq!(dict.get_id("__label__dog"), Some(2));
assert_eq!(dict.get_id("__label__cat"), Some(3));
}
#[test]
fn test_vocab_threshold_filtering() {
let mut dict = make_dict();
dict.add("rare"); dict.add("common"); dict.add("common");
dict.add("__label__a"); dict.add("__label__b"); dict.add("__label__b");
dict.threshold(2, 1);
assert_eq!(dict.nwords(), 1, "Only 'common' survives");
assert_eq!(dict.nlabels(), 2, "Both labels survive");
assert!(dict.get_id("common").is_some(), "common should be in vocab");
assert_eq!(dict.get_id("rare"), None, "rare should be filtered out");
assert!(
dict.get_id("__label__a").is_some(),
"label a should be in vocab"
);
assert!(
dict.get_id("__label__b").is_some(),
"label b should be in vocab"
);
}
#[test]
fn test_vocab_hash_collision_resolution() {
let mut dict = make_dict();
for i in 0..100 {
dict.add(&format!("word_{}", i));
}
assert_eq!(dict.size(), 100);
for i in 0..100 {
let id = dict.get_id(&format!("word_{}", i));
assert!(id.is_some(), "word_{} should be in vocab", i);
}
assert_eq!(dict.get_id("not_in_vocab"), None);
}
#[test]
fn test_vocab_eos_is_word_type() {
let dict = make_dict();
assert_eq!(dict.get_type_from_str(EOS), EntryType::Word);
}
#[test]
fn test_get_label_out_of_range() {
let mut dict = make_dict();
dict.add("__label__cat");
dict.threshold(1, 1);
assert!(dict.get_label(0).is_ok());
assert!(dict.get_label(-1).is_err());
assert!(dict.get_label(1).is_err());
}
#[test]
fn test_tokenize_empty_string() {
let result = Dictionary::tokenize("");
assert!(result.is_empty(), "Empty string should yield no tokens");
}
#[test]
fn test_tokenize_only_whitespace() {
let result = Dictionary::tokenize(" \t ");
assert!(result.is_empty(), "Whitespace-only should yield no tokens");
}
#[test]
fn test_tokenize_only_newlines() {
let result = Dictionary::tokenize("\n\n\n");
assert_eq!(result, vec![EOS, EOS, EOS]);
}
#[test]
fn test_add_increments_ntokens() {
let mut dict = make_dict();
assert_eq!(dict.ntokens(), 0);
dict.add("a");
assert_eq!(dict.ntokens(), 1);
dict.add("b");
assert_eq!(dict.ntokens(), 2);
dict.add("a"); assert_eq!(dict.ntokens(), 3);
}
#[test]
fn test_word_count_accumulation() {
let mut dict = make_dict();
for _ in 0..5 {
dict.add("hello");
}
dict.threshold(1, 1);
let id = dict.get_id("hello").unwrap();
assert_eq!(dict.words()[id as usize].count, 5);
}
#[test]
fn test_read_word_from_reader_basic() {
let text = "hello world\nfoo";
let mut reader = text.as_bytes();
let mut pending = false;
let mut word = String::new();
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(word, "hello");
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(word, "world");
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(word, EOS);
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(word, "foo");
assert!(!Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
}
#[test]
fn test_read_word_from_reader_utf8() {
let text = "日本語 café\nhello";
let mut reader = text.as_bytes();
let mut pending = false;
let mut word = String::new();
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(
word, "日本語",
"Multi-byte UTF-8 token '日本語' should be preserved intact"
);
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(
word, "café",
"UTF-8 token 'café' with accented character should be preserved intact"
);
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(word, EOS);
assert!(Dictionary::read_word_from_reader(
&mut reader,
&mut pending,
&mut word
));
assert_eq!(word, "hello");
}
#[test]
fn test_read_from_file_utf8_tokens() {
let mut args = Args::default();
args.min_count = 1;
let args = Arc::new(args);
let mut dict = Dictionary::new_with_capacity(args, 1024);
let content = "日本語 café hello\n日本語 world\ncafé test\n";
let mut reader = content.as_bytes();
dict.read_from_file(&mut reader).unwrap();
let id_jp = dict.get_id("日本語");
let id_cafe = dict.get_id("café");
let id_hello = dict.get_id("hello");
let id_world = dict.get_id("world");
let id_test = dict.get_id("test");
assert!(id_jp.is_some(), "'日本語' should be in vocabulary");
assert!(id_cafe.is_some(), "'café' should be in vocabulary");
assert!(id_hello.is_some(), "'hello' should be in vocabulary");
assert!(id_world.is_some(), "'world' should be in vocabulary");
assert!(id_test.is_some(), "'test' should be in vocabulary");
let id_jp = id_jp.unwrap();
let id_cafe = id_cafe.unwrap();
assert_eq!(
dict.get_word(id_jp),
"日本語",
"Stored word for id_jp should be '日本語'"
);
assert_eq!(
dict.get_word(id_cafe),
"café",
"Stored word for id_café should be 'café'"
);
assert_eq!(
dict.words()[id_jp as usize].count,
2,
"'日本語' should have count 2"
);
assert_eq!(
dict.words()[id_cafe as usize].count,
2,
"'café' should have count 2"
);
}
#[test]
fn test_tokenize_eos_entry_type() {
let mut dict = make_dict();
dict.add(EOS);
dict.add("hello");
dict.threshold(1, 1);
let eos_id = dict.get_id(EOS).unwrap();
assert_eq!(dict.get_type_by_id(eos_id), EntryType::Word);
}
fn make_subword_args(minn: i32, maxn: i32, bucket: i32) -> Arc<Args> {
let mut args = Args::default();
args.minn = minn;
args.maxn = maxn;
args.bucket = bucket;
Arc::new(args)
}
#[test]
fn test_subword_computation_bow_eow_wrapping() {
let args = make_subword_args(3, 4, 100000);
let dict = Dictionary::new_with_capacity(args, 1024);
let mut ngrams = Vec::new();
dict.compute_subwords("<he>", &mut ngrams);
assert_eq!(
ngrams.len(),
3,
"Should have 3 n-grams for '<he>' with minn=3 maxn=4"
);
}
#[test]
fn test_subword_bucket_index() {
let args = make_subword_args(3, 6, 200000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.add("world");
dict.threshold(1, 1);
dict.init_ngrams();
let nwords = dict.nwords();
let bucket = 200000;
for wid in 0..nwords {
let subwords = dict.get_subwords(wid);
assert_eq!(subwords[0], wid, "First subword should be word ID");
for &sid in &subwords[1..] {
assert!(
sid >= nwords,
"Subword ID {} should be >= nwords {}",
sid,
nwords
);
assert!(
sid < nwords + bucket,
"Subword ID {} should be < nwords+bucket {}",
sid,
nwords + bucket
);
}
}
}
#[test]
fn test_subword_computation_known_values() {
let args = make_subword_args(3, 6, 2_000_000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let wid = dict.get_id("hello").unwrap();
let subwords = dict.get_subwords(wid);
let ngram_strings = [
"<he",
"<hel",
"<hell",
"<hello",
"hel",
"hell",
"hello",
"hello>",
"ell",
"ello",
"ello>",
"llo",
"llo>",
"lo>",
"unused_for_count_only",
];
let _ = ngram_strings;
assert_eq!(
subwords.len(),
15,
"Expected 15 subwords for 'hello' (1 word ID + 14 n-grams): got {:?}",
subwords
);
}
#[test]
fn test_subword_computation_utf8_aware() {
let args = make_subword_args(2, 3, 100000);
let dict = Dictionary::new_with_capacity(args, 1024);
let word_with_markers = format!("{}{}{}", BOW, "café", EOW);
let mut ngrams = Vec::new();
dict.compute_subwords(&word_with_markers, &mut ngrams);
for &id in &ngrams {
assert!(
(0..100000).contains(&id),
"N-gram ID {} out of range [0, 100000)",
id
);
}
assert_eq!(
ngrams.len(),
9,
"Expected 9 n-grams for '<café>' with minn=2 maxn=3"
);
let expected_hash_of_f_e_gt = (fnv1a_hash("fé>".as_bytes()) % 100000) as i32;
assert!(
ngrams.contains(&expected_hash_of_f_e_gt),
"N-grams should contain hash('fé>') = {}",
expected_hash_of_f_e_gt
);
}
#[test]
fn test_subword_eos_no_subwords() {
let args = make_subword_args(3, 6, 2_000_000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add(EOS);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let eos_id = dict.get_id(EOS).unwrap();
let subwords = dict.get_subwords(eos_id);
assert_eq!(
subwords.len(),
1,
"EOS should have only 1 entry (its own ID)"
);
assert_eq!(subwords[0], eos_id, "EOS subwords[0] should be its own ID");
}
#[test]
fn test_subword_zero_bucket() {
let args = make_subword_args(3, 6, 0);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.add("world");
dict.threshold(1, 1);
dict.init_ngrams();
for wid in 0..dict.nwords() {
let subwords = dict.get_subwords(wid);
assert_eq!(
subwords.len(),
1,
"With bucket=0, word {} should have only 1 subword",
dict.get_word(wid)
);
assert_eq!(subwords[0], wid, "Subwords[0] should be the word ID");
}
}
#[test]
fn test_subword_zero_maxn() {
let args = make_subword_args(0, 0, 2_000_000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let wid = dict.get_id("hello").unwrap();
let subwords = dict.get_subwords(wid);
assert_eq!(
subwords.len(),
1,
"With maxn=0, only word ID should be in subwords"
);
}
#[test]
fn test_subword_compute_subwords_direct() {
let args = make_subword_args(3, 6, 0);
let dict = Dictionary::new_with_capacity(args, 1024);
let mut ngrams = Vec::new();
dict.compute_subwords("<hello>", &mut ngrams);
assert!(
ngrams.is_empty(),
"compute_subwords with bucket=0 should produce nothing"
);
}
#[test]
fn test_subword_compute_subwords_maxn_zero() {
let args = make_subword_args(0, 0, 100);
let dict = Dictionary::new_with_capacity(args, 1024);
let mut ngrams = Vec::new();
dict.compute_subwords("<hello>", &mut ngrams);
assert!(
ngrams.is_empty(),
"compute_subwords with maxn=0 should produce nothing"
);
}
#[test]
fn test_word_ngram_hash_bigram() {
let mut args = Args::default();
args.bucket = 100000;
args.word_ngrams = 2;
let dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
let h1 = fnv1a_hash(b"hello") as i32;
let h2 = fnv1a_hash(b"world") as i32;
let hashes = vec![h1, h2];
let mut line = Vec::new();
dict.add_word_ngrams(&mut line, &hashes, 2);
assert_eq!(line.len(), 1, "Bigram should produce exactly one hash");
let expected_h = (h1 as i64 as u64)
.wrapping_mul(116049371u64)
.wrapping_add(h2 as i64 as u64);
let expected_id = (expected_h % 100000) as i32;
assert_eq!(
line[0], expected_id,
"Bigram hash mismatch: got {}, expected {}",
line[0], expected_id
);
}
#[test]
fn test_word_ngram_hash_trigram() {
let mut args = Args::default();
args.bucket = 1_000_000;
args.word_ngrams = 3;
let dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
let h1 = fnv1a_hash(b"the") as i32;
let h2 = fnv1a_hash(b"quick") as i32;
let h3 = fnv1a_hash(b"brown") as i32;
let hashes = vec![h1, h2, h3];
let mut line = Vec::new();
dict.add_word_ngrams(&mut line, &hashes, 3);
assert_eq!(
line.len(),
3,
"3 words with wordNgrams=3 should produce 3 hashes"
);
let h12 = (h1 as i64 as u64)
.wrapping_mul(116049371u64)
.wrapping_add(h2 as i64 as u64);
let id12 = (h12 % 1_000_000) as i32;
assert_eq!(line[0], id12, "First hash should be bigram (h1,h2)");
let h123 = h12
.wrapping_mul(116049371u64)
.wrapping_add(h3 as i64 as u64);
let id123 = (h123 % 1_000_000) as i32;
assert_eq!(line[1], id123, "Second hash should be trigram (h1,h2,h3)");
let h23 = (h2 as i64 as u64)
.wrapping_mul(116049371u64)
.wrapping_add(h3 as i64 as u64);
let id23 = (h23 % 1_000_000) as i32;
assert_eq!(line[2], id23, "Third hash should be bigram (h2,h3)");
}
#[test]
fn test_word_ngram_no_ngrams_for_word_ngrams_1() {
let mut args = Args::default();
args.bucket = 100000;
args.word_ngrams = 1;
let dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
let hashes = vec![fnv1a_hash(b"hello") as i32, fnv1a_hash(b"world") as i32];
let mut line = Vec::new();
dict.add_word_ngrams(&mut line, &hashes, 1);
assert!(line.is_empty(), "wordNgrams=1 should produce no n-grams");
}
#[test]
fn test_word_ngram_zero_bucket() {
let mut args = Args::default();
args.bucket = 0;
args.word_ngrams = 2;
let dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
let hashes = vec![fnv1a_hash(b"hello") as i32, fnv1a_hash(b"world") as i32];
let mut line = Vec::new();
dict.add_word_ngrams(&mut line, &hashes, 2);
assert!(
line.is_empty(),
"bucket=0 should disable word n-gram hashing"
);
}
#[test]
fn test_word_ngram_ids_in_range() {
let mut args = Args::default();
args.bucket = 500;
args.word_ngrams = 3;
let mut dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
dict.add("a");
dict.add("b");
dict.add("c");
dict.threshold(1, 1);
let hashes: Vec<i32> = ["a", "b", "c"]
.iter()
.map(|w| fnv1a_hash(w.as_bytes()) as i32)
.collect();
let mut line = Vec::new();
dict.add_word_ngrams(&mut line, &hashes, 3);
let nwords = dict.nwords();
for &id in &line {
assert!(
id >= nwords,
"N-gram ID {} should be >= nwords {}",
id,
nwords
);
assert!(
id < nwords + 500,
"N-gram ID {} should be < nwords+bucket {}",
id,
nwords + 500
);
}
}
#[test]
fn test_discard_table_formula() {
let mut args = Args::default();
args.t = 0.0001;
let mut dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
for _ in 0..100 {
dict.add("word");
}
for _ in 0..900 {
dict.add("other");
}
dict.threshold(1, 1);
dict.init_discard();
let wid = dict.get_id("word").unwrap();
let pdiscard = dict.get_discard(wid);
let f = 100.0f32 / 1000.0f32;
let t = 0.0001f32;
let expected = (t / f).sqrt() + t / f;
assert!(
(pdiscard - expected).abs() < 1e-5,
"pdiscard {} should be close to expected {}",
pdiscard,
expected
);
}
#[test]
fn test_discard_supervised_bypass() {
let mut args = Args::default();
args.t = 1.0; args.apply_supervised_defaults();
let mut dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
for _ in 0..3 {
dict.add("word");
}
dict.threshold(1, 1);
dict.init_discard();
let wid = dict.get_id("word").unwrap();
assert!(
!dict.discard(wid, 1.0),
"Supervised mode should never discard"
);
assert!(
!dict.discard(wid, 0.5),
"Supervised mode should never discard"
);
assert!(
!dict.discard(wid, 0.0),
"Supervised mode should never discard"
);
}
#[test]
fn test_discard_unsupervised_formula() {
let mut args = Args::default();
args.t = 0.9; args.model = ModelName::SkipGram; let mut dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
dict.add("rare");
for _ in 0..1000 {
dict.add("common");
}
dict.threshold(1, 1);
dict.init_discard();
let wid = dict.get_id("rare").unwrap();
let pdiscard_rare = dict.get_discard(wid);
assert!(
pdiscard_rare > 1.0,
"Very rare word with high t should have pdiscard > 1.0"
);
assert!(
!dict.discard(wid, 0.5),
"Should not discard rare word with rand=0.5"
);
assert!(
!dict.discard(wid, 0.99),
"Should not discard rare word with rand=0.99"
);
let _wid_common = dict.get_id("common");
let mut args2 = Args::default();
args2.t = 0.0001;
args2.model = ModelName::SkipGram;
let mut dict2 = Dictionary::new_with_capacity(Arc::new(args2), 1024);
for _ in 0..9999 {
dict2.add("frequent");
}
dict2.add("rare2");
dict2.threshold(1, 1);
dict2.init_discard();
let wid_frequent = dict2.get_id("frequent").unwrap();
let pdiscard_freq = dict2.get_discard(wid_frequent);
assert!(
pdiscard_freq < 1.0,
"Frequent word should have pdiscard < 1.0, got {}",
pdiscard_freq
);
assert!(
dict2.discard(wid_frequent, 0.5),
"Should discard frequent word with rand=0.5 > pdiscard"
);
assert!(
!dict2.discard(wid_frequent, 0.001),
"Should not discard frequent word with rand=0.001 < pdiscard"
);
}
#[test]
fn test_getline_word_label_split() {
let args = make_subword_args(0, 0, 0); let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("cat");
dict.add("sit");
dict.add("on");
dict.add("mat");
dict.add("__label__good");
dict.add("__label__bad");
dict.threshold(1, 1);
dict.init_ngrams();
let wid_cat = dict.get_id("cat").unwrap();
let wid_sit = dict.get_id("sit").unwrap();
let wid_on = dict.get_id("on").unwrap();
let wid_mat = dict.get_id("mat").unwrap();
let wid_good = dict.get_id("__label__good").unwrap();
let wid_bad = dict.get_id("__label__bad").unwrap();
let nwords = dict.nwords();
let text = "__label__good cat sit on mat\n";
let mut reader = text.as_bytes();
let mut words = Vec::new();
let mut labels = Vec::new();
let mut pending = false;
dict.get_line(&mut reader, &mut words, &mut labels, &mut pending);
assert_eq!(labels.len(), 1, "Should have one label");
assert_eq!(
labels[0],
wid_good - nwords,
"Label should be __label__good"
);
assert_eq!(words.len(), 4, "Should have 4 words");
assert!(words.contains(&wid_cat), "Should contain 'cat'");
assert!(words.contains(&wid_sit), "Should contain 'sit'");
assert!(words.contains(&wid_on), "Should contain 'on'");
assert!(words.contains(&wid_mat), "Should contain 'mat'");
let text2 = "__label__good __label__bad cat sit\n";
let mut reader2 = text2.as_bytes();
let mut words2 = Vec::new();
let mut labels2 = Vec::new();
let mut pending2 = false;
dict.get_line(&mut reader2, &mut words2, &mut labels2, &mut pending2);
assert_eq!(labels2.len(), 2, "Should have two labels");
assert!(labels2.contains(&(wid_good - nwords)));
assert!(labels2.contains(&(wid_bad - nwords)));
assert_eq!(words2.len(), 2);
let _ = wid_bad; }
#[test]
fn test_getline_eos_terminates() {
let args = make_subword_args(0, 0, 0);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add(EOS);
dict.add("hello");
dict.add("world");
dict.threshold(1, 1);
dict.init_ngrams();
let text = "hello world\nhello world\n";
let mut reader = text.as_bytes();
let mut words = Vec::new();
let mut labels = Vec::new();
let mut pending = false;
dict.get_line(&mut reader, &mut words, &mut labels, &mut pending);
let wid_eos = dict.get_id(EOS).unwrap();
let wid_hello = dict.get_id("hello").unwrap();
let wid_world = dict.get_id("world").unwrap();
assert!(words.contains(&wid_eos), "words should contain EOS word id");
assert!(words.contains(&wid_hello));
assert!(words.contains(&wid_world));
assert_eq!(words.len(), 3, "Only first line + EOS should be in words");
}
#[test]
fn test_getline_oov_with_subwords() {
let args = make_subword_args(3, 6, 100000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello"); dict.threshold(1, 1);
dict.init_ngrams();
let text = "hello unknown\n";
let mut reader = text.as_bytes();
let mut words = Vec::new();
let mut labels = Vec::new();
let mut pending = false;
dict.get_line(&mut reader, &mut words, &mut labels, &mut pending);
let wid_hello = dict.get_id("hello").unwrap();
assert!(!words.is_empty(), "words should not be empty");
assert!(
words.contains(&wid_hello),
"words should contain hello's word ID"
);
let nwords = dict.nwords();
for &id in &words {
if id != wid_hello {
assert!(
id >= nwords,
"N-gram hash ID {} should be >= nwords {}",
id,
nwords
);
}
}
}
#[test]
fn test_getline_oov_without_subwords() {
let args = make_subword_args(0, 0, 0);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let text = "hello unknown_word\n";
let mut reader = text.as_bytes();
let mut words = Vec::new();
let mut labels = Vec::new();
let mut pending = false;
dict.get_line(&mut reader, &mut words, &mut labels, &mut pending);
let wid_hello = dict.get_id("hello").unwrap();
assert_eq!(words, vec![wid_hello], "Only 'hello' should be in words");
}
#[test]
fn test_getline_oov_label_dropped() {
let args = make_subword_args(0, 0, 0);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.add("__label__good");
dict.threshold(1, 1);
dict.init_ngrams();
let text = "__label__good __label__unknown hello\n";
let mut reader = text.as_bytes();
let mut words = Vec::new();
let mut labels = Vec::new();
let mut pending = false;
dict.get_line(&mut reader, &mut words, &mut labels, &mut pending);
assert_eq!(labels.len(), 1, "Only known label should be in labels");
let wid_good = dict.get_id("__label__good").unwrap();
let nwords = dict.nwords();
assert_eq!(labels[0], wid_good - nwords);
}
#[test]
fn test_getline_with_word_ngrams() {
let mut args = Args::default();
args.minn = 0;
args.maxn = 0;
args.bucket = 100000;
args.word_ngrams = 2;
let mut dict = Dictionary::new_with_capacity(Arc::new(args), 1024);
dict.add("hello");
dict.add("world");
dict.threshold(1, 1);
dict.init_ngrams();
let text = "hello world\n";
let mut reader = text.as_bytes();
let mut words = Vec::new();
let mut labels = Vec::new();
let mut pending = false;
dict.get_line(&mut reader, &mut words, &mut labels, &mut pending);
let wid_hello = dict.get_id("hello").unwrap();
let wid_world = dict.get_id("world").unwrap();
let nwords = dict.nwords();
assert!(words.contains(&wid_hello), "words should contain hello ID");
assert!(words.contains(&wid_world), "words should contain world ID");
let h1 = fnv1a_hash(b"hello") as i32;
let h2 = fnv1a_hash(b"world") as i32;
let h_bigram = (h1 as i64 as u64)
.wrapping_mul(116049371u64)
.wrapping_add(h2 as i64 as u64);
let bigram_id = nwords + (h_bigram % 100000) as i32;
assert!(
words.contains(&bigram_id),
"words should contain bigram hash {}, words={:?}",
bigram_id,
words
);
}
#[test]
fn test_getline_returns_ntokens() {
let args = make_subword_args(0, 0, 0);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add(EOS);
dict.add("a");
dict.add("b");
dict.threshold(1, 1);
dict.init_ngrams();
let text = "a b\n";
let mut reader = text.as_bytes();
let mut words = Vec::new();
let mut labels = Vec::new();
let mut pending = false;
let ntokens = dict.get_line(&mut reader, &mut words, &mut labels, &mut pending);
assert_eq!(ntokens, 3, "Should count 3 tokens (a, b, EOS)");
}
#[test]
fn test_getline_from_str() {
let args = make_subword_args(0, 0, 0);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("cat");
dict.add("mat");
dict.add("__label__good");
dict.threshold(1, 1);
dict.init_ngrams();
let mut words = Vec::new();
let mut labels = Vec::new();
let ntokens = dict.get_line_from_str("__label__good cat mat", &mut words, &mut labels);
assert_eq!(ntokens, 3);
assert_eq!(labels.len(), 1);
let wid_good = dict.get_id("__label__good").unwrap();
let nwords = dict.nwords();
assert_eq!(labels[0], wid_good - nwords);
assert_eq!(words.len(), 2);
}
#[test]
fn test_get_subwords_for_string_in_vocab() {
let args = make_subword_args(3, 6, 100000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let wid = dict.get_id("hello").unwrap();
let subwords_by_id = dict.get_subwords(wid);
let subwords_by_str = dict.get_subwords_for_string("hello");
assert_eq!(
subwords_by_id, subwords_by_str,
"get_subwords_for_string should return same as get_subwords for in-vocab word"
);
}
#[test]
fn test_get_subwords_for_string_oov() {
let args = make_subword_args(3, 6, 100000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let oov_subwords = dict.get_subwords_for_string("unknown");
let nwords = dict.nwords();
for &id in &oov_subwords {
assert!(
id >= nwords,
"OOV subword ID {} should be >= nwords {}",
id,
nwords
);
}
assert!(
!oov_subwords.is_empty(),
"OOV word with maxn>0 should have subwords"
);
}
#[test]
fn test_get_subwords_for_string_eos_oov() {
let args = make_subword_args(3, 6, 100000);
let dict = Dictionary::new_with_capacity(args, 1024);
let eos_subwords = dict.get_subwords_for_string(EOS);
assert!(eos_subwords.is_empty(), "EOS OOV should have no subwords");
}
#[test]
fn test_add_subwords_oov_no_subwords_when_maxn_zero() {
let args = make_subword_args(0, 0, 0);
let dict = Dictionary::new_with_capacity(args, 1024);
let mut line = Vec::new();
dict.add_subwords(&mut line, "oov_word", -1);
assert!(line.is_empty(), "OOV with maxn=0 should add nothing");
}
#[test]
fn test_add_subwords_in_vocab_with_subwords() {
let args = make_subword_args(3, 6, 100000);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let wid = dict.get_id("hello").unwrap();
let mut line = Vec::new();
dict.add_subwords(&mut line, "hello", wid);
let expected = dict.get_subwords(wid);
assert_eq!(
line, expected,
"add_subwords should push all subwords for in-vocab word"
);
}
#[test]
fn test_add_subwords_in_vocab_no_subwords() {
let args = make_subword_args(0, 0, 0);
let mut dict = Dictionary::new_with_capacity(args, 1024);
dict.add("hello");
dict.threshold(1, 1);
dict.init_ngrams();
let wid = dict.get_id("hello").unwrap();
let mut line = Vec::new();
dict.add_subwords(&mut line, "hello", wid);
assert_eq!(
line,
vec![wid],
"With maxn=0, add_subwords should only push word ID"
);
}