Skip to main content

scirs2_text/
topic_model.rs

1//! # Advanced Topic Modeling
2//!
3//! This module provides advanced topic modeling algorithms distinct from the
4//! variational LDA in [`topic_modeling`](crate::topic_modeling):
5//!
6//! - **LDA via collapsed Gibbs sampling**: The standard Bayesian approach
7//! - **Non-negative Matrix Factorization (NMF)**: A linear-algebraic topic model
8//! - **Topic coherence scoring**: C_v and UMass metrics
9//! - **Topic-document and topic-word distributions**
10//! - **Automatic topic number selection via coherence**
11//!
12//! ## Example
13//!
14//! ```rust
15//! use scirs2_text::topic_model::{GibbsLda, GibbsLdaConfig, NmfTopicModel};
16//!
17//! let docs = vec![
18//!     vec!["machine", "learning", "algorithm", "data"],
19//!     vec!["deep", "learning", "neural", "network"],
20//!     vec!["natural", "language", "processing", "text"],
21//!     vec!["cat", "dog", "pet", "animal"],
22//!     vec!["pet", "care", "food", "animal"],
23//! ];
24//!
25//! // LDA via Gibbs sampling
26//! let config = GibbsLdaConfig {
27//!     n_topics: 2,
28//!     alpha: 0.1,
29//!     beta: 0.01,
30//!     n_iterations: 100,
31//!     seed: Some(42),
32//!     ..Default::default()
33//! };
34//!
35//! let docs_str: Vec<Vec<&str>> = docs.iter().map(|d| d.iter().map(|s| *s).collect()).collect();
36//! let mut lda = GibbsLda::new(config);
37//! lda.fit(&docs_str).unwrap();
38//!
39//! let topics = lda.top_words(5);
40//! assert_eq!(topics.len(), 2);
41//! ```
42
43use crate::error::{Result, TextError};
44use scirs2_core::ndarray::{Array1, Array2};
45use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
46use std::collections::{HashMap, HashSet};
47
48// ---------------------------------------------------------------------------
49// Gibbs LDA
50// ---------------------------------------------------------------------------
51
52/// Configuration for collapsed Gibbs sampling LDA.
53#[derive(Debug, Clone)]
54pub struct GibbsLdaConfig {
55    /// Number of topics.
56    pub n_topics: usize,
57    /// Dirichlet prior on document-topic distribution (alpha).
58    pub alpha: f64,
59    /// Dirichlet prior on topic-word distribution (beta).
60    pub beta: f64,
61    /// Number of Gibbs sampling iterations.
62    pub n_iterations: usize,
63    /// Burn-in iterations to discard.
64    pub burn_in: usize,
65    /// Random seed.
66    pub seed: Option<u64>,
67}
68
69impl Default for GibbsLdaConfig {
70    fn default() -> Self {
71        Self {
72            n_topics: 10,
73            alpha: 0.1,
74            beta: 0.01,
75            n_iterations: 500,
76            burn_in: 50,
77            seed: None,
78        }
79    }
80}
81
82/// LDA model using collapsed Gibbs sampling (Griffiths & Steyvers, 2004).
83///
84/// In collapsed Gibbs sampling, we integrate out the topic-word and
85/// document-topic distributions and only sample the topic assignments.
86#[derive(Debug)]
87pub struct GibbsLda {
88    config: GibbsLdaConfig,
89    /// Vocabulary: word -> index
90    vocab: HashMap<String, usize>,
91    /// Reverse vocabulary: index -> word
92    rev_vocab: Vec<String>,
93    /// Topic assignment for each word occurrence: topic_assignments[doc][word_pos]
94    topic_assignments: Vec<Vec<usize>>,
95    /// Count of topic k in document d: n_dk[d][k]
96    n_dk: Vec<Vec<usize>>,
97    /// Count of word w assigned to topic k: n_kw[k][w]
98    n_kw: Vec<Vec<usize>>,
99    /// Total words assigned to topic k: n_k[k]
100    n_k: Vec<usize>,
101    /// Number of words in each document
102    doc_lengths: Vec<usize>,
103    /// Document word indices: docs[d][pos] = word_index
104    doc_words: Vec<Vec<usize>>,
105    /// Whether the model has been fitted
106    fitted: bool,
107}
108
109impl GibbsLda {
110    /// Create a new Gibbs LDA model.
111    pub fn new(config: GibbsLdaConfig) -> Self {
112        Self {
113            config,
114            vocab: HashMap::new(),
115            rev_vocab: Vec::new(),
116            topic_assignments: Vec::new(),
117            n_dk: Vec::new(),
118            n_kw: Vec::new(),
119            n_k: Vec::new(),
120            doc_lengths: Vec::new(),
121            doc_words: Vec::new(),
122            fitted: false,
123        }
124    }
125
126    /// Fit the LDA model on a tokenized corpus.
127    pub fn fit(&mut self, documents: &[Vec<&str>]) -> Result<()> {
128        if documents.is_empty() {
129            return Err(TextError::InvalidInput(
130                "Cannot fit LDA on empty corpus".to_string(),
131            ));
132        }
133
134        let n_topics = self.config.n_topics;
135        if n_topics == 0 {
136            return Err(TextError::InvalidInput(
137                "Number of topics must be > 0".to_string(),
138            ));
139        }
140
141        // Build vocabulary
142        self.vocab.clear();
143        self.rev_vocab.clear();
144        let mut word_set: HashSet<String> = HashSet::new();
145        for doc in documents {
146            for &word in doc {
147                word_set.insert(word.to_string());
148            }
149        }
150        let mut sorted_words: Vec<String> = word_set.into_iter().collect();
151        sorted_words.sort();
152        for (idx, word) in sorted_words.iter().enumerate() {
153            self.vocab.insert(word.clone(), idx);
154        }
155        self.rev_vocab = sorted_words;
156        let n_vocab = self.rev_vocab.len();
157
158        if n_vocab == 0 {
159            return Err(TextError::InvalidInput(
160                "Empty vocabulary after tokenization".to_string(),
161            ));
162        }
163
164        let n_docs = documents.len();
165
166        // Convert documents to word indices
167        self.doc_words = documents
168            .iter()
169            .map(|doc| {
170                doc.iter()
171                    .filter_map(|w| self.vocab.get(*w).copied())
172                    .collect()
173            })
174            .collect();
175
176        self.doc_lengths = self.doc_words.iter().map(|d| d.len()).collect();
177
178        // Initialize counts
179        self.n_dk = vec![vec![0; n_topics]; n_docs];
180        self.n_kw = vec![vec![0; n_vocab]; n_topics];
181        self.n_k = vec![0; n_topics];
182        self.topic_assignments = Vec::with_capacity(n_docs);
183
184        // Random initialization of topic assignments
185        let mut rng = match self.config.seed {
186            Some(seed) => StdRng::seed_from_u64(seed),
187            None => StdRng::seed_from_u64(42),
188        };
189
190        for d in 0..n_docs {
191            let mut doc_topics = Vec::with_capacity(self.doc_words[d].len());
192            for &w in &self.doc_words[d] {
193                let k = (rng.random::<f64>() * n_topics as f64) as usize % n_topics;
194                doc_topics.push(k);
195                self.n_dk[d][k] += 1;
196                self.n_kw[k][w] += 1;
197                self.n_k[k] += 1;
198            }
199            self.topic_assignments.push(doc_topics);
200        }
201
202        // Gibbs sampling iterations
203        let alpha = self.config.alpha;
204        let beta = self.config.beta;
205        let beta_sum = beta * n_vocab as f64;
206
207        for _iter in 0..self.config.n_iterations {
208            for d in 0..n_docs {
209                let n_words_d = self.doc_words[d].len();
210                for i in 0..n_words_d {
211                    let w = self.doc_words[d][i];
212                    let old_k = self.topic_assignments[d][i];
213
214                    // Decrement counts
215                    self.n_dk[d][old_k] -= 1;
216                    self.n_kw[old_k][w] -= 1;
217                    self.n_k[old_k] -= 1;
218
219                    // Compute conditional distribution p(z_i = k | ...)
220                    let mut probs = vec![0.0f64; n_topics];
221                    for k in 0..n_topics {
222                        probs[k] = (self.n_dk[d][k] as f64 + alpha)
223                            * (self.n_kw[k][w] as f64 + beta)
224                            / (self.n_k[k] as f64 + beta_sum);
225                    }
226
227                    // Sample new topic
228                    let total: f64 = probs.iter().sum();
229                    if total < 1e-15 {
230                        // Fallback: uniform
231                        let new_k = (rng.random::<f64>() * n_topics as f64) as usize % n_topics;
232                        self.topic_assignments[d][i] = new_k;
233                        self.n_dk[d][new_k] += 1;
234                        self.n_kw[new_k][w] += 1;
235                        self.n_k[new_k] += 1;
236                        continue;
237                    }
238
239                    let threshold = rng.random::<f64>() * total;
240                    let mut cumsum = 0.0;
241                    let mut new_k = n_topics - 1;
242                    for k in 0..n_topics {
243                        cumsum += probs[k];
244                        if cumsum >= threshold {
245                            new_k = k;
246                            break;
247                        }
248                    }
249
250                    // Update counts
251                    self.topic_assignments[d][i] = new_k;
252                    self.n_dk[d][new_k] += 1;
253                    self.n_kw[new_k][w] += 1;
254                    self.n_k[new_k] += 1;
255                }
256            }
257        }
258
259        self.fitted = true;
260        Ok(())
261    }
262
263    /// Get the topic-word distribution for topic k: P(w | k).
264    pub fn topic_word_distribution(&self, topic: usize) -> Result<Array1<f64>> {
265        if !self.fitted {
266            return Err(TextError::ModelNotFitted("LDA not fitted".to_string()));
267        }
268        if topic >= self.config.n_topics {
269            return Err(TextError::InvalidInput(format!(
270                "Topic {} out of range ({})",
271                topic, self.config.n_topics
272            )));
273        }
274
275        let n_vocab = self.rev_vocab.len();
276        let beta = self.config.beta;
277        let beta_sum = beta * n_vocab as f64;
278        let total = self.n_k[topic] as f64 + beta_sum;
279
280        let mut dist = Array1::<f64>::zeros(n_vocab);
281        for w in 0..n_vocab {
282            dist[w] = (self.n_kw[topic][w] as f64 + beta) / total;
283        }
284        Ok(dist)
285    }
286
287    /// Get the document-topic distribution for document d: P(k | d).
288    pub fn doc_topic_distribution(&self, doc: usize) -> Result<Array1<f64>> {
289        if !self.fitted {
290            return Err(TextError::ModelNotFitted("LDA not fitted".to_string()));
291        }
292        if doc >= self.n_dk.len() {
293            return Err(TextError::InvalidInput(format!(
294                "Document {} out of range ({})",
295                doc,
296                self.n_dk.len()
297            )));
298        }
299
300        let n_topics = self.config.n_topics;
301        let alpha = self.config.alpha;
302        let total = self.doc_lengths[doc] as f64 + alpha * n_topics as f64;
303
304        let mut dist = Array1::<f64>::zeros(n_topics);
305        if total < 1e-15 {
306            // Empty document: uniform
307            let uniform = 1.0 / n_topics as f64;
308            for k in 0..n_topics {
309                dist[k] = uniform;
310            }
311        } else {
312            for k in 0..n_topics {
313                dist[k] = (self.n_dk[doc][k] as f64 + alpha) / total;
314            }
315        }
316        Ok(dist)
317    }
318
319    /// Get the full document-topic matrix.
320    pub fn doc_topic_matrix(&self) -> Result<Array2<f64>> {
321        if !self.fitted {
322            return Err(TextError::ModelNotFitted("LDA not fitted".to_string()));
323        }
324
325        let n_docs = self.n_dk.len();
326        let n_topics = self.config.n_topics;
327        let mut matrix = Array2::<f64>::zeros((n_docs, n_topics));
328        for d in 0..n_docs {
329            let dist = self.doc_topic_distribution(d)?;
330            for k in 0..n_topics {
331                matrix[[d, k]] = dist[k];
332            }
333        }
334        Ok(matrix)
335    }
336
337    /// Get top words for each topic.
338    pub fn top_words(&self, n_words: usize) -> Vec<Vec<(String, f64)>> {
339        let n_topics = self.config.n_topics;
340        let mut result = Vec::with_capacity(n_topics);
341
342        for k in 0..n_topics {
343            let dist = match self.topic_word_distribution(k) {
344                Ok(d) => d,
345                Err(_) => {
346                    result.push(Vec::new());
347                    continue;
348                }
349            };
350
351            let mut word_probs: Vec<(usize, f64)> =
352                dist.iter().enumerate().map(|(i, &p)| (i, p)).collect();
353            word_probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
354
355            let top: Vec<(String, f64)> = word_probs
356                .iter()
357                .take(n_words.min(self.rev_vocab.len()))
358                .map(|(idx, prob)| (self.rev_vocab[*idx].clone(), *prob))
359                .collect();
360
361            result.push(top);
362        }
363        result
364    }
365
366    /// Get the vocabulary.
367    pub fn vocabulary(&self) -> &HashMap<String, usize> {
368        &self.vocab
369    }
370
371    /// Get the number of topics.
372    pub fn n_topics(&self) -> usize {
373        self.config.n_topics
374    }
375
376    /// Check if the model is fitted.
377    pub fn is_fitted(&self) -> bool {
378        self.fitted
379    }
380}
381
382// ---------------------------------------------------------------------------
383// Non-negative Matrix Factorization (NMF) Topic Model
384// ---------------------------------------------------------------------------
385
386/// Configuration for NMF topic modeling.
387#[derive(Debug, Clone)]
388pub struct NmfConfig {
389    /// Number of topics (components).
390    pub n_topics: usize,
391    /// Maximum iterations.
392    pub max_iter: usize,
393    /// Convergence tolerance.
394    pub tolerance: f64,
395    /// Random seed.
396    pub seed: u64,
397}
398
399impl Default for NmfConfig {
400    fn default() -> Self {
401        Self {
402            n_topics: 10,
403            max_iter: 200,
404            tolerance: 1e-4,
405            seed: 42,
406        }
407    }
408}
409
410/// NMF-based topic model.
411///
412/// Factorizes a term-document matrix V ~ W * H where:
413/// - W (n_docs x n_topics): document-topic weights
414/// - H (n_topics x n_terms): topic-term weights
415///
416/// Uses multiplicative update rules (Lee & Seung, 2001).
417#[derive(Debug)]
418pub struct NmfTopicModel {
419    config: NmfConfig,
420    /// Document-topic matrix W.
421    w: Option<Array2<f64>>,
422    /// Topic-term matrix H.
423    h: Option<Array2<f64>>,
424    /// Vocabulary.
425    vocab: Vec<String>,
426    /// Reconstruction error history.
427    error_history: Vec<f64>,
428    /// Whether fitted.
429    fitted: bool,
430}
431
432impl NmfTopicModel {
433    /// Create a new NMF topic model.
434    pub fn new(config: NmfConfig) -> Self {
435        Self {
436            config,
437            w: None,
438            h: None,
439            vocab: Vec::new(),
440            error_history: Vec::new(),
441            fitted: false,
442        }
443    }
444
445    /// Fit NMF on a non-negative document-term matrix.
446    ///
447    /// `matrix` is (n_docs, n_terms), `vocabulary` maps index -> word.
448    pub fn fit(&mut self, matrix: &Array2<f64>, vocabulary: &[String]) -> Result<()> {
449        let (n_docs, n_terms) = matrix.dim();
450        let n_topics = self.config.n_topics;
451
452        if n_docs == 0 || n_terms == 0 {
453            return Err(TextError::InvalidInput(
454                "Cannot fit NMF on empty matrix".to_string(),
455            ));
456        }
457        if n_topics > n_docs || n_topics > n_terms {
458            return Err(TextError::InvalidInput(format!(
459                "n_topics ({}) must not exceed matrix dimensions ({}, {})",
460                n_topics, n_docs, n_terms
461            )));
462        }
463
464        self.vocab = vocabulary.to_vec();
465
466        // Initialize W and H with small random positive values
467        let mut rng = StdRng::seed_from_u64(self.config.seed);
468        let mut w = Array2::<f64>::zeros((n_docs, n_topics));
469        let mut h = Array2::<f64>::zeros((n_topics, n_terms));
470
471        let eps = 1e-10;
472        for elem in w.iter_mut() {
473            *elem = rng.random::<f64>() * 0.1 + eps;
474        }
475        for elem in h.iter_mut() {
476            *elem = rng.random::<f64>() * 0.1 + eps;
477        }
478
479        self.error_history.clear();
480
481        // Multiplicative update rules
482        for _iter in 0..self.config.max_iter {
483            // Update H: H <- H * (W^T * V) / (W^T * W * H)
484            let wt_v = mat_mul_ata_b(&w, matrix);
485            let wt_w = mat_mul_ata_b(&w, &w);
486            let wt_w_h = mat_mul_ab(&wt_w, &h);
487
488            for i in 0..n_topics {
489                for j in 0..n_terms {
490                    let denom = wt_w_h[[i, j]] + eps;
491                    h[[i, j]] *= wt_v[[i, j]] / denom;
492                    if h[[i, j]] < eps {
493                        h[[i, j]] = eps;
494                    }
495                }
496            }
497
498            // Update W: W <- W * (V * H^T) / (W * H * H^T)
499            let v_ht = mat_mul_abt(matrix, &h);
500            let w_h = mat_mul_ab(&w, &h);
501            let w_h_ht = mat_mul_abt(&w_h, &h);
502
503            for i in 0..n_docs {
504                for j in 0..n_topics {
505                    let denom = w_h_ht[[i, j]] + eps;
506                    w[[i, j]] *= v_ht[[i, j]] / denom;
507                    if w[[i, j]] < eps {
508                        w[[i, j]] = eps;
509                    }
510                }
511            }
512
513            // Compute reconstruction error
514            let wh = mat_mul_ab(&w, &h);
515            let mut error = 0.0;
516            for i in 0..n_docs {
517                for j in 0..n_terms {
518                    let diff = matrix[[i, j]] - wh[[i, j]];
519                    error += diff * diff;
520                }
521            }
522            error = error.sqrt();
523            self.error_history.push(error);
524
525            // Check convergence
526            if self.error_history.len() >= 2 {
527                let prev = self.error_history[self.error_history.len() - 2];
528                if (prev - error).abs() < self.config.tolerance {
529                    break;
530                }
531            }
532        }
533
534        self.w = Some(w);
535        self.h = Some(h);
536        self.fitted = true;
537        Ok(())
538    }
539
540    /// Get the document-topic matrix W.
541    pub fn doc_topic_matrix(&self) -> Result<&Array2<f64>> {
542        self.w
543            .as_ref()
544            .ok_or_else(|| TextError::ModelNotFitted("NMF not fitted".to_string()))
545    }
546
547    /// Get the topic-term matrix H.
548    pub fn topic_term_matrix(&self) -> Result<&Array2<f64>> {
549        self.h
550            .as_ref()
551            .ok_or_else(|| TextError::ModelNotFitted("NMF not fitted".to_string()))
552    }
553
554    /// Get top words for each topic.
555    pub fn top_words(&self, n_words: usize) -> Result<Vec<Vec<(String, f64)>>> {
556        let h = self.topic_term_matrix()?;
557        let n_topics = h.nrows();
558        let mut result = Vec::with_capacity(n_topics);
559
560        for k in 0..n_topics {
561            let row = h.row(k);
562            let mut word_scores: Vec<(usize, f64)> =
563                row.iter().enumerate().map(|(i, &v)| (i, v)).collect();
564            word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
565
566            let top: Vec<(String, f64)> = word_scores
567                .iter()
568                .take(n_words.min(self.vocab.len()))
569                .filter_map(|(idx, score)| self.vocab.get(*idx).map(|w| (w.clone(), *score)))
570                .collect();
571            result.push(top);
572        }
573        Ok(result)
574    }
575
576    /// Get the reconstruction error history.
577    pub fn error_history(&self) -> &[f64] {
578        &self.error_history
579    }
580
581    /// Check if fitted.
582    pub fn is_fitted(&self) -> bool {
583        self.fitted
584    }
585}
586
587// ---------------------------------------------------------------------------
588// Topic Coherence Scoring
589// ---------------------------------------------------------------------------
590
591/// Topic coherence calculator.
592///
593/// Supports C_v (NPMI-based) and UMass coherence metrics.
594#[derive(Debug, Clone)]
595pub struct TopicCoherenceScorer {
596    /// Window size for co-occurrence.
597    window_size: usize,
598    /// Smoothing epsilon.
599    epsilon: f64,
600}
601
602impl Default for TopicCoherenceScorer {
603    fn default() -> Self {
604        Self {
605            window_size: 10,
606            epsilon: 1e-12,
607        }
608    }
609}
610
611impl TopicCoherenceScorer {
612    /// Create a new coherence scorer.
613    pub fn new() -> Self {
614        Self::default()
615    }
616
617    /// Set window size.
618    pub fn with_window_size(mut self, size: usize) -> Self {
619        self.window_size = size;
620        self
621    }
622
623    /// Calculate C_v coherence (NPMI-based, Roder et al. 2015).
624    ///
625    /// Higher values indicate more coherent topics.
626    pub fn cv_coherence(
627        &self,
628        topic_words: &[Vec<String>],
629        documents: &[Vec<String>],
630    ) -> Result<f64> {
631        if topic_words.is_empty() || documents.is_empty() {
632            return Err(TextError::InvalidInput(
633                "Topic words and documents must not be empty".to_string(),
634            ));
635        }
636
637        let n_docs = documents.len() as f64;
638
639        // Compute document frequency and co-document frequency
640        let doc_sets: Vec<HashSet<&String>> =
641            documents.iter().map(|doc| doc.iter().collect()).collect();
642
643        let mut topic_scores = Vec::with_capacity(topic_words.len());
644
645        for words in topic_words {
646            if words.len() < 2 {
647                topic_scores.push(0.0);
648                continue;
649            }
650
651            let mut npmi_sum = 0.0;
652            let mut pair_count = 0;
653
654            for i in 0..words.len() {
655                for j in (i + 1)..words.len() {
656                    let wi = &words[i];
657                    let wj = &words[j];
658
659                    let df_i = doc_sets.iter().filter(|s| s.contains(wi)).count() as f64;
660                    let df_j = doc_sets.iter().filter(|s| s.contains(wj)).count() as f64;
661                    let df_ij = doc_sets
662                        .iter()
663                        .filter(|s| s.contains(wi) && s.contains(wj))
664                        .count() as f64;
665
666                    let p_i = (df_i + self.epsilon) / n_docs;
667                    let p_j = (df_j + self.epsilon) / n_docs;
668                    let p_ij = (df_ij + self.epsilon) / n_docs;
669
670                    // NPMI = (log(P(i,j) / (P(i) * P(j)))) / (-log(P(i,j)))
671                    let pmi = (p_ij / (p_i * p_j)).ln();
672                    let neg_log_p_ij = -(p_ij.ln());
673
674                    let npmi = if neg_log_p_ij.abs() > self.epsilon {
675                        pmi / neg_log_p_ij
676                    } else {
677                        0.0
678                    };
679
680                    npmi_sum += npmi;
681                    pair_count += 1;
682                }
683            }
684
685            let score = if pair_count > 0 {
686                npmi_sum / pair_count as f64
687            } else {
688                0.0
689            };
690            topic_scores.push(score);
691        }
692
693        let avg = topic_scores.iter().sum::<f64>() / topic_scores.len() as f64;
694        Ok(avg)
695    }
696
697    /// Calculate UMass coherence (Mimno et al. 2011).
698    ///
699    /// Uses document co-occurrence. Higher (less negative) values are better.
700    pub fn umass_coherence(
701        &self,
702        topic_words: &[Vec<String>],
703        documents: &[Vec<String>],
704    ) -> Result<f64> {
705        if topic_words.is_empty() || documents.is_empty() {
706            return Err(TextError::InvalidInput(
707                "Topic words and documents must not be empty".to_string(),
708            ));
709        }
710
711        let doc_sets: Vec<HashSet<&String>> =
712            documents.iter().map(|doc| doc.iter().collect()).collect();
713
714        let mut topic_scores = Vec::with_capacity(topic_words.len());
715
716        for words in topic_words {
717            if words.len() < 2 {
718                topic_scores.push(0.0);
719                continue;
720            }
721
722            let mut score = 0.0;
723            let mut pair_count = 0;
724
725            for i in 1..words.len() {
726                for j in 0..i {
727                    let wi = &words[i];
728                    let wj = &words[j];
729
730                    let df_j = doc_sets.iter().filter(|s| s.contains(wj)).count() as f64;
731                    let df_ij = doc_sets
732                        .iter()
733                        .filter(|s| s.contains(wi) && s.contains(wj))
734                        .count() as f64;
735
736                    // UMass: log((D(wi, wj) + epsilon) / D(wj))
737                    score += ((df_ij + self.epsilon) / (df_j + self.epsilon)).ln();
738                    pair_count += 1;
739                }
740            }
741
742            let avg_score = if pair_count > 0 {
743                score / pair_count as f64
744            } else {
745                0.0
746            };
747            topic_scores.push(avg_score);
748        }
749
750        let avg = topic_scores.iter().sum::<f64>() / topic_scores.len() as f64;
751        Ok(avg)
752    }
753}
754
755// ---------------------------------------------------------------------------
756// Automatic Topic Number Selection
757// ---------------------------------------------------------------------------
758
759/// Select the optimal number of topics by maximizing coherence.
760///
761/// Fits LDA models with different numbers of topics and returns the one
762/// with the highest coherence score.
763///
764/// # Arguments
765///
766/// * `documents` - Tokenized documents
767/// * `min_topics` - Minimum number of topics to try
768/// * `max_topics` - Maximum number of topics to try
769/// * `n_iterations` - Gibbs sampling iterations per model
770/// * `seed` - Random seed
771///
772/// # Returns
773///
774/// (optimal_n_topics, coherence_scores)
775pub fn select_n_topics(
776    documents: &[Vec<&str>],
777    min_topics: usize,
778    max_topics: usize,
779    n_iterations: usize,
780    seed: u64,
781) -> Result<(usize, Vec<(usize, f64)>)> {
782    if documents.is_empty() {
783        return Err(TextError::InvalidInput(
784            "Cannot select topics on empty corpus".to_string(),
785        ));
786    }
787    if min_topics == 0 || min_topics > max_topics {
788        return Err(TextError::InvalidInput(format!(
789            "Invalid topic range: {} to {}",
790            min_topics, max_topics
791        )));
792    }
793
794    let scorer = TopicCoherenceScorer::new();
795
796    // Convert documents to owned strings for coherence calculation
797    let doc_strings: Vec<Vec<String>> = documents
798        .iter()
799        .map(|doc| doc.iter().map(|w| w.to_string()).collect())
800        .collect();
801
802    let mut scores: Vec<(usize, f64)> = Vec::new();
803    let mut best_k = min_topics;
804    let mut best_score = f64::NEG_INFINITY;
805
806    for k in min_topics..=max_topics {
807        let config = GibbsLdaConfig {
808            n_topics: k,
809            alpha: 50.0 / k as f64,
810            beta: 0.01,
811            n_iterations,
812            burn_in: n_iterations / 5,
813            seed: Some(seed),
814        };
815
816        let mut lda = GibbsLda::new(config);
817        lda.fit(documents)?;
818
819        let top_words = lda.top_words(10);
820        let topic_word_strs: Vec<Vec<String>> = top_words
821            .iter()
822            .map(|tw| tw.iter().map(|(w, _)| w.clone()).collect())
823            .collect();
824
825        let coherence = scorer.cv_coherence(&topic_word_strs, &doc_strings)?;
826        scores.push((k, coherence));
827
828        if coherence > best_score {
829            best_score = coherence;
830            best_k = k;
831        }
832    }
833
834    Ok((best_k, scores))
835}
836
837// ---------------------------------------------------------------------------
838// Matrix helper functions (for NMF)
839// ---------------------------------------------------------------------------
840
841/// A * B
842fn mat_mul_ab(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
843    let (ar, ac) = a.dim();
844    let (_br, bc) = b.dim();
845    let mut result = Array2::<f64>::zeros((ar, bc));
846    for i in 0..ar {
847        for k in 0..ac {
848            let a_ik = a[[i, k]];
849            if a_ik.abs() < 1e-15 {
850                continue;
851            }
852            for j in 0..bc {
853                result[[i, j]] += a_ik * b[[k, j]];
854            }
855        }
856    }
857    result
858}
859
860/// A^T * B
861fn mat_mul_ata_b(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
862    let (ar, ac) = a.dim();
863    let (_br, bc) = b.dim();
864    let mut result = Array2::<f64>::zeros((ac, bc));
865    for k in 0..ar {
866        for i in 0..ac {
867            let a_ki = a[[k, i]];
868            if a_ki.abs() < 1e-15 {
869                continue;
870            }
871            for j in 0..bc {
872                result[[i, j]] += a_ki * b[[k, j]];
873            }
874        }
875    }
876    result
877}
878
879/// A * B^T
880fn mat_mul_abt(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
881    let (ar, ac) = a.dim();
882    let (br, _bc) = b.dim();
883    let mut result = Array2::<f64>::zeros((ar, br));
884    for i in 0..ar {
885        for j in 0..br {
886            let mut sum = 0.0;
887            for k in 0..ac {
888                sum += a[[i, k]] * b[[j, k]];
889            }
890            result[[i, j]] = sum;
891        }
892    }
893    result
894}
895
896// ---------------------------------------------------------------------------
897// Tests
898// ---------------------------------------------------------------------------
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903
904    fn sample_docs() -> Vec<Vec<&'static str>> {
905        vec![
906            vec!["machine", "learning", "algorithm", "data", "model"],
907            vec!["deep", "learning", "neural", "network", "training"],
908            vec!["natural", "language", "processing", "text", "word"],
909            vec!["cat", "dog", "pet", "animal", "food"],
910            vec!["pet", "care", "food", "animal", "home"],
911            vec!["dog", "cat", "play", "park", "fun"],
912        ]
913    }
914
915    #[test]
916    fn test_gibbs_lda_fit() {
917        let docs = sample_docs();
918        let config = GibbsLdaConfig {
919            n_topics: 2,
920            n_iterations: 50,
921            seed: Some(42),
922            ..Default::default()
923        };
924        let mut lda = GibbsLda::new(config);
925        lda.fit(&docs).expect("fit failed");
926        assert!(lda.is_fitted());
927    }
928
929    #[test]
930    fn test_gibbs_lda_top_words() {
931        let docs = sample_docs();
932        let config = GibbsLdaConfig {
933            n_topics: 2,
934            n_iterations: 100,
935            seed: Some(42),
936            ..Default::default()
937        };
938        let mut lda = GibbsLda::new(config);
939        lda.fit(&docs).expect("fit failed");
940
941        let topics = lda.top_words(5);
942        assert_eq!(topics.len(), 2);
943        for topic in &topics {
944            assert_eq!(topic.len(), 5);
945            // Probabilities should sum to something reasonable
946            let prob_sum: f64 = topic.iter().map(|(_, p)| p).sum();
947            assert!(prob_sum > 0.0);
948        }
949    }
950
951    #[test]
952    fn test_gibbs_lda_doc_topic_distribution() {
953        let docs = sample_docs();
954        let config = GibbsLdaConfig {
955            n_topics: 2,
956            n_iterations: 50,
957            seed: Some(42),
958            ..Default::default()
959        };
960        let mut lda = GibbsLda::new(config);
961        lda.fit(&docs).expect("fit failed");
962
963        let dist = lda.doc_topic_distribution(0).expect("dist failed");
964        assert_eq!(dist.len(), 2);
965        let sum: f64 = dist.iter().sum();
966        assert!((sum - 1.0).abs() < 1e-6);
967    }
968
969    #[test]
970    fn test_gibbs_lda_topic_word_distribution() {
971        let docs = sample_docs();
972        let config = GibbsLdaConfig {
973            n_topics: 2,
974            n_iterations: 50,
975            seed: Some(42),
976            ..Default::default()
977        };
978        let mut lda = GibbsLda::new(config);
979        lda.fit(&docs).expect("fit failed");
980
981        let dist = lda.topic_word_distribution(0).expect("dist failed");
982        let sum: f64 = dist.iter().sum();
983        assert!((sum - 1.0).abs() < 1e-6);
984    }
985
986    #[test]
987    fn test_gibbs_lda_doc_topic_matrix() {
988        let docs = sample_docs();
989        let config = GibbsLdaConfig {
990            n_topics: 3,
991            n_iterations: 50,
992            seed: Some(42),
993            ..Default::default()
994        };
995        let mut lda = GibbsLda::new(config);
996        lda.fit(&docs).expect("fit failed");
997
998        let matrix = lda.doc_topic_matrix().expect("matrix failed");
999        assert_eq!(matrix.dim(), (6, 3));
1000
1001        // Each row should sum to ~1
1002        for i in 0..6 {
1003            let sum: f64 = matrix.row(i).iter().sum();
1004            assert!((sum - 1.0).abs() < 1e-6);
1005        }
1006    }
1007
1008    #[test]
1009    fn test_gibbs_lda_empty_corpus() {
1010        let config = GibbsLdaConfig::default();
1011        let mut lda = GibbsLda::new(config);
1012        let result = lda.fit(&[]);
1013        assert!(result.is_err());
1014    }
1015
1016    #[test]
1017    fn test_gibbs_lda_not_fitted() {
1018        let lda = GibbsLda::new(GibbsLdaConfig::default());
1019        assert!(lda.doc_topic_distribution(0).is_err());
1020        assert!(lda.topic_word_distribution(0).is_err());
1021    }
1022
1023    #[test]
1024    fn test_nmf_fit() {
1025        // Build a simple term-document matrix
1026        let matrix = Array2::from_shape_vec(
1027            (4, 5),
1028            vec![
1029                1.0, 2.0, 0.0, 0.0, 0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0, 0.0,
1030                0.0, 2.0, 1.0, 2.0,
1031            ],
1032        )
1033        .expect("matrix creation failed");
1034
1035        let vocab: Vec<String> = vec!["ml", "deep", "cat", "dog", "pet"]
1036            .into_iter()
1037            .map(String::from)
1038            .collect();
1039
1040        let config = NmfConfig {
1041            n_topics: 2,
1042            max_iter: 100,
1043            ..Default::default()
1044        };
1045
1046        let mut nmf = NmfTopicModel::new(config);
1047        nmf.fit(&matrix, &vocab).expect("nmf fit failed");
1048        assert!(nmf.is_fitted());
1049    }
1050
1051    #[test]
1052    fn test_nmf_top_words() {
1053        let matrix = Array2::from_shape_vec(
1054            (4, 5),
1055            vec![
1056                1.0, 2.0, 0.0, 0.0, 0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0, 0.0,
1057                0.0, 2.0, 1.0, 2.0,
1058            ],
1059        )
1060        .expect("matrix creation failed");
1061
1062        let vocab: Vec<String> = vec!["ml", "deep", "cat", "dog", "pet"]
1063            .into_iter()
1064            .map(String::from)
1065            .collect();
1066
1067        let config = NmfConfig {
1068            n_topics: 2,
1069            max_iter: 100,
1070            ..Default::default()
1071        };
1072
1073        let mut nmf = NmfTopicModel::new(config);
1074        nmf.fit(&matrix, &vocab).expect("nmf fit failed");
1075
1076        let topics = nmf.top_words(3).expect("top_words failed");
1077        assert_eq!(topics.len(), 2);
1078        for topic in &topics {
1079            assert!(topic.len() <= 3);
1080        }
1081    }
1082
1083    #[test]
1084    fn test_nmf_convergence() {
1085        let matrix = Array2::from_shape_vec(
1086            (3, 4),
1087            vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0],
1088        )
1089        .expect("matrix creation failed");
1090
1091        let vocab: Vec<String> = (0..4).map(|i| format!("w{}", i)).collect();
1092
1093        let config = NmfConfig {
1094            n_topics: 2,
1095            max_iter: 200,
1096            ..Default::default()
1097        };
1098
1099        let mut nmf = NmfTopicModel::new(config);
1100        nmf.fit(&matrix, &vocab).expect("nmf fit failed");
1101
1102        let errors = nmf.error_history();
1103        assert!(!errors.is_empty());
1104        // Error should generally decrease
1105        if errors.len() >= 2 {
1106            assert!(
1107                errors.last().copied().unwrap_or(f64::MAX)
1108                    <= errors.first().copied().unwrap_or(0.0) + 1e-6
1109            );
1110        }
1111    }
1112
1113    #[test]
1114    fn test_nmf_not_fitted() {
1115        let nmf = NmfTopicModel::new(NmfConfig::default());
1116        assert!(nmf.doc_topic_matrix().is_err());
1117        assert!(nmf.topic_term_matrix().is_err());
1118    }
1119
1120    #[test]
1121    fn test_coherence_cv() {
1122        let topic_words = vec![
1123            vec![
1124                "machine".to_string(),
1125                "learning".to_string(),
1126                "algorithm".to_string(),
1127            ],
1128            vec!["cat".to_string(), "dog".to_string(), "pet".to_string()],
1129        ];
1130
1131        let documents = vec![
1132            vec![
1133                "machine".to_string(),
1134                "learning".to_string(),
1135                "algorithm".to_string(),
1136            ],
1137            vec![
1138                "deep".to_string(),
1139                "learning".to_string(),
1140                "neural".to_string(),
1141            ],
1142            vec!["cat".to_string(), "dog".to_string(), "pet".to_string()],
1143            vec!["cat".to_string(), "play".to_string(), "fun".to_string()],
1144        ];
1145
1146        let scorer = TopicCoherenceScorer::new();
1147        let cv = scorer
1148            .cv_coherence(&topic_words, &documents)
1149            .expect("cv failed");
1150        // Should return a finite value
1151        assert!(cv.is_finite());
1152    }
1153
1154    #[test]
1155    fn test_coherence_umass() {
1156        let topic_words = vec![
1157            vec!["machine".to_string(), "learning".to_string()],
1158            vec!["cat".to_string(), "dog".to_string()],
1159        ];
1160
1161        let documents = vec![
1162            vec!["machine".to_string(), "learning".to_string()],
1163            vec!["cat".to_string(), "dog".to_string()],
1164        ];
1165
1166        let scorer = TopicCoherenceScorer::new();
1167        let umass = scorer
1168            .umass_coherence(&topic_words, &documents)
1169            .expect("umass failed");
1170        assert!(umass.is_finite());
1171    }
1172
1173    #[test]
1174    fn test_coherence_empty() {
1175        let scorer = TopicCoherenceScorer::new();
1176        assert!(scorer.cv_coherence(&[], &[]).is_err());
1177        assert!(scorer.umass_coherence(&[], &[]).is_err());
1178    }
1179
1180    #[test]
1181    fn test_select_n_topics() {
1182        let docs = sample_docs();
1183        let (best_k, scores) = select_n_topics(&docs, 2, 3, 30, 42).expect("select failed");
1184        assert!((2..=3).contains(&best_k));
1185        assert_eq!(scores.len(), 2);
1186    }
1187
1188    #[test]
1189    fn test_select_n_topics_invalid_range() {
1190        let docs = sample_docs();
1191        assert!(select_n_topics(&docs, 5, 2, 30, 42).is_err());
1192    }
1193
1194    #[test]
1195    fn test_lda_vocabulary() {
1196        let docs = sample_docs();
1197        let config = GibbsLdaConfig {
1198            n_topics: 2,
1199            n_iterations: 10,
1200            seed: Some(42),
1201            ..Default::default()
1202        };
1203        let mut lda = GibbsLda::new(config);
1204        lda.fit(&docs).expect("fit failed");
1205
1206        let vocab = lda.vocabulary();
1207        assert!(vocab.contains_key("machine"));
1208        assert!(vocab.contains_key("cat"));
1209    }
1210
1211    #[test]
1212    fn test_lda_n_topics() {
1213        let config = GibbsLdaConfig {
1214            n_topics: 5,
1215            ..Default::default()
1216        };
1217        let lda = GibbsLda::new(config);
1218        assert_eq!(lda.n_topics(), 5);
1219    }
1220
1221    #[test]
1222    fn test_coherence_window_size() {
1223        let scorer = TopicCoherenceScorer::new().with_window_size(5);
1224        assert_eq!(scorer.window_size, 5);
1225    }
1226
1227    #[test]
1228    fn test_nmf_doc_topic_matrix() {
1229        let matrix = Array2::from_shape_vec(
1230            (3, 4),
1231            vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0],
1232        )
1233        .expect("matrix creation failed");
1234
1235        let vocab: Vec<String> = (0..4).map(|i| format!("w{}", i)).collect();
1236        let config = NmfConfig {
1237            n_topics: 2,
1238            max_iter: 50,
1239            ..Default::default()
1240        };
1241
1242        let mut nmf = NmfTopicModel::new(config);
1243        nmf.fit(&matrix, &vocab).expect("fit failed");
1244
1245        let dtm = nmf.doc_topic_matrix().expect("dtm failed");
1246        assert_eq!(dtm.dim(), (3, 2));
1247
1248        // All values should be non-negative
1249        for &v in dtm.iter() {
1250            assert!(v >= 0.0);
1251        }
1252    }
1253}