#[inline]
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[inline]
pub fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|&x| x * x).sum::<f32>().sqrt()
}
#[inline]
pub fn vec_sub(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect()
}
#[inline]
pub fn vec_add(a: &[f32], b: &[f32]) -> Vec<f32> {
a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect()
}
#[inline]
pub fn vec_scale(v: &[f32], s: f32) -> Vec<f32> {
v.iter().map(|&x| x * s).collect()
}
#[inline]
pub fn mat_vec_mul(m_flat: &[f32], rows: usize, cols: usize, v: &[f32]) -> Vec<f32> {
debug_assert_eq!(m_flat.len(), rows * cols);
debug_assert_eq!(v.len(), cols);
(0..rows)
.map(|i| {
let row = &m_flat[i * cols..(i + 1) * cols];
dot(row, v)
})
.collect()
}
#[inline]
pub fn soft_threshold(x: &[f32], threshold: f32) -> Vec<f32> {
x.iter()
.map(|&v| {
let abs_v = v.abs();
if abs_v > threshold {
v.signum() * (abs_v - threshold)
} else {
0.0
}
})
.collect()
}
pub fn fista_solve(
query: &[f32], entities: &[Vec<f32>], lambda: f32,
max_iter: usize,
) -> (Vec<f32>, Vec<f32>, f32) {
let m = entities.len();
let d = query.len();
if m == 0 {
return (vec![], query.to_vec(), l2_norm(query));
}
let mut ete = vec![0.0f32; m * m];
for i in 0..m {
for j in i..m {
let val = dot(&entities[i], &entities[j]);
ete[i * m + j] = val;
ete[j * m + i] = val; }
}
let etq: Vec<f32> = entities.iter().map(|e| dot(e, query)).collect();
let mut lip = 0.0f32;
for i in 0..m {
let row_sum: f32 = (0..m).map(|j| ete[i * m + j].abs()).sum();
if row_sum > lip {
lip = row_sum;
}
}
if lip < 1e-10 {
return (vec![0.0; m], query.to_vec(), l2_norm(query));
}
let step = 1.0 / lip;
let mut alpha = vec![0.0f32; m];
let mut y = vec![0.0f32; m];
let mut t = 1.0f32;
for _ in 0..max_iter {
let ete_y = mat_vec_mul(&ete, m, m, &y);
let grad = vec_sub(&ete_y, &etq);
let step_grad = vec_scale(&grad, step);
let y_minus_sg = vec_sub(&y, &step_grad);
let alpha_new = soft_threshold(&y_minus_sg, lambda * step);
let t_new = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
let momentum = (t - 1.0) / t_new;
let diff = vec_sub(&alpha_new, &alpha);
let scaled_diff = vec_scale(&diff, momentum);
y = vec_add(&alpha_new, &scaled_diff);
alpha = alpha_new;
t = t_new;
}
let mut reconstruction = vec![0.0f32; d];
for (i, &a) in alpha.iter().enumerate() {
if a.abs() > 1e-10 {
for (j, &e_val) in entities[i].iter().enumerate() {
reconstruction[j] += a * e_val;
}
}
}
let residual = vec_sub(query, &reconstruction);
let residual_norm = l2_norm(&residual);
(alpha, residual, residual_norm)
}
pub fn dpp_greedy(
vecs: &[Vec<f32>], scores: &[f32], k: usize,
quality_weight: f32,
) -> Vec<usize> {
let n = scores.len();
if n <= k {
return (0..n).collect();
}
let normed: Vec<Vec<f32>> = vecs
.iter()
.map(|v| {
let norm = l2_norm(v).max(1e-10);
let inv = 1.0 / norm;
v.iter().map(|&x| x * inv).collect()
})
.collect();
let q: Vec<f32> = scores
.iter()
.map(|&s| s.max(1e-10).powf(quality_weight))
.collect();
let mut diag: Vec<f32> = q.iter().map(|&qi| qi * qi + 1e-8).collect();
let mut c = vec![vec![0.0f32; n]; k];
let mut selected = Vec::with_capacity(k);
for j in 0..k {
let mut best = 0usize;
let mut best_val = f32::NEG_INFINITY;
for i in 0..n {
if diag[i] > best_val && !selected.contains(&i) {
best_val = diag[i];
best = i;
}
}
selected.push(best);
if j == k - 1 {
break;
}
if diag[best] < 1e-10 {
break;
}
let l_best_row: Vec<f32> = (0..n)
.map(|i| q[best] * dot(&normed[best], &normed[i]) * q[i])
.collect();
let mut cj = l_best_row;
for i in 0..j {
let c_i_best = c[i][best];
for idx in 0..n {
cj[idx] -= c_i_best * c[i][idx];
}
}
let inv_sqrt = 1.0 / diag[best].sqrt();
for idx in 0..n {
cj[idx] *= inv_sqrt;
}
for i in 0..n {
diag[i] -= cj[i] * cj[i];
if diag[i] < 0.0 {
diag[i] = 0.0;
}
}
c[j] = cj;
}
selected
}
pub fn nmf_multiplicative_update(
v_flat: &[f32],
m: usize,
d: usize,
k: usize,
max_iter: usize,
tol: f32,
) -> (Vec<f32>, Vec<f32>) {
let eps = 1e-10f32;
let mut v_abs = vec![0.0f32; m * d];
for (i, &val) in v_flat.iter().enumerate() {
v_abs[i] = val.abs();
}
let v_mean = v_abs.iter().sum::<f32>() / (m * d) as f32;
let avg = (v_mean / k as f32).sqrt().max(0.01);
let mut seed: u64 = 42;
let mut next_rand = || -> f32 {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let val = ((seed >> 33) as f32) / (u32::MAX as f32);
(avg + avg * 0.5 * (val - 0.5)).abs() + eps
};
let mut w: Vec<f32> = (0..m * k).map(|_| next_rand()).collect();
let mut h: Vec<f32> = (0..k * d).map(|_| next_rand()).collect();
for iter in 0..max_iter {
let mut wtv = vec![0.0f32; k * d];
for ki in 0..k {
for di in 0..d {
let mut sum = 0.0f32;
for mi in 0..m {
sum += w[mi * k + ki] * v_abs[mi * d + di];
}
wtv[ki * d + di] = sum;
}
}
let mut wtw = vec![0.0f32; k * k];
for i in 0..k {
for j in i..k {
let mut sum = 0.0f32;
for mi in 0..m {
sum += w[mi * k + i] * w[mi * k + j];
}
wtw[i * k + j] = sum;
wtw[j * k + i] = sum;
}
}
let mut wtwh = vec![0.0f32; k * d];
for i in 0..k {
for di in 0..d {
let mut sum = 0.0f32;
for j in 0..k {
sum += wtw[i * k + j] * h[j * d + di];
}
wtwh[i * d + di] = sum;
}
}
for i in 0..k * d {
h[i] *= wtv[i] / (wtwh[i] + eps);
}
let mut vht = vec![0.0f32; m * k];
for mi in 0..m {
for ki in 0..k {
let mut sum = 0.0f32;
for di in 0..d {
sum += v_abs[mi * d + di] * h[ki * d + di];
}
vht[mi * k + ki] = sum;
}
}
let mut hht = vec![0.0f32; k * k];
for i in 0..k {
for j in i..k {
let mut sum = 0.0f32;
for di in 0..d {
sum += h[i * d + di] * h[j * d + di];
}
hht[i * k + j] = sum;
hht[j * k + i] = sum;
}
}
let mut whht = vec![0.0f32; m * k];
for mi in 0..m {
for ki in 0..k {
let mut sum = 0.0f32;
for j in 0..k {
sum += w[mi * k + j] * hht[j * k + ki];
}
whht[mi * k + ki] = sum;
}
}
for i in 0..m * k {
w[i] *= vht[i] / (whht[i] + eps);
}
if iter > 0 && iter % 10 == 0 {
let mut res_sq = 0.0f32;
let mut v_sq = 0.0f32;
for mi in 0..m {
for di in 0..d {
let mut wh = 0.0f32;
for ki in 0..k {
wh += w[mi * k + ki] * h[ki * d + di];
}
let diff = v_abs[mi * d + di] - wh;
res_sq += diff * diff;
v_sq += v_abs[mi * d + di] * v_abs[mi * d + di];
}
}
if v_sq > 0.0 && (res_sq / v_sq).sqrt() < tol {
break;
}
}
}
(w, h)
}
pub fn nmf_analyze_query(
query: &[f32], h_flat: &[f32], k: usize,
d: usize,
) -> (f32, usize, f32, Vec<f32>) {
let q_abs: Vec<f32> = query.iter().map(|&x| x.abs()).collect();
let mut raw_scores = vec![0.0f32; k];
for ki in 0..k {
raw_scores[ki] = dot(&q_abs, &h_flat[ki * d..(ki + 1) * d]);
}
let max_s = raw_scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = raw_scores.iter().map(|&s| (s - max_s).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let q_topics: Vec<f32> = exp_scores.iter().map(|&e| e / (sum_exp + 1e-10)).collect();
let entropy: f32 = q_topics
.iter()
.map(|&p| if p > 1e-10 { -p * p.ln() } else { 0.0 })
.sum();
let max_entropy = if k > 1 { (k as f32).ln() } else { 1.0 };
let semantic_depth = 1.0 - entropy / max_entropy;
let threshold = 0.5 / k as f32;
let topic_coverage = q_topics.iter().filter(|&&p| p > threshold).count();
let mut q_recon = vec![0.0f32; d];
for ki in 0..k {
let w = q_topics[ki];
if w > 1e-10 {
for di in 0..d {
q_recon[di] += w * h_flat[ki * d + di];
}
}
}
let diff = vec_sub(&q_abs, &q_recon);
let q_norm = l2_norm(&q_abs).max(1e-10);
let novelty = (l2_norm(&diff) / q_norm).min(1.0);
(semantic_depth, topic_coverage, novelty, q_topics)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fista_basic() {
let entities = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
];
let query = vec![0.8, 0.6, 0.0, 0.0];
let (alpha, _residual, norm) = fista_solve(&query, &entities, 0.01, 100);
assert!((alpha[0] - 0.79).abs() < 0.1, "alpha[0]={}", alpha[0]);
assert!((alpha[1] - 0.59).abs() < 0.1, "alpha[1]={}", alpha[1]);
assert!(norm < 0.1, "residual_norm={}", norm);
}
#[test]
fn test_dpp_diversity() {
let vecs = vec![
vec![1.0, 0.0], vec![0.99, 0.1], vec![0.0, 1.0], ];
let scores = vec![1.0, 0.9, 0.8];
let selected = dpp_greedy(&vecs, &scores, 2, 1.0);
assert_eq!(selected.len(), 2);
assert_eq!(selected[0], 0); assert_eq!(selected[1], 2); }
#[test]
fn test_nmf() {
let v = vec![
1.0, 0.0, 0.5, 0.0,
0.0, 1.0, 0.0, 0.5,
0.5, 0.5, 0.5, 0.5,
];
let (w, h) = nmf_multiplicative_update(&v, 3, 4, 2, 100, 1e-3);
assert_eq!(w.len(), 3 * 2);
assert_eq!(h.len(), 2 * 4);
}
}