1use crate::dtm::model::normalise_to_simplex;
11use crate::dtm::{DtmConfig, DtmResult, DynamicTopicModel};
12use crate::error::{Result, TextError};
13
14pub fn kalman_forward(
35 observations: &[f64],
36 sigma_sq: f64,
37 obs_noise: f64,
38) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
39 let t = observations.len();
40 if t == 0 {
41 return (Vec::new(), Vec::new(), Vec::new(), Vec::new());
42 }
43
44 let mut filter_means = vec![0.0_f64; t];
45 let mut filter_vars = vec![0.0_f64; t];
46 let mut pred_means = vec![0.0_f64; t];
47 let mut pred_vars = vec![0.0_f64; t];
48
49 let prior_mean = observations[0];
51 let prior_var = sigma_sq + obs_noise;
52
53 pred_means[0] = prior_mean;
55 pred_vars[0] = prior_var;
56
57 let k0 = pred_vars[0] / (pred_vars[0] + obs_noise);
59 filter_means[0] = pred_means[0] + k0 * (observations[0] - pred_means[0]);
60 filter_vars[0] = (1.0 - k0) * pred_vars[0];
61
62 for s in 1..t {
63 pred_means[s] = filter_means[s - 1];
65 pred_vars[s] = filter_vars[s - 1] + sigma_sq;
66
67 let gain = pred_vars[s] / (pred_vars[s] + obs_noise);
69 filter_means[s] = pred_means[s] + gain * (observations[s] - pred_means[s]);
70 filter_vars[s] = (1.0 - gain) * pred_vars[s];
71 }
72
73 (filter_means, filter_vars, pred_means, pred_vars)
74}
75
76pub fn kalman_backward(
88 filter_means: &[f64],
89 filter_vars: &[f64],
90 pred_means: &[f64],
91 pred_vars: &[f64],
92 sigma_sq: f64,
93) -> (Vec<f64>, Vec<f64>) {
94 let t = filter_means.len();
95 if t == 0 {
96 return (Vec::new(), Vec::new());
97 }
98
99 let mut smoother_means = vec![0.0_f64; t];
100 let mut smoother_vars = vec![0.0_f64; t];
101
102 smoother_means[t - 1] = filter_means[t - 1];
104 smoother_vars[t - 1] = filter_vars[t - 1];
105
106 for s in (0..t - 1).rev() {
107 let pred_var_next = pred_vars[s + 1].max(1e-15);
108 let g = filter_vars[s] / pred_var_next;
110 smoother_means[s] = filter_means[s] + g * (smoother_means[s + 1] - pred_means[s + 1]);
112 smoother_vars[s] = filter_vars[s] + g * g * (smoother_vars[s + 1] - pred_var_next);
114 smoother_vars[s] = smoother_vars[s].max(1e-15);
116
117 let _ = sigma_sq; }
119
120 (smoother_means, smoother_vars)
121}
122
123fn digamma(x: f64) -> f64 {
129 if x <= 0.0 {
130 return -1e10;
131 }
132 let mut z = x;
134 let mut result = 0.0_f64;
135 while z < 6.0 {
136 result -= 1.0 / z;
137 z += 1.0;
138 }
139 result += z.ln() - 0.5 / z - 1.0 / (12.0 * z * z) + 1.0 / (120.0 * z * z * z * z)
140 - 1.0 / (252.0 * z * z * z * z * z * z);
141 result
142}
143
144fn e_step_doc(
154 doc_counts: &[f64],
155 gamma: &mut [f64],
156 phi: &mut [Vec<f64>],
157 beta_t: &[Vec<f64>],
158 alpha: f64,
159 max_inner: usize,
160) {
161 let k = gamma.len();
162 let vocab = doc_counts.len();
163
164 for _ in 0..max_inner {
165 let dg: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
167 for w in 0..vocab {
168 if doc_counts[w] <= 0.0 {
169 continue;
170 }
171 let mut row_sum = 0.0_f64;
172 for t in 0..k {
173 let beta_val = if t < beta_t.len() && w < beta_t[t].len() {
174 beta_t[t][w].max(1e-15)
175 } else {
176 1e-15
177 };
178 phi[t][w] = beta_val * dg[t].exp();
179 row_sum += phi[t][w];
180 }
181 if row_sum > 1e-15 {
182 for t in 0..k {
183 phi[t][w] /= row_sum;
184 }
185 }
186 }
187
188 for t in 0..k {
190 let weighted: f64 = (0..vocab).map(|w| doc_counts[w] * phi[t][w]).sum();
191 gamma[t] = alpha + weighted;
192 }
193 }
194}
195
196impl DynamicTopicModel {
201 pub fn fit(&self, docs_by_time: &[Vec<Vec<f64>>], vocab_size: usize) -> Result<DtmResult> {
210 let n_time = docs_by_time.len();
211 if n_time == 0 {
212 return Err(TextError::InvalidInput(
213 "Empty time-slice collection".into(),
214 ));
215 }
216
217 let k = self.config.n_topics;
218 let v = if vocab_size > 0 {
219 vocab_size
220 } else {
221 docs_by_time
222 .iter()
223 .flat_map(|slice| slice.iter())
224 .map(|d| d.len())
225 .max()
226 .unwrap_or(1)
227 };
228 let sigma_sq = self.config.sigma_sq;
229 let alpha = self.config.alpha;
230 let obs_noise = sigma_sq * 0.1_f64; let mut trajectories: Vec<Vec<Vec<f64>>> = (0..k)
235 .map(|ki| {
236 (0..n_time)
237 .map(|ti| {
238 let mut row: Vec<f64> = (0..v)
239 .map(|wi| {
240 1.0 / v as f64
241 + ((ki * 1009 + ti * 997 + wi * 991) % 1000) as f64 * 1e-5
242 })
243 .collect();
244 normalise_to_simplex(&mut row);
245 row
246 })
247 .collect()
248 })
249 .collect();
250
251 let mut doc_gammas: Vec<Vec<Vec<f64>>> = docs_by_time
254 .iter()
255 .map(|slice| {
256 slice
257 .iter()
258 .map(|_| vec![alpha + 1.0_f64 / k as f64; k])
259 .collect::<Vec<_>>()
260 })
261 .collect();
262
263 for _iter in 0..self.config.max_iter {
264 let mut suff_stats: Vec<Vec<Vec<f64>>> = vec![vec![vec![0.0_f64; v]; n_time]; k];
268
269 for (ti, slice) in docs_by_time.iter().enumerate() {
270 let beta_t: Vec<Vec<f64>> = (0..k).map(|ki| trajectories[ki][ti].clone()).collect();
271
272 for (di, doc_counts) in slice.iter().enumerate() {
273 let mut phi = vec![vec![0.0_f64; v]; k];
274 e_step_doc(
275 doc_counts,
276 &mut doc_gammas[ti][di],
277 &mut phi,
278 &beta_t,
279 alpha,
280 5,
281 );
282 for ki in 0..k {
284 for w in 0..v {
285 suff_stats[ki][ti][w] +=
286 doc_counts.get(w).copied().unwrap_or(0.0) * phi[ki][w];
287 }
288 }
289 }
290 }
291
292 for ki in 0..k {
294 for w in 0..v {
295 let obs: Vec<f64> = (0..n_time)
297 .map(|ti| {
298 let total: f64 = (0..v).map(|ww| suff_stats[ki][ti][ww]).sum();
299 if total > 1e-15 {
300 (suff_stats[ki][ti][w] / total).max(1e-15)
301 } else {
302 1.0 / v as f64
303 }
304 })
305 .collect();
306
307 let (fm, fv, pm, pv) = kalman_forward(&obs, sigma_sq, obs_noise);
308 let (sm, _sv) = kalman_backward(&fm, &fv, &pm, &pv, sigma_sq);
309
310 for ti in 0..n_time {
311 trajectories[ki][ti][w] = sm[ti].max(1e-15);
312 }
313 }
314
315 for ti in 0..n_time {
317 normalise_to_simplex(&mut trajectories[ki][ti]);
318 }
319 }
320 }
321
322 let mut doc_topic_matrix: Vec<Vec<f64>> = Vec::new();
325 for slice_gammas in &doc_gammas {
326 for gamma in slice_gammas {
327 let s: f64 = gamma.iter().sum();
328 let theta: Vec<f64> = gamma.iter().map(|&g| g / s.max(1e-15)).collect();
329 doc_topic_matrix.push(theta);
330 }
331 }
332
333 Ok(DtmResult {
334 topic_word_trajectories: trajectories,
335 doc_topic_matrix,
336 })
337 }
338}
339
340#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::dtm::{DtmConfig, DynamicTopicModel};
348
349 fn make_slice(n_docs: usize, vocab: usize, seed: usize) -> Vec<Vec<f64>> {
350 (0..n_docs)
351 .map(|d| {
352 (0..vocab)
353 .map(|w| ((d * 3 + w * 7 + seed) % 5) as f64)
354 .collect()
355 })
356 .collect()
357 }
358
359 #[test]
360 fn kalman_forward_correct_shape() {
361 let obs = vec![0.1_f64, 0.15, 0.12, 0.18, 0.14];
362 let (fm, fv, pm, pv) = kalman_forward(&obs, 0.01, 0.001);
363 assert_eq!(fm.len(), 5);
364 assert_eq!(fv.len(), 5);
365 assert_eq!(pm.len(), 5);
366 assert_eq!(pv.len(), 5);
367 }
368
369 #[test]
370 fn kalman_backward_smoother_variance_le_filter_variance() {
371 let obs = vec![0.1_f64, 0.15, 0.12, 0.18, 0.14, 0.13];
372 let (fm, fv, pm, pv) = kalman_forward(&obs, 0.01, 0.001);
373 let (_, sv) = kalman_backward(&fm, &fv, &pm, &pv, 0.01);
374 for (i, (&sv_i, &fv_i)) in sv.iter().zip(fv.iter()).enumerate() {
376 assert!(
377 sv_i <= fv_i + 1e-10,
378 "smoother_var[{i}]={sv_i} > filter_var[{i}]={fv_i}"
379 );
380 }
381 }
382
383 #[test]
384 fn kalman_roundtrip_recovers_trajectory() {
385 let truth = 0.2_f64;
387 let obs: Vec<f64> = vec![truth; 10];
388 let (fm, fv, pm, pv) = kalman_forward(&obs, 1e-4, 1e-3);
389 let (sm, _) = kalman_backward(&fm, &fv, &pm, &pv, 1e-4);
390 for (i, &m) in sm.iter().enumerate() {
391 assert!((m - truth).abs() < 0.05, "smoother[{i}]={m}, truth={truth}");
392 }
393 }
394
395 #[test]
396 fn dtm_fit_trajectories_shape() {
397 let config = DtmConfig {
398 n_topics: 2,
399 n_time_slices: 3,
400 max_iter: 5,
401 sigma_sq: 0.1,
402 alpha: 0.1,
403 };
404 let model = DynamicTopicModel::new(config);
405 let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3).map(|t| make_slice(4, 5, t)).collect();
406 let res = model.fit(&docs_by_time, 5).expect("fit failed");
407 assert_eq!(res.topic_word_trajectories.len(), 2);
409 assert_eq!(res.topic_word_trajectories[0].len(), 3);
410 assert_eq!(res.topic_word_trajectories[0][0].len(), 5);
411 }
412
413 #[test]
414 fn dtm_fit_doc_topic_rows_sum_to_one() {
415 let config = DtmConfig {
416 n_topics: 2,
417 n_time_slices: 3,
418 max_iter: 3,
419 sigma_sq: 0.1,
420 alpha: 0.1,
421 };
422 let model = DynamicTopicModel::new(config);
423 let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3).map(|t| make_slice(3, 5, t)).collect();
424 let res = model.fit(&docs_by_time, 5).expect("fit failed");
425 for (d, row) in res.doc_topic_matrix.iter().enumerate() {
426 let s: f64 = row.iter().sum();
427 assert!((s - 1.0).abs() < 1e-6, "doc {d} topic sum = {s}");
428 }
429 }
430
431 #[test]
432 fn dtm_fit_trajectories_row_sums_to_one() {
433 let config = DtmConfig {
434 n_topics: 2,
435 n_time_slices: 3,
436 max_iter: 3,
437 sigma_sq: 0.1,
438 alpha: 0.1,
439 };
440 let model = DynamicTopicModel::new(config);
441 let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3).map(|t| make_slice(3, 5, t)).collect();
442 let res = model.fit(&docs_by_time, 5).expect("fit failed");
443 for (ki, topic_traj) in res.topic_word_trajectories.iter().enumerate() {
444 for (ti, row) in topic_traj.iter().enumerate() {
445 let s: f64 = row.iter().sum();
446 assert!(
447 (s - 1.0).abs() < 1e-4,
448 "topic {ki} time {ti} word sum = {s}"
449 );
450 }
451 }
452 }
453}