use super::*;
#[test]
fn test_pseudo_random_deterministic() {
let lda1 = LatentDirichletAllocation::new(2).with_random_seed(42);
let lda2 = LatentDirichletAllocation::new(2).with_random_seed(42);
for idx in 0..50 {
assert_eq!(lda1.pseudo_random(idx), lda2.pseudo_random(idx));
}
}
#[test]
fn test_pseudo_random_different_seeds() {
let lda1 = LatentDirichletAllocation::new(2).with_random_seed(1);
let lda2 = LatentDirichletAllocation::new(2).with_random_seed(2);
let mut any_different = false;
for idx in 0..10 {
if (lda1.pseudo_random(idx) - lda2.pseudo_random(idx)).abs() > 1e-10 {
any_different = true;
break;
}
}
assert!(any_different);
}
#[test]
fn test_lda_fit_all_zero_document() {
let dtm = Matrix::from_vec(
3,
3,
vec![
1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0,
],
)
.expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
lda.fit(&dtm, 10).expect("fit should succeed");
let doc_topics = lda.document_topics().expect("topics");
assert_eq!(doc_topics.n_rows(), 3);
}
#[test]
fn test_lda_fit_uniform_distribution() {
let dtm = Matrix::from_vec(
3,
4,
vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
)
.expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
lda.fit(&dtm, 10).expect("fit should succeed");
let topic_words = lda.topic_words().expect("topic words");
for topic in 0..2 {
let mut sum = 0.0;
for word in 0..4 {
sum += topic_words.get(topic, word);
}
assert!((sum - 1.0).abs() < 0.01);
}
}
#[test]
fn test_lda_top_words_with_ties() {
let dtm = Matrix::from_vec(2, 4, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(1);
lda.fit(&dtm, 10).expect("fit");
let vocab = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let top_words = lda.top_words(&vocab, 2).expect("top words");
assert_eq!(top_words.len(), 1);
assert_eq!(top_words[0].len(), 2);
}
#[test]
fn test_lda_single_document() {
let dtm = Matrix::from_vec(1, 3, vec![1.0, 2.0, 3.0]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
lda.fit(&dtm, 5).expect("fit");
let doc_topics = lda.document_topics().expect("topics");
assert_eq!(doc_topics.n_rows(), 1);
assert_eq!(doc_topics.n_cols(), 2);
}
#[test]
fn test_lda_single_term() {
let dtm = Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
lda.fit(&dtm, 5).expect("fit");
let topic_words = lda.topic_words().expect("topic words");
assert_eq!(topic_words.n_cols(), 1);
for topic in 0..2 {
assert!((topic_words.get(topic, 0) - 1.0).abs() < 0.01);
}
}
#[test]
fn test_lda_topics_equals_terms() {
let dtm =
Matrix::from_vec(3, 3, vec![2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(3);
lda.fit(&dtm, 20).expect("fit");
let doc_topics = lda.document_topics().expect("topics");
assert_eq!(doc_topics.n_cols(), 3);
}
#[test]
fn test_lda_topics_exceeds_terms() {
let dtm = Matrix::from_vec(2, 2, vec![1.0, 2.0, 2.0, 1.0]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(5);
lda.fit(&dtm, 5).expect("fit");
let doc_topics = lda.document_topics().expect("topics");
assert_eq!(doc_topics.n_cols(), 5);
}
#[test]
fn test_lda_high_count_values() {
let dtm = Matrix::from_vec(2, 3, vec![100.0, 50.0, 10.0, 10.0, 50.0, 100.0]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
lda.fit(&dtm, 10).expect("fit");
let doc_topics = lda.document_topics().expect("topics");
for row in 0..2 {
let mut sum = 0.0;
for col in 0..2 {
let val = doc_topics.get(row, col);
assert!(val >= 0.0 && val <= 1.0);
sum += val;
}
assert!((sum - 1.0).abs() < 0.01);
}
}
#[test]
fn test_lda_small_count_values() {
let dtm =
Matrix::from_vec(2, 3, vec![0.001, 0.002, 0.001, 0.002, 0.001, 0.002]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
lda.fit(&dtm, 10).expect("fit");
let topic_words = lda.topic_words().expect("topic words");
assert_eq!(topic_words.n_rows(), 2);
}
#[test]
fn test_empty_rows_error() {
let dtm = Matrix::from_vec(0, 3, vec![]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
let result = lda.fit(&dtm, 5);
assert!(result.is_err());
}
#[test]
fn test_empty_cols_error() {
let dtm = Matrix::from_vec(3, 0, vec![]).expect("matrix");
let mut lda = LatentDirichletAllocation::new(2);
let result = lda.fit(&dtm, 5);
assert!(result.is_err());
}
#[test]
fn test_lda_n_topics_getter() {
let lda = LatentDirichletAllocation::new(7);
assert_eq!(lda.n_topics, 7);
}