1use ferrolearn_core::error::FerroError;
39use ferrolearn_core::traits::{Fit, Transform};
40use ndarray::Array2;
41use rand::SeedableRng;
42use rand_distr::{Distribution, Uniform};
43use rand_xoshiro::Xoshiro256PlusPlus;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum LdaLearningMethod {
52 Batch,
54 Online,
56}
57
58#[derive(Debug, Clone)]
68pub struct LatentDirichletAllocation {
69 n_components: usize,
71 max_iter: usize,
73 learning_method: LdaLearningMethod,
75 learning_offset: f64,
77 learning_decay: f64,
79 doc_topic_prior: Option<f64>,
81 topic_word_prior: Option<f64>,
83 max_doc_update_iter: usize,
85 random_state: Option<u64>,
87}
88
89impl LatentDirichletAllocation {
90 #[must_use]
96 pub fn new(n_components: usize) -> Self {
97 Self {
98 n_components,
99 max_iter: 10,
100 learning_method: LdaLearningMethod::Batch,
101 learning_offset: 10.0,
102 learning_decay: 0.7,
103 doc_topic_prior: None,
104 topic_word_prior: None,
105 max_doc_update_iter: 100,
106 random_state: None,
107 }
108 }
109
110 #[must_use]
112 pub fn with_max_iter(mut self, n: usize) -> Self {
113 self.max_iter = n;
114 self
115 }
116
117 #[must_use]
119 pub fn with_learning_method(mut self, m: LdaLearningMethod) -> Self {
120 self.learning_method = m;
121 self
122 }
123
124 #[must_use]
126 pub fn with_learning_offset(mut self, v: f64) -> Self {
127 self.learning_offset = v;
128 self
129 }
130
131 #[must_use]
133 pub fn with_learning_decay(mut self, v: f64) -> Self {
134 self.learning_decay = v;
135 self
136 }
137
138 #[must_use]
140 pub fn with_doc_topic_prior(mut self, v: f64) -> Self {
141 self.doc_topic_prior = Some(v);
142 self
143 }
144
145 #[must_use]
147 pub fn with_topic_word_prior(mut self, v: f64) -> Self {
148 self.topic_word_prior = Some(v);
149 self
150 }
151
152 #[must_use]
154 pub fn with_random_state(mut self, seed: u64) -> Self {
155 self.random_state = Some(seed);
156 self
157 }
158
159 #[must_use]
161 pub fn with_max_doc_update_iter(mut self, n: usize) -> Self {
162 self.max_doc_update_iter = n;
163 self
164 }
165
166 #[must_use]
168 pub fn n_components(&self) -> usize {
169 self.n_components
170 }
171
172 #[must_use]
174 pub fn max_iter(&self) -> usize {
175 self.max_iter
176 }
177
178 #[must_use]
180 pub fn learning_method(&self) -> LdaLearningMethod {
181 self.learning_method
182 }
183
184 #[must_use]
186 pub fn learning_offset(&self) -> f64 {
187 self.learning_offset
188 }
189
190 #[must_use]
192 pub fn learning_decay(&self) -> f64 {
193 self.learning_decay
194 }
195
196 #[must_use]
198 pub fn doc_topic_prior(&self) -> Option<f64> {
199 self.doc_topic_prior
200 }
201
202 #[must_use]
204 pub fn topic_word_prior(&self) -> Option<f64> {
205 self.topic_word_prior
206 }
207
208 #[must_use]
210 pub fn random_state(&self) -> Option<u64> {
211 self.random_state
212 }
213}
214
215#[derive(Debug, Clone)]
225pub struct FittedLatentDirichletAllocation {
226 components_: Array2<f64>,
230 alpha_: f64,
232 beta_: f64,
234 n_iter_: usize,
236 max_doc_update_iter_: usize,
238}
239
240impl FittedLatentDirichletAllocation {
241 #[must_use]
246 pub fn components(&self) -> &Array2<f64> {
247 &self.components_
248 }
249
250 #[must_use]
252 pub fn n_iter(&self) -> usize {
253 self.n_iter_
254 }
255
256 #[must_use]
258 pub fn alpha(&self) -> f64 {
259 self.alpha_
260 }
261
262 #[must_use]
264 pub fn beta(&self) -> f64 {
265 self.beta_
266 }
267}
268
269fn digamma(x: f64) -> f64 {
278 if x <= 0.0 {
279 return f64::NAN;
280 }
281 let mut val = x;
282 let mut result = 0.0;
283 while val < 6.0 {
285 result -= 1.0 / val;
286 val += 1.0;
287 }
288 result += val.ln() - 0.5 / val;
290 let inv2 = 1.0 / (val * val);
291 result -=
292 inv2 * (1.0 / 12.0 - inv2 * (1.0 / 120.0 - inv2 * (1.0 / 252.0 - inv2 * 1.0 / 240.0)));
293 result
294}
295
296fn e_step_doc(doc: &[f64], e_log_beta: &Array2<f64>, alpha: f64, max_iter: usize) -> Vec<f64> {
304 let n_topics = e_log_beta.nrows();
305 let n_words = e_log_beta.ncols();
306
307 let mut gamma = vec![alpha + (n_words as f64) / (n_topics as f64); n_topics];
309
310 for _iter in 0..max_iter {
311 let e_log_theta: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
312 let gamma_sum_dig = digamma(gamma.iter().sum::<f64>());
313
314 let mut new_gamma = vec![alpha; n_topics];
315
316 for w in 0..n_words {
317 if doc[w] < 1e-16 {
318 continue;
319 }
320 let mut log_phi = Vec::with_capacity(n_topics);
322 let mut max_log = f64::NEG_INFINITY;
323 for k in 0..n_topics {
324 let v = e_log_theta[k] - gamma_sum_dig + e_log_beta[[k, w]];
325 log_phi.push(v);
326 if v > max_log {
327 max_log = v;
328 }
329 }
330 let mut sum_phi = 0.0;
332 let mut phi = Vec::with_capacity(n_topics);
333 for lp in &log_phi {
334 let p = (lp - max_log).exp();
335 phi.push(p);
336 sum_phi += p;
337 }
338 if sum_phi < 1e-16 {
339 sum_phi = 1e-16;
340 }
341 for k in 0..n_topics {
342 new_gamma[k] += doc[w] * phi[k] / sum_phi;
343 }
344 }
345
346 let mut diff = 0.0;
348 for k in 0..n_topics {
349 diff += (new_gamma[k] - gamma[k]).abs();
350 }
351 gamma = new_gamma;
352 if diff < 1e-3 {
353 break;
354 }
355 }
356
357 gamma
358}
359
360impl Fit<Array2<f64>, ()> for LatentDirichletAllocation {
365 type Fitted = FittedLatentDirichletAllocation;
366 type Error = FerroError;
367
368 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedLatentDirichletAllocation, FerroError> {
377 let (n_docs, n_words) = x.dim();
378
379 if self.n_components == 0 {
381 return Err(FerroError::InvalidParameter {
382 name: "n_components".into(),
383 reason: "must be at least 1".into(),
384 });
385 }
386 if n_docs == 0 {
387 return Err(FerroError::InsufficientSamples {
388 required: 1,
389 actual: 0,
390 context: "LatentDirichletAllocation::fit".into(),
391 });
392 }
393 if n_words == 0 {
394 return Err(FerroError::InvalidParameter {
395 name: "X".into(),
396 reason: "document-term matrix must have at least 1 word".into(),
397 });
398 }
399 for &val in x.iter() {
400 if val < 0.0 {
401 return Err(FerroError::InvalidParameter {
402 name: "X".into(),
403 reason: "LDA requires non-negative entries in the document-term matrix".into(),
404 });
405 }
406 }
407
408 let n_topics = self.n_components;
409 let alpha = self.doc_topic_prior.unwrap_or(1.0 / n_topics as f64);
410 let beta = self.topic_word_prior.unwrap_or(1.0 / n_topics as f64);
411 let seed = self.random_state.unwrap_or(0);
412
413 let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed);
415 let uniform = Uniform::new(0.5, 1.5).unwrap();
416 let mut lambda = Array2::<f64>::zeros((n_topics, n_words));
417 for elem in lambda.iter_mut() {
418 *elem = uniform.sample(&mut rng) + beta;
419 }
420
421 match self.learning_method {
422 LdaLearningMethod::Batch => {
423 self.fit_batch(x, &mut lambda, alpha, beta, n_docs, n_words, n_topics);
424 }
425 LdaLearningMethod::Online => {
426 self.fit_online(
427 x,
428 &mut lambda,
429 alpha,
430 beta,
431 n_docs,
432 n_words,
433 n_topics,
434 &mut rng,
435 );
436 }
437 }
438
439 Ok(FittedLatentDirichletAllocation {
440 components_: lambda,
441 alpha_: alpha,
442 beta_: beta,
443 n_iter_: self.max_iter,
444 max_doc_update_iter_: self.max_doc_update_iter,
445 })
446 }
447}
448
449impl LatentDirichletAllocation {
450 #[allow(clippy::too_many_arguments)]
452 fn fit_batch(
453 &self,
454 x: &Array2<f64>,
455 lambda: &mut Array2<f64>,
456 alpha: f64,
457 beta: f64,
458 n_docs: usize,
459 n_words: usize,
460 n_topics: usize,
461 ) {
462 for _outer in 0..self.max_iter {
463 let e_log_beta = compute_e_log_beta(lambda, n_topics, n_words);
465
466 let mut ss = Array2::<f64>::zeros((n_topics, n_words));
468
469 for d in 0..n_docs {
470 let doc: Vec<f64> = (0..n_words).map(|w| x[[d, w]]).collect();
471 let gamma = e_step_doc(&doc, &e_log_beta, alpha, self.max_doc_update_iter);
472
473 let e_log_theta: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
475 let gamma_sum_dig = digamma(gamma.iter().sum::<f64>());
476
477 for w in 0..n_words {
478 if doc[w] < 1e-16 {
479 continue;
480 }
481 let mut log_phi = Vec::with_capacity(n_topics);
482 let mut max_log = f64::NEG_INFINITY;
483 for k in 0..n_topics {
484 let v = e_log_theta[k] - gamma_sum_dig + e_log_beta[[k, w]];
485 log_phi.push(v);
486 if v > max_log {
487 max_log = v;
488 }
489 }
490 let mut phi = Vec::with_capacity(n_topics);
491 let mut sum_phi = 0.0;
492 for lp in &log_phi {
493 let p = (lp - max_log).exp();
494 phi.push(p);
495 sum_phi += p;
496 }
497 if sum_phi < 1e-16 {
498 sum_phi = 1e-16;
499 }
500 for k in 0..n_topics {
501 ss[[k, w]] += doc[w] * phi[k] / sum_phi;
502 }
503 }
504 }
505
506 for k in 0..n_topics {
508 for w in 0..n_words {
509 lambda[[k, w]] = beta + ss[[k, w]];
510 }
511 }
512 }
513 }
514
515 #[allow(clippy::too_many_arguments)]
517 fn fit_online(
518 &self,
519 x: &Array2<f64>,
520 lambda: &mut Array2<f64>,
521 alpha: f64,
522 beta: f64,
523 n_docs: usize,
524 n_words: usize,
525 n_topics: usize,
526 _rng: &mut Xoshiro256PlusPlus,
527 ) {
528 let mut update_count = 0u64;
529
530 for _outer in 0..self.max_iter {
531 for d in 0..n_docs {
533 let doc: Vec<f64> = (0..n_words).map(|w| x[[d, w]]).collect();
534
535 let e_log_beta = compute_e_log_beta(lambda, n_topics, n_words);
536 let gamma = e_step_doc(&doc, &e_log_beta, alpha, self.max_doc_update_iter);
537
538 let e_log_theta: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
540 let gamma_sum_dig = digamma(gamma.iter().sum::<f64>());
541
542 let mut ss = Array2::<f64>::zeros((n_topics, n_words));
543 for w in 0..n_words {
544 if doc[w] < 1e-16 {
545 continue;
546 }
547 let mut log_phi = Vec::with_capacity(n_topics);
548 let mut max_log = f64::NEG_INFINITY;
549 for k in 0..n_topics {
550 let v = e_log_theta[k] - gamma_sum_dig + e_log_beta[[k, w]];
551 log_phi.push(v);
552 if v > max_log {
553 max_log = v;
554 }
555 }
556 let mut phi = Vec::with_capacity(n_topics);
557 let mut sum_phi = 0.0;
558 for lp in &log_phi {
559 let p = (lp - max_log).exp();
560 phi.push(p);
561 sum_phi += p;
562 }
563 if sum_phi < 1e-16 {
564 sum_phi = 1e-16;
565 }
566 for k in 0..n_topics {
567 ss[[k, w]] += doc[w] * phi[k] / sum_phi;
568 }
569 }
570
571 update_count += 1;
573 let rho = (self.learning_offset + update_count as f64).powf(-self.learning_decay);
574
575 let n_docs_f = n_docs as f64;
577 for k in 0..n_topics {
578 for w in 0..n_words {
579 let target = beta + n_docs_f * ss[[k, w]];
580 lambda[[k, w]] = (1.0 - rho) * lambda[[k, w]] + rho * target;
581 }
582 }
583 }
584 }
585 }
586}
587
588fn compute_e_log_beta(lambda: &Array2<f64>, n_topics: usize, n_words: usize) -> Array2<f64> {
590 let mut e_log_beta = Array2::<f64>::zeros((n_topics, n_words));
591 for k in 0..n_topics {
592 let row_sum: f64 = (0..n_words).map(|w| lambda[[k, w]]).sum();
593 let dig_sum = digamma(row_sum);
594 for w in 0..n_words {
595 e_log_beta[[k, w]] = digamma(lambda[[k, w]]) - dig_sum;
596 }
597 }
598 e_log_beta
599}
600
601impl Transform<Array2<f64>> for FittedLatentDirichletAllocation {
602 type Output = Array2<f64>;
603 type Error = FerroError;
604
605 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
616 let n_words = self.components_.ncols();
617 if x.ncols() != n_words {
618 return Err(FerroError::ShapeMismatch {
619 expected: vec![x.nrows(), n_words],
620 actual: vec![x.nrows(), x.ncols()],
621 context: "FittedLatentDirichletAllocation::transform".into(),
622 });
623 }
624 for &val in x.iter() {
625 if val < 0.0 {
626 return Err(FerroError::InvalidParameter {
627 name: "X".into(),
628 reason: "LDA requires non-negative entries".into(),
629 });
630 }
631 }
632
633 let n_docs = x.nrows();
634 let n_topics = self.components_.nrows();
635 let e_log_beta = compute_e_log_beta(&self.components_, n_topics, n_words);
636
637 let mut result = Array2::<f64>::zeros((n_docs, n_topics));
638 for d in 0..n_docs {
639 let doc: Vec<f64> = (0..n_words).map(|w| x[[d, w]]).collect();
640 let gamma = e_step_doc(&doc, &e_log_beta, self.alpha_, self.max_doc_update_iter_);
641
642 let gamma_sum: f64 = gamma.iter().sum();
644 if gamma_sum > 1e-16 {
645 for k in 0..n_topics {
646 result[[d, k]] = gamma[k] / gamma_sum;
647 }
648 } else {
649 let uniform = 1.0 / n_topics as f64;
651 for k in 0..n_topics {
652 result[[d, k]] = uniform;
653 }
654 }
655 }
656
657 Ok(result)
658 }
659}
660
661#[cfg(test)]
666mod tests {
667 use super::*;
668 use approx::assert_abs_diff_eq;
669 use ndarray::array;
670
671 fn two_topic_corpus() -> Array2<f64> {
673 array![
674 [5.0, 5.0, 5.0, 0.0, 0.0, 0.0],
675 [4.0, 6.0, 3.0, 0.0, 0.0, 0.0],
676 [5.0, 4.0, 6.0, 0.0, 0.0, 0.0],
677 [0.0, 0.0, 0.0, 5.0, 5.0, 5.0],
678 [0.0, 0.0, 0.0, 6.0, 4.0, 3.0],
679 [0.0, 0.0, 0.0, 4.0, 6.0, 5.0],
680 ]
681 }
682
683 #[test]
684 fn test_lda_basic_shape() {
685 let dtm = two_topic_corpus();
686 let lda = LatentDirichletAllocation::new(2).with_random_state(42);
687 let fitted = lda.fit(&dtm, &()).unwrap();
688 assert_eq!(fitted.components().dim(), (2, 6));
689 }
690
691 #[test]
692 fn test_lda_transform_shape() {
693 let dtm = two_topic_corpus();
694 let lda = LatentDirichletAllocation::new(2).with_random_state(42);
695 let fitted = lda.fit(&dtm, &()).unwrap();
696 let topics = fitted.transform(&dtm).unwrap();
697 assert_eq!(topics.dim(), (6, 2));
698 }
699
700 #[test]
701 fn test_lda_topic_proportions_sum_to_one() {
702 let dtm = two_topic_corpus();
703 let lda = LatentDirichletAllocation::new(2)
704 .with_max_iter(20)
705 .with_random_state(42);
706 let fitted = lda.fit(&dtm, &()).unwrap();
707 let topics = fitted.transform(&dtm).unwrap();
708 for i in 0..topics.nrows() {
709 let sum: f64 = topics.row(i).sum();
710 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-5);
711 }
712 }
713
714 #[test]
715 fn test_lda_topics_distinguish_groups() {
716 let dtm = two_topic_corpus();
717 let lda = LatentDirichletAllocation::new(2)
718 .with_max_iter(30)
719 .with_random_state(42);
720 let fitted = lda.fit(&dtm, &()).unwrap();
721 let topics = fitted.transform(&dtm).unwrap();
722
723 let first_group_topic: Vec<usize> = (0..3)
726 .map(|i| {
727 if topics[[i, 0]] > topics[[i, 1]] {
728 0
729 } else {
730 1
731 }
732 })
733 .collect();
734 let second_group_topic: Vec<usize> = (3..6)
735 .map(|i| {
736 if topics[[i, 0]] > topics[[i, 1]] {
737 0
738 } else {
739 1
740 }
741 })
742 .collect();
743
744 let fg_mode = if first_group_topic.iter().filter(|&&t| t == 0).count() >= 2 {
746 0
747 } else {
748 1
749 };
750 let sg_mode = if second_group_topic.iter().filter(|&&t| t == 0).count() >= 2 {
751 0
752 } else {
753 1
754 };
755
756 assert_ne!(
757 fg_mode, sg_mode,
758 "the two document groups should be assigned to different topics"
759 );
760 }
761
762 #[test]
763 fn test_lda_online_learning() {
764 let dtm = two_topic_corpus();
765 let lda = LatentDirichletAllocation::new(2)
766 .with_learning_method(LdaLearningMethod::Online)
767 .with_max_iter(10)
768 .with_random_state(42);
769 let fitted = lda.fit(&dtm, &()).unwrap();
770 assert_eq!(fitted.components().dim(), (2, 6));
771 let topics = fitted.transform(&dtm).unwrap();
772 for i in 0..topics.nrows() {
774 let sum: f64 = topics.row(i).sum();
775 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-5);
776 }
777 }
778
779 #[test]
780 fn test_lda_components_non_negative() {
781 let dtm = two_topic_corpus();
782 let lda = LatentDirichletAllocation::new(2).with_random_state(42);
783 let fitted = lda.fit(&dtm, &()).unwrap();
784 for &val in fitted.components().iter() {
785 assert!(val >= 0.0, "component should be non-negative, got {val}");
786 }
787 }
788
789 #[test]
790 fn test_lda_transform_shape_mismatch() {
791 let dtm = two_topic_corpus();
792 let lda = LatentDirichletAllocation::new(2).with_random_state(42);
793 let fitted = lda.fit(&dtm, &()).unwrap();
794 let bad = array![[1.0, 2.0, 3.0]]; assert!(fitted.transform(&bad).is_err());
796 }
797
798 #[test]
799 fn test_lda_transform_negative_rejected() {
800 let dtm = two_topic_corpus();
801 let lda = LatentDirichletAllocation::new(2).with_random_state(42);
802 let fitted = lda.fit(&dtm, &()).unwrap();
803 let bad = array![[1.0, -1.0, 0.0, 0.0, 0.0, 0.0]];
804 assert!(fitted.transform(&bad).is_err());
805 }
806
807 #[test]
808 fn test_lda_invalid_n_components_zero() {
809 let dtm = two_topic_corpus();
810 let lda = LatentDirichletAllocation::new(0);
811 assert!(lda.fit(&dtm, &()).is_err());
812 }
813
814 #[test]
815 fn test_lda_negative_input_rejected() {
816 let dtm = array![[1.0, -1.0], [2.0, 3.0]];
817 let lda = LatentDirichletAllocation::new(1);
818 assert!(lda.fit(&dtm, &()).is_err());
819 }
820
821 #[test]
822 fn test_lda_empty_corpus() {
823 let dtm = Array2::<f64>::zeros((0, 5));
824 let lda = LatentDirichletAllocation::new(2);
825 assert!(lda.fit(&dtm, &()).is_err());
826 }
827
828 #[test]
829 fn test_lda_zero_words() {
830 let dtm = Array2::<f64>::zeros((5, 0));
831 let lda = LatentDirichletAllocation::new(2);
832 assert!(lda.fit(&dtm, &()).is_err());
833 }
834
835 #[test]
836 fn test_lda_getters() {
837 let lda = LatentDirichletAllocation::new(5)
838 .with_max_iter(20)
839 .with_learning_method(LdaLearningMethod::Online)
840 .with_learning_offset(15.0)
841 .with_learning_decay(0.5)
842 .with_doc_topic_prior(0.1)
843 .with_topic_word_prior(0.01)
844 .with_random_state(99);
845 assert_eq!(lda.n_components(), 5);
846 assert_eq!(lda.max_iter(), 20);
847 assert_eq!(lda.learning_method(), LdaLearningMethod::Online);
848 assert!((lda.learning_offset() - 15.0).abs() < 1e-10);
849 assert!((lda.learning_decay() - 0.5).abs() < 1e-10);
850 assert_eq!(lda.doc_topic_prior(), Some(0.1));
851 assert_eq!(lda.topic_word_prior(), Some(0.01));
852 assert_eq!(lda.random_state(), Some(99));
853 }
854
855 #[test]
856 fn test_lda_fitted_accessors() {
857 let dtm = two_topic_corpus();
858 let lda = LatentDirichletAllocation::new(2)
859 .with_doc_topic_prior(0.5)
860 .with_topic_word_prior(0.1)
861 .with_random_state(42);
862 let fitted = lda.fit(&dtm, &()).unwrap();
863 assert!((fitted.alpha() - 0.5).abs() < 1e-10);
864 assert!((fitted.beta() - 0.1).abs() < 1e-10);
865 assert!(fitted.n_iter() > 0);
866 }
867
868 #[test]
869 fn test_lda_single_topic() {
870 let dtm = two_topic_corpus();
871 let lda = LatentDirichletAllocation::new(1).with_random_state(42);
872 let fitted = lda.fit(&dtm, &()).unwrap();
873 let topics = fitted.transform(&dtm).unwrap();
874 assert_eq!(topics.ncols(), 1);
875 for i in 0..topics.nrows() {
877 assert_abs_diff_eq!(topics[[i, 0]], 1.0, epsilon = 1e-3);
878 }
879 }
880
881 #[test]
882 fn test_digamma_basic() {
883 let val = digamma(1.0);
885 assert!((val - (-0.5772156649)).abs() < 1e-4, "digamma(1) = {val}");
886 }
887
888 #[test]
889 fn test_digamma_large() {
890 let val = digamma(10.0);
892 assert!((val - 2.2517525890).abs() < 1e-4, "digamma(10) = {val}");
893 }
894}