use crate::Phi;
use crate::embed::l2_normalize_cols;
use ndarray::{Array2, ArrayView1, ArrayView2, Axis};
use rand::seq::SliceRandom;
use rand_chacha::ChaCha8Rng;
pub fn dist_mat(y: ArrayView2<'_, f64>, z_cos: ArrayView2<'_, f64>) -> Array2<f64> {
let k = y.ncols();
let n = z_cos.ncols();
let mut out = Array2::<f64>::zeros((k, n));
for kk in 0..k {
for i in 0..n {
let dot = y.column(kk).dot(&z_cos.column(i));
out[[kk, i]] = 2.0 * (1.0 - dot);
}
}
out
}
pub fn update_y(z_cos: ArrayView2<'_, f64>, r: ArrayView2<'_, f64>) -> Array2<f64> {
let (d, n) = z_cos.dim();
let k = r.nrows();
let mut y = Array2::<f64>::zeros((d, k));
for row in 0..d {
for col in 0..k {
let mut acc = 0.0;
for i in 0..n {
acc += z_cos[[row, i]] * r[[col, i]];
}
y[[row, col]] = acc;
}
}
l2_normalize_cols(y.view())
}
pub fn normalise_cols_l1(m: &mut Array2<f64>) {
for mut col in m.axis_iter_mut(Axis(1)) {
let s: f64 = col.iter().sum();
if s > 0.0 {
col.mapv_inplace(|v| v / s);
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn update_r_block(
r: &mut Array2<f64>,
o: &mut Array2<f64>,
e: &mut Array2<f64>,
dist_mat: ArrayView2<'_, f64>,
phi: &Phi,
pr_b: ArrayView1<'_, f64>,
sigma: ArrayView1<'_, f64>,
theta: ArrayView1<'_, f64>,
block_size: f64,
rng: &mut ChaCha8Rng,
) -> Array2<f64> {
let (k, n) = dist_mat.dim();
let mut scale_dist = Array2::<f64>::zeros((k, n));
for kk in 0..k {
for i in 0..n {
scale_dist[[kk, i]] = (-dist_mat[[kk, i]] / sigma[kk]).exp();
}
}
normalise_cols_l1(&mut scale_dist);
let mut order: Vec<usize> = (0..n).collect();
order.shuffle(rng);
let cells_per_block = ((n as f64) * block_size).floor() as usize;
let n_blocks = ((1.0 / block_size).ceil() as usize).max(1);
for block_i in 0..n_blocks {
let idx_min = block_i * cells_per_block;
let idx_max = if block_i == n_blocks - 1 {
n
} else {
(block_i + 1) * cells_per_block
};
if idx_min >= n {
break;
}
let slice = &order[idx_min..idx_max.min(n)];
for &i in slice {
for kk in 0..k {
let r_ki = r[[kk, i]];
for b in 0..o.ncols() {
e[[kk, b]] -= r_ki * pr_b[b];
}
for c in 0..phi.n_cov {
let b = phi.row_of_cell[c * phi.n + i] as usize;
o[[kk, b]] -= r_ki;
}
}
}
for &i in slice {
for kk in 0..k {
let mut prod = 1.0;
for c in 0..phi.n_cov {
let b = phi.row_of_cell[c * phi.n + i] as usize;
let ratio = e[[kk, b]] / (o[[kk, b]] + e[[kk, b]]);
if ratio > 0.0 {
prod *= ratio.powf(theta[c]);
} else {
prod = 0.0;
}
}
r[[kk, i]] = scale_dist[[kk, i]] * prod;
}
}
for &i in slice {
let s: f64 = (0..k).map(|kk| r[[kk, i]]).sum();
if s > 0.0 {
for kk in 0..k {
r[[kk, i]] /= s;
}
}
}
for &i in slice {
for kk in 0..k {
let r_ki = r[[kk, i]];
for b in 0..o.ncols() {
e[[kk, b]] += r_ki * pr_b[b];
}
for c in 0..phi.n_cov {
let b = phi.row_of_cell[c * phi.n + i] as usize;
o[[kk, b]] += r_ki;
}
}
}
}
scale_dist
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn dist_mat_unit_cosine() {
let y = array![[1.0, 0.0], [0.0, 1.0]];
let z_cos = array![[1.0, 0.0], [0.0, 1.0]];
let d = dist_mat(y.view(), z_cos.view());
assert_abs_diff_eq!(d[[0, 0]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(d[[1, 1]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(d[[0, 1]], 2.0, epsilon = 1e-12);
}
#[test]
fn update_y_preserves_unit_norm() {
let z = array![[0.6, 0.8, 0.0], [0.8, -0.6, 1.0]];
let r = array![[1.0, 0.0, 0.0], [0.0, 1.0, 1.0]];
let y = update_y(z.view(), r.view());
for k in 0..y.ncols() {
let n: f64 = y.column(k).iter().map(|v| v * v).sum::<f64>().sqrt();
assert_abs_diff_eq!(n, 1.0, epsilon = 1e-12);
}
}
#[test]
fn normalise_cols_l1_sums_to_one() {
let mut m = array![[1.0, 2.0], [3.0, 0.0]];
normalise_cols_l1(&mut m);
assert_abs_diff_eq!(m[[0, 0]] + m[[1, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(m[[0, 1]] + m[[1, 1]], 1.0, epsilon = 1e-12);
}
}