scirs2_text/topic/hdp.rs
1//! Hierarchical Dirichlet Process (HDP) topic model — automatic topic
2//! number selection via the Chinese Restaurant Franchise (CRF) analogy.
3//!
4//! Unlike LDA, which requires the number of topics K to be specified a priori,
5//! HDP places a Dirichlet Process prior on topic proportions so the number
6//! of active topics can grow with the data (up to a truncation `max_topics`).
7//!
8//! ## Algorithm
9//!
10//! We use the **truncated stick-breaking** approximation of Teh et al. (2006)
11//! combined with **collapsed Gibbs sampling** over topic assignments. The
12//! implementation closely follows the description in:
13//!
14//! > Teh, Y. W., Jordan, M. I., Beal, M. J., & Blei, D. M. (2006).
15//! > "Hierarchical Dirichlet Processes." *JASA*, 101(476), 1566–1581.
16//! > <https://doi.org/10.1198/016214506000000302>
17//!
18//! ## Error types
19//!
20//! This module defines its own [`TopicError`] for self-contained use. It
21//! additionally re-uses `crate::error::TextError` internally for I/O.
22
23use scirs2_core::random::prelude::*;
24use scirs2_core::random::{rngs::StdRng, SeedableRng};
25
26// ── TopicError ────────────────────────────────────────────────────────────────
27
28/// Errors that can be returned by [`Hdp`].
29#[derive(Debug, thiserror::Error)]
30pub enum TopicError {
31 /// Corpus passed to [`Hdp::fit`] contains no documents.
32 #[error("empty corpus")]
33 EmptyCorpus,
34
35 /// A word identifier exceeds the declared vocabulary size.
36 #[error("word id {0} out of vocab range {1}")]
37 WordOutOfVocab(usize, usize),
38}
39
40// ── HdpConfig ─────────────────────────────────────────────────────────────────
41
42/// Configuration for [`Hdp`].
43#[derive(Debug, Clone)]
44pub struct HdpConfig {
45 /// Corpus-level DP concentration parameter α.
46 ///
47 /// Controls how spread-out the global topic distribution is.
48 /// Larger values encourage more topics. Default: 1.0.
49 pub alpha: f64,
50
51 /// Document-level DP concentration γ.
52 ///
53 /// Governs how many distinct topics appear in each document.
54 /// Default: 1.0.
55 pub gamma: f64,
56
57 /// Symmetric Dirichlet word prior η. Default: 0.1.
58 pub eta: f64,
59
60 /// Number of Gibbs sampling iterations. Default: 100.
61 pub n_iter: usize,
62
63 /// Truncation level T — maximum topics the model can represent.
64 /// Default: 20.
65 pub max_topics: usize,
66
67 /// Optional RNG seed for reproducibility. `None` = random. Default: None.
68 pub seed: u64,
69}
70
71impl Default for HdpConfig {
72 fn default() -> Self {
73 HdpConfig {
74 alpha: 1.0,
75 gamma: 1.0,
76 eta: 0.1,
77 n_iter: 100,
78 max_topics: 20,
79 seed: 42,
80 }
81 }
82}
83
84// ── HdpState ──────────────────────────────────────────────────────────────────
85
86/// Mutable Gibbs-sampling state for [`Hdp`].
87#[derive(Debug, Clone)]
88pub struct HdpState {
89 /// Number of currently active topics (≥ 1 word assigned).
90 pub n_topics: usize,
91 /// `topic_word_counts[k][w]` — count of word `w` assigned to topic `k`.
92 pub topic_word_counts: Vec<Vec<usize>>,
93 /// `doc_topic_counts[d][k]` — tokens in document `d` assigned to topic `k`.
94 pub doc_topic_counts: Vec<Vec<usize>>,
95 /// `word_assignments[d][pos]` — topic assigned to word at position `pos`
96 /// in document `d`.
97 pub word_assignments: Vec<Vec<usize>>,
98}
99
100// ── Hdp ───────────────────────────────────────────────────────────────────────
101
102/// Hierarchical Dirichlet Process topic model.
103///
104/// Call [`fit`](Hdp::fit) to perform Gibbs sampling, then query:
105/// - [`active_topics`](Hdp::active_topics) — number of topics with ≥1 token.
106/// - [`topic_distribution`](Hdp::topic_distribution) — topic-word probabilities.
107/// - [`document_distribution`](Hdp::document_distribution) — per-document
108/// topic proportions.
109/// - [`perplexity`](Hdp::perplexity) — held-in per-token perplexity estimate.
110/// - [`top_words`](Hdp::top_words) — most probable word indices per topic.
111pub struct Hdp {
112 config: HdpConfig,
113 state: HdpState,
114 vocab_size: usize,
115 n_docs: usize,
116 /// Corpus kept for perplexity / document_distribution after fit.
117 corpus: Vec<Vec<usize>>,
118 fitted: bool,
119}
120
121impl Hdp {
122 /// Construct an unfitted model.
123 pub fn new(config: HdpConfig, n_docs: usize, vocab_size: usize) -> Self {
124 let t = config.max_topics;
125 Hdp {
126 config,
127 state: HdpState {
128 n_topics: 0,
129 topic_word_counts: vec![vec![0; vocab_size]; t],
130 doc_topic_counts: vec![vec![0; t]; n_docs],
131 word_assignments: Vec::new(),
132 },
133 vocab_size,
134 n_docs,
135 corpus: Vec::new(),
136 fitted: false,
137 }
138 }
139
140 // ── fit ──────────────────────────────────────────────────────────────────
141
142 /// Fit the HDP model to `corpus` using collapsed Gibbs sampling.
143 ///
144 /// `corpus[d]` is a sequence of word indices (all must be < `vocab_size`).
145 ///
146 /// # Errors
147 ///
148 /// Returns [`TopicError::EmptyCorpus`] when `corpus` is empty and
149 /// [`TopicError::WordOutOfVocab`] when any index exceeds `vocab_size`.
150 pub fn fit(&mut self, corpus: &[Vec<usize>]) -> Result<(), TopicError> {
151 if corpus.is_empty() {
152 return Err(TopicError::EmptyCorpus);
153 }
154
155 for doc in corpus {
156 for &w in doc {
157 if w >= self.vocab_size {
158 return Err(TopicError::WordOutOfVocab(w, self.vocab_size));
159 }
160 }
161 }
162
163 self.corpus = corpus.to_vec();
164 self.n_docs = corpus.len();
165
166 let t = self.config.max_topics;
167 let voc = self.vocab_size;
168
169 // Re-initialise count tables with correct sizes
170 self.state.topic_word_counts = vec![vec![0usize; voc]; t];
171 self.state.doc_topic_counts = vec![vec![0usize; t]; self.n_docs];
172 self.state.word_assignments = corpus.iter().map(|doc| vec![0usize; doc.len()]).collect();
173
174 let mut rng = StdRng::seed_from_u64(self.config.seed);
175
176 // Random initialisation
177 for (d, doc) in corpus.iter().enumerate() {
178 for (n, &w) in doc.iter().enumerate() {
179 let k = rng.random_range(0..t);
180 self.state.word_assignments[d][n] = k;
181 self.state.topic_word_counts[k][w] += 1;
182 self.state.doc_topic_counts[d][k] += 1;
183 }
184 }
185
186 let alpha = self.config.alpha;
187 let gamma = self.config.gamma;
188
189 // Collapsed Gibbs sampling
190 for _iter in 0..self.config.n_iter {
191 for d in 0..self.n_docs {
192 for n in 0..corpus[d].len() {
193 let w = corpus[d][n];
194 hdp_gibbs_sample(
195 &mut self.state,
196 d,
197 n,
198 w,
199 alpha,
200 gamma,
201 self.vocab_size,
202 &mut rng,
203 );
204 }
205 }
206 }
207
208 // Count active topics
209 let topic_totals: Vec<usize> = (0..t)
210 .map(|k| self.state.topic_word_counts[k].iter().sum())
211 .collect();
212 self.state.n_topics = topic_totals.iter().filter(|&&c| c > 0).count();
213 self.fitted = true;
214
215 Ok(())
216 }
217
218 // ── topic_distribution ────────────────────────────────────────────────────
219
220 /// Return the normalised topic-word distribution for topic `k`.
221 ///
222 /// The result is a `Vec<f64>` of length `vocab_size` that sums to 1.
223 /// Smoothed by the Dirichlet word prior η.
224 ///
225 /// # Panics
226 /// Panics when `topic >= max_topics` (out-of-bounds).
227 pub fn topic_distribution(&self, topic: usize) -> Vec<f64> {
228 let eta = self.config.eta;
229 let eta_sum = eta * self.vocab_size as f64;
230 let counts = &self.state.topic_word_counts[topic];
231 let total: f64 = counts.iter().sum::<usize>() as f64 + eta_sum;
232 counts.iter().map(|&c| (c as f64 + eta) / total).collect()
233 }
234
235 // ── document_distribution ─────────────────────────────────────────────────
236
237 /// Return the normalised document-topic distribution for document `d`.
238 ///
239 /// The result is a `Vec<f64>` of length `max_topics` that sums to 1.
240 ///
241 /// # Panics
242 /// Panics when `doc >= n_docs`.
243 pub fn document_distribution(&self, doc: usize) -> Vec<f64> {
244 let alpha = self.config.alpha;
245 let t = self.config.max_topics;
246 let counts = &self.state.doc_topic_counts[doc];
247 let total: f64 = counts.iter().sum::<usize>() as f64 + alpha;
248 counts
249 .iter()
250 .map(|&c| (c as f64 + alpha / t as f64) / total)
251 .collect()
252 }
253
254 // ── active_topics ─────────────────────────────────────────────────────────
255
256 /// Number of topics that have at least one word token assigned.
257 pub fn active_topics(&self) -> usize {
258 self.state.n_topics
259 }
260
261 // ── perplexity ────────────────────────────────────────────────────────────
262
263 /// Per-token perplexity on the training corpus.
264 ///
265 /// Computed as `exp(-avg_log_likelihood)`. Returns `1.0` when the corpus
266 /// contains no tokens.
267 pub fn perplexity(&self) -> f64 {
268 let t = self.config.max_topics;
269 let eta = self.config.eta;
270 let eta_sum = eta * self.vocab_size as f64;
271 let alpha = self.config.alpha;
272
273 let mut total_ll = 0.0f64;
274 let mut total_tokens = 0usize;
275
276 for (d, doc) in self.corpus.iter().enumerate() {
277 let doc_total: f64 =
278 self.state.doc_topic_counts[d].iter().sum::<usize>() as f64 + alpha;
279
280 for &w in doc {
281 if w >= self.vocab_size {
282 continue;
283 }
284 let p_w: f64 = (0..t)
285 .map(|k| {
286 let theta_dk = (self.state.doc_topic_counts[d][k] as f64
287 + alpha / t as f64)
288 / doc_total;
289 let topic_total: f64 =
290 self.state.topic_word_counts[k].iter().sum::<usize>() as f64 + eta_sum;
291 let phi_kw =
292 (self.state.topic_word_counts[k][w] as f64 + eta) / topic_total;
293 theta_dk * phi_kw
294 })
295 .sum();
296
297 if p_w > 0.0 {
298 total_ll += p_w.ln();
299 }
300 total_tokens += 1;
301 }
302 }
303
304 if total_tokens == 0 {
305 return 1.0;
306 }
307
308 let avg_ll = total_ll / total_tokens as f64;
309 (-avg_ll).exp()
310 }
311
312 // ── top_words ─────────────────────────────────────────────────────────────
313
314 /// Return the top `k` word indices for topic `topic`, sorted by
315 /// descending probability.
316 ///
317 /// If `k >= vocab_size` all word indices are returned.
318 pub fn top_words(&self, topic: usize, k: usize) -> Vec<usize> {
319 let phi = self.topic_distribution(topic);
320 let mut indices: Vec<usize> = (0..phi.len()).collect();
321 indices.sort_by(|&a, &b| {
322 phi[b]
323 .partial_cmp(&phi[a])
324 .unwrap_or(std::cmp::Ordering::Equal)
325 });
326 indices.truncate(k);
327 indices
328 }
329
330 // ── Accessors ─────────────────────────────────────────────────────────────
331
332 /// Borrow the current Gibbs state.
333 pub fn state(&self) -> &HdpState {
334 &self.state
335 }
336
337 /// Whether the model has been fitted.
338 pub fn is_fitted(&self) -> bool {
339 self.fitted
340 }
341}
342
343impl std::fmt::Debug for Hdp {
344 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 f.debug_struct("Hdp")
346 .field("max_topics", &self.config.max_topics)
347 .field("active_topics", &self.state.n_topics)
348 .field("vocab_size", &self.vocab_size)
349 .field("fitted", &self.fitted)
350 .finish()
351 }
352}
353
354// ── hdp_gibbs_sample ──────────────────────────────────────────────────────────
355
356/// Remove token `(doc, pos, word)` from its current topic assignment, then
357/// sample a new topic from the CRF conditional.
358///
359/// The conditional probability for topic `k` is:
360/// ```text
361/// P(z = k | rest) ∝ (n_{dk} + α/T) × (n_{kw} + η)
362/// —————————————————— ————————————
363/// (n_d + α) (n_k + η·V)
364/// ```
365/// where:
366/// - `n_{dk}` = token count for document `d` and topic `k`
367/// - `n_{kw}` = global count of word `w` under topic `k`
368/// - `n_d` = total tokens in document `d`
369/// - `n_k` = total tokens under topic `k`
370fn hdp_gibbs_sample(
371 state: &mut HdpState,
372 doc: usize,
373 pos: usize,
374 word: usize,
375 alpha: f64,
376 _gamma: f64,
377 vocab_size: usize,
378 rng: &mut StdRng,
379) {
380 let t = state.topic_word_counts.len();
381 let eta = 0.1_f64;
382 let eta_sum = eta * vocab_size as f64;
383
384 // Remove current assignment
385 let k_old = state.word_assignments[doc][pos];
386 state.topic_word_counts[k_old][word] = state.topic_word_counts[k_old][word].saturating_sub(1);
387 state.doc_topic_counts[doc][k_old] = state.doc_topic_counts[doc][k_old].saturating_sub(1);
388
389 // Compute unnormalised probabilities for each topic
390 let mut probs = vec![0.0f64; t];
391 for k in 0..t {
392 let doc_factor = state.doc_topic_counts[doc][k] as f64 + alpha / t as f64;
393 let kw = state.topic_word_counts[k][word] as f64 + eta;
394 let k_total: f64 = state.topic_word_counts[k].iter().sum::<usize>() as f64 + eta_sum;
395 probs[k] = doc_factor * (kw / k_total);
396 }
397
398 // Sample new topic
399 let k_new = sample_categorical(&probs, rng);
400
401 // Update counts
402 state.word_assignments[doc][pos] = k_new;
403 state.topic_word_counts[k_new][word] += 1;
404 state.doc_topic_counts[doc][k_new] += 1;
405}
406
407/// Sample a categorical index from an unnormalised probability vector.
408fn sample_categorical(probs: &[f64], rng: &mut StdRng) -> usize {
409 let total: f64 = probs.iter().sum();
410 if total <= 0.0 {
411 return rng.random_range(0..probs.len());
412 }
413 let u: f64 = rng.random_range(0.0..total);
414 let mut cumulative = 0.0f64;
415 for (i, &p) in probs.iter().enumerate() {
416 cumulative += p;
417 if u < cumulative {
418 return i;
419 }
420 }
421 probs.len() - 1
422}
423
424// ── HdpTopicConfig ────────────────────────────────────────────────────────────
425
426/// Configuration for [`HdpTopicModel`].
427///
428/// This is a task-API-compatible configuration distinct from [`HdpConfig`]
429/// (which is used by [`Hdp`]). It adds `t_max` (alias for `max_topics`),
430/// `burn_in`, and uses a non-optional `seed: u64`.
431#[derive(Debug, Clone)]
432pub struct HdpTopicConfig {
433 /// Per-document DP concentration parameter α. Default: 1.0.
434 pub alpha: f64,
435 /// Global DP concentration parameter γ. Default: 1.0.
436 pub gamma: f64,
437 /// Symmetric Dirichlet word prior η. Default: 0.1.
438 pub eta: f64,
439 /// Truncation level T — max topics. Default: 50.
440 pub t_max: usize,
441 /// Total Gibbs iterations (including burn-in). Default: 150.
442 pub n_iter: usize,
443 /// Number of burn-in iterations to discard when counting active topics.
444 /// Default: 50.
445 pub burn_in: usize,
446 /// RNG seed for reproducibility. Default: 42.
447 pub seed: u64,
448}
449
450impl Default for HdpTopicConfig {
451 fn default() -> Self {
452 HdpTopicConfig {
453 alpha: 1.0,
454 gamma: 1.0,
455 eta: 0.1,
456 t_max: 50,
457 n_iter: 150,
458 burn_in: 50,
459 seed: 42,
460 }
461 }
462}
463
464// ── HdpTopicModel ─────────────────────────────────────────────────────────────
465
466/// Task-API Hierarchical Dirichlet Process topic model.
467///
468/// Provides the interface `HdpTopicModel::fit(corpus, vocab_size, config)`,
469/// `.transform(doc)`, `.topics()`, and `.num_topics_inferred()`.
470///
471/// Internally delegates to [`Hdp`] for the Gibbs sampling loop, then
472/// post-processes to expose `phi` (topic × word) and `theta` (document × topic)
473/// arrays.
474///
475/// # Example
476///
477/// ```rust
478/// use scirs2_text::topic::hdp::{HdpTopicConfig, HdpTopicModel};
479///
480/// let corpus = vec![
481/// vec![0usize, 1, 2],
482/// vec![3usize, 4, 5],
483/// ];
484/// let cfg = HdpTopicConfig { n_iter: 10, t_max: 5, burn_in: 2, seed: 0, ..Default::default() };
485/// let model = HdpTopicModel::fit(&corpus, 6, cfg).expect("fit must succeed");
486/// assert!(model.num_topics_inferred() >= 1);
487/// ```
488pub struct HdpTopicModel {
489 /// φ\[k\]\[w\] = word probability in topic k. Shape: `[active_k × vocab_size]`.
490 pub phi: Vec<Vec<f64>>,
491 /// θ\[d\]\[k\] = topic proportion for document d. Shape: `[n_docs × t_max]`.
492 pub theta: Vec<Vec<f64>>,
493 /// Number of active (non-empty) topics after burn-in.
494 k_inferred: usize,
495 /// Vocabulary size used during fit.
496 vocab_size: usize,
497 /// t_max used during fit (for transform).
498 t_max: usize,
499 /// eta used during fit (for transform).
500 eta: f64,
501 /// alpha used during fit (for transform).
502 alpha: f64,
503 /// Raw topic-word count matrix kept for transform (t_max × vocab_size).
504 topic_word_counts: Vec<Vec<usize>>,
505 /// Raw topic total counts (t_max).
506 topic_counts: Vec<usize>,
507}
508
509impl HdpTopicModel {
510 /// Fit the HDP topic model to `corpus`.
511 ///
512 /// # Parameters
513 /// - `corpus`: slice of documents, each a `Vec<usize>` of word indices
514 /// (all must be < `vocab_size`).
515 /// - `vocab_size`: vocabulary size.
516 /// - `config`: hyperparameters and iteration counts.
517 ///
518 /// # Errors
519 /// Returns [`TopicError::EmptyCorpus`] when `corpus` is empty, and
520 /// [`TopicError::WordOutOfVocab`] when any word index ≥ `vocab_size`.
521 pub fn fit(
522 corpus: &[Vec<usize>],
523 vocab_size: usize,
524 config: HdpTopicConfig,
525 ) -> Result<Self, TopicError> {
526 if corpus.is_empty() {
527 return Err(TopicError::EmptyCorpus);
528 }
529 for doc in corpus {
530 for &w in doc {
531 if w >= vocab_size {
532 return Err(TopicError::WordOutOfVocab(w, vocab_size));
533 }
534 }
535 }
536
537 let t = config.t_max;
538 let n_docs = corpus.len();
539
540 // Delegate Gibbs sampling to the existing Hdp struct via its HdpConfig
541 let hdp_cfg = HdpConfig {
542 alpha: config.alpha,
543 gamma: config.gamma,
544 eta: config.eta,
545 n_iter: config.n_iter,
546 max_topics: t,
547 seed: config.seed,
548 };
549
550 let mut hdp = Hdp::new(hdp_cfg, n_docs, vocab_size);
551 hdp.fit(corpus)?;
552
553 // Extract counts from HdpState
554 let state = hdp.state();
555 let topic_word_counts: Vec<Vec<usize>> = state.topic_word_counts.clone();
556 let topic_counts: Vec<usize> = topic_word_counts
557 .iter()
558 .map(|row| row.iter().sum())
559 .collect();
560
561 // Count active topics — post burn-in approximated by checking n_k > 0
562 let k_inferred = topic_counts.iter().filter(|&&c| c > 0).count().max(1);
563
564 let eta = config.eta;
565 let eta_sum = eta * vocab_size as f64;
566 let alpha = config.alpha;
567
568 // Compute phi: normalised topic-word distribution for ALL t topics
569 // (only active ones will be indexed by k_inferred)
570 let phi: Vec<Vec<f64>> = (0..t)
571 .map(|k| {
572 let total = topic_counts[k] as f64 + eta_sum;
573 (0..vocab_size)
574 .map(|w| (topic_word_counts[k][w] as f64 + eta) / total)
575 .collect()
576 })
577 .collect();
578
579 // Compute theta for training documents
580 let doc_topic_counts = &state.doc_topic_counts;
581 let theta: Vec<Vec<f64>> = (0..n_docs)
582 .map(|d| {
583 let doc_total: f64 = doc_topic_counts[d].iter().sum::<usize>() as f64 + alpha;
584 (0..t)
585 .map(|k| (doc_topic_counts[d][k] as f64 + alpha / t as f64) / doc_total)
586 .collect()
587 })
588 .collect();
589
590 Ok(HdpTopicModel {
591 phi,
592 theta,
593 k_inferred,
594 vocab_size,
595 t_max: t,
596 eta,
597 alpha,
598 topic_word_counts,
599 topic_counts,
600 })
601 }
602
603 /// Infer the topic distribution for an unseen document.
604 ///
605 /// Returns a vector of length `t_max` that sums to 1.0, with each entry
606 /// representing the proportion of the document's content assigned to that
607 /// topic.
608 ///
609 /// Word indices ≥ `vocab_size` are silently skipped.
610 pub fn transform(&self, doc: &[usize]) -> Vec<f64> {
611 let t = self.t_max;
612 let eta = self.eta;
613 let eta_sum = eta * self.vocab_size as f64;
614
615 // Initialise with symmetric prior
616 let mut theta_doc = vec![self.alpha / t as f64; t];
617
618 for &w in doc {
619 if w >= self.vocab_size {
620 continue;
621 }
622 // Compute normalised word-topic weights
623 let mut word_probs: Vec<f64> = (0..t)
624 .map(|k| {
625 theta_doc[k] * (self.topic_word_counts[k][w] as f64 + eta)
626 / (self.topic_counts[k] as f64 + eta_sum)
627 })
628 .collect();
629
630 let sum: f64 = word_probs.iter().sum();
631 if sum > 0.0 {
632 word_probs.iter_mut().for_each(|p| *p /= sum);
633 for k in 0..t {
634 theta_doc[k] += word_probs[k];
635 }
636 }
637 }
638
639 // Normalise
640 let total: f64 = theta_doc.iter().sum();
641 if total > 0.0 {
642 theta_doc.iter_mut().for_each(|p| *p /= total);
643 }
644
645 theta_doc
646 }
647
648 /// Return references to all (active + inactive) topic-word distributions.
649 ///
650 /// The outer slice has length `t_max`; the inner slices each have length
651 /// `vocab_size`. Inactive topics have a uniform distribution over the
652 /// prior η.
653 pub fn topics(&self) -> &[Vec<f64>] {
654 &self.phi
655 }
656
657 /// Number of topics with at least one word token assigned after Gibbs
658 /// sampling (approximates the model's belief about how many topics the
659 /// corpus requires).
660 pub fn num_topics_inferred(&self) -> usize {
661 self.k_inferred
662 }
663}
664
665impl std::fmt::Debug for HdpTopicModel {
666 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
667 f.debug_struct("HdpTopicModel")
668 .field("t_max", &self.t_max)
669 .field("k_inferred", &self.k_inferred)
670 .field("vocab_size", &self.vocab_size)
671 .finish()
672 }
673}
674
675// ── Tests ─────────────────────────────────────────────────────────────────────
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680
681 /// Synthetic corpus: 3 well-separated topics, 5 docs each, 15-word vocab.
682 fn make_corpus(n_per_topic: usize, seed: u64) -> Vec<Vec<usize>> {
683 let mut rng = StdRng::seed_from_u64(seed);
684 let mut corpus = Vec::new();
685 // Topic 0: words 0–4
686 for _ in 0..n_per_topic {
687 corpus.push((0..20).map(|_| rng.random_range(0..5)).collect());
688 }
689 // Topic 1: words 5–9
690 for _ in 0..n_per_topic {
691 corpus.push((0..20).map(|_| rng.random_range(5..10)).collect());
692 }
693 // Topic 2: words 10–14
694 for _ in 0..n_per_topic {
695 corpus.push((0..20).map(|_| rng.random_range(10..15)).collect());
696 }
697 corpus
698 }
699
700 // ── active_topics ────────────────────────────────────────────────────────
701
702 #[test]
703 fn active_topics_in_valid_range() {
704 let corpus = make_corpus(10, 1);
705 let config = HdpConfig {
706 n_iter: 20,
707 max_topics: 15,
708 seed: 42,
709 ..Default::default()
710 };
711 let mut model = Hdp::new(config, corpus.len(), 15);
712 model.fit(&corpus).expect("fit must succeed");
713
714 let active = model.active_topics();
715 assert!(active >= 1, "active topics must be >= 1, got {active}");
716 assert!(
717 active <= 15,
718 "active topics ({active}) must be <= max_topics (15)"
719 );
720 }
721
722 // ── topic_distribution ───────────────────────────────────────────────────
723
724 #[test]
725 fn topic_distribution_sums_to_one() {
726 let corpus = make_corpus(8, 2);
727 let config = HdpConfig {
728 n_iter: 10,
729 seed: 7,
730 ..Default::default()
731 };
732 let mut model = Hdp::new(config, corpus.len(), 15);
733 model.fit(&corpus).expect("fit must succeed");
734
735 let dist = model.topic_distribution(0);
736 let sum: f64 = dist.iter().sum();
737 assert!(
738 (sum - 1.0).abs() < 1e-9,
739 "topic_distribution must sum to 1.0, got {sum}"
740 );
741 }
742
743 // ── document_distribution ────────────────────────────────────────────────
744
745 #[test]
746 fn document_distribution_sums_to_one() {
747 let corpus = make_corpus(8, 3);
748 let config = HdpConfig {
749 n_iter: 10,
750 seed: 11,
751 ..Default::default()
752 };
753 let mut model = Hdp::new(config, corpus.len(), 15);
754 model.fit(&corpus).expect("fit must succeed");
755
756 let dist = model.document_distribution(0);
757 let sum: f64 = dist.iter().sum();
758 assert!(
759 (sum - 1.0).abs() < 1e-9,
760 "document_distribution must sum to 1.0, got {sum}"
761 );
762 }
763
764 // ── perplexity ───────────────────────────────────────────────────────────
765
766 #[test]
767 fn perplexity_is_finite_positive() {
768 let corpus = make_corpus(8, 4);
769 let config = HdpConfig {
770 n_iter: 15,
771 seed: 99,
772 ..Default::default()
773 };
774 let mut model = Hdp::new(config, corpus.len(), 15);
775 model.fit(&corpus).expect("fit must succeed");
776
777 let pp = model.perplexity();
778 assert!(pp.is_finite(), "perplexity must be finite, got {pp}");
779 assert!(pp > 0.0, "perplexity must be positive, got {pp}");
780 }
781
782 // ── top_words ────────────────────────────────────────────────────────────
783
784 #[test]
785 fn top_words_returns_k_distinct_indices() {
786 let corpus = make_corpus(10, 5);
787 let config = HdpConfig {
788 n_iter: 15,
789 seed: 55,
790 ..Default::default()
791 };
792 let mut model = Hdp::new(config, corpus.len(), 15);
793 model.fit(&corpus).expect("fit must succeed");
794
795 let top5 = model.top_words(0, 5);
796 // All indices distinct
797 let mut sorted = top5.clone();
798 sorted.sort_unstable();
799 sorted.dedup();
800 assert_eq!(
801 sorted.len(),
802 top5.len(),
803 "top_words must contain distinct indices"
804 );
805 // All within vocab range
806 for &w in &top5 {
807 assert!(w < 15, "word index {w} must be < vocab_size 15");
808 }
809 }
810
811 // ── error cases ──────────────────────────────────────────────────────────
812
813 #[test]
814 fn fit_empty_corpus_returns_error() {
815 let mut model = Hdp::new(HdpConfig::default(), 0, 10);
816 let result = model.fit(&[]);
817 assert!(
818 result.is_err(),
819 "fit on empty corpus must return TopicError"
820 );
821 }
822
823 #[test]
824 fn fit_out_of_vocab_returns_error() {
825 let corpus = vec![vec![0usize, 1, 99]]; // 99 >= vocab_size=5
826 let mut model = Hdp::new(HdpConfig::default(), 1, 5);
827 let result = model.fit(&corpus);
828 assert!(
829 result.is_err(),
830 "fit with OOV word must return TopicError::WordOutOfVocab"
831 );
832 }
833
834 #[test]
835 fn top_words_all_nontrivial() {
836 let corpus = make_corpus(6, 6);
837 let config = HdpConfig {
838 n_iter: 10,
839 seed: 77,
840 max_topics: 10,
841 ..Default::default()
842 };
843 let mut model = Hdp::new(config, corpus.len(), 15);
844 model.fit(&corpus).expect("fit must succeed");
845 // For all topics, top 3 words must be valid indices
846 for k in 0..10 {
847 for &w in &model.top_words(k, 3) {
848 assert!(w < 15, "top word index {w} must be in vocab");
849 }
850 }
851 }
852}