Skip to main content

scirs2_text/dtm/
mod.rs

1//! # Dynamic Topic Model (DTM)
2//!
3//! Implements the Dynamic Topic Model of Blei & Lafferty (2006), which extends
4//! LDA by modelling topic evolution over discrete time slices via a Gaussian
5//! state-space model.
6//!
7//! ## Model
8//!
9//! ```text
10//!   β_{t,k} | β_{t-1,k} ~ N(β_{t-1,k}, σ² I)   (topic word evolution)
11//!   θ_d      ~ Dir(α)                              (document-topic)
12//!   z_{dn}  ~ Categorical(θ_d)                    (topic assignment)
13//!   w_{dn}  ~ Categorical(β_{t,z_{dn}})           (word generation)
14//! ```
15//!
16//! Inference is performed via variational EM with a Kalman smoother on the
17//! topic-word parameters.
18//!
19//! ## Example
20//!
21//! ```rust
22//! use scirs2_text::dtm::{DynamicTopicModel, DtmConfig};
23//!
24//! let config = DtmConfig {
25//!     n_topics: 2,
26//!     n_time_slices: 3,
27//!     max_iter: 5,
28//!     sigma_sq: 0.1,
29//!     alpha: 0.1,
30//! };
31//! let model = DynamicTopicModel::new(config);
32//!
33//! // 3 time slices, each with 4 documents of 6 words each
34//! let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3)
35//!     .map(|t| {
36//!         (0..4)
37//!             .map(|d| (0..6).map(|w| ((t + d + w) % 3) as f64).collect())
38//!             .collect()
39//!     })
40//!     .collect();
41//!
42//! let result = model.fit(&docs_by_time, 6).expect("DTM fit failed");
43//! assert_eq!(result.topic_word_trajectories.len(), 2); // K topics
44//! ```
45
46pub mod inference;
47pub mod model;
48
49use crate::error::Result;
50
51// ────────────────────────────────────────────────────────────────────────────
52// Public re-exports
53// ────────────────────────────────────────────────────────────────────────────
54
55pub use inference::{kalman_backward, kalman_forward};
56pub use model::{top_words_at_time, topic_evolution};
57
58// ────────────────────────────────────────────────────────────────────────────
59// Configuration
60// ────────────────────────────────────────────────────────────────────────────
61
62/// Configuration for the Dynamic Topic Model.
63#[derive(Debug, Clone)]
64pub struct DtmConfig {
65    /// Number of latent topics K.
66    pub n_topics: usize,
67    /// Number of time slices T (may be 0; inferred from data if so).
68    pub n_time_slices: usize,
69    /// Maximum number of variational EM iterations.
70    pub max_iter: usize,
71    /// State-transition variance σ² for the Gaussian random walk.
72    pub sigma_sq: f64,
73    /// Dirichlet concentration parameter α for document-topic prior.
74    pub alpha: f64,
75}
76
77impl Default for DtmConfig {
78    fn default() -> Self {
79        Self {
80            n_topics: 10,
81            n_time_slices: 0,
82            max_iter: 50,
83            sigma_sq: 0.5,
84            alpha: 0.01,
85        }
86    }
87}
88
89// ────────────────────────────────────────────────────────────────────────────
90// Result type
91// ────────────────────────────────────────────────────────────────────────────
92
93/// Output of a fitted Dynamic Topic Model.
94#[derive(Debug, Clone)]
95pub struct DtmResult {
96    /// Topic-word trajectories `K × T × V`.
97    ///
98    /// `trajectories[k][t][w]` is the probability of word `w` under topic `k`
99    /// at time slice `t`.  Each slice `trajectories[k][t]` sums to 1.
100    pub topic_word_trajectories: Vec<Vec<Vec<f64>>>,
101    /// Flattened document-topic distribution (all documents across all time
102    /// slices concatenated).  Each row sums to 1.
103    pub doc_topic_matrix: Vec<Vec<f64>>,
104}
105
106// ────────────────────────────────────────────────────────────────────────────
107// Model struct
108// ────────────────────────────────────────────────────────────────────────────
109
110/// Dynamic Topic Model estimator.
111///
112/// Fit via [`DynamicTopicModel::fit`]; the result is returned as a [`DtmResult`].
113pub struct DynamicTopicModel {
114    /// Model configuration.
115    pub config: DtmConfig,
116    /// Fitted result (populated after `fit_and_store`).
117    fitted: Option<DtmResult>,
118}
119
120impl DynamicTopicModel {
121    /// Construct a new (unfitted) DTM with the given configuration.
122    pub fn new(config: DtmConfig) -> Self {
123        Self {
124            config,
125            fitted: None,
126        }
127    }
128
129    /// Return a reference to the fitted result, if available.
130    pub fn fitted_result(&self) -> Option<&DtmResult> {
131        self.fitted.as_ref()
132    }
133
134    /// Fit the model and store the result internally, also returning it.
135    pub fn fit_and_store(
136        &mut self,
137        docs_by_time: &[Vec<Vec<f64>>],
138        vocab_size: usize,
139    ) -> Result<&DtmResult> {
140        let result = self.fit(docs_by_time, vocab_size)?;
141        self.fitted = Some(result);
142        Ok(self.fitted.as_ref().expect("just set"))
143    }
144
145    /// Return the top-`n` words for each topic at time `t` using the fitted model.
146    pub fn top_words_at(&self, t: usize, vocab: &[String], n: usize) -> Option<Vec<Vec<String>>> {
147        self.fitted
148            .as_ref()
149            .map(|r| top_words_at_time(&r.topic_word_trajectories, t, vocab, n))
150    }
151
152    /// Return the evolution of word `word_id` in topic `topic_id` over time.
153    pub fn word_evolution(&self, topic_id: usize, word_id: usize) -> Option<Vec<f64>> {
154        self.fitted
155            .as_ref()
156            .map(|r| topic_evolution(&r.topic_word_trajectories, topic_id, word_id))
157    }
158}
159
160impl Default for DynamicTopicModel {
161    fn default() -> Self {
162        Self::new(DtmConfig::default())
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn dtm_default_config() {
172        let cfg = DtmConfig::default();
173        assert_eq!(cfg.n_topics, 10);
174        assert_eq!(cfg.max_iter, 50);
175        assert!((cfg.sigma_sq - 0.5).abs() < 1e-12);
176        assert!((cfg.alpha - 0.01).abs() < 1e-12);
177    }
178
179    #[test]
180    fn dtm_default_model() {
181        let m = DynamicTopicModel::default();
182        assert_eq!(m.config.n_topics, 10);
183        assert!(m.fitted_result().is_none());
184    }
185
186    #[test]
187    fn dtm_fit_and_store() {
188        let mut model = DynamicTopicModel::new(DtmConfig {
189            n_topics: 2,
190            n_time_slices: 2,
191            max_iter: 3,
192            sigma_sq: 0.1,
193            alpha: 0.1,
194        });
195        let docs_by_time: Vec<Vec<Vec<f64>>> = (0..2)
196            .map(|t| {
197                (0..3)
198                    .map(|d| (0..4).map(|w| ((t + d + w) % 3) as f64).collect())
199                    .collect()
200            })
201            .collect();
202        model.fit_and_store(&docs_by_time, 4).expect("fit failed");
203        assert!(model.fitted_result().is_some());
204    }
205
206    #[test]
207    fn dtm_top_words_at_after_fit() {
208        let mut model = DynamicTopicModel::new(DtmConfig {
209            n_topics: 2,
210            n_time_slices: 2,
211            max_iter: 3,
212            sigma_sq: 0.1,
213            alpha: 0.1,
214        });
215        let docs_by_time: Vec<Vec<Vec<f64>>> = (0..2)
216            .map(|t| {
217                (0..3)
218                    .map(|d| (0..5).map(|w| ((t + d + w) % 3) as f64).collect())
219                    .collect()
220            })
221            .collect();
222        model.fit_and_store(&docs_by_time, 5).expect("fit failed");
223        let vocab: Vec<String> = (0..5).map(|i| format!("w{i}")).collect();
224        let tw = model.top_words_at(0, &vocab, 3).expect("no fitted result");
225        assert_eq!(tw.len(), 2); // K topics
226        assert_eq!(tw[0].len(), 3); // n words
227    }
228
229    #[test]
230    fn dtm_word_evolution_length_equals_t() {
231        let mut model = DynamicTopicModel::new(DtmConfig {
232            n_topics: 2,
233            n_time_slices: 4,
234            max_iter: 3,
235            sigma_sq: 0.1,
236            alpha: 0.1,
237        });
238        let docs_by_time: Vec<Vec<Vec<f64>>> = (0..4)
239            .map(|t| {
240                (0..2)
241                    .map(|d| (0..5).map(|w| ((t + d + w) % 3) as f64).collect())
242                    .collect()
243            })
244            .collect();
245        model.fit_and_store(&docs_by_time, 5).expect("fit failed");
246        let ev = model.word_evolution(0, 2).expect("no fitted result");
247        assert_eq!(ev.len(), 4);
248    }
249}