use std::collections::{HashMap, HashSet};
use gsva::{EnrichmentResult, ExprMatrix, SsgseaParams};
use crate::data::XCellModel;
pub const MIN_GENES: usize = 5000;
pub fn raw_enrichment(expr: &ExprMatrix, model: &XCellModel) -> EnrichmentResult {
let universe: HashSet<&str> = model.genes.iter().map(String::as_str).collect();
let mut shared_rows = Vec::new();
let mut shared_names = Vec::new();
let mut seen: HashSet<&str> = HashSet::new();
for (i, name) in expr.row_names().iter().enumerate() {
let s = name.as_str();
if universe.contains(s) && seen.insert(s) {
shared_rows.push(i);
shared_names.push(name.clone());
}
}
assert!(
shared_names.len() >= MIN_GENES,
"xCell needs at least {MIN_GENES} genes shared with its universe, found {}",
shared_names.len()
);
let nsamp = expr.ncol();
let p = shared_rows.len();
let mut ranked = vec![0.0f64; p * nsamp];
for j in 0..nsamp {
let col: Vec<f64> = shared_rows.iter().map(|&i| expr.get(i, j)).collect();
let r = gsva::rank::rank_average(&col);
for (i, &rv) in r.iter().enumerate() {
ranked[i * nsamp + j] = rv;
}
}
let ranked = ExprMatrix::new(shared_names, expr.col_names().to_vec(), ranked);
let params = SsgseaParams {
normalize: false,
..SsgseaParams::default()
};
let ss = gsva::ssgsea(&ranked, &model.signatures, ¶ms);
let nsig = ss.gene_sets.len();
let mut sc = ss.scores;
for s in 0..nsig {
let mn = sc[s * nsamp..(s + 1) * nsamp]
.iter()
.copied()
.fold(f64::INFINITY, f64::min);
for v in &mut sc[s * nsamp..(s + 1) * nsamp] {
*v -= mn;
}
}
let mut groups: HashMap<&str, Vec<usize>> = HashMap::new();
for (s, name) in ss.gene_sets.iter().enumerate() {
let ct = name.split('%').next().unwrap_or(name.as_str());
groups.entry(ct).or_default().push(s);
}
let mut out_types = Vec::new();
let mut scores = Vec::new();
for ct in &model.cell_types {
let Some(idxs) = groups.get(ct.as_str()) else {
continue; };
out_types.push(ct.clone());
let k = idxs.len() as f64;
for j in 0..nsamp {
let sum: f64 = idxs.iter().map(|&s| sc[s * nsamp + j]).sum();
scores.push(sum / k);
}
}
EnrichmentResult {
gene_sets: out_types,
samples: expr.col_names().to_vec(),
scores,
}
}