use ahash::{AHashMap, AHashSet};
use statrs::distribution::{Continuous, ContinuousCDF, Normal};
use super::{
gsea::{GSEAConfig, GSEAResult, RankListItem},
ora::{get_ora, ORAConfig, ORAResult},
};
use crate::{
methods::gsea::gsea,
readers::utils::Item,
stat::{adjust, AdjustmentMethod},
};
pub enum MultiListMethod {
Max(NormalizationMethod),
Mean(NormalizationMethod),
Meta(MetaAnalysisMethod),
}
pub enum MetaAnalysisMethod {
Stouffer,
Fisher,
}
pub enum AnalysisType {
GSEA,
ORA,
}
pub struct GSEAJob {
pub gmt: Vec<Item>,
pub rank_list: Vec<RankListItem>,
pub config: GSEAConfig,
}
pub struct ORAJob {
pub gmt: Vec<Item>,
pub interest_list: AHashSet<String>,
pub reference_list: AHashSet<String>,
pub config: ORAConfig,
}
#[derive(Copy, Clone)]
pub enum NormalizationMethod {
MedianRank,
MedianValue,
MeanValue,
None,
}
pub fn multilist_gsea(
jobs: Vec<GSEAJob>,
method: MultiListMethod,
fdr_method: AdjustmentMethod,
) -> Vec<Vec<GSEAResult>> {
if let MultiListMethod::Meta(meta_method) = method {
let mut phash: AHashMap<String, Vec<f64>> = AHashMap::default();
let mut results: Vec<Vec<GSEAResult>> = Vec::new();
for job in jobs {
let res = gsea(job.rank_list, job.gmt, job.config, None);
for row in res.iter() {
let set = row.set.clone();
phash.entry(set).or_default().push(row.p);
}
results.push(res);
}
let mut final_result: Vec<GSEAResult> = Vec::new();
let mut meta_p = Vec::new();
match meta_method {
MetaAnalysisMethod::Stouffer => {
let normal = Normal::new(0.0, 1.0).unwrap();
for set in phash.keys() {
meta_p.push(stouffer_with_normal(&phash[set], &normal))
}
}
MetaAnalysisMethod::Fisher => {
for set in phash.keys() {
meta_p.push(fisher(&phash[set]));
}
}
}
let meta_fdr = adjust(&meta_p, fdr_method);
for (i, set) in phash.keys().enumerate() {
final_result.push(GSEAResult {
set: set.clone(),
p: meta_p[i],
fdr: meta_fdr[i],
nes: 0.0,
es: 0.0,
running_sum: Vec::new(),
leading_edge: 0,
})
}
results.insert(0, final_result);
results
} else {
let lists = jobs.iter().map(|x| x.rank_list.clone()).collect();
let combined_list = combine_lists(lists, method);
let gmts = jobs.iter().map(|x| x.gmt.clone()).collect();
let combined_gmt = combine_gmts(&gmts);
vec![gsea(
combined_list,
combined_gmt,
jobs.first().unwrap().config.clone(),
None,
)]
}
}
pub fn multilist_ora(
jobs: Vec<ORAJob>,
method: MultiListMethod,
fdr_method: AdjustmentMethod,
) -> Vec<Vec<ORAResult>> {
match method {
MultiListMethod::Meta(meta_method) => {
let mut phash: AHashMap<String, Vec<f64>> = AHashMap::default();
let mut results: Vec<Vec<ORAResult>> = Vec::new();
for job in jobs {
let res = get_ora(&job.interest_list, &job.reference_list, job.gmt, job.config);
for row in res.iter() {
let set = row.set.clone();
phash.entry(set).or_default().push(row.p);
}
results.push(res);
}
let mut final_result: Vec<ORAResult> = Vec::new();
let mut meta_p = Vec::new();
match meta_method {
MetaAnalysisMethod::Stouffer => {
let normal = Normal::new(0.0, 1.0).unwrap();
for set in phash.keys() {
meta_p.push(stouffer_with_normal(&phash[set], &normal))
}
}
MetaAnalysisMethod::Fisher => {
for set in phash.keys() {
meta_p.push(fisher(&phash[set]));
}
}
}
let meta_fdr = adjust(&meta_p, fdr_method);
for (i, set) in phash.keys().enumerate() {
final_result.push(ORAResult {
set: set.clone(),
p: meta_p[i],
fdr: meta_fdr[i],
overlap: 0,
expected: 0.0,
enrichment_ratio: 0.0,
})
}
results.insert(0, final_result);
results
}
_ => {
panic!("Multi-Omics ORA can only be run with meta-analysis");
}
}
}
pub fn combine_lists(
lists: Vec<Vec<RankListItem>>,
combination_method: MultiListMethod,
) -> Vec<RankListItem> {
match combination_method {
MultiListMethod::Max(normalization_method) => max_combine(lists, normalization_method),
MultiListMethod::Mean(normalization_method) => mean_combine(lists, normalization_method),
MultiListMethod::Meta(_) => panic!("Lists can not be combined for meta-analysis"),
}
}
fn max_combine(
lists: Vec<Vec<RankListItem>>,
normalization_method: NormalizationMethod,
) -> Vec<RankListItem> {
let normalized_lists: Vec<Vec<RankListItem>> = lists
.into_iter()
.map(|mut list| normalize(&mut list, normalization_method))
.collect();
let mut batches: AHashMap<String, f64> = AHashMap::default();
for list in normalized_lists {
for item in list {
if let Some(val) = batches.get_mut(&item.analyte) {
if item.rank.abs() > *val {
*val = item.rank;
}
} else {
batches.insert(item.analyte, item.rank);
}
}
}
let mut final_list: Vec<RankListItem> = Vec::new();
for key in batches.keys() {
final_list.push(RankListItem {
analyte: key.clone(),
rank: batches[key],
});
}
final_list
}
fn mean_combine(
lists: Vec<Vec<RankListItem>>,
normalization_method: NormalizationMethod,
) -> Vec<RankListItem> {
let normalized_lists: Vec<Vec<RankListItem>> = lists
.into_iter()
.map(|mut list| normalize(&mut list, normalization_method))
.collect();
let mut batches: AHashMap<String, Vec<f64>> = AHashMap::default();
for list in normalized_lists {
for item in list {
if let Some(val) = batches.get_mut(&item.analyte) {
val.push(item.rank);
} else {
batches.insert(item.analyte, vec![item.rank]);
}
}
}
let mut final_list: Vec<RankListItem> = Vec::new();
for key in batches.keys() {
final_list.push(RankListItem {
analyte: key.clone(),
rank: batches[key].iter().sum::<f64>() / (batches[key].len() as f64),
})
}
final_list
}
fn normalize(list: &mut Vec<RankListItem>, method: NormalizationMethod) -> Vec<RankListItem> {
match method {
NormalizationMethod::None => list.clone(),
NormalizationMethod::MedianRank => {
list.sort_by(|a, b| {
a.rank
.partial_cmp(&b.rank)
.expect("Invalid float comparison during normalization")
});
let median = list.len() as f64 / 2.0;
let len: f64 = list.len() as f64;
let mut final_list: Vec<RankListItem> = Vec::new();
for (i, item) in list.iter().enumerate() {
final_list.push(RankListItem {
analyte: item.analyte.clone(),
rank: (i as f64 - median) / len,
});
}
final_list
}
NormalizationMethod::MedianValue => {
list.sort_by(|a, b| {
b.rank
.partial_cmp(&a.rank)
.expect("Invalid float comparison during normalization")
});
let min = list.last().unwrap().rank;
let median = list[list.len() / 2].rank - min;
let shift = min / median;
let mut final_list: Vec<RankListItem> = Vec::new();
for item in list.iter() {
final_list.push(RankListItem {
analyte: item.analyte.clone(),
rank: (item.rank - min) / median + shift,
});
}
final_list
}
NormalizationMethod::MeanValue => {
list.sort_by(|a, b| {
b.rank
.partial_cmp(&a.rank)
.expect("Invalid float comparison during normalization")
});
let min = list.last().unwrap().rank;
let mean: f64 = list.iter().map(|x| x.rank - min).sum::<f64>() / (list.len() as f64)
- min / (list.len() as f64);
let shift = min / mean;
let mut final_list: Vec<RankListItem> = Vec::new();
for item in list.iter() {
final_list.push(RankListItem {
analyte: item.analyte.clone(),
rank: (item.rank - min) / mean + shift,
});
}
final_list
}
}
}
pub fn combine_gmts(gmts: &Vec<Vec<Item>>) -> Vec<Item> {
let mut combined_parts: AHashMap<String, Vec<String>> = AHashMap::default();
let mut combined_urls: AHashMap<String, String> = AHashMap::default();
for gmt in gmts {
for item in gmt {
if combined_parts.contains_key(&item.id) {
combined_parts
.get_mut(&item.id)
.unwrap()
.extend(item.parts.clone());
} else {
combined_parts.insert(item.id.clone(), item.parts.clone());
combined_urls.insert(item.id.clone(), item.url.clone());
}
}
}
let mut final_gmt: Vec<Item> = Vec::new();
for (key, parts) in combined_parts {
final_gmt.push(Item {
id: key.clone(),
parts,
url: combined_urls[&key].clone(),
})
}
final_gmt
}
pub fn stouffer(vals: &Vec<f64>) -> f64 {
let n = Normal::new(0.0, 1.0).unwrap();
stouffer_with_normal(vals, &n)
}
fn stouffer_with_normal(vals: &Vec<f64>, normal: &Normal) -> f64 {
let k = vals.len();
normal.cdf(vals.iter().map(|x| normal.inverse_cdf(*x)).sum::<f64>() / f64::sqrt(k as f64))
}
pub fn fisher(vals: &Vec<f64>) -> f64 {
let k = vals.len();
let pt = -2.0 * vals.iter().map(|x| x.ln()).sum::<f64>();
let dist = statrs::distribution::ChiSquared::new(2_f64.powi(k as i32 - 1)).unwrap();
dist.pdf(pt)
}
pub fn stouffer_weighted(vals: Vec<f64>, weights: Vec<f64>) -> f64 {
let n = Normal::new(0.0, 1.0).unwrap();
n.cdf(
vals.iter()
.enumerate()
.map(|(i, x)| weights[i] * n.inverse_cdf(*x))
.sum::<f64>()
/ f64::sqrt(weights.iter().map(|x| x * x).sum::<f64>()),
)
}