pub mod inference;
pub mod model;
use crate::error::Result;
pub use model::{log_likelihood, softmax, top_words, topic_correlation_matrix};
#[derive(Debug, Clone)]
pub struct CtmConfig {
pub n_topics: usize,
pub max_iter: usize,
pub tol: f64,
pub vocab_size: usize,
}
impl Default for CtmConfig {
fn default() -> Self {
Self {
n_topics: 10,
max_iter: 100,
tol: 1e-4,
vocab_size: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct CtmResult {
pub topic_word_matrix: Vec<Vec<f64>>,
pub doc_topic_matrix: Vec<Vec<f64>>,
pub mu: Vec<f64>,
pub sigma: Vec<Vec<f64>>,
pub log_likelihood: f64,
}
pub struct CorrelatedTopicModel {
pub config: CtmConfig,
fitted: Option<CtmResult>,
}
impl CorrelatedTopicModel {
pub fn new(config: CtmConfig) -> Self {
Self {
config,
fitted: None,
}
}
pub fn fitted_result(&self) -> Option<&CtmResult> {
self.fitted.as_ref()
}
pub fn fit_and_store(
&mut self,
doc_counts_list: &[Vec<f64>],
vocab_size: usize,
) -> Result<&CtmResult> {
let result = self.fit(doc_counts_list, vocab_size)?;
self.fitted = Some(result);
Ok(self.fitted.as_ref().expect("just set"))
}
pub fn top_words_from_fitted(&self, vocab: &[String], n: usize) -> Option<Vec<Vec<String>>> {
self.fitted
.as_ref()
.map(|r| top_words(&r.topic_word_matrix, vocab, n))
}
pub fn correlation_matrix_from_fitted(&self) -> Option<Vec<Vec<f64>>> {
self.fitted
.as_ref()
.map(|r| topic_correlation_matrix(&r.sigma))
}
}
impl Default for CorrelatedTopicModel {
fn default() -> Self {
Self::new(CtmConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ctm_default_config() {
let cfg = CtmConfig::default();
assert_eq!(cfg.n_topics, 10);
assert_eq!(cfg.max_iter, 100);
assert!((cfg.tol - 1e-4).abs() < 1e-12);
}
#[test]
fn ctm_model_default() {
let m = CorrelatedTopicModel::default();
assert_eq!(m.config.n_topics, 10);
assert!(m.fitted_result().is_none());
}
#[test]
fn ctm_fit_and_store() {
let mut model = CorrelatedTopicModel::new(CtmConfig {
n_topics: 2,
max_iter: 5,
tol: 1e-3,
vocab_size: 4,
});
let docs: Vec<Vec<f64>> = (0..4)
.map(|i| (0..4).map(|w| ((i + w) % 3) as f64).collect())
.collect();
model.fit_and_store(&docs, 4).expect("fit_and_store failed");
assert!(model.fitted_result().is_some());
}
}