use crate::error::StatsResult;
use crate::nonparametric_bayes::hdp::{hdp_fit, hdp_perplexity, HdpConfig, HdpResult};
pub struct HdpTopicModel {
pub result: HdpResult,
pub vocabulary: Vec<String>,
}
impl HdpTopicModel {
pub fn fit(
documents: &[Vec<usize>],
vocab_size: usize,
config: HdpConfig,
) -> StatsResult<Self> {
let result = hdp_fit(documents, vocab_size, &config)?;
Ok(Self {
result,
vocabulary: Vec::new(),
})
}
pub fn fit_with_vocab(
documents: &[Vec<usize>],
vocabulary: Vec<String>,
config: HdpConfig,
) -> StatsResult<Self> {
let vocab_size = vocabulary.len();
let result = hdp_fit(documents, vocab_size, &config)?;
Ok(Self { result, vocabulary })
}
pub fn top_words(&self, topic_id: usize, n: usize) -> Vec<(usize, f64)> {
let k = self.result.topic_word_matrix.nrows();
if topic_id >= k {
return Vec::new();
}
let row = self.result.topic_word_matrix.row(topic_id);
let vocab_size = row.len();
let mut pairs: Vec<(usize, f64)> = (0..vocab_size).map(|v| (v, row[v])).collect();
pairs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
pairs.truncate(n);
pairs
}
pub fn top_word_strings(&self, topic_id: usize, n: usize) -> Vec<(String, f64)> {
self.top_words(topic_id, n)
.into_iter()
.map(|(idx, prob)| {
let word = if idx < self.vocabulary.len() {
self.vocabulary[idx].clone()
} else {
idx.to_string()
};
(word, prob)
})
.collect()
}
pub fn perplexity(&self, documents: &[Vec<usize>], config: &HdpConfig) -> f64 {
hdp_perplexity(documents, &self.result, config)
}
pub fn doc_topic_distribution(&self, doc_id: usize) -> &[f64] {
let d = self.result.doc_topic_matrix.nrows();
if doc_id >= d {
return &[];
}
self.result
.doc_topic_matrix
.row(doc_id)
.to_slice()
.unwrap_or(&[])
}
pub fn n_topics(&self) -> usize {
self.result.topic_word_matrix.nrows()
}
pub fn n_topics_used(&self) -> usize {
self.result.n_topics_used
}
}
#[cfg(test)]
mod tests {
use super::*;
fn small_corpus() -> Vec<Vec<usize>> {
vec![
vec![0, 1, 2, 0, 1],
vec![3, 4, 5, 3, 4],
vec![0, 2, 1, 0],
]
}
#[test]
fn test_fit_returns_model() {
let cfg = HdpConfig { n_topics: 4, n_iter: 10, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg);
assert!(model.is_ok());
}
#[test]
fn test_top_words_length() {
let cfg = HdpConfig { n_topics: 4, n_iter: 10, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg).expect("fit");
let top = model.top_words(0, 3);
assert_eq!(top.len(), 3);
}
#[test]
fn test_top_words_sorted_descending() {
let cfg = HdpConfig { n_topics: 4, n_iter: 20, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg).expect("fit");
let top = model.top_words(0, 6);
for i in 1..top.len() {
assert!(top[i - 1].1 >= top[i].1, "top_words not sorted at index {i}");
}
}
#[test]
fn test_top_words_out_of_range_topic() {
let cfg = HdpConfig { n_topics: 4, n_iter: 5, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg).expect("fit");
assert!(model.top_words(999, 5).is_empty());
}
#[test]
fn test_doc_topic_distribution_correct_length() {
let cfg = HdpConfig { n_topics: 4, n_iter: 10, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg).expect("fit");
let dist = model.doc_topic_distribution(0);
assert_eq!(dist.len(), 4);
}
#[test]
fn test_doc_topic_distribution_sums_to_one() {
let cfg = HdpConfig { n_topics: 4, n_iter: 10, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg).expect("fit");
for d in 0..small_corpus().len() {
let s: f64 = model.doc_topic_distribution(d).iter().sum();
assert!((s - 1.0).abs() < 1e-9, "doc {d} sum = {s}");
}
}
#[test]
fn test_doc_topic_out_of_range_returns_empty() {
let cfg = HdpConfig { n_topics: 4, n_iter: 5, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg).expect("fit");
assert!(model.doc_topic_distribution(999).is_empty());
}
#[test]
fn test_perplexity_is_finite() {
let corpus = small_corpus();
let cfg = HdpConfig { n_topics: 4, n_iter: 20, ..Default::default() };
let model = HdpTopicModel::fit(&corpus, 6, cfg.clone()).expect("fit");
let ppl = model.perplexity(&corpus, &cfg);
assert!(ppl.is_finite() && ppl > 0.0, "perplexity = {ppl}");
}
#[test]
fn test_fit_with_vocab_attaches_strings() {
let vocab: Vec<String> = (0..6).map(|i| format!("word{i}")).collect();
let cfg = HdpConfig { n_topics: 4, n_iter: 5, ..Default::default() };
let model = HdpTopicModel::fit_with_vocab(&small_corpus(), vocab, cfg).expect("fit");
let top = model.top_word_strings(0, 2);
assert_eq!(top.len(), 2);
assert!(top[0].0.starts_with("word"), "expected word string, got {}", top[0].0);
}
#[test]
fn test_n_topics_used() {
let cfg = HdpConfig { n_topics: 10, n_iter: 30, ..Default::default() };
let model = HdpTopicModel::fit(&small_corpus(), 6, cfg).expect("fit");
assert!(model.n_topics_used() <= model.n_topics());
}
}