use gsva::EnrichmentResult;
use crate::data::SpillModel;
pub fn spill_over(
transformed: &EnrichmentResult,
spill: &SpillModel,
alpha: f64,
) -> EnrichmentResult {
let nsamp = transformed.samples.len();
let mut row_names = Vec::new();
let mut t_rows = Vec::new();
let mut k_rows = Vec::new();
for (ti, ct) in transformed.gene_sets.iter().enumerate() {
if let Some(ki) = spill.k.row_of(ct) {
row_names.push(ct.clone());
t_rows.push(ti);
k_rows.push(ki);
}
}
let n = row_names.len();
let mut a = vec![0.0f64; n * n];
for (ri, &ki) in k_rows.iter().enumerate() {
for (ci, &kj) in k_rows.iter().enumerate() {
a[ri * n + ci] = if ri == ci {
1.0
} else {
spill.k.get(ki, kj) * alpha
};
}
}
let per_sample: Vec<Vec<f64>> = crate::par::map_collect(nsamp, |j| {
let mut b = vec![0.0f64; n];
for (ri, &tr) in t_rows.iter().enumerate() {
b[ri] = transformed.scores[tr * nsamp + j];
}
let x = nnls(&a, n, n, &b);
x.iter()
.map(|&xi| if xi < 0.0 { 0.0 } else { xi })
.collect()
});
let mut scores = vec![0.0f64; n * nsamp];
for (j, x) in per_sample.iter().enumerate() {
for (ri, &xi) in x.iter().enumerate() {
scores[ri * nsamp + j] = xi;
}
}
EnrichmentResult {
gene_sets: row_names,
samples: transformed.samples.clone(),
scores,
}
}
fn nnls(a: &[f64], m: usize, n: usize, b: &[f64]) -> Vec<f64> {
let mut g = vec![0.0f64; n * n];
for i in 0..n {
for j in 0..n {
let mut s = 0.0;
for k in 0..m {
s += a[k * n + i] * a[k * n + j];
}
g[i * n + j] = s;
}
}
let mut c = vec![0.0f64; n];
for i in 0..n {
let mut s = 0.0;
for k in 0..m {
s += a[k * n + i] * b[k];
}
c[i] = s;
}
let mut x = vec![0.0f64; n];
let mut passive = vec![false; n];
let mut w = c.clone();
let c_inf = c.iter().fold(0.0f64, |mx, &v| mx.max(v.abs()));
let tol = 1e-12 * c_inf.max(1.0);
let max_outer = 3 * n + 1;
let mut outer = 0;
loop {
let mut chosen = None;
let mut best = tol;
for j in 0..n {
if !passive[j] && w[j] > best {
best = w[j];
chosen = Some(j);
}
}
let Some(t) = chosen else { break };
if outer >= max_outer {
break;
}
outer += 1;
passive[t] = true;
loop {
let idx: Vec<usize> = (0..n).filter(|&j| passive[j]).collect();
let s_p = solve_spd_subset(&g, n, &idx, &c);
let mut s = vec![0.0f64; n];
let mut min_s = f64::INFINITY;
for (k, &j) in idx.iter().enumerate() {
s[j] = s_p[k];
if s_p[k] < min_s {
min_s = s_p[k];
}
}
if min_s > 0.0 {
x = s;
break;
}
let mut step = f64::INFINITY;
for &j in &idx {
if s[j] <= 0.0 {
let denom = x[j] - s[j];
if denom > 0.0 {
let ratio = x[j] / denom;
if ratio < step {
step = ratio;
}
}
}
}
if !step.is_finite() {
x = s;
break;
}
for j in 0..n {
x[j] += step * (s[j] - x[j]);
}
for j in 0..n {
if passive[j] && x[j] <= 0.0 {
passive[j] = false;
x[j] = 0.0;
}
}
}
for i in 0..n {
let mut s = 0.0;
for j in 0..n {
s += g[i * n + j] * x[j];
}
w[i] = c[i] - s;
}
}
x
}
fn solve_spd_subset(g: &[f64], n: usize, idx: &[usize], rhs: &[f64]) -> Vec<f64> {
let k = idx.len();
let mut m = vec![0.0f64; k * k];
let mut v = vec![0.0f64; k];
for (r, &ir) in idx.iter().enumerate() {
for (col, &ic) in idx.iter().enumerate() {
m[r * k + col] = g[ir * n + ic];
}
v[r] = rhs[ir];
}
for col in 0..k {
let mut piv = col;
let mut piv_val = m[col * k + col].abs();
for r in (col + 1)..k {
let val = m[r * k + col].abs();
if val > piv_val {
piv_val = val;
piv = r;
}
}
if piv != col {
for col2 in 0..k {
m.swap(col * k + col2, piv * k + col2);
}
v.swap(col, piv);
}
let d = m[col * k + col];
if d == 0.0 {
continue; }
for r in (col + 1)..k {
let f = m[r * k + col] / d;
if f != 0.0 {
for col2 in col..k {
m[r * k + col2] -= f * m[col * k + col2];
}
v[r] -= f * v[col];
}
}
}
let mut z = vec![0.0f64; k];
for r in (0..k).rev() {
let mut s = v[r];
for col in (r + 1)..k {
s -= m[r * k + col] * z[col];
}
let d = m[r * k + r];
z[r] = if d != 0.0 { s / d } else { 0.0 };
}
z
}