Skip to main content

ferrolearn_decomp/
lda_topic.rs

1//! Latent Dirichlet Allocation (LDA) topic model.
2//!
3//! [`LatentDirichletAllocation`] discovers latent topics in a document-term
4//! matrix using variational inference. This is the *topic model* LDA, **not**
5//! Linear Discriminant Analysis (which lives in `ferrolearn-linear`).
6//!
7//! # Algorithm
8//!
9//! Two solvers are supported:
10//!
11//! - **Batch** variational EM: iterates over the full corpus each step.
12//!   E-step updates per-document topic distributions; M-step updates the
13//!   global topic-word distributions.
14//! - **Online** variational Bayes (Hoffman et al. 2010): processes mini-batches
15//!   and uses a decaying learning rate to update global parameters
16//!   incrementally.
17//!
18//! # Examples
19//!
20//! ```
21//! use ferrolearn_decomp::LatentDirichletAllocation;
22//! use ferrolearn_core::traits::{Fit, Transform};
23//! use ndarray::array;
24//!
25//! // Simple 4-document, 6-word corpus
26//! let dtm = array![
27//!     [1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
28//!     [1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
29//!     [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
30//!     [0.0, 0.0, 0.0, 1.0, 1.0, 0.0],
31//! ];
32//! let lda = LatentDirichletAllocation::new(2).with_random_state(42);
33//! let fitted = lda.fit(&dtm, &()).unwrap();
34//! let topics = fitted.transform(&dtm).unwrap();
35//! assert_eq!(topics.dim(), (4, 2));
36//! ```
37
38use ferrolearn_core::error::FerroError;
39use ferrolearn_core::traits::{Fit, Transform};
40use ndarray::Array2;
41use rand::SeedableRng;
42use rand_distr::{Distribution, Uniform};
43use rand_xoshiro::Xoshiro256PlusPlus;
44
45// ---------------------------------------------------------------------------
46// Learning method enum
47// ---------------------------------------------------------------------------
48
49/// The learning method for LDA.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum LdaLearningMethod {
52    /// Batch variational EM — iterates over the full corpus each step.
53    Batch,
54    /// Online variational Bayes (Hoffman et al. 2010).
55    Online,
56}
57
58// ---------------------------------------------------------------------------
59// LatentDirichletAllocation (unfitted)
60// ---------------------------------------------------------------------------
61
62/// Latent Dirichlet Allocation configuration.
63///
64/// Holds hyperparameters for the LDA topic model. Calling [`Fit::fit`]
65/// learns topic-word distributions and returns a
66/// [`FittedLatentDirichletAllocation`].
67#[derive(Debug, Clone)]
68pub struct LatentDirichletAllocation {
69    /// Number of topics to extract.
70    n_components: usize,
71    /// Maximum number of E-M iterations (batch) or passes (online).
72    max_iter: usize,
73    /// Learning method.
74    learning_method: LdaLearningMethod,
75    /// Offset for learning rate in online mode (default 10.0).
76    learning_offset: f64,
77    /// Decay for learning rate in online mode (default 0.7).
78    learning_decay: f64,
79    /// Document-topic prior (Dirichlet alpha). None = 1/n_components.
80    doc_topic_prior: Option<f64>,
81    /// Topic-word prior (Dirichlet beta). None = 1/n_components.
82    topic_word_prior: Option<f64>,
83    /// Maximum E-step iterations per document.
84    max_doc_update_iter: usize,
85    /// Optional random seed.
86    random_state: Option<u64>,
87}
88
89impl LatentDirichletAllocation {
90    /// Create a new `LatentDirichletAllocation` with `n_components` topics.
91    ///
92    /// Defaults: `max_iter=10`, `learning_method=Batch`,
93    /// `learning_offset=10.0`, `learning_decay=0.7`,
94    /// priors=`1/n_components`, `max_doc_update_iter=100`.
95    #[must_use]
96    pub fn new(n_components: usize) -> Self {
97        Self {
98            n_components,
99            max_iter: 10,
100            learning_method: LdaLearningMethod::Batch,
101            learning_offset: 10.0,
102            learning_decay: 0.7,
103            doc_topic_prior: None,
104            topic_word_prior: None,
105            max_doc_update_iter: 100,
106            random_state: None,
107        }
108    }
109
110    /// Set the maximum number of iterations.
111    #[must_use]
112    pub fn with_max_iter(mut self, n: usize) -> Self {
113        self.max_iter = n;
114        self
115    }
116
117    /// Set the learning method.
118    #[must_use]
119    pub fn with_learning_method(mut self, m: LdaLearningMethod) -> Self {
120        self.learning_method = m;
121        self
122    }
123
124    /// Set the learning offset (online mode).
125    #[must_use]
126    pub fn with_learning_offset(mut self, v: f64) -> Self {
127        self.learning_offset = v;
128        self
129    }
130
131    /// Set the learning decay (online mode).
132    #[must_use]
133    pub fn with_learning_decay(mut self, v: f64) -> Self {
134        self.learning_decay = v;
135        self
136    }
137
138    /// Set the document-topic prior (alpha).
139    #[must_use]
140    pub fn with_doc_topic_prior(mut self, v: f64) -> Self {
141        self.doc_topic_prior = Some(v);
142        self
143    }
144
145    /// Set the topic-word prior (beta).
146    #[must_use]
147    pub fn with_topic_word_prior(mut self, v: f64) -> Self {
148        self.topic_word_prior = Some(v);
149        self
150    }
151
152    /// Set the random seed.
153    #[must_use]
154    pub fn with_random_state(mut self, seed: u64) -> Self {
155        self.random_state = Some(seed);
156        self
157    }
158
159    /// Set the maximum E-step iterations per document.
160    #[must_use]
161    pub fn with_max_doc_update_iter(mut self, n: usize) -> Self {
162        self.max_doc_update_iter = n;
163        self
164    }
165
166    /// Return the configured number of topics.
167    #[must_use]
168    pub fn n_components(&self) -> usize {
169        self.n_components
170    }
171
172    /// Return the configured maximum iterations.
173    #[must_use]
174    pub fn max_iter(&self) -> usize {
175        self.max_iter
176    }
177
178    /// Return the configured learning method.
179    #[must_use]
180    pub fn learning_method(&self) -> LdaLearningMethod {
181        self.learning_method
182    }
183
184    /// Return the configured learning offset.
185    #[must_use]
186    pub fn learning_offset(&self) -> f64 {
187        self.learning_offset
188    }
189
190    /// Return the configured learning decay.
191    #[must_use]
192    pub fn learning_decay(&self) -> f64 {
193        self.learning_decay
194    }
195
196    /// Return the configured document-topic prior, if explicitly set.
197    #[must_use]
198    pub fn doc_topic_prior(&self) -> Option<f64> {
199        self.doc_topic_prior
200    }
201
202    /// Return the configured topic-word prior, if explicitly set.
203    #[must_use]
204    pub fn topic_word_prior(&self) -> Option<f64> {
205        self.topic_word_prior
206    }
207
208    /// Return the configured random state, if any.
209    #[must_use]
210    pub fn random_state(&self) -> Option<u64> {
211        self.random_state
212    }
213}
214
215// ---------------------------------------------------------------------------
216// FittedLatentDirichletAllocation
217// ---------------------------------------------------------------------------
218
219/// A fitted LDA model holding the learned topic-word distributions.
220///
221/// Created by calling [`Fit::fit`] on a [`LatentDirichletAllocation`].
222/// Implements [`Transform<Array2<f64>>`] to compute document-topic
223/// distributions for new documents.
224#[derive(Debug, Clone)]
225pub struct FittedLatentDirichletAllocation {
226    /// Topic-word distribution (un-normalised), shape `(n_topics, n_words)`.
227    /// The `components_[k][w]` entry is proportional to the probability
228    /// of word `w` in topic `k`.
229    components_: Array2<f64>,
230    /// Document-topic prior (alpha).
231    alpha_: f64,
232    /// Topic-word prior (beta).
233    beta_: f64,
234    /// Number of iterations performed.
235    n_iter_: usize,
236    /// Maximum E-step iterations per document.
237    max_doc_update_iter_: usize,
238}
239
240impl FittedLatentDirichletAllocation {
241    /// Topic-word distribution, shape `(n_topics, n_words)`.
242    ///
243    /// Each row is a (possibly un-normalised) distribution over the
244    /// vocabulary for one topic.
245    #[must_use]
246    pub fn components(&self) -> &Array2<f64> {
247        &self.components_
248    }
249
250    /// Number of iterations performed during fitting.
251    #[must_use]
252    pub fn n_iter(&self) -> usize {
253        self.n_iter_
254    }
255
256    /// The document-topic prior used during fitting.
257    #[must_use]
258    pub fn alpha(&self) -> f64 {
259        self.alpha_
260    }
261
262    /// The topic-word prior used during fitting.
263    #[must_use]
264    pub fn beta(&self) -> f64 {
265        self.beta_
266    }
267}
268
269// ---------------------------------------------------------------------------
270// Internal: digamma approximation
271// ---------------------------------------------------------------------------
272
273/// Approximate digamma function (psi) using the asymptotic expansion.
274///
275/// For x >= 6 uses the series; for x < 6 uses the recurrence
276/// psi(x) = psi(x+1) - 1/x.
277fn digamma(x: f64) -> f64 {
278    if x <= 0.0 {
279        return f64::NAN;
280    }
281    let mut val = x;
282    let mut result = 0.0;
283    // Use recurrence to bring val >= 6.
284    while val < 6.0 {
285        result -= 1.0 / val;
286        val += 1.0;
287    }
288    // Asymptotic expansion.
289    result += val.ln() - 0.5 / val;
290    let inv2 = 1.0 / (val * val);
291    result -=
292        inv2 * (1.0 / 12.0 - inv2 * (1.0 / 120.0 - inv2 * (1.0 / 252.0 - inv2 * 1.0 / 240.0)));
293    result
294}
295
296/// Compute the E-step for a single document.
297///
298/// Given the document word counts `doc` (length V) and the current
299/// topic-word log expectations `e_log_beta` (shape K x V), compute the
300/// variational parameters `gamma` (length K, document-topic).
301///
302/// Returns the gamma vector (un-normalised document-topic distribution).
303fn e_step_doc(doc: &[f64], e_log_beta: &Array2<f64>, alpha: f64, max_iter: usize) -> Vec<f64> {
304    let n_topics = e_log_beta.nrows();
305    let n_words = e_log_beta.ncols();
306
307    // Initialise gamma uniformly.
308    let mut gamma = vec![alpha + (n_words as f64) / (n_topics as f64); n_topics];
309
310    for _iter in 0..max_iter {
311        let e_log_theta: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
312        let gamma_sum_dig = digamma(gamma.iter().sum::<f64>());
313
314        let mut new_gamma = vec![alpha; n_topics];
315
316        for w in 0..n_words {
317            if doc[w] < 1e-16 {
318                continue;
319            }
320            // Compute log of un-normalised phi for each topic.
321            let mut log_phi = Vec::with_capacity(n_topics);
322            let mut max_log = f64::NEG_INFINITY;
323            for k in 0..n_topics {
324                let v = e_log_theta[k] - gamma_sum_dig + e_log_beta[[k, w]];
325                log_phi.push(v);
326                if v > max_log {
327                    max_log = v;
328                }
329            }
330            // Normalise in log space.
331            let mut sum_phi = 0.0;
332            let mut phi = Vec::with_capacity(n_topics);
333            for lp in &log_phi {
334                let p = (lp - max_log).exp();
335                phi.push(p);
336                sum_phi += p;
337            }
338            if sum_phi < 1e-16 {
339                sum_phi = 1e-16;
340            }
341            for k in 0..n_topics {
342                new_gamma[k] += doc[w] * phi[k] / sum_phi;
343            }
344        }
345
346        // Check convergence.
347        let mut diff = 0.0;
348        for k in 0..n_topics {
349            diff += (new_gamma[k] - gamma[k]).abs();
350        }
351        gamma = new_gamma;
352        if diff < 1e-3 {
353            break;
354        }
355    }
356
357    gamma
358}
359
360// ---------------------------------------------------------------------------
361// Trait implementations
362// ---------------------------------------------------------------------------
363
364impl Fit<Array2<f64>, ()> for LatentDirichletAllocation {
365    type Fitted = FittedLatentDirichletAllocation;
366    type Error = FerroError;
367
368    /// Fit the LDA model on a document-term matrix.
369    ///
370    /// # Errors
371    ///
372    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or
373    ///   any entry of the input is negative.
374    /// - [`FerroError::InsufficientSamples`] if there are zero documents or
375    ///   zero words.
376    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedLatentDirichletAllocation, FerroError> {
377        let (n_docs, n_words) = x.dim();
378
379        // Validate.
380        if self.n_components == 0 {
381            return Err(FerroError::InvalidParameter {
382                name: "n_components".into(),
383                reason: "must be at least 1".into(),
384            });
385        }
386        if n_docs == 0 {
387            return Err(FerroError::InsufficientSamples {
388                required: 1,
389                actual: 0,
390                context: "LatentDirichletAllocation::fit".into(),
391            });
392        }
393        if n_words == 0 {
394            return Err(FerroError::InvalidParameter {
395                name: "X".into(),
396                reason: "document-term matrix must have at least 1 word".into(),
397            });
398        }
399        for &val in x.iter() {
400            if val < 0.0 {
401                return Err(FerroError::InvalidParameter {
402                    name: "X".into(),
403                    reason: "LDA requires non-negative entries in the document-term matrix".into(),
404                });
405            }
406        }
407
408        let n_topics = self.n_components;
409        let alpha = self.doc_topic_prior.unwrap_or(1.0 / n_topics as f64);
410        let beta = self.topic_word_prior.unwrap_or(1.0 / n_topics as f64);
411        let seed = self.random_state.unwrap_or(0);
412
413        // Initialise lambda (topic-word variational parameters) randomly.
414        let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
415        let uniform = Uniform::new(0.5, 1.5).unwrap();
416        let mut lambda = Array2::<f64>::zeros((n_topics, n_words));
417        for elem in lambda.iter_mut() {
418            *elem = uniform.sample(&mut rng) + beta;
419        }
420
421        match self.learning_method {
422            LdaLearningMethod::Batch => {
423                self.fit_batch(x, &mut lambda, alpha, beta, n_docs, n_words, n_topics);
424            }
425            LdaLearningMethod::Online => {
426                self.fit_online(
427                    x,
428                    &mut lambda,
429                    alpha,
430                    beta,
431                    n_docs,
432                    n_words,
433                    n_topics,
434                    &mut rng,
435                );
436            }
437        }
438
439        Ok(FittedLatentDirichletAllocation {
440            components_: lambda,
441            alpha_: alpha,
442            beta_: beta,
443            n_iter_: self.max_iter,
444            max_doc_update_iter_: self.max_doc_update_iter,
445        })
446    }
447}
448
449impl LatentDirichletAllocation {
450    /// Batch variational EM.
451    #[allow(clippy::too_many_arguments)]
452    fn fit_batch(
453        &self,
454        x: &Array2<f64>,
455        lambda: &mut Array2<f64>,
456        alpha: f64,
457        beta: f64,
458        n_docs: usize,
459        n_words: usize,
460        n_topics: usize,
461    ) {
462        for _outer in 0..self.max_iter {
463            // Compute E[log beta] from current lambda.
464            let e_log_beta = compute_e_log_beta(lambda, n_topics, n_words);
465
466            // Accumulate sufficient statistics.
467            let mut ss = Array2::<f64>::zeros((n_topics, n_words));
468
469            for d in 0..n_docs {
470                let doc: Vec<f64> = (0..n_words).map(|w| x[[d, w]]).collect();
471                let gamma = e_step_doc(&doc, &e_log_beta, alpha, self.max_doc_update_iter);
472
473                // Compute phi for this document and accumulate.
474                let e_log_theta: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
475                let gamma_sum_dig = digamma(gamma.iter().sum::<f64>());
476
477                for w in 0..n_words {
478                    if doc[w] < 1e-16 {
479                        continue;
480                    }
481                    let mut log_phi = Vec::with_capacity(n_topics);
482                    let mut max_log = f64::NEG_INFINITY;
483                    for k in 0..n_topics {
484                        let v = e_log_theta[k] - gamma_sum_dig + e_log_beta[[k, w]];
485                        log_phi.push(v);
486                        if v > max_log {
487                            max_log = v;
488                        }
489                    }
490                    let mut phi = Vec::with_capacity(n_topics);
491                    let mut sum_phi = 0.0;
492                    for lp in &log_phi {
493                        let p = (lp - max_log).exp();
494                        phi.push(p);
495                        sum_phi += p;
496                    }
497                    if sum_phi < 1e-16 {
498                        sum_phi = 1e-16;
499                    }
500                    for k in 0..n_topics {
501                        ss[[k, w]] += doc[w] * phi[k] / sum_phi;
502                    }
503                }
504            }
505
506            // M-step: update lambda.
507            for k in 0..n_topics {
508                for w in 0..n_words {
509                    lambda[[k, w]] = beta + ss[[k, w]];
510                }
511            }
512        }
513    }
514
515    /// Online variational Bayes (Hoffman et al. 2010).
516    #[allow(clippy::too_many_arguments)]
517    fn fit_online(
518        &self,
519        x: &Array2<f64>,
520        lambda: &mut Array2<f64>,
521        alpha: f64,
522        beta: f64,
523        n_docs: usize,
524        n_words: usize,
525        n_topics: usize,
526        _rng: &mut Xoshiro256PlusPlus,
527    ) {
528        let mut update_count = 0u64;
529
530        for _outer in 0..self.max_iter {
531            // Process each document as a mini-batch of size 1.
532            for d in 0..n_docs {
533                let doc: Vec<f64> = (0..n_words).map(|w| x[[d, w]]).collect();
534
535                let e_log_beta = compute_e_log_beta(lambda, n_topics, n_words);
536                let gamma = e_step_doc(&doc, &e_log_beta, alpha, self.max_doc_update_iter);
537
538                // Compute sufficient statistics for this document.
539                let e_log_theta: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
540                let gamma_sum_dig = digamma(gamma.iter().sum::<f64>());
541
542                let mut ss = Array2::<f64>::zeros((n_topics, n_words));
543                for w in 0..n_words {
544                    if doc[w] < 1e-16 {
545                        continue;
546                    }
547                    let mut log_phi = Vec::with_capacity(n_topics);
548                    let mut max_log = f64::NEG_INFINITY;
549                    for k in 0..n_topics {
550                        let v = e_log_theta[k] - gamma_sum_dig + e_log_beta[[k, w]];
551                        log_phi.push(v);
552                        if v > max_log {
553                            max_log = v;
554                        }
555                    }
556                    let mut phi = Vec::with_capacity(n_topics);
557                    let mut sum_phi = 0.0;
558                    for lp in &log_phi {
559                        let p = (lp - max_log).exp();
560                        phi.push(p);
561                        sum_phi += p;
562                    }
563                    if sum_phi < 1e-16 {
564                        sum_phi = 1e-16;
565                    }
566                    for k in 0..n_topics {
567                        ss[[k, w]] += doc[w] * phi[k] / sum_phi;
568                    }
569                }
570
571                // Online update with decaying step size.
572                update_count += 1;
573                let rho = (self.learning_offset + update_count as f64).powf(-self.learning_decay);
574
575                // lambda_new = (1-rho)*lambda + rho*(beta + n_docs * ss)
576                let n_docs_f = n_docs as f64;
577                for k in 0..n_topics {
578                    for w in 0..n_words {
579                        let target = beta + n_docs_f * ss[[k, w]];
580                        lambda[[k, w]] = (1.0 - rho) * lambda[[k, w]] + rho * target;
581                    }
582                }
583            }
584        }
585    }
586}
587
588/// Compute E[log beta] from lambda (the variational parameters for topic-word).
589fn compute_e_log_beta(lambda: &Array2<f64>, n_topics: usize, n_words: usize) -> Array2<f64> {
590    let mut e_log_beta = Array2::<f64>::zeros((n_topics, n_words));
591    for k in 0..n_topics {
592        let row_sum: f64 = (0..n_words).map(|w| lambda[[k, w]]).sum();
593        let dig_sum = digamma(row_sum);
594        for w in 0..n_words {
595            e_log_beta[[k, w]] = digamma(lambda[[k, w]]) - dig_sum;
596        }
597    }
598    e_log_beta
599}
600
601impl Transform<Array2<f64>> for FittedLatentDirichletAllocation {
602    type Output = Array2<f64>;
603    type Error = FerroError;
604
605    /// Compute the document-topic distribution for new documents.
606    ///
607    /// Returns an array of shape `(n_docs, n_topics)` where each row sums
608    /// approximately to 1.
609    ///
610    /// # Errors
611    ///
612    /// - [`FerroError::ShapeMismatch`] if the number of words does not match
613    ///   the vocabulary size from fitting.
614    /// - [`FerroError::InvalidParameter`] if any entry is negative.
615    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
616        let n_words = self.components_.ncols();
617        if x.ncols() != n_words {
618            return Err(FerroError::ShapeMismatch {
619                expected: vec![x.nrows(), n_words],
620                actual: vec![x.nrows(), x.ncols()],
621                context: "FittedLatentDirichletAllocation::transform".into(),
622            });
623        }
624        for &val in x.iter() {
625            if val < 0.0 {
626                return Err(FerroError::InvalidParameter {
627                    name: "X".into(),
628                    reason: "LDA requires non-negative entries".into(),
629                });
630            }
631        }
632
633        let n_docs = x.nrows();
634        let n_topics = self.components_.nrows();
635        let e_log_beta = compute_e_log_beta(&self.components_, n_topics, n_words);
636
637        let mut result = Array2::<f64>::zeros((n_docs, n_topics));
638        for d in 0..n_docs {
639            let doc: Vec<f64> = (0..n_words).map(|w| x[[d, w]]).collect();
640            let gamma = e_step_doc(&doc, &e_log_beta, self.alpha_, self.max_doc_update_iter_);
641
642            // Normalise gamma to get document-topic proportions.
643            let gamma_sum: f64 = gamma.iter().sum();
644            if gamma_sum > 1e-16 {
645                for k in 0..n_topics {
646                    result[[d, k]] = gamma[k] / gamma_sum;
647                }
648            } else {
649                // Uniform fallback.
650                let uniform = 1.0 / n_topics as f64;
651                for k in 0..n_topics {
652                    result[[d, k]] = uniform;
653                }
654            }
655        }
656
657        Ok(result)
658    }
659}
660
661// ---------------------------------------------------------------------------
662// Tests
663// ---------------------------------------------------------------------------
664
665#[cfg(test)]
666mod tests {
667    use super::*;
668    use approx::assert_abs_diff_eq;
669    use ndarray::array;
670
671    /// Simple two-topic corpus.
672    fn two_topic_corpus() -> Array2<f64> {
673        array![
674            [5.0, 5.0, 5.0, 0.0, 0.0, 0.0],
675            [4.0, 6.0, 3.0, 0.0, 0.0, 0.0],
676            [5.0, 4.0, 6.0, 0.0, 0.0, 0.0],
677            [0.0, 0.0, 0.0, 5.0, 5.0, 5.0],
678            [0.0, 0.0, 0.0, 6.0, 4.0, 3.0],
679            [0.0, 0.0, 0.0, 4.0, 6.0, 5.0],
680        ]
681    }
682
683    #[test]
684    fn test_lda_basic_shape() {
685        let dtm = two_topic_corpus();
686        let lda = LatentDirichletAllocation::new(2).with_random_state(42);
687        let fitted = lda.fit(&dtm, &()).unwrap();
688        assert_eq!(fitted.components().dim(), (2, 6));
689    }
690
691    #[test]
692    fn test_lda_transform_shape() {
693        let dtm = two_topic_corpus();
694        let lda = LatentDirichletAllocation::new(2).with_random_state(42);
695        let fitted = lda.fit(&dtm, &()).unwrap();
696        let topics = fitted.transform(&dtm).unwrap();
697        assert_eq!(topics.dim(), (6, 2));
698    }
699
700    #[test]
701    fn test_lda_topic_proportions_sum_to_one() {
702        let dtm = two_topic_corpus();
703        let lda = LatentDirichletAllocation::new(2)
704            .with_max_iter(20)
705            .with_random_state(42);
706        let fitted = lda.fit(&dtm, &()).unwrap();
707        let topics = fitted.transform(&dtm).unwrap();
708        for i in 0..topics.nrows() {
709            let sum: f64 = topics.row(i).sum();
710            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-5);
711        }
712    }
713
714    #[test]
715    fn test_lda_topics_distinguish_groups() {
716        let dtm = two_topic_corpus();
717        let lda = LatentDirichletAllocation::new(2)
718            .with_max_iter(30)
719            .with_random_state(42);
720        let fitted = lda.fit(&dtm, &()).unwrap();
721        let topics = fitted.transform(&dtm).unwrap();
722
723        // First 3 docs should cluster on one topic, last 3 on another.
724        // Check that the dominant topic differs between the two groups.
725        let first_group_topic: Vec<usize> = (0..3)
726            .map(|i| {
727                if topics[[i, 0]] > topics[[i, 1]] {
728                    0
729                } else {
730                    1
731                }
732            })
733            .collect();
734        let second_group_topic: Vec<usize> = (3..6)
735            .map(|i| {
736                if topics[[i, 0]] > topics[[i, 1]] {
737                    0
738                } else {
739                    1
740                }
741            })
742            .collect();
743
744        // At least 2 out of 3 in each group should agree on the topic.
745        let fg_mode = if first_group_topic.iter().filter(|&&t| t == 0).count() >= 2 {
746            0
747        } else {
748            1
749        };
750        let sg_mode = if second_group_topic.iter().filter(|&&t| t == 0).count() >= 2 {
751            0
752        } else {
753            1
754        };
755
756        assert_ne!(
757            fg_mode, sg_mode,
758            "the two document groups should be assigned to different topics"
759        );
760    }
761
762    #[test]
763    fn test_lda_online_learning() {
764        let dtm = two_topic_corpus();
765        let lda = LatentDirichletAllocation::new(2)
766            .with_learning_method(LdaLearningMethod::Online)
767            .with_max_iter(10)
768            .with_random_state(42);
769        let fitted = lda.fit(&dtm, &()).unwrap();
770        assert_eq!(fitted.components().dim(), (2, 6));
771        let topics = fitted.transform(&dtm).unwrap();
772        // Each row should sum to ~1.
773        for i in 0..topics.nrows() {
774            let sum: f64 = topics.row(i).sum();
775            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-5);
776        }
777    }
778
779    #[test]
780    fn test_lda_components_non_negative() {
781        let dtm = two_topic_corpus();
782        let lda = LatentDirichletAllocation::new(2).with_random_state(42);
783        let fitted = lda.fit(&dtm, &()).unwrap();
784        for &val in fitted.components().iter() {
785            assert!(val >= 0.0, "component should be non-negative, got {val}");
786        }
787    }
788
789    #[test]
790    fn test_lda_transform_shape_mismatch() {
791        let dtm = two_topic_corpus();
792        let lda = LatentDirichletAllocation::new(2).with_random_state(42);
793        let fitted = lda.fit(&dtm, &()).unwrap();
794        let bad = array![[1.0, 2.0, 3.0]]; // 3 words instead of 6
795        assert!(fitted.transform(&bad).is_err());
796    }
797
798    #[test]
799    fn test_lda_transform_negative_rejected() {
800        let dtm = two_topic_corpus();
801        let lda = LatentDirichletAllocation::new(2).with_random_state(42);
802        let fitted = lda.fit(&dtm, &()).unwrap();
803        let bad = array![[1.0, -1.0, 0.0, 0.0, 0.0, 0.0]];
804        assert!(fitted.transform(&bad).is_err());
805    }
806
807    #[test]
808    fn test_lda_invalid_n_components_zero() {
809        let dtm = two_topic_corpus();
810        let lda = LatentDirichletAllocation::new(0);
811        assert!(lda.fit(&dtm, &()).is_err());
812    }
813
814    #[test]
815    fn test_lda_negative_input_rejected() {
816        let dtm = array![[1.0, -1.0], [2.0, 3.0]];
817        let lda = LatentDirichletAllocation::new(1);
818        assert!(lda.fit(&dtm, &()).is_err());
819    }
820
821    #[test]
822    fn test_lda_empty_corpus() {
823        let dtm = Array2::<f64>::zeros((0, 5));
824        let lda = LatentDirichletAllocation::new(2);
825        assert!(lda.fit(&dtm, &()).is_err());
826    }
827
828    #[test]
829    fn test_lda_zero_words() {
830        let dtm = Array2::<f64>::zeros((5, 0));
831        let lda = LatentDirichletAllocation::new(2);
832        assert!(lda.fit(&dtm, &()).is_err());
833    }
834
835    #[test]
836    fn test_lda_getters() {
837        let lda = LatentDirichletAllocation::new(5)
838            .with_max_iter(20)
839            .with_learning_method(LdaLearningMethod::Online)
840            .with_learning_offset(15.0)
841            .with_learning_decay(0.5)
842            .with_doc_topic_prior(0.1)
843            .with_topic_word_prior(0.01)
844            .with_random_state(99);
845        assert_eq!(lda.n_components(), 5);
846        assert_eq!(lda.max_iter(), 20);
847        assert_eq!(lda.learning_method(), LdaLearningMethod::Online);
848        assert!((lda.learning_offset() - 15.0).abs() < 1e-10);
849        assert!((lda.learning_decay() - 0.5).abs() < 1e-10);
850        assert_eq!(lda.doc_topic_prior(), Some(0.1));
851        assert_eq!(lda.topic_word_prior(), Some(0.01));
852        assert_eq!(lda.random_state(), Some(99));
853    }
854
855    #[test]
856    fn test_lda_fitted_accessors() {
857        let dtm = two_topic_corpus();
858        let lda = LatentDirichletAllocation::new(2)
859            .with_doc_topic_prior(0.5)
860            .with_topic_word_prior(0.1)
861            .with_random_state(42);
862        let fitted = lda.fit(&dtm, &()).unwrap();
863        assert!((fitted.alpha() - 0.5).abs() < 1e-10);
864        assert!((fitted.beta() - 0.1).abs() < 1e-10);
865        assert!(fitted.n_iter() > 0);
866    }
867
868    #[test]
869    fn test_lda_single_topic() {
870        let dtm = two_topic_corpus();
871        let lda = LatentDirichletAllocation::new(1).with_random_state(42);
872        let fitted = lda.fit(&dtm, &()).unwrap();
873        let topics = fitted.transform(&dtm).unwrap();
874        assert_eq!(topics.ncols(), 1);
875        // With 1 topic, all documents should have proportion ~1.
876        for i in 0..topics.nrows() {
877            assert_abs_diff_eq!(topics[[i, 0]], 1.0, epsilon = 1e-3);
878        }
879    }
880
881    #[test]
882    fn test_digamma_basic() {
883        // digamma(1) = -gamma (Euler-Mascheroni constant) ~ -0.5772
884        let val = digamma(1.0);
885        assert!((val - (-0.5772156649)).abs() < 1e-4, "digamma(1) = {val}");
886    }
887
888    #[test]
889    fn test_digamma_large() {
890        // digamma(10) ~ 2.2517525890
891        let val = digamma(10.0);
892        assert!((val - 2.2517525890).abs() < 1e-4, "digamma(10) = {val}");
893    }
894}