use crate::matrix::ExprMatrix;
use crate::prep;
use crate::result::EnrichmentResult;
#[derive(Clone, Debug)]
pub struct PlageParams {
pub min_size: usize,
pub max_size: usize,
}
impl Default for PlageParams {
fn default() -> Self {
PlageParams {
min_size: 1,
max_size: usize::MAX,
}
}
}
pub fn plage(
expr: &ExprMatrix,
gene_sets: &crate::geneset::GeneSets,
params: &PlageParams,
) -> EnrichmentResult {
let filtered = prep::filter_constant_rows(expr);
let (names, indices) =
prep::map_and_filter_sets(&filtered, gene_sets, params.min_size, params.max_size);
let z = prep::scale_rows(&filtered);
let n = filtered.ncol();
let mut scores = vec![0.0f64; names.len() * n];
crate::par::fill_chunks_mut(&mut scores, n, |s, out| {
let idx = &indices[s];
let k = idx.len();
let mut a = vec![0.0f64; k * n];
for (r, &g) in idx.iter().enumerate() {
a[r * n..(r + 1) * n].copy_from_slice(&z[g * n..(g + 1) * n]);
}
let v1 = leading_right_singular_vector(&a, k, n);
out.copy_from_slice(&v1);
});
EnrichmentResult {
gene_sets: names,
samples: filtered.col_names().to_vec(),
scores,
}
}
#[cfg(feature = "faer")]
fn leading_right_singular_vector(a: &[f64], k: usize, n: usize) -> Vec<f64> {
use faer::Mat;
let mat = Mat::from_fn(k, n, |i, j| a[i * n + j]);
let svd = mat.svd().expect("faer SVD failed to converge");
let v = svd.V();
let mut v1: Vec<f64> = (0..n).map(|i| v[(i, 0)]).collect();
canonicalize_sign(&mut v1);
v1
}
#[cfg(not(feature = "faer"))]
fn leading_right_singular_vector(a: &[f64], k: usize, n: usize) -> Vec<f64> {
let mut acol = a.to_vec(); let mut v = vec![0.0f64; n * n]; for i in 0..n {
v[i * n + i] = 1.0;
}
let max_sweeps = 60;
for _ in 0..max_sweeps {
let mut rotated = false;
for p in 0..n {
for q in (p + 1)..n {
let mut alpha = 0.0f64; let mut beta = 0.0f64; let mut gamma = 0.0f64; for i in 0..k {
let ap = acol[i * n + p];
let aq = acol[i * n + q];
alpha += ap * ap;
beta += aq * aq;
gamma += ap * aq;
}
if gamma == 0.0 || gamma.abs() <= 1e-15 * (alpha * beta).sqrt() {
continue;
}
rotated = true;
let zeta = (beta - alpha) / (2.0 * gamma);
let sign_zeta = if zeta >= 0.0 { 1.0 } else { -1.0 };
let t = sign_zeta / (zeta.abs() + (1.0 + zeta * zeta).sqrt());
let c = 1.0 / (1.0 + t * t).sqrt();
let s = c * t;
for i in 0..k {
let ap = acol[i * n + p];
let aq = acol[i * n + q];
acol[i * n + p] = c * ap - s * aq;
acol[i * n + q] = s * ap + c * aq;
}
for i in 0..n {
let vp = v[i * n + p];
let vq = v[i * n + q];
v[i * n + p] = c * vp - s * vq;
v[i * n + q] = s * vp + c * vq;
}
}
}
if !rotated {
break;
}
}
let mut best = 0usize;
let mut best_norm = -1.0f64;
for j in 0..n {
let mut nrm = 0.0f64;
for i in 0..k {
let x = acol[i * n + j];
nrm += x * x;
}
if nrm > best_norm {
best_norm = nrm;
best = j;
}
}
let mut v1: Vec<f64> = (0..n).map(|i| v[i * n + best]).collect();
canonicalize_sign(&mut v1);
v1
}
fn canonicalize_sign(v1: &mut [f64]) {
let mut imax = 0usize;
let mut vmax = 0.0f64;
for (i, &val) in v1.iter().enumerate() {
if val.abs() > vmax {
vmax = val.abs();
imax = i;
}
}
if v1[imax] < 0.0 {
for x in v1.iter_mut() {
*x = -*x;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rank_one_right_singular_vector() {
let w = [2.0f64, 1.0];
let u = [1.0f64, 2.0, 3.0];
let (k, n) = (3usize, 2usize);
let mut a = vec![0.0f64; k * n];
for (i, &ui) in u.iter().enumerate() {
for (j, &wj) in w.iter().enumerate() {
a[i * n + j] = ui * wj;
}
}
let v1 = leading_right_singular_vector(&a, k, n);
let wn = (w[0] * w[0] + w[1] * w[1]).sqrt();
let expect = [w[0] / wn, w[1] / wn]; assert!((v1[0] - expect[0]).abs() < 1e-12, "v1={v1:?}");
assert!((v1[1] - expect[1]).abs() < 1e-12, "v1={v1:?}");
let nrm = (v1[0] * v1[0] + v1[1] * v1[1]).sqrt();
assert!((nrm - 1.0).abs() < 1e-12);
}
}