use scirs2_core::random::prelude::*;
use scirs2_core::random::{rngs::StdRng, SeedableRng};
#[derive(Debug, thiserror::Error)]
pub enum TopicError {
#[error("empty corpus")]
EmptyCorpus,
#[error("word id {0} out of vocab range {1}")]
WordOutOfVocab(usize, usize),
}
#[derive(Debug, Clone)]
pub struct HdpConfig {
pub alpha: f64,
pub gamma: f64,
pub eta: f64,
pub n_iter: usize,
pub max_topics: usize,
pub seed: u64,
}
impl Default for HdpConfig {
fn default() -> Self {
HdpConfig {
alpha: 1.0,
gamma: 1.0,
eta: 0.1,
n_iter: 100,
max_topics: 20,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct HdpState {
pub n_topics: usize,
pub topic_word_counts: Vec<Vec<usize>>,
pub doc_topic_counts: Vec<Vec<usize>>,
pub word_assignments: Vec<Vec<usize>>,
}
pub struct Hdp {
config: HdpConfig,
state: HdpState,
vocab_size: usize,
n_docs: usize,
corpus: Vec<Vec<usize>>,
fitted: bool,
}
impl Hdp {
pub fn new(config: HdpConfig, n_docs: usize, vocab_size: usize) -> Self {
let t = config.max_topics;
Hdp {
config,
state: HdpState {
n_topics: 0,
topic_word_counts: vec![vec![0; vocab_size]; t],
doc_topic_counts: vec![vec![0; t]; n_docs],
word_assignments: Vec::new(),
},
vocab_size,
n_docs,
corpus: Vec::new(),
fitted: false,
}
}
pub fn fit(&mut self, corpus: &[Vec<usize>]) -> Result<(), TopicError> {
if corpus.is_empty() {
return Err(TopicError::EmptyCorpus);
}
for doc in corpus {
for &w in doc {
if w >= self.vocab_size {
return Err(TopicError::WordOutOfVocab(w, self.vocab_size));
}
}
}
self.corpus = corpus.to_vec();
self.n_docs = corpus.len();
let t = self.config.max_topics;
let voc = self.vocab_size;
self.state.topic_word_counts = vec![vec![0usize; voc]; t];
self.state.doc_topic_counts = vec![vec![0usize; t]; self.n_docs];
self.state.word_assignments = corpus.iter().map(|doc| vec![0usize; doc.len()]).collect();
let mut rng = StdRng::seed_from_u64(self.config.seed);
for (d, doc) in corpus.iter().enumerate() {
for (n, &w) in doc.iter().enumerate() {
let k = rng.random_range(0..t);
self.state.word_assignments[d][n] = k;
self.state.topic_word_counts[k][w] += 1;
self.state.doc_topic_counts[d][k] += 1;
}
}
let alpha = self.config.alpha;
let gamma = self.config.gamma;
for _iter in 0..self.config.n_iter {
for d in 0..self.n_docs {
for n in 0..corpus[d].len() {
let w = corpus[d][n];
hdp_gibbs_sample(
&mut self.state,
d,
n,
w,
alpha,
gamma,
self.vocab_size,
&mut rng,
);
}
}
}
let topic_totals: Vec<usize> = (0..t)
.map(|k| self.state.topic_word_counts[k].iter().sum())
.collect();
self.state.n_topics = topic_totals.iter().filter(|&&c| c > 0).count();
self.fitted = true;
Ok(())
}
pub fn topic_distribution(&self, topic: usize) -> Vec<f64> {
let eta = self.config.eta;
let eta_sum = eta * self.vocab_size as f64;
let counts = &self.state.topic_word_counts[topic];
let total: f64 = counts.iter().sum::<usize>() as f64 + eta_sum;
counts.iter().map(|&c| (c as f64 + eta) / total).collect()
}
pub fn document_distribution(&self, doc: usize) -> Vec<f64> {
let alpha = self.config.alpha;
let t = self.config.max_topics;
let counts = &self.state.doc_topic_counts[doc];
let total: f64 = counts.iter().sum::<usize>() as f64 + alpha;
counts
.iter()
.map(|&c| (c as f64 + alpha / t as f64) / total)
.collect()
}
pub fn active_topics(&self) -> usize {
self.state.n_topics
}
pub fn perplexity(&self) -> f64 {
let t = self.config.max_topics;
let eta = self.config.eta;
let eta_sum = eta * self.vocab_size as f64;
let alpha = self.config.alpha;
let mut total_ll = 0.0f64;
let mut total_tokens = 0usize;
for (d, doc) in self.corpus.iter().enumerate() {
let doc_total: f64 =
self.state.doc_topic_counts[d].iter().sum::<usize>() as f64 + alpha;
for &w in doc {
if w >= self.vocab_size {
continue;
}
let p_w: f64 = (0..t)
.map(|k| {
let theta_dk = (self.state.doc_topic_counts[d][k] as f64
+ alpha / t as f64)
/ doc_total;
let topic_total: f64 =
self.state.topic_word_counts[k].iter().sum::<usize>() as f64 + eta_sum;
let phi_kw =
(self.state.topic_word_counts[k][w] as f64 + eta) / topic_total;
theta_dk * phi_kw
})
.sum();
if p_w > 0.0 {
total_ll += p_w.ln();
}
total_tokens += 1;
}
}
if total_tokens == 0 {
return 1.0;
}
let avg_ll = total_ll / total_tokens as f64;
(-avg_ll).exp()
}
pub fn top_words(&self, topic: usize, k: usize) -> Vec<usize> {
let phi = self.topic_distribution(topic);
let mut indices: Vec<usize> = (0..phi.len()).collect();
indices.sort_by(|&a, &b| {
phi[b]
.partial_cmp(&phi[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
indices.truncate(k);
indices
}
pub fn state(&self) -> &HdpState {
&self.state
}
pub fn is_fitted(&self) -> bool {
self.fitted
}
}
impl std::fmt::Debug for Hdp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Hdp")
.field("max_topics", &self.config.max_topics)
.field("active_topics", &self.state.n_topics)
.field("vocab_size", &self.vocab_size)
.field("fitted", &self.fitted)
.finish()
}
}
fn hdp_gibbs_sample(
state: &mut HdpState,
doc: usize,
pos: usize,
word: usize,
alpha: f64,
_gamma: f64,
vocab_size: usize,
rng: &mut StdRng,
) {
let t = state.topic_word_counts.len();
let eta = 0.1_f64;
let eta_sum = eta * vocab_size as f64;
let k_old = state.word_assignments[doc][pos];
state.topic_word_counts[k_old][word] = state.topic_word_counts[k_old][word].saturating_sub(1);
state.doc_topic_counts[doc][k_old] = state.doc_topic_counts[doc][k_old].saturating_sub(1);
let mut probs = vec![0.0f64; t];
for k in 0..t {
let doc_factor = state.doc_topic_counts[doc][k] as f64 + alpha / t as f64;
let kw = state.topic_word_counts[k][word] as f64 + eta;
let k_total: f64 = state.topic_word_counts[k].iter().sum::<usize>() as f64 + eta_sum;
probs[k] = doc_factor * (kw / k_total);
}
let k_new = sample_categorical(&probs, rng);
state.word_assignments[doc][pos] = k_new;
state.topic_word_counts[k_new][word] += 1;
state.doc_topic_counts[doc][k_new] += 1;
}
fn sample_categorical(probs: &[f64], rng: &mut StdRng) -> usize {
let total: f64 = probs.iter().sum();
if total <= 0.0 {
return rng.random_range(0..probs.len());
}
let u: f64 = rng.random_range(0.0..total);
let mut cumulative = 0.0f64;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if u < cumulative {
return i;
}
}
probs.len() - 1
}
#[derive(Debug, Clone)]
pub struct HdpTopicConfig {
pub alpha: f64,
pub gamma: f64,
pub eta: f64,
pub t_max: usize,
pub n_iter: usize,
pub burn_in: usize,
pub seed: u64,
}
impl Default for HdpTopicConfig {
fn default() -> Self {
HdpTopicConfig {
alpha: 1.0,
gamma: 1.0,
eta: 0.1,
t_max: 50,
n_iter: 150,
burn_in: 50,
seed: 42,
}
}
}
pub struct HdpTopicModel {
pub phi: Vec<Vec<f64>>,
pub theta: Vec<Vec<f64>>,
k_inferred: usize,
vocab_size: usize,
t_max: usize,
eta: f64,
alpha: f64,
topic_word_counts: Vec<Vec<usize>>,
topic_counts: Vec<usize>,
}
impl HdpTopicModel {
pub fn fit(
corpus: &[Vec<usize>],
vocab_size: usize,
config: HdpTopicConfig,
) -> Result<Self, TopicError> {
if corpus.is_empty() {
return Err(TopicError::EmptyCorpus);
}
for doc in corpus {
for &w in doc {
if w >= vocab_size {
return Err(TopicError::WordOutOfVocab(w, vocab_size));
}
}
}
let t = config.t_max;
let n_docs = corpus.len();
let hdp_cfg = HdpConfig {
alpha: config.alpha,
gamma: config.gamma,
eta: config.eta,
n_iter: config.n_iter,
max_topics: t,
seed: config.seed,
};
let mut hdp = Hdp::new(hdp_cfg, n_docs, vocab_size);
hdp.fit(corpus)?;
let state = hdp.state();
let topic_word_counts: Vec<Vec<usize>> = state.topic_word_counts.clone();
let topic_counts: Vec<usize> = topic_word_counts
.iter()
.map(|row| row.iter().sum())
.collect();
let k_inferred = topic_counts.iter().filter(|&&c| c > 0).count().max(1);
let eta = config.eta;
let eta_sum = eta * vocab_size as f64;
let alpha = config.alpha;
let phi: Vec<Vec<f64>> = (0..t)
.map(|k| {
let total = topic_counts[k] as f64 + eta_sum;
(0..vocab_size)
.map(|w| (topic_word_counts[k][w] as f64 + eta) / total)
.collect()
})
.collect();
let doc_topic_counts = &state.doc_topic_counts;
let theta: Vec<Vec<f64>> = (0..n_docs)
.map(|d| {
let doc_total: f64 = doc_topic_counts[d].iter().sum::<usize>() as f64 + alpha;
(0..t)
.map(|k| (doc_topic_counts[d][k] as f64 + alpha / t as f64) / doc_total)
.collect()
})
.collect();
Ok(HdpTopicModel {
phi,
theta,
k_inferred,
vocab_size,
t_max: t,
eta,
alpha,
topic_word_counts,
topic_counts,
})
}
pub fn transform(&self, doc: &[usize]) -> Vec<f64> {
let t = self.t_max;
let eta = self.eta;
let eta_sum = eta * self.vocab_size as f64;
let mut theta_doc = vec![self.alpha / t as f64; t];
for &w in doc {
if w >= self.vocab_size {
continue;
}
let mut word_probs: Vec<f64> = (0..t)
.map(|k| {
theta_doc[k] * (self.topic_word_counts[k][w] as f64 + eta)
/ (self.topic_counts[k] as f64 + eta_sum)
})
.collect();
let sum: f64 = word_probs.iter().sum();
if sum > 0.0 {
word_probs.iter_mut().for_each(|p| *p /= sum);
for k in 0..t {
theta_doc[k] += word_probs[k];
}
}
}
let total: f64 = theta_doc.iter().sum();
if total > 0.0 {
theta_doc.iter_mut().for_each(|p| *p /= total);
}
theta_doc
}
pub fn topics(&self) -> &[Vec<f64>] {
&self.phi
}
pub fn num_topics_inferred(&self) -> usize {
self.k_inferred
}
}
impl std::fmt::Debug for HdpTopicModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HdpTopicModel")
.field("t_max", &self.t_max)
.field("k_inferred", &self.k_inferred)
.field("vocab_size", &self.vocab_size)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_corpus(n_per_topic: usize, seed: u64) -> Vec<Vec<usize>> {
let mut rng = StdRng::seed_from_u64(seed);
let mut corpus = Vec::new();
for _ in 0..n_per_topic {
corpus.push((0..20).map(|_| rng.random_range(0..5)).collect());
}
for _ in 0..n_per_topic {
corpus.push((0..20).map(|_| rng.random_range(5..10)).collect());
}
for _ in 0..n_per_topic {
corpus.push((0..20).map(|_| rng.random_range(10..15)).collect());
}
corpus
}
#[test]
fn active_topics_in_valid_range() {
let corpus = make_corpus(10, 1);
let config = HdpConfig {
n_iter: 20,
max_topics: 15,
seed: 42,
..Default::default()
};
let mut model = Hdp::new(config, corpus.len(), 15);
model.fit(&corpus).expect("fit must succeed");
let active = model.active_topics();
assert!(active >= 1, "active topics must be >= 1, got {active}");
assert!(
active <= 15,
"active topics ({active}) must be <= max_topics (15)"
);
}
#[test]
fn topic_distribution_sums_to_one() {
let corpus = make_corpus(8, 2);
let config = HdpConfig {
n_iter: 10,
seed: 7,
..Default::default()
};
let mut model = Hdp::new(config, corpus.len(), 15);
model.fit(&corpus).expect("fit must succeed");
let dist = model.topic_distribution(0);
let sum: f64 = dist.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-9,
"topic_distribution must sum to 1.0, got {sum}"
);
}
#[test]
fn document_distribution_sums_to_one() {
let corpus = make_corpus(8, 3);
let config = HdpConfig {
n_iter: 10,
seed: 11,
..Default::default()
};
let mut model = Hdp::new(config, corpus.len(), 15);
model.fit(&corpus).expect("fit must succeed");
let dist = model.document_distribution(0);
let sum: f64 = dist.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-9,
"document_distribution must sum to 1.0, got {sum}"
);
}
#[test]
fn perplexity_is_finite_positive() {
let corpus = make_corpus(8, 4);
let config = HdpConfig {
n_iter: 15,
seed: 99,
..Default::default()
};
let mut model = Hdp::new(config, corpus.len(), 15);
model.fit(&corpus).expect("fit must succeed");
let pp = model.perplexity();
assert!(pp.is_finite(), "perplexity must be finite, got {pp}");
assert!(pp > 0.0, "perplexity must be positive, got {pp}");
}
#[test]
fn top_words_returns_k_distinct_indices() {
let corpus = make_corpus(10, 5);
let config = HdpConfig {
n_iter: 15,
seed: 55,
..Default::default()
};
let mut model = Hdp::new(config, corpus.len(), 15);
model.fit(&corpus).expect("fit must succeed");
let top5 = model.top_words(0, 5);
let mut sorted = top5.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(
sorted.len(),
top5.len(),
"top_words must contain distinct indices"
);
for &w in &top5 {
assert!(w < 15, "word index {w} must be < vocab_size 15");
}
}
#[test]
fn fit_empty_corpus_returns_error() {
let mut model = Hdp::new(HdpConfig::default(), 0, 10);
let result = model.fit(&[]);
assert!(
result.is_err(),
"fit on empty corpus must return TopicError"
);
}
#[test]
fn fit_out_of_vocab_returns_error() {
let corpus = vec![vec![0usize, 1, 99]]; let mut model = Hdp::new(HdpConfig::default(), 1, 5);
let result = model.fit(&corpus);
assert!(
result.is_err(),
"fit with OOV word must return TopicError::WordOutOfVocab"
);
}
#[test]
fn top_words_all_nontrivial() {
let corpus = make_corpus(6, 6);
let config = HdpConfig {
n_iter: 10,
seed: 77,
max_topics: 10,
..Default::default()
};
let mut model = Hdp::new(config, corpus.len(), 15);
model.fit(&corpus).expect("fit must succeed");
for k in 0..10 {
for &w in &model.top_words(k, 3) {
assert!(w < 15, "top word index {w} must be in vocab");
}
}
}
}