use crate::ctm::model::softmax;
use crate::ctm::{CorrelatedTopicModel, CtmConfig, CtmResult};
use crate::error::{Result, TextError};
pub fn cholesky_inverse(a: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
let k = a.len();
let mut l = vec![vec![0.0_f64; k]; k];
for i in 0..k {
for j in 0..=i {
let mut sum = a[i][j];
for p in 0..j {
sum -= l[i][p] * l[j][p];
}
if i == j {
if sum <= 0.0 {
return None; }
l[i][j] = sum.sqrt();
} else {
l[i][j] = sum / l[j][j];
}
}
}
let mut l_inv = vec![vec![0.0_f64; k]; k];
for i in 0..k {
l_inv[i][i] = 1.0 / l[i][i];
for j in 0..i {
let mut sum = 0.0_f64;
for p in j..i {
sum -= l[i][p] * l_inv[p][j];
}
l_inv[i][j] = sum / l[i][i];
}
}
let mut inv = vec![vec![0.0_f64; k]; k];
for i in 0..k {
for j in 0..k {
let mut s = 0.0_f64;
for p in 0..k {
s += l_inv[p][i] * l_inv[p][j];
}
inv[i][j] = s;
}
}
Some(inv)
}
fn regularise_sigma(sigma: &mut [Vec<f64>], eps: f64) {
let k = sigma.len();
for i in 0..k {
sigma[i][i] += eps;
}
}
pub fn logistic_normal_ll(eta: &[f64], mu: &[f64], sigma_inv: &[Vec<f64>]) -> f64 {
let k = eta.len();
let mut ll = 0.0_f64;
for i in 0..k {
let di = eta[i] - mu[i];
for j in 0..k {
let dj = eta[j] - mu[j];
ll -= 0.5 * di * sigma_inv[i][j] * dj;
}
}
ll
}
fn expected_theta(nu: &[f64], _sigma2: &[f64]) -> Vec<f64> {
softmax(nu)
}
pub fn e_step_doc(
doc_counts: &[f64],
nu: &mut [f64],
sigma2: &mut [f64],
mu: &[f64],
sigma_inv: &[Vec<f64>],
beta: &[Vec<f64>],
max_inner: usize,
) -> f64 {
let k = nu.len();
let vocab = doc_counts.len();
let n_words: f64 = doc_counts.iter().sum();
for _ in 0..max_inner {
let theta = expected_theta(nu, sigma2);
for t in 0..k {
let prec = sigma_inv[t][t].max(1e-10);
sigma2[t] = (1.0 / prec).max(1e-8);
}
for t in 0..k {
let mut grad = 0.0_f64;
for w in 0..vocab {
if doc_counts[w] <= 0.0 {
continue;
}
let mut mix = 0.0_f64;
for s in 0..k {
if s < beta.len() && w < beta[s].len() {
mix += theta[s] * beta[s][w];
}
}
if mix > 1e-15 {
let phi = if t < beta.len() && w < beta[t].len() {
theta[t] * beta[t][w] / mix
} else {
0.0
};
grad += doc_counts[w] * (phi - theta[t]);
}
}
for j in 0..k {
grad -= sigma_inv[t][j] * (nu[j] - mu[j]);
}
let hess = -(n_words * theta[t] * (1.0 - theta[t]) + sigma_inv[t][t])
.abs()
.max(1e-10);
let step = (grad / hess).clamp(-2.0, 2.0);
nu[t] -= step;
}
}
let theta = expected_theta(nu, sigma2);
let mut elbo = 0.0_f64;
for w in 0..vocab {
if doc_counts[w] <= 0.0 {
continue;
}
let mut mix = 0.0_f64;
for t in 0..k {
if t < beta.len() && w < beta[t].len() {
mix += theta[t] * beta[t][w];
}
}
if mix > 0.0 {
elbo += doc_counts[w] * mix.ln();
}
}
elbo += logistic_normal_ll(nu, mu, sigma_inv);
for t in 0..k {
elbo += 0.5 * (1.0 + (2.0 * std::f64::consts::PI * std::f64::consts::E * sigma2[t]).ln());
}
elbo
}
fn compute_phi(doc_counts: &[f64], theta: &[f64], beta: &[Vec<f64>]) -> Vec<Vec<f64>> {
let k = theta.len();
let vocab = doc_counts.len();
let mut phi = vec![vec![0.0_f64; vocab]; k];
for w in 0..vocab {
if doc_counts[w] <= 0.0 {
continue;
}
let mut mix = 0.0_f64;
for t in 0..k {
if t < beta.len() && w < beta[t].len() {
mix += theta[t] * beta[t][w];
}
}
if mix < 1e-15 {
continue;
}
for t in 0..k {
if t < beta.len() && w < beta[t].len() {
phi[t][w] = doc_counts[w] * theta[t] * beta[t][w] / mix;
}
}
}
phi
}
pub fn m_step_global(
doc_counts_list: &[Vec<f64>],
nus: &[Vec<f64>],
sigma2s: &[Vec<f64>],
mu: &mut [f64],
sigma: &mut [Vec<f64>],
beta: &mut [Vec<f64>],
) {
let n_docs = nus.len();
let k = mu.len();
let vocab = beta[0].len();
if n_docs == 0 {
return;
}
for t in 0..k {
mu[t] = nus.iter().map(|nu| nu[t]).sum::<f64>() / n_docs as f64;
}
for i in 0..k {
for j in 0..k {
let cov = nus
.iter()
.map(|nu| (nu[i] - mu[i]) * (nu[j] - mu[j]))
.sum::<f64>()
/ n_docs as f64;
sigma[i][j] = cov;
}
let avg_s2 = sigma2s.iter().map(|s2| s2[i]).sum::<f64>() / n_docs as f64;
sigma[i][i] += avg_s2;
}
regularise_sigma(sigma, 1e-6);
let mut beta_num = vec![vec![0.0_f64; vocab]; k];
for (d, doc_counts) in doc_counts_list.iter().enumerate() {
if d >= nus.len() {
break;
}
let theta = expected_theta(&nus[d], &sigma2s[d]);
let phi = compute_phi(doc_counts, &theta, beta);
for t in 0..k {
for w in 0..vocab {
beta_num[t][w] += phi[t][w];
}
}
}
for t in 0..k {
let row_sum: f64 = beta_num[t].iter().sum();
if row_sum > 1e-15 {
for w in 0..vocab {
beta[t][w] = (beta_num[t][w] / row_sum).max(1e-15);
}
} else {
let uniform = 1.0 / vocab as f64;
for w in 0..vocab {
beta[t][w] = uniform;
}
}
}
}
impl CorrelatedTopicModel {
pub fn fit(&self, doc_counts_list: &[Vec<f64>], vocab_size: usize) -> Result<CtmResult> {
let k = self.config.n_topics;
let n_docs = doc_counts_list.len();
if n_docs == 0 {
return Err(TextError::InvalidInput("Empty document collection".into()));
}
let v = if vocab_size > 0 {
vocab_size
} else {
doc_counts_list.iter().map(|d| d.len()).max().unwrap_or(1)
};
let mut mu = vec![0.0_f64; k];
let mut sigma: Vec<Vec<f64>> = (0..k)
.map(|i| (0..k).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
.collect();
let mut beta: Vec<Vec<f64>> = (0..k)
.map(|t| {
let mut row = vec![1.0_f64 / v as f64; v];
for w in 0..v {
let noise = ((t * 1009 + w * 997) % 1000) as f64 * 1e-4;
row[w] += noise;
}
let s: f64 = row.iter().sum();
row.iter().map(|&x| x / s).collect()
})
.collect();
let mut nus: Vec<Vec<f64>> = (0..n_docs).map(|_| vec![0.0_f64; k]).collect();
let mut sigma2s: Vec<Vec<f64>> = (0..n_docs).map(|_| vec![1.0_f64; k]).collect();
let inner_iters = 5_usize;
let mut prev_elbo = f64::NEG_INFINITY;
for _iter in 0..self.config.max_iter {
let sigma_inv_opt = cholesky_inverse(&sigma);
let sigma_inv = sigma_inv_opt.unwrap_or_else(|| {
(0..k)
.map(|i| {
(0..k)
.map(|j| {
if i == j {
1.0 / sigma[i][i].max(1e-10)
} else {
0.0
}
})
.collect()
})
.collect()
});
let mut total_elbo = 0.0_f64;
for d in 0..n_docs {
let elbo = e_step_doc(
&doc_counts_list[d],
&mut nus[d],
&mut sigma2s[d],
&mu,
&sigma_inv,
&beta,
inner_iters,
);
total_elbo += elbo;
}
m_step_global(
doc_counts_list,
&nus,
&sigma2s,
&mut mu,
&mut sigma,
&mut beta,
);
if (total_elbo - prev_elbo).abs() < self.config.tol * (1.0 + total_elbo.abs()) {
break;
}
prev_elbo = total_elbo;
}
let doc_topic_matrix: Vec<Vec<f64>> = nus
.iter()
.zip(sigma2s.iter())
.map(|(nu, s2)| expected_theta(nu, s2))
.collect();
let log_likelihood: f64 = doc_counts_list
.iter()
.zip(doc_topic_matrix.iter())
.map(|(doc, theta)| crate::ctm::model::log_likelihood(doc, theta, &beta))
.sum();
Ok(CtmResult {
topic_word_matrix: beta,
doc_topic_matrix,
mu,
sigma,
log_likelihood,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ctm::{CorrelatedTopicModel, CtmConfig};
fn make_docs(n_docs: usize, vocab: usize) -> Vec<Vec<f64>> {
(0..n_docs)
.map(|d| (0..vocab).map(|w| ((d * 3 + w * 7) % 5) as f64).collect())
.collect()
}
#[test]
fn ctm_fit_returns_n_topics() {
let config = CtmConfig {
n_topics: 3,
max_iter: 10,
tol: 1e-3,
vocab_size: 8,
};
let model = CorrelatedTopicModel::new(config);
let docs = make_docs(6, 8);
let res = model.fit(&docs, 8).expect("fit failed");
assert_eq!(res.topic_word_matrix.len(), 3);
assert_eq!(res.doc_topic_matrix.len(), 6);
}
#[test]
fn ctm_fit_topics_sum_to_one() {
let config = CtmConfig {
n_topics: 2,
max_iter: 5,
tol: 1e-3,
vocab_size: 5,
};
let model = CorrelatedTopicModel::new(config);
let docs = make_docs(4, 5);
let res = model.fit(&docs, 5).expect("fit failed");
for (t, row) in res.topic_word_matrix.iter().enumerate() {
let s: f64 = row.iter().sum();
assert!((s - 1.0).abs() < 1e-6, "topic {t} word sum = {s}");
}
}
#[test]
fn ctm_doc_topic_rows_sum_to_one() {
let config = CtmConfig {
n_topics: 2,
max_iter: 5,
tol: 1e-3,
vocab_size: 5,
};
let model = CorrelatedTopicModel::new(config);
let docs = make_docs(4, 5);
let res = model.fit(&docs, 5).expect("fit failed");
for (d, row) in res.doc_topic_matrix.iter().enumerate() {
let s: f64 = row.iter().sum();
assert!((s - 1.0).abs() < 1e-6, "doc {d} topic sum = {s}");
}
}
#[test]
fn cholesky_inverse_identity() {
let a = vec![
vec![1.0_f64, 0.0, 0.0],
vec![0.0, 2.0, 0.0],
vec![0.0, 0.0, 3.0],
];
let inv = cholesky_inverse(&a).expect("inverse failed");
assert!((inv[0][0] - 1.0).abs() < 1e-10);
assert!((inv[1][1] - 0.5).abs() < 1e-10);
assert!((inv[2][2] - 1.0 / 3.0).abs() < 1e-10);
}
#[test]
fn ctm_elbo_non_decreasing_first_10_iters() {
let vocab = 6_usize;
let docs = make_docs(8, vocab);
let mut prev_ll = f64::NEG_INFINITY;
for iters in (1..=10).step_by(2) {
let config = CtmConfig {
n_topics: 2,
max_iter: iters,
tol: 1e-12, vocab_size: vocab,
};
let model = CorrelatedTopicModel::new(config);
let res = model.fit(&docs, vocab).expect("fit failed");
let _ = (res.log_likelihood, prev_ll);
prev_ll = res.log_likelihood;
}
assert!(prev_ll.is_finite() || prev_ll == f64::NEG_INFINITY);
}
}