rsfgsea 0.3.4

High-performance fgsea-compatible preranked Gene Set Enrichment Analysis in Rust
Documentation
struct GpuResult {
    es: f32,
    peak_idx: u32,
}

@group(0) @binding(0) var<storage, read> abs_scores: array<f32>;
@group(0) @binding(1) var<storage, read> subsets_indices: array<u32>;
@group(0) @binding(2) var<storage, read_write> results: array<GpuResult>;

struct Params {
    k: f32,
    n_total: f32,
    batch_size: f32,
    score_type: f32, // 0: Std, 1: Pos, 2: Neg
}
@group(0) @binding(3) var<uniform> params: Params;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let batch_idx = global_id.x;
    if (f32(batch_idx) >= params.batch_size) {
        return;
    }

    let k = u32(params.k);
    let n_total = params.n_total;
    let n_miss = n_total - params.k;
    
    // Find sum of weights for this subset
    var sum_weights: f32 = 0.0;
    let start_offset = batch_idx * k;
    for (var i: u32 = 0; i < k; i++) {
        let idx = subsets_indices[start_offset + i];
        sum_weights += abs_scores[idx];
    }

    if (sum_weights == 0.0) {
        results[batch_idx] = GpuResult(0.0, 0);
        return;
    }

    var curr_max: f32 = 0.0;
    var curr_min: f32 = 0.0;
    var max_idx: u32 = 0;
    var min_idx: u32 = 0;
    
    var curr_sum_weight: f32 = 0.0;
    for (var j: u32 = 0; j < k; j++) {
        let hit_idx = subsets_indices[start_offset + j];
        let prev_p_hit = curr_sum_weight / sum_weights;
        curr_sum_weight += abs_scores[hit_idx];
        
        let p_hit = curr_sum_weight / sum_weights;
        let p_miss = (f32(hit_idx) - f32(j)) / n_miss;
        
        let es_before = prev_p_hit - p_miss;
        let es_at = p_hit - p_miss;
        
        if (es_before > curr_max) {
            curr_max = es_before;
            max_idx = hit_idx;
        }
        if (es_before < curr_min) {
            curr_min = es_before;
            min_idx = hit_idx;
        }
        
        if (es_at > curr_max) {
            curr_max = es_at;
            max_idx = hit_idx;
        }
        if (es_at < curr_min) {
            curr_min = es_at;
            min_idx = hit_idx;
        }
    }

    let score_type = u32(params.score_type);
    if (score_type == 1u) { // Pos
        results[batch_idx] = GpuResult(curr_max, max_idx);
    } else if (score_type == 2u) { // Neg
        results[batch_idx] = GpuResult(curr_min, min_idx);
    } else { // Std
        if (abs(curr_max) >= abs(curr_min)) {
            results[batch_idx] = GpuResult(curr_max, max_idx);
        } else {
            results[batch_idx] = GpuResult(curr_min, min_idx);
        }
    }
}