use crate::error::{Result, TextError};
use scirs2_core::ndarray::{Array1, Array2, Axis};
use scirs2_core::random::prelude::*;
use scirs2_core::random::seq::SliceRandom;
use scirs2_core::random::{rngs::StdRng, SeedableRng};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LdaLearningMethod {
Batch,
Online,
}
#[derive(Debug, Clone)]
pub struct LdaConfig {
pub ntopics: usize,
pub doc_topic_prior: Option<f64>,
pub topic_word_prior: Option<f64>,
pub learning_method: LdaLearningMethod,
pub learning_decay: f64,
pub learning_offset: f64,
pub maxiter: usize,
pub batch_size: usize,
pub mean_change_tol: f64,
pub max_doc_update_iter: usize,
pub random_seed: Option<u64>,
}
impl Default for LdaConfig {
fn default() -> Self {
Self {
ntopics: 10,
doc_topic_prior: None, topic_word_prior: None, learning_method: LdaLearningMethod::Batch,
learning_decay: 0.7,
learning_offset: 10.0,
maxiter: 10,
batch_size: 128,
mean_change_tol: 1e-3,
max_doc_update_iter: 100,
random_seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Topic {
pub id: usize,
pub top_words: Vec<(String, f64)>,
pub coherence: Option<f64>,
}
pub struct LatentDirichletAllocation {
config: LdaConfig,
components: Option<Array2<f64>>,
exp_dirichlet_component: Option<Array2<f64>>,
#[allow(dead_code)]
vocabulary: Option<HashMap<usize, String>>,
n_documents: usize,
n_iter: usize,
#[allow(dead_code)]
bound: Option<Vec<f64>>,
}
impl LatentDirichletAllocation {
pub fn new(config: LdaConfig) -> Self {
Self {
config,
components: None,
exp_dirichlet_component: None,
vocabulary: None,
n_documents: 0,
n_iter: 0,
bound: None,
}
}
pub fn with_ntopics(ntopics: usize) -> Self {
let config = LdaConfig {
ntopics,
..Default::default()
};
Self::new(config)
}
pub fn fit(&mut self, doc_termmatrix: &Array2<f64>) -> Result<&mut Self> {
if doc_termmatrix.nrows() == 0 || doc_termmatrix.ncols() == 0 {
return Err(TextError::InvalidInput(
"Document-term _matrix cannot be empty".to_string(),
));
}
let n_samples = doc_termmatrix.nrows();
let n_features = doc_termmatrix.ncols();
let doc_topic_prior = self
.config
.doc_topic_prior
.unwrap_or(1.0 / self.config.ntopics as f64);
let topic_word_prior = self
.config
.topic_word_prior
.unwrap_or(1.0 / self.config.ntopics as f64);
let mut rng = self.create_rng();
self.components = Some(self.initialize_components(n_features, &mut rng));
match self.config.learning_method {
LdaLearningMethod::Batch => {
self.fit_batch(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
}
LdaLearningMethod::Online => {
self.fit_online(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
}
}
self.n_documents = n_samples;
Ok(self)
}
pub fn transform(&self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
if self.components.is_none() {
return Err(TextError::ModelNotFitted(
"LDA model not fitted yet".to_string(),
));
}
let n_samples = doc_termmatrix.nrows();
let ntopics = self.config.ntopics;
let mut doc_topic_distr = Array2::zeros((n_samples, ntopics));
let exp_dirichlet_component = self.get_exp_dirichlet_component()?;
let doc_topic_prior = self.config.doc_topic_prior.unwrap_or(1.0 / ntopics as f64);
for (doc_idx, doc) in doc_termmatrix.axis_iter(Axis(0)).enumerate() {
let mut gamma = Array1::from_elem(ntopics, doc_topic_prior);
self.update_doc_distribution(
&doc.to_owned(),
&mut gamma,
exp_dirichlet_component,
doc_topic_prior,
)?;
let gamma_sum = gamma.sum();
if gamma_sum > 0.0 {
gamma /= gamma_sum;
}
doc_topic_distr.row_mut(doc_idx).assign(&gamma);
}
Ok(doc_topic_distr)
}
pub fn fit_transform(&mut self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
self.fit(doc_termmatrix)?;
self.transform(doc_termmatrix)
}
pub fn get_topics(
&self,
n_top_words: usize,
vocabulary: &HashMap<usize, String>,
) -> Result<Vec<Topic>> {
if self.components.is_none() {
return Err(TextError::ModelNotFitted(
"LDA model not fitted yet".to_string(),
));
}
let components = self.components.as_ref().expect("Operation failed");
let mut topics = Vec::new();
for (topic_idx, topic_dist) in components.axis_iter(Axis(0)).enumerate() {
let mut word_scores: Vec<(usize, f64)> = topic_dist
.iter()
.enumerate()
.map(|(idx, &score)| (idx, score))
.collect();
word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
let top_words: Vec<(String, f64)> = word_scores
.into_iter()
.take(n_top_words)
.filter_map(|(idx, score)| vocabulary.get(&idx).map(|word| (word.clone(), score)))
.collect();
topics.push(Topic {
id: topic_idx,
top_words,
coherence: None,
});
}
Ok(topics)
}
pub fn get_topic_word_distribution(&self) -> Option<&Array2<f64>> {
self.components.as_ref()
}
fn create_rng(&self) -> scirs2_core::random::rngs::StdRng {
use scirs2_core::random::SeedableRng;
match self.config.random_seed {
Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
None => {
let mut temp_rng = scirs2_core::random::rng();
scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
}
}
}
fn initialize_components(
&self,
n_features: usize,
rng: &mut scirs2_core::random::rngs::StdRng,
) -> Array2<f64> {
let mut components = Array2::zeros((self.config.ntopics, n_features));
for mut row in components.axis_iter_mut(Axis(0)) {
for val in row.iter_mut() {
*val = rng.random_range(0.0..1.0);
}
let row_sum: f64 = row.sum();
if row_sum > 0.0 {
row /= row_sum;
}
}
components
}
fn get_exp_dirichlet_component(&self) -> Result<&Array2<f64>> {
if self.exp_dirichlet_component.is_none() {
return Err(TextError::ModelNotFitted(
"Components not initialized".to_string(),
));
}
Ok(self
.exp_dirichlet_component
.as_ref()
.expect("Operation failed"))
}
fn fit_batch(
&mut self,
doc_term_matrix: &Array2<f64>,
doc_topic_prior: f64,
topic_word_prior: f64,
) -> Result<()> {
let n_samples = doc_term_matrix.nrows();
let ntopics = self.config.ntopics;
let mut doc_topic_distr = Array2::from_elem((n_samples, ntopics), doc_topic_prior);
for iter in 0..self.config.maxiter {
self.update_exp_dirichlet_component()?;
let mut mean_change = 0.0;
for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
let mut gamma = doc_topic_distr.row(doc_idx).to_owned();
let old_gamma = gamma.clone();
self.update_doc_distribution(
&doc.to_owned(),
&mut gamma,
self.get_exp_dirichlet_component()?,
doc_topic_prior,
)?;
let change: f64 = (&gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
mean_change += change / ntopics as f64;
doc_topic_distr.row_mut(doc_idx).assign(&gamma);
}
mean_change /= n_samples as f64;
self.update_topic_distribution(doc_term_matrix, &doc_topic_distr, topic_word_prior)?;
if mean_change < self.config.mean_change_tol {
break;
}
self.n_iter = iter + 1;
}
Ok(())
}
fn fit_online(
&mut self,
doc_term_matrix: &Array2<f64>,
doc_topic_prior: f64,
topic_word_prior: f64,
) -> Result<()> {
let (n_samples, n_features) = doc_term_matrix.dim();
self.vocabulary
.get_or_insert_with(|| (0..n_features).map(|i| (i, format!("word_{i}"))).collect());
self.bound.get_or_insert_with(Vec::new);
if self.components.is_none() {
let mut rng = if let Some(seed) = self.config.random_seed {
StdRng::seed_from_u64(seed)
} else {
StdRng::from_rng(&mut scirs2_core::random::rng())
};
let mut components = Array2::<f64>::zeros((self.config.ntopics, n_features));
for i in 0..self.config.ntopics {
for j in 0..n_features {
components[[i, j]] = rng.random::<f64>() + topic_word_prior;
}
}
self.components = Some(components);
}
let batch_size = self.config.batch_size.min(n_samples);
let n_batches = n_samples.div_ceil(batch_size);
for epoch in 0..self.config.maxiter {
let mut total_bound = 0.0;
let mut doc_indices: Vec<usize> = (0..n_samples).collect();
let mut rng = if let Some(seed) = self.config.random_seed {
StdRng::seed_from_u64(seed + epoch as u64)
} else {
StdRng::from_rng(&mut scirs2_core::random::rng())
};
doc_indices.shuffle(&mut rng);
for batch_idx in 0..n_batches {
let start_idx = batch_idx * batch_size;
let end_idx = ((batch_idx + 1) * batch_size).min(n_samples);
let batch_docs: Vec<usize> = doc_indices[start_idx..end_idx].to_vec();
let mut batch_gamma = Array2::<f64>::zeros((batch_docs.len(), self.config.ntopics));
let mut batch_bound = 0.0;
for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
let doc = doc_term_matrix.row(doc_idx);
let mut gamma = Array1::<f64>::from_elem(self.config.ntopics, doc_topic_prior);
let components = self.components.as_ref().expect("Operation failed");
let exp_topic_word_distr = components.map(|x| x.exp());
self.update_doc_distribution(
&doc.to_owned(),
&mut gamma,
&exp_topic_word_distr,
doc_topic_prior,
)?;
batch_gamma.row_mut(local_idx).assign(&gamma);
batch_bound += gamma.sum();
}
let learning_rate = self.compute_learning_rate(epoch * n_batches + batch_idx);
self.update_topic_word_distribution(
&batch_docs,
doc_term_matrix,
&batch_gamma,
topic_word_prior,
learning_rate,
n_samples,
)?;
total_bound += batch_bound;
}
if let Some(ref mut bound) = self.bound {
bound.push(total_bound / n_samples as f64);
}
if let Some(ref bound) = self.bound {
if bound.len() > 1 {
let current_bound = bound[bound.len() - 1];
let prev_bound = bound[bound.len() - 2];
let change = (current_bound - prev_bound).abs();
if change < self.config.mean_change_tol {
break;
}
}
}
self.n_iter = epoch + 1;
}
self.n_documents = n_samples;
Ok(())
}
fn compute_learning_rate(&self, iteration: usize) -> f64 {
(self.config.learning_offset + iteration as f64).powf(-self.config.learning_decay)
}
fn update_topic_word_distribution(
&mut self,
batch_docs: &[usize],
doc_term_matrix: &Array2<f64>,
batch_gamma: &Array2<f64>,
topic_word_prior: f64,
learning_rate: f64,
total_docs: usize,
) -> Result<()> {
let batch_size = batch_docs.len();
let n_features = doc_term_matrix.ncols();
if let Some(ref mut components) = self.components {
let mut batch_stats = Array2::<f64>::zeros((self.config.ntopics, n_features));
for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
let doc = doc_term_matrix.row(doc_idx);
let gamma = batch_gamma.row(local_idx);
let gamma_sum = gamma.sum();
for (word_idx, &count) in doc.iter().enumerate() {
if count > 0.0 {
for topic_idx in 0..self.config.ntopics {
let phi = gamma[topic_idx] / gamma_sum;
batch_stats[[topic_idx, word_idx]] += count * phi;
}
}
}
}
let scale_factor = total_docs as f64 / batch_size as f64;
batch_stats.mapv_inplace(|x| x * scale_factor);
for topic_idx in 0..self.config.ntopics {
for word_idx in 0..n_features {
let old_val = components[[topic_idx, word_idx]];
let new_val = topic_word_prior + batch_stats[[topic_idx, word_idx]];
components[[topic_idx, word_idx]] =
(1.0 - learning_rate) * old_val + learning_rate * new_val;
}
}
}
Ok(())
}
fn update_doc_distribution(
&self,
doc: &Array1<f64>,
gamma: &mut Array1<f64>,
exp_topic_word_distr: &Array2<f64>,
doc_topic_prior: f64,
) -> Result<()> {
for _ in 0..self.config.max_doc_update_iter {
let old_gamma = gamma.clone();
gamma.fill(doc_topic_prior);
for (word_idx, &count) in doc.iter().enumerate() {
}
let change: f64 = (&*gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
if change < self.config.mean_change_tol {
break;
}
}
Ok(())
}
fn update_topic_distribution(
&mut self,
doc_term_matrix: &Array2<f64>,
doc_topic_distr: &Array2<f64>,
topic_word_prior: f64,
) -> Result<()> {
if let Some(ref mut components) = self.components {
let _n_features = doc_term_matrix.ncols();
components.fill(topic_word_prior);
for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
let doc_topics = doc_topic_distr.row(doc_idx);
for (word_idx, &count) in doc.iter().enumerate() {
if count > 0.0 {
for topic_idx in 0..self.config.ntopics {
components[[topic_idx, word_idx]] += count * doc_topics[topic_idx];
}
}
}
}
for mut topic in components.axis_iter_mut(Axis(0)) {
let topic_sum = topic.sum();
if topic_sum > 0.0 {
topic /= topic_sum;
}
}
}
Ok(())
}
fn update_exp_dirichlet_component(&mut self) -> Result<()> {
if let Some(ref components) = self.components {
self.exp_dirichlet_component = Some(components.clone());
}
Ok(())
}
}
pub struct LdaBuilder {
config: LdaConfig,
}
impl LdaBuilder {
pub fn new() -> Self {
Self {
config: LdaConfig::default(),
}
}
pub fn ntopics(mut self, ntopics: usize) -> Self {
self.config.ntopics = ntopics;
self
}
pub fn doc_topic_prior(mut self, prior: f64) -> Self {
self.config.doc_topic_prior = Some(prior);
self
}
pub fn topic_word_prior(mut self, prior: f64) -> Self {
self.config.topic_word_prior = Some(prior);
self
}
pub fn learning_method(mut self, method: LdaLearningMethod) -> Self {
self.config.learning_method = method;
self
}
pub fn maxiter(mut self, maxiter: usize) -> Self {
self.config.maxiter = maxiter;
self
}
pub fn random_seed(mut self, seed: u64) -> Self {
self.config.random_seed = Some(seed);
self
}
pub fn build(self) -> LatentDirichletAllocation {
LatentDirichletAllocation::new(self.config)
}
}
impl Default for LdaBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lda_creation() {
let lda = LatentDirichletAllocation::with_ntopics(5);
assert_eq!(lda.config.ntopics, 5);
}
#[test]
fn test_lda_builder() {
let lda = LdaBuilder::new()
.ntopics(10)
.doc_topic_prior(0.1)
.maxiter(20)
.random_seed(42)
.build();
assert_eq!(lda.config.ntopics, 10);
assert_eq!(lda.config.doc_topic_prior, Some(0.1));
assert_eq!(lda.config.maxiter, 20);
assert_eq!(lda.config.random_seed, Some(42));
}
#[test]
fn test_lda_fit_transform() {
let doc_term_matrix = Array2::from_shape_vec(
(4, 6),
vec![
1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, ],
)
.expect("Operation failed");
let mut lda = LatentDirichletAllocation::with_ntopics(2);
let doc_topics = lda
.fit_transform(&doc_term_matrix)
.expect("Operation failed");
assert_eq!(doc_topics.nrows(), 4);
assert_eq!(doc_topics.ncols(), 2);
for row in doc_topics.axis_iter(Axis(0)) {
let sum: f64 = row.sum();
assert!((sum - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_get_topics() {
let doc_term_matrix = Array2::from_shape_vec(
(4, 3),
vec![2.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 2.0, 2.0, 1.0, 1.0],
)
.expect("Operation failed");
let mut vocabulary = HashMap::new();
vocabulary.insert(0, "word1".to_string());
vocabulary.insert(1, "word2".to_string());
vocabulary.insert(2, "word3".to_string());
let mut lda = LatentDirichletAllocation::with_ntopics(2);
lda.fit(&doc_term_matrix).expect("Operation failed");
let topics = lda.get_topics(3, &vocabulary).expect("Operation failed");
assert_eq!(topics.len(), 2);
for topic in &topics {
assert_eq!(topic.top_words.len(), 3);
}
}
}