scirs2_text/
topic_modeling.rs

1//! # Topic Modeling Module
2//!
3//! This module provides advanced topic modeling algorithms for discovering
4//! hidden thematic structures in document collections, with a focus on
5//! Latent Dirichlet Allocation (LDA).
6//!
7//! ## Overview
8//!
9//! Topic modeling is an unsupervised machine learning technique that discovers
10//! abstract "topics" that occur in a collection of documents. This module implements:
11//!
12//! - **Latent Dirichlet Allocation (LDA)**: The most popular topic modeling algorithm
13//! - **Batch and Online Learning**: Different training strategies for various dataset sizes
14//! - **Coherence Metrics**: Model evaluation using CV, UMass, and UCI coherence
15//! - **Topic Visualization**: Tools for understanding and presenting results
16//!
17//! ## Quick Start
18//!
19//! ```rust
20//! use scirs2_text::topic_modeling::{LatentDirichletAllocation, LdaConfig, LdaLearningMethod};
21//! use scirs2_text::vectorize::{CountVectorizer, Vectorizer};
22//! use std::collections::HashMap;
23//!
24//! // Sample documents
25//! let documents = vec![
26//!     "machine learning algorithms are powerful tools",
27//!     "natural language processing uses machine learning",
28//!     "deep learning is a subset of machine learning",
29//!     "cats and dogs are popular pets",
30//!     "pet care requires attention and love",
31//!     "dogs need regular exercise and training"
32//! ];
33//!
34//! // Vectorize documents
35//! let mut vectorizer = CountVectorizer::new(false);
36//! let doc_term_matrix = vectorizer.fit_transform(&documents).expect("Operation failed");
37//!
38//! // Configure LDA
39//! let config = LdaConfig {
40//!     ntopics: 2,
41//!     doc_topic_prior: Some(0.1),    // Alpha parameter
42//!     topic_word_prior: Some(0.01),  // Beta parameter
43//!     learning_method: LdaLearningMethod::Batch,
44//!     maxiter: 100,
45//!     mean_change_tol: 1e-4,
46//!     random_seed: Some(42),
47//!     ..Default::default()
48//! };
49//!
50//! // Train the model
51//! let mut lda = LatentDirichletAllocation::new(config);
52//! lda.fit(&doc_term_matrix).expect("Operation failed");
53//!
54//! // Create vocabulary mapping for topic display
55//! let vocab_map: HashMap<usize, String> = (0..1000).map(|i| (i, format!("word_{}", i))).collect();
56//!
57//! // Get topics
58//! let topics = lda.get_topics(10, &vocab_map); // Top 10 words per topic
59//! for (i, topic) in topics.iter().enumerate() {
60//!     println!("Topic {}: {:?}", i, topic);
61//! }
62//!
63//! // Transform documents to topic space
64//! let doc_topics = lda.transform(&doc_term_matrix).expect("Operation failed");
65//! println!("Document-topic distribution: {:?}", doc_topics);
66//! ```
67//!
68//! ## Advanced Usage
69//!
70//! ### Online Learning for Large Datasets
71//!
72//! ```rust
73//! use scirs2_text::topic_modeling::{LdaConfig, LdaLearningMethod, LatentDirichletAllocation};
74//!
75//! let config = LdaConfig {
76//!     ntopics: 10,
77//!     learning_method: LdaLearningMethod::Online,
78//!     batch_size: 64,                // Mini-batch size
79//!     learning_decay: 0.7,           // Learning rate decay
80//!     learning_offset: 10.0,         // Learning rate offset
81//!     maxiter: 500,
82//!     ..Default::default()
83//! };
84//!
85//! let mut lda = LatentDirichletAllocation::new(config);
86//! // Process documents in batches for memory efficiency
87//! ```
88//!
89//! ### Custom Hyperparameters
90//!
91//! ```rust
92//! use scirs2_text::topic_modeling::LdaConfig;
93//!
94//! let config = LdaConfig {
95//!     ntopics: 20,
96//!     doc_topic_prior: Some(50.0 / 20.0),  // Symmetric Dirichlet
97//!     topic_word_prior: Some(0.1),         // Sparse topics
98//!     maxiter: 1000,                      // More iterations
99//!     mean_change_tol: 1e-6,               // Stricter convergence
100//!     ..Default::default()
101//! };
102//! ```
103//!
104//! ### Model Evaluation
105//!
106//! ```rust
107//! use scirs2_text::topic_modeling::{LatentDirichletAllocation, LdaConfig};
108//! use scirs2_text::vectorize::{CountVectorizer, Vectorizer};
109//! use std::collections::HashMap;
110//!
111//! # let documents = vec!["the quick brown fox", "jumped over the lazy dog"];
112//! # let mut vectorizer = CountVectorizer::new(false);
113//! # let doc_term_matrix = vectorizer.fit_transform(&documents).expect("Operation failed");
114//! # let mut lda = LatentDirichletAllocation::new(LdaConfig::default());
115//! # lda.fit(&doc_term_matrix).expect("Operation failed");
116//! # let vocab_map: HashMap<usize, String> = (0..100).map(|i| (i, format!("word_{}", i))).collect();
117//! // Get model information
118//! let topics = lda.get_topics(5, &vocab_map); // Top 5 words per topic
119//! println!("Number of topics: {}", topics.unwrap().len());
120//!
121//! // Get document-topic probabilities
122//! let doc_topic_probs = lda.transform(&doc_term_matrix).expect("Operation failed");
123//! println!("Document-topic shape: {:?}", doc_topic_probs.shape());
124//! ```
125//!
126//! ## Parameter Tuning Guide
127//!
128//! ### Number of Topics
129//! - **Too few**: Broad, less meaningful topics
130//! - **Too many**: Narrow, potentially noisy topics
131//! - **Recommendation**: Start with √(number of documents) and tune based on coherence
132//!
133//! ### Alpha (doc_topic_prior)
134//! - **High values (e.g., 1.0)**: Documents contain many topics
135//! - **Low values (e.g., 0.1)**: Documents contain few topics
136//! - **Default**: 50/ntopics (symmetric)
137//!
138//! ### Beta (topic_word_prior)
139//! - **High values (e.g., 1.0)**: Topics contain many words
140//! - **Low values (e.g., 0.01)**: Topics are more focused
141//! - **Default**: 0.01 for sparse topics
142//!
143//! ## Performance Optimization
144//!
145//! 1. **Use Online Learning**: For datasets that don't fit in memory
146//! 2. **Tune Batch Size**: Balance between speed and convergence stability
147//! 3. **Set Tolerance**: Stop early when convergence is reached
148//! 4. **Monitor Perplexity**: Track model performance during training
149//! 5. **Parallel Processing**: Enable for faster vocabulary building
150//!
151//! ## Mathematical Background
152//!
153//! LDA assumes each document is a mixture of topics, and each topic is a distribution over words.
154//! The generative process:
155//!
156//! 1. For each topic k: Draw word distribution φₖ ~ Dirichlet(β)
157//! 2. For each document d:
158//!    - Draw topic distribution θ_d ~ Dirichlet(α)
159//!    - For each word n in document d:
160//!      - Draw topic assignment z_{d,n} ~ Multinomial(θ_d)
161//!      - Draw word w_{d,n} ~ Multinomial(φ_{z_{d,n}})
162//!
163//! The goal is to infer the posterior distributions of θ and φ given the observed words.
164
165use crate::error::{Result, TextError};
166use scirs2_core::ndarray::{Array1, Array2, Axis};
167use scirs2_core::random::prelude::*;
168use scirs2_core::random::seq::SliceRandom;
169use scirs2_core::random::{rngs::StdRng, SeedableRng};
170use std::collections::HashMap;
171
172/// Learning method for LDA
173#[derive(Debug, Clone, Copy, PartialEq)]
174pub enum LdaLearningMethod {
175    /// Batch learning - process all documents at once
176    Batch,
177    /// Online learning - process documents in mini-batches
178    Online,
179}
180
181/// Latent Dirichlet Allocation configuration
182#[derive(Debug, Clone)]
183pub struct LdaConfig {
184    /// Number of topics
185    pub ntopics: usize,
186    /// Prior for document-topic distribution (alpha)
187    pub doc_topic_prior: Option<f64>,
188    /// Prior for topic-word distribution (eta)
189    pub topic_word_prior: Option<f64>,
190    /// Learning method
191    pub learning_method: LdaLearningMethod,
192    /// Learning decay for online learning
193    pub learning_decay: f64,
194    /// Learning offset for online learning
195    pub learning_offset: f64,
196    /// Maximum iterations
197    pub maxiter: usize,
198    /// Batch size for online learning
199    pub batch_size: usize,
200    /// Mean change tolerance for convergence
201    pub mean_change_tol: f64,
202    /// Maximum iterations for document E-step
203    pub max_doc_update_iter: usize,
204    /// Random seed
205    pub random_seed: Option<u64>,
206}
207
208impl Default for LdaConfig {
209    fn default() -> Self {
210        Self {
211            ntopics: 10,
212            doc_topic_prior: None,  // Will be set to 1/ntopics
213            topic_word_prior: None, // Will be set to 1/ntopics
214            learning_method: LdaLearningMethod::Batch,
215            learning_decay: 0.7,
216            learning_offset: 10.0,
217            maxiter: 10,
218            batch_size: 128,
219            mean_change_tol: 1e-3,
220            max_doc_update_iter: 100,
221            random_seed: None,
222        }
223    }
224}
225
226/// Topic representation
227#[derive(Debug, Clone)]
228pub struct Topic {
229    /// Topic ID
230    pub id: usize,
231    /// Top words in the topic with their weights
232    pub top_words: Vec<(String, f64)>,
233    /// Topic coherence score (if computed)
234    pub coherence: Option<f64>,
235}
236
237/// Latent Dirichlet Allocation
238pub struct LatentDirichletAllocation {
239    config: LdaConfig,
240    /// Topic-word distribution (learned parameters)
241    components: Option<Array2<f64>>,
242    /// exp(E[log(beta)]) for efficient computation
243    exp_dirichlet_component: Option<Array2<f64>>,
244    /// Vocabulary mapping
245    #[allow(dead_code)]
246    vocabulary: Option<HashMap<usize, String>>,
247    /// Number of documents seen
248    n_documents: usize,
249    /// Number of iterations performed
250    n_iter: usize,
251    /// Final perplexity bound
252    #[allow(dead_code)]
253    bound: Option<Vec<f64>>,
254}
255
256impl LatentDirichletAllocation {
257    /// Create a new LDA model with the given configuration
258    pub fn new(config: LdaConfig) -> Self {
259        Self {
260            config,
261            components: None,
262            exp_dirichlet_component: None,
263            vocabulary: None,
264            n_documents: 0,
265            n_iter: 0,
266            bound: None,
267        }
268    }
269
270    /// Create a new LDA model with default configuration
271    pub fn with_ntopics(ntopics: usize) -> Self {
272        let config = LdaConfig {
273            ntopics,
274            ..Default::default()
275        };
276        Self::new(config)
277    }
278
279    /// Fit the LDA model on a document-term matrix
280    pub fn fit(&mut self, doc_termmatrix: &Array2<f64>) -> Result<&mut Self> {
281        if doc_termmatrix.nrows() == 0 || doc_termmatrix.ncols() == 0 {
282            return Err(TextError::InvalidInput(
283                "Document-term _matrix cannot be empty".to_string(),
284            ));
285        }
286
287        let n_samples = doc_termmatrix.nrows();
288        let n_features = doc_termmatrix.ncols();
289
290        // Set default priors if not provided
291        let doc_topic_prior = self
292            .config
293            .doc_topic_prior
294            .unwrap_or(1.0 / self.config.ntopics as f64);
295        let topic_word_prior = self
296            .config
297            .topic_word_prior
298            .unwrap_or(1.0 / self.config.ntopics as f64);
299
300        // Initialize topic-word distribution randomly
301        let mut rng = self.create_rng();
302        self.components = Some(self.initialize_components(n_features, &mut rng));
303
304        // Perform training based on learning method
305        match self.config.learning_method {
306            LdaLearningMethod::Batch => {
307                self.fit_batch(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
308            }
309            LdaLearningMethod::Online => {
310                self.fit_online(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
311            }
312        }
313
314        self.n_documents = n_samples;
315        Ok(self)
316    }
317
318    /// Transform documents to topic distribution
319    pub fn transform(&self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
320        if self.components.is_none() {
321            return Err(TextError::ModelNotFitted(
322                "LDA model not fitted yet".to_string(),
323            ));
324        }
325
326        let n_samples = doc_termmatrix.nrows();
327        let ntopics = self.config.ntopics;
328
329        // Initialize document-topic distribution
330        let mut doc_topic_distr = Array2::zeros((n_samples, ntopics));
331
332        // Get exp(E[log(beta)])
333        let exp_dirichlet_component = self.get_exp_dirichlet_component()?;
334
335        // Set default prior
336        let doc_topic_prior = self.config.doc_topic_prior.unwrap_or(1.0 / ntopics as f64);
337
338        // Update document-topic distribution for each document
339        for (doc_idx, doc) in doc_termmatrix.axis_iter(Axis(0)).enumerate() {
340            let mut gamma = Array1::from_elem(ntopics, doc_topic_prior);
341            self.update_doc_distribution(
342                &doc.to_owned(),
343                &mut gamma,
344                exp_dirichlet_component,
345                doc_topic_prior,
346            )?;
347
348            // Normalize to get probability distribution
349            let gamma_sum = gamma.sum();
350            if gamma_sum > 0.0 {
351                gamma /= gamma_sum;
352            }
353
354            doc_topic_distr.row_mut(doc_idx).assign(&gamma);
355        }
356
357        Ok(doc_topic_distr)
358    }
359
360    /// Fit and transform in one step
361    pub fn fit_transform(&mut self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
362        self.fit(doc_termmatrix)?;
363        self.transform(doc_termmatrix)
364    }
365
366    /// Get the topics with top words
367    pub fn get_topics(
368        &self,
369        n_top_words: usize,
370        vocabulary: &HashMap<usize, String>,
371    ) -> Result<Vec<Topic>> {
372        if self.components.is_none() {
373            return Err(TextError::ModelNotFitted(
374                "LDA model not fitted yet".to_string(),
375            ));
376        }
377
378        let components = self.components.as_ref().expect("Operation failed");
379        let mut topics = Vec::new();
380
381        for (topic_idx, topic_dist) in components.axis_iter(Axis(0)).enumerate() {
382            // Get indices of top _words
383            let mut word_scores: Vec<(usize, f64)> = topic_dist
384                .iter()
385                .enumerate()
386                .map(|(idx, &score)| (idx, score))
387                .collect();
388
389            word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
390
391            // Get top _words with their scores
392            let top_words: Vec<(String, f64)> = word_scores
393                .into_iter()
394                .take(n_top_words)
395                .filter_map(|(idx, score)| vocabulary.get(&idx).map(|word| (word.clone(), score)))
396                .collect();
397
398            topics.push(Topic {
399                id: topic_idx,
400                top_words,
401                coherence: None,
402            });
403        }
404
405        Ok(topics)
406    }
407
408    /// Get the topic-word distribution matrix
409    pub fn get_topic_word_distribution(&self) -> Option<&Array2<f64>> {
410        self.components.as_ref()
411    }
412
413    // Helper functions
414
415    fn create_rng(&self) -> scirs2_core::random::rngs::StdRng {
416        use scirs2_core::random::SeedableRng;
417        match self.config.random_seed {
418            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
419            None => {
420                let mut temp_rng = scirs2_core::random::rng();
421                scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
422            }
423        }
424    }
425
426    fn initialize_components(
427        &self,
428        n_features: usize,
429        rng: &mut scirs2_core::random::rngs::StdRng,
430    ) -> Array2<f64> {
431        // Use the RNG directly
432
433        let mut components = Array2::zeros((self.config.ntopics, n_features));
434        for mut row in components.axis_iter_mut(Axis(0)) {
435            for val in row.iter_mut() {
436                *val = rng.random_range(0.0..1.0);
437            }
438            // Normalize each topic
439            let row_sum: f64 = row.sum();
440            if row_sum > 0.0 {
441                row /= row_sum;
442            }
443        }
444
445        components
446    }
447
448    fn get_exp_dirichlet_component(&self) -> Result<&Array2<f64>> {
449        if self.exp_dirichlet_component.is_none() {
450            return Err(TextError::ModelNotFitted(
451                "Components not initialized".to_string(),
452            ));
453        }
454        Ok(self
455            .exp_dirichlet_component
456            .as_ref()
457            .expect("Operation failed"))
458    }
459
460    fn fit_batch(
461        &mut self,
462        doc_term_matrix: &Array2<f64>,
463        doc_topic_prior: f64,
464        topic_word_prior: f64,
465    ) -> Result<()> {
466        let n_samples = doc_term_matrix.nrows();
467        let ntopics = self.config.ntopics;
468
469        // Initialize document-topic distribution
470        let mut doc_topic_distr = Array2::from_elem((n_samples, ntopics), doc_topic_prior);
471
472        // Training loop
473        for iter in 0..self.config.maxiter {
474            // Update exp(E[log(beta)])
475            self.update_exp_dirichlet_component()?;
476
477            // E-step: Update document-topic distribution
478            let mut mean_change = 0.0;
479            for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
480                let mut gamma = doc_topic_distr.row(doc_idx).to_owned();
481                let old_gamma = gamma.clone();
482
483                self.update_doc_distribution(
484                    &doc.to_owned(),
485                    &mut gamma,
486                    self.get_exp_dirichlet_component()?,
487                    doc_topic_prior,
488                )?;
489
490                // Calculate mean change
491                let change: f64 = (&gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
492                mean_change += change / ntopics as f64;
493
494                doc_topic_distr.row_mut(doc_idx).assign(&gamma);
495            }
496            mean_change /= n_samples as f64;
497
498            // M-step: Update topic-word distribution
499            self.update_topic_distribution(doc_term_matrix, &doc_topic_distr, topic_word_prior)?;
500
501            // Check convergence
502            if mean_change < self.config.mean_change_tol {
503                break;
504            }
505
506            self.n_iter = iter + 1;
507        }
508
509        Ok(())
510    }
511
512    fn fit_online(
513        &mut self,
514        doc_term_matrix: &Array2<f64>,
515        doc_topic_prior: f64,
516        topic_word_prior: f64,
517    ) -> Result<()> {
518        let (n_samples, n_features) = doc_term_matrix.dim();
519        self.vocabulary
520            .get_or_insert_with(|| (0..n_features).map(|i| (i, format!("word_{i}"))).collect());
521        self.bound.get_or_insert_with(Vec::new);
522
523        // Initialize topic-word distribution if not already done
524        if self.components.is_none() {
525            let mut rng = if let Some(seed) = self.config.random_seed {
526                StdRng::seed_from_u64(seed)
527            } else {
528                StdRng::from_rng(&mut scirs2_core::random::rng())
529            };
530
531            let mut components = Array2::<f64>::zeros((self.config.ntopics, n_features));
532            for i in 0..self.config.ntopics {
533                for j in 0..n_features {
534                    components[[i, j]] = rng.random::<f64>() + topic_word_prior;
535                }
536            }
537            self.components = Some(components);
538        }
539
540        let batch_size = self.config.batch_size.min(n_samples);
541        let n_batches = n_samples.div_ceil(batch_size);
542
543        for epoch in 0..self.config.maxiter {
544            let mut total_bound = 0.0;
545
546            // Shuffle document indices for each epoch
547            let mut doc_indices: Vec<usize> = (0..n_samples).collect();
548            let mut rng = if let Some(seed) = self.config.random_seed {
549                StdRng::seed_from_u64(seed + epoch as u64)
550            } else {
551                StdRng::from_rng(&mut scirs2_core::random::rng())
552            };
553            doc_indices.shuffle(&mut rng);
554
555            for batch_idx in 0..n_batches {
556                let start_idx = batch_idx * batch_size;
557                let end_idx = ((batch_idx + 1) * batch_size).min(n_samples);
558
559                // Get batch documents
560                let batch_docs: Vec<usize> = doc_indices[start_idx..end_idx].to_vec();
561
562                // E-step: Update document-topic distributions for batch
563                let mut batch_gamma = Array2::<f64>::zeros((batch_docs.len(), self.config.ntopics));
564                let mut batch_bound = 0.0;
565
566                for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
567                    let doc = doc_term_matrix.row(doc_idx);
568                    let mut gamma = Array1::<f64>::from_elem(self.config.ntopics, doc_topic_prior);
569
570                    // Update document distribution
571                    let components = self.components.as_ref().expect("Operation failed");
572                    let exp_topic_word_distr = components.map(|x| x.exp());
573                    self.update_doc_distribution(
574                        &doc.to_owned(),
575                        &mut gamma,
576                        &exp_topic_word_distr,
577                        doc_topic_prior,
578                    )?;
579
580                    batch_gamma.row_mut(local_idx).assign(&gamma);
581
582                    // Compute bound contribution (simplified)
583                    batch_bound += gamma.sum();
584                }
585
586                // M-step: Update topic-word distributions
587                let learning_rate = self.compute_learning_rate(epoch * n_batches + batch_idx);
588                self.update_topic_word_distribution(
589                    &batch_docs,
590                    doc_term_matrix,
591                    &batch_gamma,
592                    topic_word_prior,
593                    learning_rate,
594                    n_samples,
595                )?;
596
597                total_bound += batch_bound;
598            }
599
600            // Store bound for this epoch
601            if let Some(ref mut bound) = self.bound {
602                bound.push(total_bound / n_samples as f64);
603            }
604
605            // Check convergence
606            if let Some(ref bound) = self.bound {
607                if bound.len() > 1 {
608                    let current_bound = bound[bound.len() - 1];
609                    let prev_bound = bound[bound.len() - 2];
610                    let change = (current_bound - prev_bound).abs();
611                    if change < self.config.mean_change_tol {
612                        break;
613                    }
614                }
615            }
616
617            self.n_iter = epoch + 1;
618        }
619
620        self.n_documents = n_samples;
621        Ok(())
622    }
623
624    /// Compute learning rate for online learning
625    fn compute_learning_rate(&self, iteration: usize) -> f64 {
626        (self.config.learning_offset + iteration as f64).powf(-self.config.learning_decay)
627    }
628
629    /// Update topic-word distributions in online learning
630    fn update_topic_word_distribution(
631        &mut self,
632        batch_docs: &[usize],
633        doc_term_matrix: &Array2<f64>,
634        batch_gamma: &Array2<f64>,
635        topic_word_prior: f64,
636        learning_rate: f64,
637        total_docs: usize,
638    ) -> Result<()> {
639        let batch_size = batch_docs.len();
640        let n_features = doc_term_matrix.ncols();
641
642        if let Some(ref mut components) = self.components {
643            // Compute sufficient statistics for this batch
644            let mut batch_stats = Array2::<f64>::zeros((self.config.ntopics, n_features));
645
646            for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
647                let doc = doc_term_matrix.row(doc_idx);
648                let gamma = batch_gamma.row(local_idx);
649                let gamma_sum = gamma.sum();
650
651                for (word_idx, &count) in doc.iter().enumerate() {
652                    if count > 0.0 {
653                        for topic_idx in 0..self.config.ntopics {
654                            let phi = gamma[topic_idx] / gamma_sum;
655                            batch_stats[[topic_idx, word_idx]] += count * phi;
656                        }
657                    }
658                }
659            }
660
661            // Scale batch statistics to full corpus size
662            let scale_factor = total_docs as f64 / batch_size as f64;
663            batch_stats.mapv_inplace(|x| x * scale_factor);
664
665            // Update components using natural gradient with learning _rate
666            for topic_idx in 0..self.config.ntopics {
667                for word_idx in 0..n_features {
668                    let old_val = components[[topic_idx, word_idx]];
669                    let new_val = topic_word_prior + batch_stats[[topic_idx, word_idx]];
670                    components[[topic_idx, word_idx]] =
671                        (1.0 - learning_rate) * old_val + learning_rate * new_val;
672                }
673            }
674        }
675
676        Ok(())
677    }
678
679    fn update_doc_distribution(
680        &self,
681        doc: &Array1<f64>,
682        gamma: &mut Array1<f64>,
683        exp_topic_word_distr: &Array2<f64>,
684        doc_topic_prior: f64,
685    ) -> Result<()> {
686        // Simple mean-field update
687        for _ in 0..self.config.max_doc_update_iter {
688            let old_gamma = gamma.clone();
689
690            // Reset gamma
691            gamma.fill(doc_topic_prior);
692
693            // Update based on word counts and topic-word probabilities
694            for (word_idx, &count) in doc.iter().enumerate() {
695                // Processing logic here
696            }
697
698            // Check convergence
699            let change: f64 = (&*gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
700            if change < self.config.mean_change_tol {
701                break;
702            }
703        }
704
705        Ok(())
706    }
707
708    fn update_topic_distribution(
709        &mut self,
710        doc_term_matrix: &Array2<f64>,
711        doc_topic_distr: &Array2<f64>,
712        topic_word_prior: f64,
713    ) -> Result<()> {
714        if let Some(ref mut components) = self.components {
715            let _n_features = doc_term_matrix.ncols();
716
717            // Reset components
718            components.fill(topic_word_prior);
719
720            // Accumulate sufficient statistics
721            for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
722                let doc_topics = doc_topic_distr.row(doc_idx);
723
724                for (word_idx, &count) in doc.iter().enumerate() {
725                    if count > 0.0 {
726                        for topic_idx in 0..self.config.ntopics {
727                            components[[topic_idx, word_idx]] += count * doc_topics[topic_idx];
728                        }
729                    }
730                }
731            }
732
733            // Normalize each topic
734            for mut topic in components.axis_iter_mut(Axis(0)) {
735                let topic_sum = topic.sum();
736                if topic_sum > 0.0 {
737                    topic /= topic_sum;
738                }
739            }
740        }
741
742        Ok(())
743    }
744
745    fn update_exp_dirichlet_component(&mut self) -> Result<()> {
746        if let Some(ref components) = self.components {
747            // For simplicity, we'll use the components directly
748            // In a full implementation, this would compute exp(E[log(beta)])
749            self.exp_dirichlet_component = Some(components.clone());
750        }
751        Ok(())
752    }
753}
754
755/// Builder for creating LDA models
756pub struct LdaBuilder {
757    config: LdaConfig,
758}
759
760impl LdaBuilder {
761    /// Create a new builder with default configuration
762    pub fn new() -> Self {
763        Self {
764            config: LdaConfig::default(),
765        }
766    }
767
768    /// Set the number of topics
769    pub fn ntopics(mut self, ntopics: usize) -> Self {
770        self.config.ntopics = ntopics;
771        self
772    }
773
774    /// Set the document-topic prior (alpha)
775    pub fn doc_topic_prior(mut self, prior: f64) -> Self {
776        self.config.doc_topic_prior = Some(prior);
777        self
778    }
779
780    /// Set the topic-word prior (eta)
781    pub fn topic_word_prior(mut self, prior: f64) -> Self {
782        self.config.topic_word_prior = Some(prior);
783        self
784    }
785
786    /// Set the learning method
787    pub fn learning_method(mut self, method: LdaLearningMethod) -> Self {
788        self.config.learning_method = method;
789        self
790    }
791
792    /// Set the maximum iterations
793    pub fn maxiter(mut self, maxiter: usize) -> Self {
794        self.config.maxiter = maxiter;
795        self
796    }
797
798    /// Set the random seed
799    pub fn random_seed(mut self, seed: u64) -> Self {
800        self.config.random_seed = Some(seed);
801        self
802    }
803
804    /// Build the LDA model
805    pub fn build(self) -> LatentDirichletAllocation {
806        LatentDirichletAllocation::new(self.config)
807    }
808}
809
810impl Default for LdaBuilder {
811    fn default() -> Self {
812        Self::new()
813    }
814}
815
816#[cfg(test)]
817mod tests {
818    use super::*;
819
820    #[test]
821    fn test_lda_creation() {
822        let lda = LatentDirichletAllocation::with_ntopics(5);
823        assert_eq!(lda.config.ntopics, 5);
824    }
825
826    #[test]
827    fn test_lda_builder() {
828        let lda = LdaBuilder::new()
829            .ntopics(10)
830            .doc_topic_prior(0.1)
831            .maxiter(20)
832            .random_seed(42)
833            .build();
834
835        assert_eq!(lda.config.ntopics, 10);
836        assert_eq!(lda.config.doc_topic_prior, Some(0.1));
837        assert_eq!(lda.config.maxiter, 20);
838        assert_eq!(lda.config.random_seed, Some(42));
839    }
840
841    #[test]
842    fn test_lda_fit_transform() {
843        // Create a simple document-term matrix
844        let doc_term_matrix = Array2::from_shape_vec(
845            (4, 6),
846            vec![
847                1.0, 1.0, 0.0, 0.0, 0.0, 0.0, // Doc 1
848                0.0, 1.0, 1.0, 0.0, 0.0, 0.0, // Doc 2
849                0.0, 0.0, 0.0, 1.0, 1.0, 0.0, // Doc 3
850                0.0, 0.0, 0.0, 0.0, 1.0, 1.0, // Doc 4
851            ],
852        )
853        .expect("Operation failed");
854
855        let mut lda = LatentDirichletAllocation::with_ntopics(2);
856        let doc_topics = lda
857            .fit_transform(&doc_term_matrix)
858            .expect("Operation failed");
859
860        assert_eq!(doc_topics.nrows(), 4);
861        assert_eq!(doc_topics.ncols(), 2);
862
863        // Check that each document's topic distribution sums to 1
864        for row in doc_topics.axis_iter(Axis(0)) {
865            let sum: f64 = row.sum();
866            assert!((sum - 1.0).abs() < 1e-6);
867        }
868    }
869
870    #[test]
871    fn test_get_topics() {
872        let doc_term_matrix = Array2::from_shape_vec(
873            (4, 3),
874            vec![2.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 2.0, 2.0, 1.0, 1.0],
875        )
876        .expect("Operation failed");
877
878        let mut vocabulary = HashMap::new();
879        vocabulary.insert(0, "word1".to_string());
880        vocabulary.insert(1, "word2".to_string());
881        vocabulary.insert(2, "word3".to_string());
882
883        let mut lda = LatentDirichletAllocation::with_ntopics(2);
884        lda.fit(&doc_term_matrix).expect("Operation failed");
885
886        let topics = lda.get_topics(3, &vocabulary).expect("Operation failed");
887        assert_eq!(topics.len(), 2);
888
889        for topic in &topics {
890            assert_eq!(topic.top_words.len(), 3);
891        }
892    }
893}