use ferrolearn_preprocess::count_vectorizer::CountVectorizer;
fn fit_vocab(docs: &[&str], cv: CountVectorizer) -> Vec<String> {
let owned: Vec<String> = docs.iter().map(|s| (*s).to_string()).collect();
let fitted = cv.fit(&owned).unwrap();
let mut v = fitted.vocabulary().to_vec();
v.sort();
v
}
#[test]
fn divergence_empty_vocab_all_single_char() {
let docs = vec!["a b c d".to_string()];
let result = CountVectorizer::new().fit(&docs);
assert!(
result.is_err(),
"sklearn raises ValueError(empty vocabulary) at text.py:1278 for docs \
with only single-char tokens; ferrolearn returned Ok with vocab={:?}",
result.map(|f| f.vocabulary().to_vec())
);
}
#[test]
fn divergence_empty_vocab_min_df_prunes_all() {
let docs = vec!["aa bb".to_string(), "cc dd".to_string()];
let result = CountVectorizer::new().min_df(5).fit(&docs);
assert!(
result.is_err(),
"sklearn raises ValueError at text.py:1382 when min_df prunes every \
term; ferrolearn returned Ok with vocab={:?}",
result.map(|f| f.vocabulary().to_vec())
);
}
#[test]
fn guard_tokenizer_drops_length_one_tokens() {
let vocab = fit_vocab(&["foo a bar"], CountVectorizer::new());
assert_eq!(vocab, vec!["bar".to_string(), "foo".to_string()]);
}
#[test]
fn guard_tokenizer_underscore_is_word_char() {
let vocab = fit_vocab(&["a_b cd"], CountVectorizer::new());
assert_eq!(vocab, vec!["a_b".to_string(), "cd".to_string()]);
}
#[test]
fn guard_max_df_float_threshold_no_ceil() {
let vocab = fit_vocab(
&["cat dog", "cat bird", "xx yy"],
CountVectorizer::new().max_df(0.5),
);
assert_eq!(
vocab,
vec![
"bird".to_string(),
"dog".to_string(),
"xx".to_string(),
"yy".to_string()
]
);
}
#[test]
fn guard_default_value_match() {
let docs = vec![
"the cat sat".to_string(),
"the cat sat on the mat".to_string(),
];
let fitted = CountVectorizer::new().fit(&docs).unwrap();
let mut vocab = fitted.vocabulary().to_vec();
vocab.sort();
assert_eq!(vocab, vec!["cat", "mat", "on", "sat", "the"]);
let counts = fitted.transform(&docs).unwrap();
let map = fitted.vocabulary_map();
let expect_doc0 = [
("cat", 1.0),
("mat", 0.0),
("on", 0.0),
("sat", 1.0),
("the", 1.0),
];
let expect_doc1 = [
("cat", 1.0),
("mat", 1.0),
("on", 1.0),
("sat", 1.0),
("the", 2.0),
];
for (term, v) in expect_doc0 {
assert_eq!(counts[[0, map[term]]], v, "doc0 term {term}");
}
for (term, v) in expect_doc1 {
assert_eq!(counts[[1, map[term]]], v, "doc1 term {term}");
}
}
#[test]
fn guard_binary_value_match() {
let docs = vec!["the the the cat".to_string()];
let fitted = CountVectorizer::new().binary(true).fit(&docs).unwrap();
let counts = fitted.transform(&docs).unwrap();
let map = fitted.vocabulary_map();
assert_eq!(counts[[0, map["the"]]], 1.0);
assert_eq!(counts[[0, map["cat"]]], 1.0);
}
#[test]
fn guard_lowercase_value_match() {
let vocab = fit_vocab(&["Hello HELLO hello world"], CountVectorizer::new());
assert_eq!(vocab, vec!["hello".to_string(), "world".to_string()]);
}
#[test]
fn guard_no_lowercase_value_match() {
let vocab = fit_vocab(
&["Hello hello world"],
CountVectorizer::new().lowercase(false),
);
assert_eq!(
vocab,
vec![
"Hello".to_string(),
"hello".to_string(),
"world".to_string()
]
);
}
#[test]
fn guard_min_df_absolute_count() {
let vocab = fit_vocab(
&["cat dog", "cat bird", "cat fish"],
CountVectorizer::new().min_df(3),
);
assert_eq!(vocab, vec!["cat".to_string()]);
}
#[test]
fn guard_max_features_topn_alpha_tiebreak() {
let docs = vec![
"cat cat cat dog dog bird ant".to_string(),
"cat dog bird".to_string(),
];
let fitted = CountVectorizer::new().max_features(3).fit(&docs).unwrap();
let mut vocab = fitted.vocabulary().to_vec();
vocab.sort();
assert_eq!(vocab, vec!["bird", "cat", "dog"]);
let counts = fitted.transform(&docs).unwrap();
let map = fitted.vocabulary_map();
assert_eq!(counts[[0, map["cat"]]], 3.0);
assert_eq!(counts[[0, map["dog"]]], 2.0);
assert_eq!(counts[[0, map["bird"]]], 1.0);
}