pub mod inference;
pub mod model;
use crate::error::Result;
pub use inference::{kalman_backward, kalman_forward};
pub use model::{top_words_at_time, topic_evolution};
#[derive(Debug, Clone)]
pub struct DtmConfig {
pub n_topics: usize,
pub n_time_slices: usize,
pub max_iter: usize,
pub sigma_sq: f64,
pub alpha: f64,
}
impl Default for DtmConfig {
fn default() -> Self {
Self {
n_topics: 10,
n_time_slices: 0,
max_iter: 50,
sigma_sq: 0.5,
alpha: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct DtmResult {
pub topic_word_trajectories: Vec<Vec<Vec<f64>>>,
pub doc_topic_matrix: Vec<Vec<f64>>,
}
pub struct DynamicTopicModel {
pub config: DtmConfig,
fitted: Option<DtmResult>,
}
impl DynamicTopicModel {
pub fn new(config: DtmConfig) -> Self {
Self {
config,
fitted: None,
}
}
pub fn fitted_result(&self) -> Option<&DtmResult> {
self.fitted.as_ref()
}
pub fn fit_and_store(
&mut self,
docs_by_time: &[Vec<Vec<f64>>],
vocab_size: usize,
) -> Result<&DtmResult> {
let result = self.fit(docs_by_time, vocab_size)?;
self.fitted = Some(result);
Ok(self.fitted.as_ref().expect("just set"))
}
pub fn top_words_at(&self, t: usize, vocab: &[String], n: usize) -> Option<Vec<Vec<String>>> {
self.fitted
.as_ref()
.map(|r| top_words_at_time(&r.topic_word_trajectories, t, vocab, n))
}
pub fn word_evolution(&self, topic_id: usize, word_id: usize) -> Option<Vec<f64>> {
self.fitted
.as_ref()
.map(|r| topic_evolution(&r.topic_word_trajectories, topic_id, word_id))
}
}
impl Default for DynamicTopicModel {
fn default() -> Self {
Self::new(DtmConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dtm_default_config() {
let cfg = DtmConfig::default();
assert_eq!(cfg.n_topics, 10);
assert_eq!(cfg.max_iter, 50);
assert!((cfg.sigma_sq - 0.5).abs() < 1e-12);
assert!((cfg.alpha - 0.01).abs() < 1e-12);
}
#[test]
fn dtm_default_model() {
let m = DynamicTopicModel::default();
assert_eq!(m.config.n_topics, 10);
assert!(m.fitted_result().is_none());
}
#[test]
fn dtm_fit_and_store() {
let mut model = DynamicTopicModel::new(DtmConfig {
n_topics: 2,
n_time_slices: 2,
max_iter: 3,
sigma_sq: 0.1,
alpha: 0.1,
});
let docs_by_time: Vec<Vec<Vec<f64>>> = (0..2)
.map(|t| {
(0..3)
.map(|d| (0..4).map(|w| ((t + d + w) % 3) as f64).collect())
.collect()
})
.collect();
model.fit_and_store(&docs_by_time, 4).expect("fit failed");
assert!(model.fitted_result().is_some());
}
#[test]
fn dtm_top_words_at_after_fit() {
let mut model = DynamicTopicModel::new(DtmConfig {
n_topics: 2,
n_time_slices: 2,
max_iter: 3,
sigma_sq: 0.1,
alpha: 0.1,
});
let docs_by_time: Vec<Vec<Vec<f64>>> = (0..2)
.map(|t| {
(0..3)
.map(|d| (0..5).map(|w| ((t + d + w) % 3) as f64).collect())
.collect()
})
.collect();
model.fit_and_store(&docs_by_time, 5).expect("fit failed");
let vocab: Vec<String> = (0..5).map(|i| format!("w{i}")).collect();
let tw = model.top_words_at(0, &vocab, 3).expect("no fitted result");
assert_eq!(tw.len(), 2); assert_eq!(tw[0].len(), 3); }
#[test]
fn dtm_word_evolution_length_equals_t() {
let mut model = DynamicTopicModel::new(DtmConfig {
n_topics: 2,
n_time_slices: 4,
max_iter: 3,
sigma_sq: 0.1,
alpha: 0.1,
});
let docs_by_time: Vec<Vec<Vec<f64>>> = (0..4)
.map(|t| {
(0..2)
.map(|d| (0..5).map(|w| ((t + d + w) % 3) as f64).collect())
.collect()
})
.collect();
model.fit_and_store(&docs_by_time, 5).expect("fit failed");
let ev = model.word_evolution(0, 2).expect("no fitted result");
assert_eq!(ev.len(), 4);
}
}