rsfgsea 0.3.4

High-performance fgsea-compatible preranked Gene Set Enrichment Analysis in Rust
Documentation
use crate::algo::calculate_es_fgsea;
use crate::algo_support::{
    apply_bh_adjustment, build_gene_index, compute_nes, extract_pathway_hits, leading_edge,
    mode_fraction_count, multilevel_error, selected_tail_count, should_refine_multilevel,
};
use crate::core::{EnrichmentResult, Pathway, RankedList, ScoreType};
use crate::multilevel::run_multilevel_gsea_impl;
use anyhow::Result;
use rayon::prelude::*;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};

#[allow(clippy::too_many_arguments)]
pub(crate) fn run_gsea_gpu_with_config_impl(
    ranks: &RankedList,
    pathways: &[Pathway],
    n_perm: usize,
    seed: u64,
    min_size: usize,
    max_size: usize,
    eps: f64,
    score_type: ScoreType,
    gsea_param: f64,
    sample_size: usize,
    allow_multilevel: bool,
) -> Result<Vec<EnrichmentResult>> {
    use crate::gpu::GpuEngine;

    let sample_size = sample_size.max(1);
    let eps = eps.clamp(0.0, 1.0);

    let (abs_weights, scaled_scores, _ns_total) = ranks.prepare(gsea_param);
    let simple_stats: Vec<f64> = scaled_scores.iter().map(|&score| score as f64).collect();
    let abs_weights_f32: Vec<f32> = abs_weights.iter().map(|&w| w as f32).collect();
    let gene_to_idx = build_gene_index(ranks);

    let mut by_size: BTreeMap<usize, Vec<(usize, Vec<usize>)>> = BTreeMap::new();
    for (i, pathway) in pathways.iter().enumerate() {
        let hits = extract_pathway_hits(pathway, &gene_to_idx);
        let k = hits.len();
        if k >= min_size && k <= max_size {
            by_size.entry(k).or_default().push((i, hits));
        }
    }

    let runtime = tokio::runtime::Runtime::new()?;
    let engine = runtime.block_on(GpuEngine::new())?;
    let scores_buffer = engine.upload_scores(&abs_weights_f32);
    let gpu_score_type = match score_type {
        ScoreType::Std => 0,
        ScoreType::Pos => 1,
        ScoreType::Neg => 2,
    };

    let mut results = vec![None; pathways.len()];
    let gpu_verbose = std::env::var("RSFGSEA_GPU_VERBOSE")
        .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
        .unwrap_or(false);
    let total_groups = by_size.len();
    let mut group_idx = 0usize;

    let total_null_gen_time = Arc::new(AtomicU64::new(0));
    let total_pure_screening_time = Arc::new(AtomicU64::new(0));
    let total_multilevel_time = Arc::new(AtomicU64::new(0));

    for (k, group) in by_size {
        group_idx += 1;
        if gpu_verbose
            || group_idx == 1
            || group_idx == total_groups
            || group_idx.is_multiple_of(25)
        {
            log::info!(
                "GPU null group {}/{}: {} pathways (k={})",
                group_idx,
                total_groups,
                group.len(),
                k
            );
        }

        let gen_start = std::time::Instant::now();
        let mut null_distribution = engine.generate_null_distribution(
            &scores_buffer,
            k,
            ranks.len(),
            n_perm,
            seed,
            gpu_score_type,
        )?;
        total_null_gen_time.fetch_add(gen_start.elapsed().as_micros() as u64, Ordering::Relaxed);
        null_distribution.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());

        let (_n_le_all, _n_ge_all, n_le_zero, n_ge_zero, le_zero_sum, ge_zero_sum) =
            GpuEngine::calculate_null_stats(&null_distribution, 0.0);
        let le_zero_mean = if n_le_zero > 0 {
            le_zero_sum / n_le_zero as f64
        } else {
            0.0
        };
        let ge_zero_mean = if n_ge_zero > 0 {
            ge_zero_sum / n_ge_zero as f64
        } else {
            0.0
        };

        let group_results: Vec<Option<EnrichmentResult>> = group
            .par_iter()
            .map(|(orig_idx, hits)| {
                let screening_start = std::time::Instant::now();
                let mut sorted_hits = hits.clone();
                sorted_hits.sort_unstable();
                let (obs_es, peak_idx) =
                    calculate_es_fgsea(&simple_stats, &sorted_hits, ranks.len(), score_type);

                let (n_le_es, n_ge_es, _, _, _, _) =
                    GpuEngine::calculate_null_stats(&null_distribution, obs_es);

                let p_le = (n_le_es + 1) as f64 / (n_le_zero + 1) as f64;
                let p_ge = (n_ge_es + 1) as f64 / (n_ge_zero + 1) as f64;
                let p_value_simple = match score_type {
                    ScoreType::Std => p_le.min(p_ge),
                    ScoreType::Pos => p_ge,
                    ScoreType::Neg => p_le,
                };
                total_pure_screening_time.fetch_add(
                    screening_start.elapsed().as_micros() as u64,
                    Ordering::Relaxed,
                );

                let nes = compute_nes(obs_es, score_type, le_zero_mean, ge_zero_mean);
                let leading_edge = leading_edge(&sorted_hits, peak_idx, obs_es, score_type, ranks);
                let n_more_extreme = selected_tail_count(score_type, obs_es, n_le_es, n_ge_es);
                let mode_fraction = mode_fraction_count(score_type, obs_es, n_le_zero, n_ge_zero);
                let simple_log2err = crate::algo_support::simple_log2err(n_more_extreme, n_perm);

                if allow_multilevel
                    && should_refine_multilevel(
                        n_more_extreme,
                        mode_fraction,
                        n_perm,
                        sample_size,
                        p_value_simple,
                    )
                {
                    let ml_start = std::time::Instant::now();
                    let (m_p, is_cp_ge_half, _m_err) = run_multilevel_gsea_impl(
                        ranks.len(),
                        &scaled_scores,
                        k,
                        obs_es,
                        score_type,
                        sample_size,
                        seed + *orig_idx as u64,
                        eps,
                    );
                    total_multilevel_time
                        .fetch_add(ml_start.elapsed().as_micros() as u64, Ordering::Relaxed);

                    let denom_prob = (mode_fraction + 1) as f64 / (n_perm + 1) as f64;
                    let mut p_value_ml = (m_p / denom_prob).min(1.0);
                    let log2err = if p_value_ml < eps {
                        p_value_ml = eps;
                        None
                    } else if is_cp_ge_half {
                        Some(multilevel_error(p_value_ml, sample_size))
                    } else {
                        None
                    };

                    Some(EnrichmentResult {
                        pathway_name: pathways[*orig_idx].name.clone(),
                        size: k,
                        es: obs_es,
                        nes,
                        p_value: p_value_ml,
                        padj: None,
                        log2err,
                        leading_edge,
                    })
                } else {
                    Some(EnrichmentResult {
                        pathway_name: pathways[*orig_idx].name.clone(),
                        size: k,
                        es: obs_es,
                        nes,
                        p_value: p_value_simple,
                        padj: None,
                        log2err: simple_log2err,
                        leading_edge,
                    })
                }
            })
            .collect();

        for (i, res) in group_results.into_iter().enumerate() {
            if let Some(result) = res {
                results[group[i].0] = Some(result);
            }
        }
    }

    log::info!("GPU execution timings:");
    log::info!(
        "null distribution generation: {} ms",
        total_null_gen_time.load(Ordering::Relaxed) / 1000
    );
    log::info!(
        "pure screening pass: {} ms",
        total_pure_screening_time.load(Ordering::Relaxed) / 1000
    );
    log::info!(
        "multilevel pass: {} ms",
        total_multilevel_time.load(Ordering::Relaxed) / 1000
    );

    let mut final_results: Vec<EnrichmentResult> = results.into_iter().flatten().collect();
    apply_bh_adjustment(&mut final_results);
    final_results.sort_by(|a, b| a.pathway_name.cmp(&b.pathway_name));
    Ok(final_results)
}