Skip to main content

limma/
geneset.rs

1//! Competitive gene-set tests: the deterministic, rank-based members of
2//! limma's gene-set family.
3//!
4//! Ported here:
5//! * [`rank_sum_test_with_correlation`] — Wilcoxon–Mann–Whitney rank-sum test
6//!   adjusted for inter-gene correlation (`rankSumTestWithCorrelation`).
7//! * [`gene_set_test`] / [`wilcox_gst`] — competitive set test from gene ranks
8//!   (`geneSetTest(..., ranks.only=TRUE)` / `wilcoxGST`). The permutation path
9//!   (`ranks.only=FALSE`) is RNG-dependent and not ported.
10//! * [`camera_pr`] — pre-ranked competitive test with inter-gene-correlation
11//!   correction (`cameraPR`, `directional=TRUE`, fixed scalar correlation).
12//! * [`camera`] — competitive test from an expression matrix and design,
13//!   computing the moderated-t statistics internally (`camera`/`camera.default`,
14//!   fixed `inter.gene.cor`).
15//! * [`inter_gene_correlation`] — variance-inflation factor and mean inter-gene
16//!   correlation of the residuals (`interGeneCorrelation`).
17//! * [`fry`] — fast approximation to `roast` (the `nrot = Inf`, `prior.df = Inf`
18//!   limit), giving deterministic directional and mixed p-values (`fry`).
19//! * [`roast`] — self-contained rotation gene-set test (`roast`), Monte-Carlo
20//!   over random rotations of the residual space. Reproduces limma's p-values
21//!   bit-for-bit via the [`RRng`] port of R's Mersenne-Twister.
22//! * [`ids2indices`] — map gene-set identifier lists to row indices.
23
24use std::collections::HashSet;
25
26use anyhow::{bail, Result};
27use ndarray::{Array1, Array2};
28use statrs::distribution::{Beta, ContinuousCDF, Normal, StudentsT};
29
30use crate::ebayes::{fit_fdist, squeeze_var, squeeze_var_post, tmixture_vector};
31use crate::linalg::{eigh, matrix_rank, qr_econ, qr_full_q};
32use crate::proptruenull::{prop_true_null, PropTrueNullMethod};
33use crate::rng::RRng;
34use crate::special::gauss_legendre_01;
35use crate::zscore::{zscore_t, ZscoreTMethod};
36
37/// Average ranks (R's `rank(x, ties.method="average")`), 1-based, ascending.
38fn rank_average(x: &[f64]) -> Vec<f64> {
39    let n = x.len();
40    let mut idx: Vec<usize> = (0..n).collect();
41    idx.sort_by(|&a, &b| x[a].partial_cmp(&x[b]).unwrap());
42    let mut ranks = vec![0.0; n];
43    let mut i = 0;
44    while i < n {
45        let mut j = i;
46        while j + 1 < n && x[idx[j + 1]] == x[idx[i]] {
47            j += 1;
48        }
49        // Sorted positions i..=j occupy ranks (i+1)..=(j+1); assign their mean.
50        let avg = ((i + 1 + j + 1) as f64) / 2.0;
51        for &k in &idx[i..=j] {
52            ranks[k] = avg;
53        }
54        i = j + 1;
55    }
56    ranks
57}
58
59/// Average ranks together with the sizes of exactly-equal-value groups, sharing
60/// a single sort. Equivalent to running [`rank_average`] and the old separate
61/// `tie_group_sizes` back to back, but sorts once instead of twice. Used by the
62/// correlation-adjusted rank-sum path; [`camera_pr`] computes it once and reuses
63/// it across every set, since both outputs depend only on the shared statistic
64/// vector (not on which genes are in a given set).
65fn rank_average_and_ties(x: &[f64]) -> (Vec<f64>, Vec<usize>) {
66    let n = x.len();
67    let mut idx: Vec<usize> = (0..n).collect();
68    idx.sort_by(|&a, &b| x[a].partial_cmp(&x[b]).unwrap());
69    let mut ranks = vec![0.0; n];
70    let mut sizes = Vec::new();
71    let mut i = 0;
72    while i < n {
73        let mut j = i;
74        while j + 1 < n && x[idx[j + 1]] == x[idx[i]] {
75            j += 1;
76        }
77        // Sorted positions i..=j occupy ranks (i+1)..=(j+1); assign their mean.
78        let avg = ((i + 1 + j + 1) as f64) / 2.0;
79        for &k in &idx[i..=j] {
80            ranks[k] = avg;
81        }
82        sizes.push(j - i + 1);
83        i = j + 1;
84    }
85    (ranks, sizes)
86}
87
88/// Lower-tail Student's t CDF, treating an infinite `df` as standard normal.
89fn pt_lower(x: f64, df: f64) -> f64 {
90    if df.is_infinite() {
91        Normal::new(0.0, 1.0).unwrap().cdf(x)
92    } else {
93        StudentsT::new(0.0, 1.0, df).unwrap().cdf(x)
94    }
95}
96
97/// Upper-tail Student's t CDF, via central symmetry (`P(T>x)=P(T<-x)`).
98fn pt_upper(x: f64, df: f64) -> f64 {
99    pt_lower(-x, df)
100}
101
102/// Rank-sum test (two-sample Wilcoxon–Mann–Whitney) allowing for correlation
103/// between members of the test set.
104///
105/// `index` are 0-based positions of the test set within `statistics`. Returns
106/// `(less, greater)` one-sided p-values, matching the `c(less, greater)` output
107/// of limma's `rankSumTestWithCorrelation`.
108pub fn rank_sum_test_with_correlation(
109    index: &[usize],
110    statistics: &[f64],
111    correlation: f64,
112    df: f64,
113) -> (f64, f64) {
114    let (r, tie_sizes) = rank_average_and_ties(statistics);
115    rank_sum_core(index, statistics.len(), &r, &tie_sizes, correlation, df)
116}
117
118/// Core of [`rank_sum_test_with_correlation`] given the average ranks `r` and
119/// tie-group sizes `tie_sizes` already computed over the full statistic vector
120/// of length `n`. Factored out so competitive callers that test many sets
121/// against the *same* statistics (e.g. [`camera_pr`]) rank the universe once and
122/// reuse it, rather than re-sorting per set.
123fn rank_sum_core(
124    index: &[usize],
125    n: usize,
126    r: &[f64],
127    tie_sizes: &[usize],
128    correlation: f64,
129    df: f64,
130) -> (f64, f64) {
131    let n1 = index.len();
132    let n2 = n - n1;
133    let sum_r1: f64 = index.iter().map(|&i| r[i]).sum();
134
135    let n1f = n1 as f64;
136    let n2f = n2 as f64;
137    let nf = n as f64;
138
139    let u = n1f * n2f + n1f * (n1f + 1.0) / 2.0 - sum_r1;
140    let mu = n1f * n2f / 2.0;
141
142    let mut sigma2 = if correlation == 0.0 || n1 == 1 {
143        n1f * n2f * (nf + 1.0) / 12.0
144    } else {
145        // asin(1) = pi/2.
146        let s = std::f64::consts::FRAC_PI_2 * n1f * n2f
147            + 0.5_f64.asin() * n1f * n2f * (n2f - 1.0)
148            + (correlation / 2.0).asin() * n1f * (n1f - 1.0) * n2f * (n2f - 1.0)
149            + ((correlation + 1.0) / 2.0).asin() * n1f * (n1f - 1.0) * n2f;
150        s / 2.0 / std::f64::consts::PI
151    };
152
153    if tie_sizes.iter().any(|&c| c > 1) {
154        let adjustment: f64 = tie_sizes
155            .iter()
156            .map(|&c| {
157                let cf = c as f64;
158                cf * (cf + 1.0) * (cf - 1.0)
159            })
160            .sum::<f64>()
161            / (nf * (nf + 1.0) * (nf - 1.0));
162        sigma2 *= 1.0 - adjustment;
163    }
164
165    let sd = sigma2.sqrt();
166    let zlowertail = (u + 0.5 - mu) / sd;
167    let zuppertail = (u - 0.5 - mu) / sd;
168
169    // Tails reversed on output: R's ranks are the reverse of Mann–Whitney's.
170    let less = pt_upper(zuppertail, df);
171    let greater = pt_lower(zlowertail, df);
172    (less, greater)
173}
174
175/// Alternative hypothesis for [`gene_set_test`].
176#[derive(Clone, Copy, Debug, PartialEq, Eq)]
177pub enum Alternative {
178    /// Genes in the set are up-regulated.
179    Up,
180    /// Genes in the set are down-regulated.
181    Down,
182    /// Two-sided: up or down (limma `"either"`/`"two.sided"`).
183    Either,
184    /// Genes change in either direction (limma `"mixed"`, the default).
185    Mixed,
186}
187
188/// Competitive gene-set test from gene ranks (`geneSetTest(..., ranks.only=TRUE)`).
189///
190/// `index` are 0-based positions of the set within `statistics`. Returns the
191/// p-value for the requested alternative.
192pub fn gene_set_test(index: &[usize], statistics: &[f64], alternative: Alternative) -> f64 {
193    let mut stats = statistics.to_vec();
194    let mut alt = alternative;
195    match alt {
196        Alternative::Mixed => {
197            for s in stats.iter_mut() {
198                *s = s.abs();
199            }
200        }
201        Alternative::Down => {
202            for s in stats.iter_mut() {
203                *s = -*s;
204            }
205            alt = Alternative::Up;
206        }
207        _ => {}
208    }
209    let (less, greater) = rank_sum_test_with_correlation(index, &stats, 0.0, f64::INFINITY);
210    match alt {
211        Alternative::Up => greater,
212        Alternative::Either => 2.0 * less.min(greater),
213        Alternative::Mixed => greater,
214        // Down is rewritten to Up above; this arm is unreachable.
215        Alternative::Down => less,
216    }
217}
218
219/// Mean-rank gene-set test (`wilcoxGST`): [`gene_set_test`] with the default
220/// `"mixed"` alternative.
221pub fn wilcox_gst(index: &[usize], statistics: &[f64]) -> f64 {
222    gene_set_test(index, statistics, Alternative::Mixed)
223}
224
225/// Direction of enrichment reported by [`camera_pr`].
226#[derive(Clone, Copy, Debug, PartialEq, Eq)]
227pub enum Direction {
228    /// Set statistics shifted up relative to the rest.
229    Up,
230    /// Set statistics shifted down relative to the rest.
231    Down,
232}
233
234/// One gene set's result from [`camera_pr`].
235#[derive(Clone, Debug)]
236pub struct CameraResult {
237    /// 0-based position of this set in the input `index` list.
238    pub set: usize,
239    /// Number of genes in the set.
240    pub n_genes: usize,
241    /// Net direction of enrichment.
242    pub direction: Direction,
243    /// Two-sided competitive p-value.
244    pub p_value: f64,
245    /// Benjamini–Hochberg adjusted p-value across all sets.
246    pub fdr: f64,
247}
248
249fn mean(x: &[f64]) -> f64 {
250    x.iter().sum::<f64>() / x.len() as f64
251}
252
253/// Sample variance (divisor `n-1`), matching R's `var`.
254fn sample_var(x: &[f64]) -> f64 {
255    let n = x.len();
256    let m = mean(x);
257    x.iter().map(|&v| (v - m) * (v - m)).sum::<f64>() / (n as f64 - 1.0)
258}
259
260/// Pre-ranked competitive gene-set test (`cameraPR`) for the default
261/// `directional=TRUE`, fixed scalar `inter.gene.cor` case.
262///
263/// `statistic` is a gene-level statistic (e.g. a moderated t); `index` lists
264/// the sets as 0-based gene positions. With `use_ranks=false` the parametric
265/// two-sample t form is used (df = `G-2`); with `use_ranks=true` the
266/// correlation-adjusted rank-sum test is used. Results carry the
267/// Benjamini–Hochberg FDR and, when `sort` is set and there is more than one
268/// set, are ordered by ascending p-value (stable on ties, as in R).
269pub fn camera_pr(
270    statistic: &[f64],
271    index: &[Vec<usize>],
272    inter_gene_cor: f64,
273    use_ranks: bool,
274    sort: bool,
275) -> Vec<CameraResult> {
276    let g = statistic.len();
277    let gf = g as f64;
278    let mean_stat = mean(statistic);
279    let var_stat = sample_var(statistic);
280    let df = if use_ranks { f64::INFINITY } else { gf - 2.0 };
281
282    // The rank path's average ranks and tie-group sizes depend only on the
283    // shared `statistic`, so compute them once rather than re-sorting per set.
284    let ranks = use_ranks.then(|| rank_average_and_ties(statistic));
285
286    let mut rows: Vec<CameraResult> = Vec::with_capacity(index.len());
287    for (si, iset) in index.iter().enumerate() {
288        let m = iset.len();
289        let (down, up) = if let Some((r, tie_sizes)) = ranks.as_ref() {
290            rank_sum_core(iset, g, r, tie_sizes, inter_gene_cor, df)
291        } else {
292            let mf = m as f64;
293            let m2 = gf - mf;
294            let vif = 1.0 + (mf - 1.0) * inter_gene_cor;
295            let mean_in_set = iset.iter().map(|&i| statistic[i]).sum::<f64>() / mf;
296            let delta = gf / m2 * (mean_in_set - mean_stat);
297            let var_pooled = ((gf - 1.0) * var_stat - delta * delta * mf * m2 / gf) / (gf - 2.0);
298            let t = delta / (var_pooled * (vif / mf + 1.0 / m2)).sqrt();
299            (pt_lower(t, df), pt_upper(t, df))
300        };
301        let p_value = 2.0 * down.min(up);
302        let direction = if down < up {
303            Direction::Down
304        } else {
305            Direction::Up
306        };
307        rows.push(CameraResult {
308            set: si,
309            n_genes: m,
310            direction,
311            p_value,
312            fdr: f64::NAN,
313        });
314    }
315
316    // BH adjustment across sets (only when there is more than one set).
317    if rows.len() > 1 {
318        let pvals: Vec<f64> = rows.iter().map(|r| r.p_value).collect();
319        let fdr = crate::toptable::p_adjust_bh(&pvals);
320        for (r, f) in rows.iter_mut().zip(fdr) {
321            r.fdr = f;
322        }
323    } else if let Some(r) = rows.first_mut() {
324        r.fdr = r.p_value;
325    }
326
327    if sort && rows.len() > 1 {
328        rows.sort_by(|a, b| a.p_value.partial_cmp(&b.p_value).unwrap());
329    }
330    rows
331}
332
333/// Variance-inflation factor and mean inter-gene correlation of the residuals
334/// (`interGeneCorrelation`).
335///
336/// `y` is `G x n` (genes by samples), `design` is `n x p`. The residual effects
337/// are the trailing `n - rank(design)` rows of `Q' y'`; each gene is scaled to
338/// unit mean square, then averaged across genes per residual coordinate.
339/// Returns `(vif, correlation)` with `correlation = (vif - 1) / (G - 1)`.
340pub fn inter_gene_correlation(y: &Array2<f64>, design: &Array2<f64>) -> (f64, f64) {
341    let g = y.nrows();
342    let n = y.ncols();
343    let rank = matrix_rank(design);
344    let nres = n - rank;
345    let qfull = qr_full_q(design);
346    let effects = qfull.t().dot(&y.t()); // n x G = Q' t(y)
347
348    let mut sigma = vec![0.0; g];
349    for (gi, s) in sigma.iter_mut().enumerate() {
350        let mut acc = 0.0;
351        for k in rank..n {
352            let e = effects[[k, gi]];
353            acc += e * e;
354        }
355        *s = (acc / nres as f64).sqrt();
356    }
357    let mut sumsq = 0.0;
358    for k in rank..n {
359        let mut ubar = 0.0;
360        for gi in 0..g {
361            ubar += effects[[k, gi]] / sigma[gi];
362        }
363        ubar /= g as f64;
364        sumsq += ubar * ubar;
365    }
366    let vif = g as f64 * sumsq / nres as f64;
367    let correlation = (vif - 1.0) / (g as f64 - 1.0);
368    (vif, correlation)
369}
370
371/// Z-score equivalent of a t-statistic via Hill's 1970 approximation
372/// (`.zscoreTHill`; `zscoreT(approx=TRUE, method="hill")`). Accurate for
373/// `df >= 2`; requires `df > 0.5`.
374fn zscore_t_hill(x: f64, df: f64) -> f64 {
375    let a = df - 0.5;
376    let b = 48.0 * a * a;
377    let mut z = a * (x * x / df).ln_1p();
378    z = (((((-0.4 * z - 3.3) * z - 24.0) * z - 85.5) / (0.8 * z * z + 100.0 + b) + z + 3.0) / b
379        + 1.0)
380        * z.sqrt();
381    z * x.signum()
382}
383
384/// Reorder design columns so column `coef` becomes the last one, preserving the
385/// order of the rest (limma's `design[,c((1:p)[-contrast],contrast)]`).
386fn move_coef_last(design: &Array2<f64>, coef: usize) -> Array2<f64> {
387    let p = design.ncols();
388    if coef == p - 1 {
389        return design.to_owned();
390    }
391    let n = design.nrows();
392    let mut order: Vec<usize> = (0..p).filter(|&c| c != coef).collect();
393    order.push(coef);
394    let mut out = Array2::<f64>::zeros((n, p));
395    for (newj, &oldj) in order.iter().enumerate() {
396        out.column_mut(newj).assign(&design.column(oldj));
397    }
398    out
399}
400
401/// Result of [`contrast_as_coef`].
402#[derive(Clone, Debug)]
403pub struct ContrastAsCoef {
404    /// Reformed design matrix (`n x p`) in which the requested contrasts appear
405    /// as plain coefficients.
406    pub design: Array2<f64>,
407    /// 0-based columns of `design` that hold the contrast coefficients.
408    pub coef: Vec<usize>,
409    /// Rank of the contrast matrix (the number of contrast coefficients).
410    pub rank: usize,
411}
412
413/// Reform a design matrix so that one or more contrasts become simple
414/// coefficients (`contrastAsCoef`).
415///
416/// `design` is `n x p`; `contrast` is `p x ncontrasts`. With `first = true` the
417/// contrast coefficients occupy the leading columns of the reformed design,
418/// otherwise the trailing columns (limma's `first` argument). The non-contrast
419/// columns are the orthogonal completion of the contrast space.
420///
421/// Only full-column-rank contrasts are supported: limma's rank-deficient path
422/// relies on LINPACK column pivoting in `qr`, which this port does not
423/// replicate. The completion columns follow the same Householder convention as
424/// R's `qr`, so the reformed design matches limma to rounding.
425pub fn contrast_as_coef(
426    design: &Array2<f64>,
427    contrast: &Array2<f64>,
428    first: bool,
429) -> Result<ContrastAsCoef> {
430    let n = design.nrows();
431    let p = design.ncols();
432    if contrast.nrows() != p {
433        bail!(
434            "contrast_as_coef: contrast rows ({}) must match design cols ({})",
435            contrast.nrows(),
436            p
437        );
438    }
439    let nc = contrast.ncols();
440    let rank = matrix_rank(contrast);
441    if rank == 0 {
442        bail!("contrast_as_coef: contrast is all zero");
443    }
444    if rank != nc {
445        bail!(
446            "contrast_as_coef: only full-column-rank contrasts are supported (rank {} of {} columns)",
447            rank,
448            nc
449        );
450    }
451    let k = nc;
452
453    // designT = Q' t(design) using the full orthogonal factor of the contrast.
454    let qfull = qr_full_q(contrast); // p x p
455    let (_, rmat) = qr_econ(contrast); // k x k upper triangular
456    let mut designt = qfull.t().dot(&design.t()); // p x n
457
458    // Replace the leading k rows with R^-1 designT[0..k] (back-substitution),
459    // turning the contrast directions into plain coefficients.
460    for col in 0..n {
461        for i in (0..k).rev() {
462            let mut s = designt[[i, col]];
463            for j in (i + 1)..k {
464                s -= rmat[[i, j]] * designt[[j, col]];
465            }
466            designt[[i, col]] = s / rmat[[i, i]];
467        }
468    }
469    let reformed = designt.t().to_owned(); // n x p, columns 0..k the contrasts
470
471    // Place contrast coefficients first or last, as requested.
472    if first {
473        Ok(ContrastAsCoef {
474            design: reformed,
475            coef: (0..k).collect(),
476            rank,
477        })
478    } else {
479        let mut out = Array2::<f64>::zeros((n, p));
480        for (newj, oldj) in (k..p).chain(0..k).enumerate() {
481            out.column_mut(newj).assign(&reformed.column(oldj));
482        }
483        Ok(ContrastAsCoef {
484            design: out,
485            coef: (p - k..p).collect(),
486            rank,
487        })
488    }
489}
490
491/// Competitive gene-set test from an expression matrix and design
492/// (`camera`/`camera.default`) with a fixed inter-gene correlation.
493///
494/// `exprs` is `G x n` (genes by samples), `design` is `n x p`, and `coef` is the
495/// 0-based design column whose contrast is tested (limma's `contrast`, default
496/// the last column). Moderated-t statistics are computed internally from the QR
497/// effects and [`squeeze_var`] (`trend.var = FALSE`, `robust = FALSE`), then
498/// converted to z-scores with Hill's approximation (`use_ranks = false`) or used
499/// directly (`use_ranks = true`) before the per-set machinery of [`camera_pr`].
500///
501/// `inter_gene_cor` is clamped at 0 (limma's `allow.neg.cor = FALSE` default), a
502/// no-op for the usual small positive correlation.
503pub fn camera(
504    exprs: &Array2<f64>,
505    design: &Array2<f64>,
506    coef: usize,
507    index: &[Vec<usize>],
508    inter_gene_cor: f64,
509    use_ranks: bool,
510    sort: bool,
511) -> Result<Vec<CameraResult>> {
512    let g = exprs.nrows();
513    let n = exprs.ncols();
514    let p = design.ncols();
515    assert!(g >= 3, "camera: need at least 3 genes");
516    let df_residual = n as f64 - p as f64;
517    assert!(df_residual >= 1.0, "camera: no residual df");
518
519    // Reorder so the tested contrast is the final column, then take QR effects.
520    let design = move_coef_last(design, coef);
521    let qfull = qr_full_q(&design);
522    let (_, r) = qr_econ(&design);
523    let effects = qfull.t().dot(&exprs.t()); // n x G = Q' t(y)
524
525    // Unscaled t = the p-th effect, signed by the R pivot. The product of the
526    // effect and sign(R[p,p]) is invariant to the QR sign convention, so it
527    // matches limma even though the Householder signs may differ from LAPACK.
528    let sign = if r[[p - 1, p - 1]] < 0.0 { -1.0 } else { 1.0 };
529    let unscaledt: Vec<f64> = (0..g).map(|gi| effects[[p - 1, gi]] * sign).collect();
530
531    // Residual variance per gene = mean square of the trailing effects.
532    let mut sigma2 = Array1::<f64>::zeros(g);
533    for (gi, s) in sigma2.iter_mut().enumerate() {
534        let mut acc = 0.0;
535        for k in p..n {
536            let e = effects[[k, gi]];
537            acc += e * e;
538        }
539        *s = acc / df_residual;
540    }
541
542    let sv = squeeze_var(&sigma2, &Array1::from_elem(g, df_residual), None, false)?;
543
544    let mut stat = vec![0.0; g];
545    if use_ranks {
546        for gi in 0..g {
547            stat[gi] = unscaledt[gi] / sv.var_post[gi].sqrt();
548        }
549    } else {
550        let df_total = (df_residual + sv.df_prior[0]).min(g as f64 * df_residual);
551        for gi in 0..g {
552            let modt = unscaledt[gi] / sv.var_post[gi].sqrt();
553            stat[gi] = zscore_t_hill(modt, df_total);
554        }
555    }
556
557    let cor = inter_gene_cor.max(0.0);
558    Ok(camera_pr(&stat, index, cor, use_ranks, sort))
559}
560
561/// Matrix of genewise effects with `n - p + 1` columns (`.lmEffects`, no
562/// weights/blocks): column 0 is the sign-corrected contrast effect, the rest are
563/// the residual effects. `exprs` is `G x n`, `design` is `n x p`, `coef` the
564/// 0-based contrast column.
565fn lm_effects(exprs: &Array2<f64>, design: &Array2<f64>, coef: usize) -> Array2<f64> {
566    let g = exprs.nrows();
567    let n = exprs.ncols();
568    let p = design.ncols();
569    let design = move_coef_last(design, coef);
570    let qfull = qr_full_q(&design);
571    let (_, r) = qr_econ(&design);
572    let full = qfull.t().dot(&exprs.t()); // n x G = Q' t(y)
573    let signc = if r[[p - 1, p - 1]] < 0.0 { -1.0 } else { 1.0 };
574    let neff = n - p + 1;
575    let mut eff = Array2::<f64>::zeros((g, neff));
576    for gi in 0..g {
577        eff[[gi, 0]] = full[[p - 1, gi]] * signc;
578        for k in 1..neff {
579            eff[[gi, k]] = full[[p - 1 + k, gi]];
580        }
581    }
582    eff
583}
584
585/// Sort order for [`fry`] (`sort` argument).
586#[derive(Clone, Copy, Debug, PartialEq, Eq)]
587pub enum FrySort {
588    /// By directional p-value, then descending set size, then mixed p-value.
589    Directional,
590    /// By mixed p-value, then descending set size, then directional p-value.
591    Mixed,
592    /// Leave sets in input order.
593    NoSort,
594}
595
596/// One gene set's result from [`fry`].
597#[derive(Clone, Debug)]
598pub struct FryResult {
599    /// 0-based position of this set in the input `index` list.
600    pub set: usize,
601    /// Number of genes in the set.
602    pub n_genes: usize,
603    /// Net direction of enrichment.
604    pub direction: Direction,
605    /// Directional (two-sided) p-value.
606    pub p_value: f64,
607    /// Benjamini–Hochberg FDR for `p_value` (equals `p_value` for a single set).
608    pub fdr: f64,
609    /// Mixed (non-directional) p-value.
610    pub p_value_mixed: f64,
611    /// Benjamini–Hochberg FDR for `p_value_mixed` (equals it for a single set).
612    pub fdr_mixed: f64,
613}
614
615/// Mixed (non-directional) p-value for one set via the `nrot = Inf` Beta
616/// approximation of `.fryEffects` (`m > 1`). `eff` is the standardized effects
617/// matrix; `iset` the 0-based set members.
618fn fry_mixed_pvalue(eff: &Array2<f64>, iset: &[usize]) -> f64 {
619    let neff = eff.ncols();
620    let m = iset.len();
621
622    // Squared singular values of the m x neff set block = eigenvalues of the
623    // smaller Gram matrix, descending.
624    let mut a: Vec<f64> = if neff <= m {
625        let mut gram = Array2::<f64>::zeros((neff, neff));
626        for &gi in iset {
627            for i in 0..neff {
628                for j in 0..neff {
629                    gram[[i, j]] += eff[[gi, i]] * eff[[gi, j]];
630                }
631            }
632        }
633        eigh(&gram).0.to_vec()
634    } else {
635        let mut gram = Array2::<f64>::zeros((m, m));
636        for (ai, &gi) in iset.iter().enumerate() {
637            for (bi, &gj) in iset.iter().enumerate() {
638                let mut s = 0.0;
639                for k in 0..neff {
640                    s += eff[[gi, k]] * eff[[gj, k]];
641                }
642                gram[[ai, bi]] = s;
643            }
644        }
645        eigh(&gram).0.to_vec()
646    };
647    a.reverse(); // descending
648
649    let d1 = a.len();
650    let d1f = d1 as f64;
651    let d = d1f - 1.0;
652    let beta_mean = 1.0 / d1f;
653    let beta_var = d / d1f / d1f / (d1f / 2.0 + 1.0);
654
655    let a1 = a[0];
656    let ad1 = a[d1 - 1];
657    let span = a1 - ad1;
658    let sum_col1_sq: f64 = iset.iter().map(|&gi| eff[[gi, 0]] * eff[[gi, 0]]).sum();
659    let fobs = (sum_col1_sq - ad1) / span;
660
661    let suma: f64 = a.iter().sum();
662    let suma2: f64 = a.iter().map(|&v| v * v).sum();
663    let frb_mean = (suma * beta_mean - ad1) / span;
664    // A' COV A with COV = beta_var I - (beta_var/d)(J - I).
665    let quad = beta_var * suma2 - (beta_var / d) * (suma * suma - suma2);
666    let frb_var = quad / (span * span);
667
668    let alphaplusbeta = frb_mean * (1.0 - frb_mean) / frb_var - 1.0;
669    let alpha = alphaplusbeta * frb_mean;
670    let beta = alphaplusbeta - alpha;
671    let dist = Beta::new(alpha, beta).unwrap();
672    1.0 - dist.cdf(fobs)
673}
674
675/// Fast approximation to `roast` (`fry`): the `nrot = Inf`, `prior.df = Inf`
676/// limit, giving deterministic directional and mixed competitive p-values.
677///
678/// `exprs` is `G x n` (genes by samples), `design` is `n x p`, and `coef` is the
679/// 0-based contrast column (limma's `contrast`, default the last). Uses the
680/// default `standardize = "posterior.sd"`: robust genewise variances squeezed
681/// toward an [`fit_fdist`] prior estimated from the residual variances.
682pub fn fry(
683    exprs: &Array2<f64>,
684    design: &Array2<f64>,
685    coef: usize,
686    index: &[Vec<usize>],
687    sort: FrySort,
688) -> Result<Vec<FryResult>> {
689    let mut eff = lm_effects(exprs, design, coef);
690    let g = eff.nrows();
691    let neff = eff.ncols();
692    let df_residual = (neff - 1) as f64;
693
694    // Expected maximum squared effect under the null, by Gauss–Legendre
695    // quadrature; `qchisq(x, df=1) = qnorm((x+1)/2)^2`.
696    let (nodes, weights) = gauss_legendre_01(128);
697    let normal = Normal::new(0.0, 1.0).unwrap();
698    let mut eu2max = 0.0;
699    for (&x, &w) in nodes.iter().zip(weights.iter()) {
700        let q = normal.inverse_cdf((x + 1.0) / 2.0);
701        eu2max += (df_residual + 1.0) * x.powf(df_residual) * (q * q) * w;
702    }
703
704    // Robust variance (drop the largest squared effect) and residual variance.
705    let mut s2_robust = Array1::<f64>::zeros(g);
706    let mut s2 = Array1::<f64>::zeros(g);
707    for gi in 0..g {
708        let mut sumsq = 0.0;
709        let mut maxsq = f64::NEG_INFINITY;
710        let mut sumsq_resid = 0.0;
711        for k in 0..neff {
712            let e2 = eff[[gi, k]] * eff[[gi, k]];
713            sumsq += e2;
714            if e2 > maxsq {
715                maxsq = e2;
716            }
717            if k >= 1 {
718                sumsq_resid += e2;
719            }
720        }
721        s2_robust[gi] = (sumsq - maxsq) / (df_residual + 1.0 - eu2max);
722        s2[gi] = sumsq_resid / df_residual;
723    }
724
725    // Empirical-Bayes squeeze: prior from residual variances, applied to robust.
726    let (scale, df2) = fit_fdist(&s2, &Array1::from_elem(g, df_residual));
727    let s2_robust = squeeze_var_post(
728        &s2_robust,
729        &Array1::from_elem(g, 0.92 * df_residual),
730        &Array1::from_elem(g, scale),
731        &Array1::from_elem(g, df2),
732    );
733    for gi in 0..g {
734        let s = s2_robust[gi].sqrt();
735        for k in 0..neff {
736            eff[[gi, k]] /= s;
737        }
738    }
739
740    // Per-set directional and mixed statistics.
741    let mut rows: Vec<FryResult> = Vec::with_capacity(index.len());
742    for (si, iset) in index.iter().enumerate() {
743        let m = iset.len();
744        let mut colmean = vec![0.0; neff];
745        for &gi in iset {
746            for (k, cm) in colmean.iter_mut().enumerate() {
747                *cm += eff[[gi, k]];
748            }
749        }
750        for cm in colmean.iter_mut() {
751            *cm /= m as f64;
752        }
753        let mean_resid_sq = colmean[1..].iter().map(|&v| v * v).sum::<f64>() / (neff - 1) as f64;
754        let t_stat = colmean[0] / mean_resid_sq.sqrt();
755        let direction = if t_stat < 0.0 {
756            Direction::Down
757        } else {
758            Direction::Up
759        };
760        let p_value = 2.0 * pt_lower(-t_stat.abs(), df_residual);
761        let p_value_mixed = if m > 1 {
762            fry_mixed_pvalue(&eff, iset)
763        } else {
764            p_value
765        };
766        rows.push(FryResult {
767            set: si,
768            n_genes: m,
769            direction,
770            p_value,
771            fdr: f64::NAN,
772            p_value_mixed,
773            fdr_mixed: f64::NAN,
774        });
775    }
776
777    if rows.len() > 1 {
778        let p: Vec<f64> = rows.iter().map(|r| r.p_value).collect();
779        let pm: Vec<f64> = rows.iter().map(|r| r.p_value_mixed).collect();
780        let fdr = crate::toptable::p_adjust_bh(&p);
781        let fdr_mixed = crate::toptable::p_adjust_bh(&pm);
782        for (r, (f, fm)) in rows.iter_mut().zip(fdr.into_iter().zip(fdr_mixed)) {
783            r.fdr = f;
784            r.fdr_mixed = fm;
785        }
786    } else if let Some(r) = rows.first_mut() {
787        r.fdr = r.p_value;
788        r.fdr_mixed = r.p_value_mixed;
789    }
790
791    match sort {
792        FrySort::Directional => rows.sort_by(|a, b| {
793            a.p_value
794                .partial_cmp(&b.p_value)
795                .unwrap()
796                .then(b.n_genes.cmp(&a.n_genes))
797                .then(a.p_value_mixed.partial_cmp(&b.p_value_mixed).unwrap())
798        }),
799        FrySort::Mixed => rows.sort_by(|a, b| {
800            a.p_value_mixed
801                .partial_cmp(&b.p_value_mixed)
802                .unwrap()
803                .then(b.n_genes.cmp(&a.n_genes))
804                .then(a.p_value.partial_cmp(&b.p_value).unwrap())
805        }),
806        FrySort::NoSort => {}
807    }
808    Ok(rows)
809}
810
811/// Result of a single-set rotation gene-set test ([`roast`]).
812///
813/// limma reports a four-row data frame with rows `Down`, `Up`, `UpOrDown`,
814/// `Mixed`; the arrays here follow that row order.
815#[derive(Clone, Debug)]
816pub struct Roast {
817    /// Active proportions `[Down, Up, UpOrDown, Mixed]`, i.e.
818    /// `[a2, a1, max(a1, a2), a1 + a2]` where `a1`/`a2` are the fractions of the
819    /// set with moderated z above `+sqrt(2)` / below `-sqrt(2)`.
820    pub active_prop: [f64; 4],
821    /// Rotation p-values `[Down, Up, UpOrDown, Mixed]`.
822    pub p_value: [f64; 4],
823    /// Number of genes in the tested set.
824    pub n_genes_in_set: usize,
825}
826
827/// Rotation gene-set test for a single set (`roast`).
828///
829/// Ports limma's default configuration: `set.statistic = "mean"`,
830/// `approx.zscore = TRUE`, `legacy = FALSE`, with no gene weights, array weights
831/// or blocking. `exprs` is `G x n` (genes by samples), `design` is `n x p`,
832/// `coef` the 0-based contrast column, `index` the 0-based members of the set
833/// and `nrot` the number of rotations (limma's default is `1999`).
834///
835/// `rng` is supplied already seeded by the caller — equivalent to calling R's
836/// `set.seed` immediately before `roast`. The rotations are the test's only
837/// source of randomness; they draw from `rng` exactly as `.roastEffects` does
838/// (`rnorm(nroti * neffects)` per chunk of `1000`, filled column-major), so a
839/// bit-exact [`RRng`] reproduces limma's Monte-Carlo counts.
840pub fn roast(
841    exprs: &Array2<f64>,
842    design: &Array2<f64>,
843    coef: usize,
844    index: &[usize],
845    nrot: usize,
846    rng: &mut RRng,
847) -> Result<Roast> {
848    let (eff, var_prior, df_prior, var_post_all) = roast_prepare(exprs, design, coef)?;
849    let (set_eff, var_post) = subset_effects(&eff, &var_post_all, index);
850    Ok(roast_effects(
851        &set_eff, var_prior, df_prior, &var_post, nrot, rng,
852    ))
853}
854
855/// Shared preprocessing for [`roast`]/[`mroast`]: returns the gene-wise effects
856/// matrix (`G x neffects`, column 0 the primary effect), the squeezed prior
857/// `(var_prior, df_prior)` estimated over all genes, and the per-gene posterior
858/// variances. limma computes these once and reuses them across every set.
859fn roast_prepare(
860    exprs: &Array2<f64>,
861    design: &Array2<f64>,
862    coef: usize,
863) -> Result<(Array2<f64>, f64, f64, Array1<f64>)> {
864    let eff = lm_effects(exprs, design, coef);
865    let g = eff.nrows();
866    let neff = eff.ncols();
867    let df_residual = (neff - 1) as f64;
868
869    let mut s2 = Array1::<f64>::zeros(g);
870    for gi in 0..g {
871        let mut acc = 0.0;
872        for k in 1..neff {
873            acc += eff[[gi, k]] * eff[[gi, k]];
874        }
875        s2[gi] = acc / df_residual;
876    }
877    let sv = squeeze_var(&s2, &Array1::from_elem(g, df_residual), None, false)?;
878    Ok((eff, sv.var_prior[0], sv.df_prior[0], sv.var_post))
879}
880
881/// Slice the effects matrix and posterior variances down to one gene set
882/// (0-based `index`), preserving the set's member order.
883fn subset_effects(
884    eff: &Array2<f64>,
885    var_post_all: &Array1<f64>,
886    index: &[usize],
887) -> (Array2<f64>, Vec<f64>) {
888    let neff = eff.ncols();
889    let nset = index.len();
890    let mut set_eff = Array2::<f64>::zeros((nset, neff));
891    let mut var_post = vec![0.0; nset];
892    for (si, &gi) in index.iter().enumerate() {
893        for k in 0..neff {
894            set_eff[[si, k]] = eff[[gi, k]];
895        }
896        var_post[si] = var_post_all[gi];
897    }
898    (set_eff, var_post)
899}
900
901/// Rotation core (`.roastEffects`) for the default `set.statistic = "mean"`,
902/// `approx.zscore = TRUE`, `legacy = FALSE`, no-gene-weights path. `effects` is
903/// the `nset x neffects` block for one set (column 0 is the primary effect);
904/// `var_prior` / `df_prior` are the scalar squeezed prior and `var_post` the
905/// per-gene posterior variances.
906fn roast_effects(
907    effects: &Array2<f64>,
908    var_prior: f64,
909    df_prior: f64,
910    var_post: &[f64],
911    nrot: usize,
912    rng: &mut RRng,
913) -> Roast {
914    let nset = effects.nrows();
915    let neff = effects.ncols();
916    let df_residual = (neff - 1) as f64;
917    let df_total = df_prior + df_residual;
918    let df_total_winsor = df_total.min(10000.0);
919    let prior_term = df_prior * var_prior;
920    let nset_f = nset as f64;
921    let sqrt2 = std::f64::consts::SQRT_2;
922
923    // Observed moderated z-statistics, active proportions and set statistics.
924    let mut sum_modt = 0.0;
925    let mut sum_abs_modt = 0.0;
926    let mut n_up = 0usize;
927    let mut n_down = 0usize;
928    for gi in 0..nset {
929        let modt = zscore_t(
930            effects[[gi, 0]] / var_post[gi].sqrt(),
931            df_total_winsor,
932            ZscoreTMethod::Bailey,
933        );
934        sum_modt += modt;
935        sum_abs_modt += modt.abs();
936        if modt > sqrt2 {
937            n_up += 1;
938        }
939        if modt < -sqrt2 {
940            n_down += 1;
941        }
942    }
943    let a1 = n_up as f64 / nset_f;
944    let a2 = n_down as f64 / nset_f;
945    let m = sum_modt / nset_f;
946    let statobs_down = -m;
947    let statobs_up = m;
948    let statobs_mixed = sum_abs_modt / nset_f;
949
950    // Per-gene sum of squared effects (the rotation-invariant total).
951    let mut rowsq = vec![0.0; nset];
952    for gi in 0..nset {
953        let mut acc = 0.0;
954        for k in 0..neff {
955            acc += effects[[gi, k]] * effects[[gi, k]];
956        }
957        rowsq[gi] = acc;
958    }
959
960    // Rotations are conducted in chunks; the chunk sizes fix the per-chunk RNG
961    // draw counts, which must match limma exactly. The draw itself (`rng.rnorm`)
962    // stays serial to preserve limma's exact Mersenne-Twister draw order — the
963    // source of the bit-exact p-values — but the per-rotation work that *consumes*
964    // those draws is independent across rotations and feeds only integer counters,
965    // whose sum is order-independent. So that inner loop is parallelised across
966    // rotations (behind the `parallel` feature) with a result that is bit-identical
967    // to the serial path, and to limma.
968    let chunk = 1000usize;
969    let nchunk = nrot.div_ceil(chunk);
970    let nroti0 = nrot.div_ceil(nchunk);
971    let overshoot = nchunk * nroti0 - nrot;
972
973    let mut count = [0i64; 4];
974    for chunki in 0..nchunk {
975        let nroti = if chunki == nchunk - 1 {
976            nroti0 - overshoot
977        } else {
978            nroti0
979        };
980        // rnorm(nroti * neffects), interpreted column-major as nroti x neffects.
981        let draws = rng.rnorm(nroti * neff);
982        let ctx = RotationCtx {
983            draws: &draws,
984            nroti,
985            neff,
986            nset,
987            effects,
988            rowsq: &rowsq,
989            prior_term,
990            df_residual,
991            df_total,
992            df_total_winsor,
993            nset_f,
994            statobs_down,
995            statobs_up,
996            statobs_mixed,
997        };
998        // [down, up, mixed] counts summed over this chunk's rotations.
999        let part = ctx.count_rotations();
1000        count[0] += part[0];
1001        count[1] += part[1];
1002        count[3] += part[2];
1003    }
1004    // For "mean", UpOrDown is the more significant of the one-sided counts.
1005    count[2] = count[0].min(count[1]);
1006
1007    let nrot_i = nrot as i64;
1008    let denom = [2 * nrot_i + 1, 2 * nrot_i + 1, nrot_i + 1, nrot_i + 1];
1009    let mut p_value = [0.0; 4];
1010    for i in 0..4 {
1011        p_value[i] = (count[i] as f64 + 1.0) / denom[i] as f64;
1012    }
1013
1014    Roast {
1015        active_prop: [a2, a1, a1.max(a2), a1 + a2],
1016        p_value,
1017        n_genes_in_set: nset,
1018    }
1019}
1020
1021/// Borrowed, fully-immutable context for one chunk's worth of [`roast_effects`]
1022/// rotations. Every field a single rotation reads is here, so the per-rotation
1023/// statistic is a pure function of the rotation index `r`: rotations within a
1024/// chunk are independent and accumulate only into integer counters. That makes
1025/// the chunk's total order-independent, so it can be summed with a parallel (or
1026/// serial) reduction that is bit-identical either way — and identical to limma.
1027struct RotationCtx<'a> {
1028    /// `rnorm(nroti * neff)` for this chunk, column-major as `nroti x neff`.
1029    draws: &'a [f64],
1030    nroti: usize,
1031    neff: usize,
1032    nset: usize,
1033    /// `nset x neff` effects block for the set (column 0 the primary effect).
1034    effects: &'a Array2<f64>,
1035    /// Per-gene sum of squared effects (the rotation-invariant total).
1036    rowsq: &'a [f64],
1037    prior_term: f64,
1038    df_residual: f64,
1039    df_total: f64,
1040    df_total_winsor: f64,
1041    nset_f: f64,
1042    statobs_down: f64,
1043    statobs_up: f64,
1044    statobs_mixed: f64,
1045}
1046
1047impl RotationCtx<'_> {
1048    /// Counts contributed by a single rotation `r`, as
1049    /// `[#{rot > statobs_down}, #{rot > statobs_up}, #{mixed > statobs_mixed}]`
1050    /// (the first two each count both the down and up rotated statistics, exactly
1051    /// as limma tallies `statrot[,c("down","up")]`). `zrow` is a caller-owned
1052    /// scratch buffer of length `neff`, reused across rotations to avoid
1053    /// allocating inside the hot loop.
1054    #[inline]
1055    fn count_one(&self, r: usize, zrow: &mut [f64]) -> [i64; 3] {
1056        // Unit-normalize the rotation row (limma's modtr / sqrt(rowSums^2)).
1057        let mut znorm = 0.0;
1058        for (k, z) in zrow.iter_mut().enumerate() {
1059            let v = self.draws[k * self.nroti + r];
1060            *z = v;
1061            znorm += v * v;
1062        }
1063        let znorm = znorm.sqrt();
1064        for z in zrow.iter_mut() {
1065            *z /= znorm;
1066        }
1067        // Rotated, moderated z-statistics for each gene in the set.
1068        let mut sum_z = 0.0;
1069        let mut sum_abs_z = 0.0;
1070        for gi in 0..self.nset {
1071            // zrow.len() == neff, so this is k = 0..neff in order (bit-identical
1072            // accumulation), just without the redundant bounds check clippy flags.
1073            let mut t = 0.0;
1074            for (k, &zv) in zrow.iter().enumerate() {
1075                t += self.effects[[gi, k]] * zv;
1076            }
1077            let s2r0 = (self.rowsq[gi] - t * t) / self.df_residual;
1078            let s2r = (self.prior_term + self.df_residual * s2r0) / self.df_total;
1079            let z = zscore_t(t / s2r.sqrt(), self.df_total_winsor, ZscoreTMethod::Bailey);
1080            sum_z += z;
1081            sum_abs_z += z.abs();
1082        }
1083        let up_r = sum_z / self.nset_f;
1084        let down_r = -up_r;
1085        let mixed_r = sum_abs_z / self.nset_f;
1086        [
1087            (down_r > self.statobs_down) as i64 + (up_r > self.statobs_down) as i64,
1088            (down_r > self.statobs_up) as i64 + (up_r > self.statobs_up) as i64,
1089            (mixed_r > self.statobs_mixed) as i64,
1090        ]
1091    }
1092
1093    /// Sum [`Self::count_one`] over all `nroti` rotations in the chunk. Parallel
1094    /// across rotations under the `parallel` feature; because the reduction is
1095    /// over integer counters it is bit-identical to the serial fold (and limma).
1096    #[cfg(feature = "parallel")]
1097    fn count_rotations(&self) -> [i64; 3] {
1098        use rayon::prelude::*;
1099        (0..self.nroti)
1100            .into_par_iter()
1101            .fold(
1102                || ([0i64; 3], vec![0.0; self.neff]),
1103                |(mut acc, mut zrow), r| {
1104                    let c = self.count_one(r, &mut zrow);
1105                    acc[0] += c[0];
1106                    acc[1] += c[1];
1107                    acc[2] += c[2];
1108                    (acc, zrow)
1109                },
1110            )
1111            .map(|(acc, _)| acc)
1112            .reduce(|| [0i64; 3], |a, b| [a[0] + b[0], a[1] + b[1], a[2] + b[2]])
1113    }
1114
1115    /// Serial fallback (`--no-default-features`): a single reused scratch buffer.
1116    #[cfg(not(feature = "parallel"))]
1117    fn count_rotations(&self) -> [i64; 3] {
1118        let mut acc = [0i64; 3];
1119        let mut zrow = vec![0.0; self.neff];
1120        for r in 0..self.nroti {
1121            let c = self.count_one(r, &mut zrow);
1122            acc[0] += c[0];
1123            acc[1] += c[1];
1124            acc[2] += c[2];
1125        }
1126        acc
1127    }
1128}
1129
1130/// One row of the [`mroast`] result table (limma's `mroast` data frame).
1131#[derive(Clone, Debug)]
1132pub struct MroastRow {
1133    /// 0-based position of this set in the input `index` list, recorded before
1134    /// any sorting so the caller can recover the original order.
1135    pub set: usize,
1136    /// Number of genes in the set (`NGenes`).
1137    pub n_genes: usize,
1138    /// Active proportion in the down direction (`PropDown`).
1139    pub prop_down: f64,
1140    /// Active proportion in the up direction (`PropUp`).
1141    pub prop_up: f64,
1142    /// Net direction (`Direction`): [`Direction::Up`] when the up p-value is the
1143    /// smaller of the two one-sided p-values, otherwise [`Direction::Down`].
1144    pub direction: Direction,
1145    /// Two-sided (UpOrDown) rotation p-value (`PValue`).
1146    pub p_value: f64,
1147    /// Benjamini-Hochberg FDR across sets over the two-sided p-values (`FDR`).
1148    pub fdr: f64,
1149    /// Mixed, non-directional rotation p-value (`PValue.Mixed`).
1150    pub p_value_mixed: f64,
1151    /// Benjamini-Hochberg FDR across sets over the mixed p-values (`FDR.Mixed`).
1152    pub fdr_mixed: f64,
1153}
1154
1155/// Multi-set rotation gene-set test (`mroast`).
1156///
1157/// Runs the [`roast`] rotation test for every set in `index` (each a slice of
1158/// 0-based gene indices), sharing the effects matrix and empirical-Bayes prior
1159/// across sets, then assembles limma's `mroast` table with Benjamini-Hochberg
1160/// FDRs computed across the sets. The default `set.statistic = "mean"`,
1161/// `approx.zscore = TRUE`, `legacy = FALSE`, no-gene-weights path is ported.
1162///
1163/// `midp` toggles limma's default mid-p correction (`midp = TRUE`): the FDRs are
1164/// computed from p-values shifted down by `1/2/(nrot+1)` and then floored back
1165/// at the raw rotation p-value. `sort` orders the rows ([`FrySort::Directional`]
1166/// is limma's default; [`FrySort::Mixed`] / [`FrySort::NoSort`] match
1167/// `sort = "mixed"` / `"none"`). Each [`MroastRow::set`] records the row's
1168/// original 0-based position in `index`.
1169///
1170/// `rng` is supplied already seeded by the caller. The sets are processed in
1171/// input order through a single shared `rng`, exactly as limma reuses the
1172/// rotation stream across sets, so a bit-exact [`RRng`] reproduces limma's
1173/// Monte-Carlo counts.
1174#[allow(clippy::too_many_arguments)]
1175pub fn mroast(
1176    exprs: &Array2<f64>,
1177    design: &Array2<f64>,
1178    coef: usize,
1179    index: &[Vec<usize>],
1180    nrot: usize,
1181    midp: bool,
1182    sort: FrySort,
1183    rng: &mut RRng,
1184) -> Result<Vec<MroastRow>> {
1185    let (eff, var_prior, df_prior, var_post_all) = roast_prepare(exprs, design, coef)?;
1186
1187    let mut rows = Vec::with_capacity(index.len());
1188    for (si, set) in index.iter().enumerate() {
1189        let (set_eff, var_post) = subset_effects(&eff, &var_post_all, set);
1190        let r = roast_effects(&set_eff, var_prior, df_prior, &var_post, nrot, rng);
1191        // Direction follows the smaller one-sided p-value (ties resolve to Down,
1192        // matching R's `pv[,"Up"] < pv[,"Down"]`).
1193        let direction = if r.p_value[1] < r.p_value[0] {
1194            Direction::Up
1195        } else {
1196            Direction::Down
1197        };
1198        rows.push(MroastRow {
1199            set: si,
1200            n_genes: r.n_genes_in_set,
1201            prop_down: r.active_prop[0],
1202            prop_up: r.active_prop[1],
1203            direction,
1204            p_value: r.p_value[2],
1205            fdr: f64::NAN,
1206            p_value_mixed: r.p_value[3],
1207            fdr_mixed: f64::NAN,
1208        });
1209    }
1210
1211    // Mid-p shift, then Benjamini-Hochberg across sets, then (for mid-p) floor
1212    // each FDR back at its raw rotation p-value.
1213    let midp_adj = if midp { 0.5 / (nrot as f64 + 1.0) } else { 0.0 };
1214    let two_sided: Vec<f64> = rows.iter().map(|r| r.p_value - midp_adj).collect();
1215    let mixed: Vec<f64> = rows.iter().map(|r| r.p_value_mixed - midp_adj).collect();
1216    let mut fdr = crate::toptable::p_adjust_bh(&two_sided);
1217    let mut fdr_mixed = crate::toptable::p_adjust_bh(&mixed);
1218    if midp {
1219        for (i, r) in rows.iter().enumerate() {
1220            fdr[i] = fdr[i].max(r.p_value);
1221            fdr_mixed[i] = fdr_mixed[i].max(r.p_value_mixed);
1222        }
1223    }
1224    for (r, (f, fm)) in rows.iter_mut().zip(fdr.into_iter().zip(fdr_mixed)) {
1225        r.fdr = f;
1226        r.fdr_mixed = fm;
1227    }
1228
1229    match sort {
1230        FrySort::Directional => rows.sort_by(|a, b| {
1231            a.p_value
1232                .partial_cmp(&b.p_value)
1233                .unwrap()
1234                .then(
1235                    b.prop_up
1236                        .max(b.prop_down)
1237                        .partial_cmp(&a.prop_up.max(a.prop_down))
1238                        .unwrap(),
1239                )
1240                .then(b.n_genes.cmp(&a.n_genes))
1241                .then(a.p_value_mixed.partial_cmp(&b.p_value_mixed).unwrap())
1242        }),
1243        FrySort::Mixed => rows.sort_by(|a, b| {
1244            a.p_value_mixed
1245                .partial_cmp(&b.p_value_mixed)
1246                .unwrap()
1247                .then(
1248                    (b.prop_up + b.prop_down)
1249                        .partial_cmp(&(a.prop_up + a.prop_down))
1250                        .unwrap(),
1251                )
1252                .then(b.n_genes.cmp(&a.n_genes))
1253                .then(a.p_value.partial_cmp(&b.p_value).unwrap())
1254        }),
1255        FrySort::NoSort => {}
1256    }
1257    Ok(rows)
1258}
1259
1260/// One row of [`romer`] output: the set size and the three rotation p-values.
1261#[derive(Clone, Debug)]
1262pub struct RomerRow {
1263    /// 0-based position of this set in the input `index` list.
1264    pub set: usize,
1265    /// Number of genes in the set (`NGenes`).
1266    pub n_genes: usize,
1267    /// Up-regulation p-value (`Up`): high mean rank of the moderated t.
1268    pub p_up: f64,
1269    /// Down-regulation p-value (`Down`): low mean rank of the moderated t.
1270    pub p_down: f64,
1271    /// Mixed p-value (`Mixed`): high mean rank of the absolute moderated t.
1272    pub p_mixed: f64,
1273}
1274
1275/// Alternative hypothesis for [`top_romer`] (`topRomer`'s `alternative`).
1276#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1277pub enum RomerAlternative {
1278    /// `"up"`: most up-regulated sets first.
1279    Up,
1280    /// `"down"`: most down-regulated sets first.
1281    Down,
1282    /// `"mixed"`: most differentially expressed (either direction) first.
1283    Mixed,
1284}
1285
1286/// Set-level summary statistic for [`romer`] (`romer`'s `set.statistic`).
1287#[derive(Clone, Copy, Debug, PartialEq, Eq)]
1288pub enum RomerStatistic {
1289    /// `"mean"` (default): mean rank of the moderated t over the set.
1290    Mean,
1291    /// `"floormean"`: mean rank after flooring the statistic (separate
1292    /// non-negative ranks for up, down and mixed directions).
1293    FloorMean,
1294    /// `"mean50"`: mean of the more extreme half of the within-set ranks.
1295    Mean50,
1296}
1297
1298/// Rotation mean-rank GSEA for linear models (`romer`).
1299///
1300/// Ports all three `set.statistic` options ([`RomerStatistic`]); array weights
1301/// and blocking are out of scope. `exprs` is `G x n` (genes by samples),
1302/// `design` is `n x p`, `coef` the 0-based contrast column, and `index` the gene
1303/// sets as 0-based member indices. `shrink_resid` toggles the empirical-Bayes
1304/// shrinkage of the contrast effect (`shrink.resid`, limma's default `true`).
1305///
1306/// `rng` is supplied already seeded by the caller. The rotation loop is the
1307/// test's only source of randomness, drawing `rnorm(n - p + 1)` once per
1308/// rotation exactly as limma does, so a bit-exact [`RRng`] reproduces limma's
1309/// Monte-Carlo counts. Rows are returned in input order (use [`top_romer`] to
1310/// rank them).
1311#[allow(clippy::too_many_arguments)]
1312pub fn romer(
1313    exprs: &Array2<f64>,
1314    design: &Array2<f64>,
1315    coef: usize,
1316    index: &[Vec<usize>],
1317    set_statistic: RomerStatistic,
1318    nrot: usize,
1319    shrink_resid: bool,
1320    rng: &mut RRng,
1321) -> Result<Vec<RomerRow>> {
1322    let g = exprs.nrows();
1323    let n = exprs.ncols();
1324    let p = design.ncols();
1325    let d = (n - p) as f64;
1326    let p0 = p - 1;
1327    let neff = n - p0; // d + 1: contrast effect plus residual effects.
1328
1329    // Reorder so the tested contrast is last, then take QR effects (raw, with the
1330    // pivot sign applied explicitly at the statistic stage as romer does).
1331    let design = move_coef_last(design, coef);
1332    let qfull = qr_full_q(&design);
1333    let (_, rmat) = qr_econ(&design);
1334    let full = qfull.t().dot(&exprs.t()); // n x G = Q' t(y)
1335    let signc = if rmat[[p - 1, p - 1]] < 0.0 {
1336        -1.0
1337    } else {
1338        1.0
1339    };
1340
1341    // Residual variance per gene and the empirical-Bayes posterior.
1342    let mut s2 = Array1::<f64>::zeros(g);
1343    for (gi, s) in s2.iter_mut().enumerate() {
1344        let mut acc = 0.0;
1345        for k in p..n {
1346            let e = full[[k, gi]];
1347            acc += e * e;
1348        }
1349        *s = acc / d;
1350    }
1351    let sv = squeeze_var(&s2, &Array1::from_elem(g, d), None, false)?;
1352    let d0 = sv.df_prior[0];
1353    let s02 = sv.var_prior[0];
1354
1355    // Y (gene by effect): column 0 the contrast effect, the rest residuals. yy is
1356    // the per-gene sum of squares, captured before any shrinkage of column 0.
1357    let mut ymat = Array2::<f64>::zeros((g, neff));
1358    let mut yy = vec![0.0; g];
1359    let mut modt = vec![0.0; g];
1360    for gi in 0..g {
1361        let mut acc = 0.0;
1362        for k in 0..neff {
1363            let e = full[[p0 + k, gi]];
1364            ymat[[gi, k]] = e;
1365            acc += e * e;
1366        }
1367        yy[gi] = acc;
1368        modt[gi] = signc * ymat[[gi, 0]] / sv.var_post[gi].sqrt();
1369    }
1370
1371    // Empirical-Bayes shrinkage of the contrast effect toward a residual.
1372    if shrink_resid {
1373        let pvals: Vec<f64> = modt
1374            .iter()
1375            .map(|&m| 2.0 * pt_upper(m.abs(), d0 + d))
1376            .collect();
1377        let proportion = 1.0 - prop_true_null(&pvals, PropTrueNullMethod::Lfdr, 20);
1378        let stdev_unscaled = 1.0 / rmat[[p - 1, p - 1]].abs();
1379        let var_unscaled = stdev_unscaled * stdev_unscaled;
1380        let df_total = d + d0;
1381        let var_prior_lim = (0.01 / s02, 16.0 / s02);
1382        let su = vec![stdev_unscaled; g];
1383        let dt = vec![df_total; g];
1384        let mut var_prior = tmixture_vector(&modt, &su, &dt, proportion, var_prior_lim);
1385        if var_prior.is_nan() {
1386            var_prior = 1.0 / s02;
1387        }
1388        let r = (var_unscaled + var_prior) / var_unscaled;
1389        let logodds = (proportion / (1.0 - proportion)).ln() - r.ln() / 2.0;
1390        for gi in 0..g {
1391            let t2 = modt[gi] * modt[gi];
1392            let kernel = if d0 > 1e6 {
1393                t2 * (1.0 - 1.0 / r) / 2.0
1394            } else {
1395                (1.0 + df_total) / 2.0 * ((t2 + df_total) / (t2 / r + df_total)).ln()
1396            };
1397            let lods = logodds + kernel;
1398            let prob_de = lods.exp() / (1.0 + lods.exp());
1399            ymat[[gi, 0]] *= (var_unscaled / (var_unscaled + var_prior * prob_de)).sqrt();
1400        }
1401    }
1402
1403    // Observed per-set statistic `[Up, Down, Mixed]` for the chosen aggregation.
1404    let gf = g as f64;
1405    let obs = romer_set_stats(&modt, index, gf, set_statistic);
1406
1407    // For "mean50" the Down statistic is the *small* half of the ranks, so a
1408    // rotation supports Down when it falls at or below the observed value.
1409    let down_low = matches!(set_statistic, RomerStatistic::Mean50);
1410
1411    // Rotations: draw a unit direction in the (d+1)-dim effect space, recompute
1412    // the moderated t, and tally how often each rotated statistic beats observed.
1413    //
1414    // The rotation directions are the test's only randomness and must follow
1415    // limma's exact Mersenne-Twister order, so draw all `nrot` of them serially
1416    // up front (row-major, rotation `r` occupies `draws[r*neff..]`). Consuming a
1417    // drawn row — recomputing modt and tallying per-set hits — is a pure
1418    // function of that row feeding integer counters, so it parallelizes (under
1419    // the `parallel` feature) with a reduction that is bit-identical to the
1420    // serial fold and to limma.
1421    let mut draws = Vec::with_capacity(nrot * neff);
1422    for _ in 0..nrot {
1423        draws.extend_from_slice(&rng.rnorm(neff));
1424    }
1425    let ctx = RomerRotCtx {
1426        draws: &draws,
1427        neff,
1428        ymat: &ymat,
1429        yy: &yy,
1430        index,
1431        obs: &obs,
1432        gf,
1433        set_statistic,
1434        d,
1435        d0,
1436        s02,
1437        signc,
1438        g,
1439        down_low,
1440    };
1441    let nset = index.len();
1442    let count: Vec<[i64; 3]> = {
1443        #[cfg(feature = "parallel")]
1444        {
1445            use rayon::prelude::*;
1446            (0..nrot)
1447                .into_par_iter()
1448                .fold(
1449                    || (vec![[0i64; 3]; nset], vec![0.0f64; neff], vec![0.0f64; g]),
1450                    |(mut acc, mut rvec, mut modtr), r| {
1451                        ctx.add_rotation(r, &mut rvec, &mut modtr, &mut acc);
1452                        (acc, rvec, modtr)
1453                    },
1454                )
1455                .map(|(acc, _, _)| acc)
1456                .reduce(
1457                    || vec![[0i64; 3]; nset],
1458                    |mut a, b| {
1459                        for (x, y) in a.iter_mut().zip(&b) {
1460                            x[0] += y[0];
1461                            x[1] += y[1];
1462                            x[2] += y[2];
1463                        }
1464                        a
1465                    },
1466                )
1467        }
1468        #[cfg(not(feature = "parallel"))]
1469        {
1470            let mut acc = vec![[0i64; 3]; nset];
1471            let mut rvec = vec![0.0f64; neff];
1472            let mut modtr = vec![0.0f64; g];
1473            for r in 0..nrot {
1474                ctx.add_rotation(r, &mut rvec, &mut modtr, &mut acc);
1475            }
1476            acc
1477        }
1478    };
1479
1480    let denom = nrot as f64 + 1.0;
1481    Ok(index
1482        .iter()
1483        .enumerate()
1484        .map(|(si, set)| RomerRow {
1485            set: si,
1486            n_genes: set.len(),
1487            p_up: (count[si][0] as f64 + 1.0) / denom,
1488            p_down: (count[si][1] as f64 + 1.0) / denom,
1489            p_mixed: (count[si][2] as f64 + 1.0) / denom,
1490        })
1491        .collect())
1492}
1493
1494/// Everything a single [`romer`] rotation needs to read, so consuming a rotation
1495/// is a pure function of its index `r`. Holds only shared borrows and `Copy`
1496/// scalars, so it is `Sync` and can be shared across rayon workers; the mutable
1497/// per-rotation scratch (`rvec`, `modtr`) and the accumulator are passed in.
1498struct RomerRotCtx<'a> {
1499    /// All rotation directions, row-major: rotation `r` is `draws[r*neff..]`.
1500    draws: &'a [f64],
1501    neff: usize,
1502    /// Gene-by-effect matrix (column 0 the contrast effect, rest residuals).
1503    ymat: &'a Array2<f64>,
1504    /// Per-gene sum of squared effects (rotation-invariant total).
1505    yy: &'a [f64],
1506    index: &'a [Vec<usize>],
1507    /// Observed per-set `[Up, Down, Mixed]` statistic.
1508    obs: &'a [[f64; 3]],
1509    gf: f64,
1510    set_statistic: RomerStatistic,
1511    d: f64,
1512    d0: f64,
1513    s02: f64,
1514    signc: f64,
1515    g: usize,
1516    down_low: bool,
1517}
1518
1519impl RomerRotCtx<'_> {
1520    /// Recompute the moderated t under rotation `r` and add this rotation's
1521    /// per-set hits into `acc`. `rvec` (len `neff`) and `modtr` (len `g`) are
1522    /// caller-owned scratch reused across rotations. The arithmetic — unit-norm
1523    /// of the drawn row, the per-gene rotated statistic, and the `>=` tallies —
1524    /// is identical to the serial loop, so `acc` is bit-identical regardless of
1525    /// how rotations are split across threads.
1526    fn add_rotation(&self, r: usize, rvec: &mut [f64], modtr: &mut [f64], acc: &mut [[i64; 3]]) {
1527        let row = &self.draws[r * self.neff..(r + 1) * self.neff];
1528        let mut nrm = 0.0;
1529        for (k, &v) in row.iter().enumerate() {
1530            rvec[k] = v;
1531            nrm += v * v;
1532        }
1533        let nrm = nrm.sqrt();
1534        for v in rvec.iter_mut() {
1535            *v /= nrm;
1536        }
1537        for (gi, m) in modtr.iter_mut().enumerate().take(self.g) {
1538            let mut br = 0.0;
1539            for (k, &rv) in rvec.iter().enumerate().take(self.neff) {
1540                br += rv * self.ymat[[gi, k]];
1541            }
1542            let s2r = (self.yy[gi] - br * br) / self.d;
1543            let sdr_post = if self.d0.is_finite() {
1544                ((self.d0 * self.s02 + self.d * s2r) / (self.d0 + self.d)).sqrt()
1545            } else {
1546                self.s02.sqrt()
1547            };
1548            *m = self.signc * br / sdr_post;
1549        }
1550        let rot = romer_set_stats(modtr, self.index, self.gf, self.set_statistic);
1551        for (c, (o, rr)) in acc.iter_mut().zip(self.obs.iter().zip(&rot)) {
1552            if rr[0] >= o[0] {
1553                c[0] += 1;
1554            }
1555            let down_hit = if self.down_low {
1556                rr[1] <= o[1]
1557            } else {
1558                rr[1] >= o[1]
1559            };
1560            if down_hit {
1561                c[1] += 1;
1562            }
1563            if rr[2] >= o[2] {
1564                c[2] += 1;
1565            }
1566        }
1567    }
1568}
1569
1570/// Per-set mean ranks `[Up, Down, Mixed]` from a vector of statistics:
1571/// `Up = mean rank(stat)`, `Down = mean (N - rank(stat) + 1)`,
1572/// `Mixed = mean rank(|stat|)`.
1573fn set_mean_ranks(stat: &[f64], index: &[Vec<usize>], gf: f64) -> Vec<[f64; 3]> {
1574    let r = rank_average(stat);
1575    let abs: Vec<f64> = stat.iter().map(|v| v.abs()).collect();
1576    let ra = rank_average(&abs);
1577    index
1578        .iter()
1579        .map(|set| {
1580            let sz = set.len() as f64;
1581            let mut up = 0.0;
1582            let mut dn = 0.0;
1583            let mut mx = 0.0;
1584            for &gi in set {
1585                up += r[gi];
1586                dn += gf - r[gi] + 1.0;
1587                mx += ra[gi];
1588            }
1589            [up / sz, dn / sz, mx / sz]
1590        })
1591        .collect()
1592}
1593
1594/// Per-set `[Up, Down, Mixed]` statistic for the chosen [`RomerStatistic`].
1595fn romer_set_stats(
1596    stat: &[f64],
1597    index: &[Vec<usize>],
1598    gf: f64,
1599    set_statistic: RomerStatistic,
1600) -> Vec<[f64; 3]> {
1601    match set_statistic {
1602        RomerStatistic::Mean => set_mean_ranks(stat, index, gf),
1603        RomerStatistic::FloorMean => {
1604            // Separate non-negative ranks per direction (limma's pmax flooring).
1605            let up_r = rank_average(&stat.iter().map(|&v| v.max(0.0)).collect::<Vec<_>>());
1606            let dn_r = rank_average(&stat.iter().map(|&v| (-v).max(0.0)).collect::<Vec<_>>());
1607            let mx_r = rank_average(&stat.iter().map(|&v| v.abs().max(1.0)).collect::<Vec<_>>());
1608            index
1609                .iter()
1610                .map(|set| {
1611                    let sz = set.len() as f64;
1612                    let mut up = 0.0;
1613                    let mut dn = 0.0;
1614                    let mut mx = 0.0;
1615                    for &gi in set {
1616                        up += up_r[gi];
1617                        dn += dn_r[gi];
1618                        mx += mx_r[gi];
1619                    }
1620                    [up / sz, dn / sz, mx / sz]
1621                })
1622                .collect()
1623        }
1624        RomerStatistic::Mean50 => {
1625            let r = rank_average(stat);
1626            let ra = rank_average(&stat.iter().map(|&v| v.abs()).collect::<Vec<_>>());
1627            index
1628                .iter()
1629                .map(|set| {
1630                    let m = set.len().div_ceil(2); // floor((|set| + 1) / 2)
1631                    let r_set: Vec<f64> = set.iter().map(|&gi| r[gi]).collect();
1632                    let ra_set: Vec<f64> = set.iter().map(|&gi| ra[gi]).collect();
1633                    let (small, large) = mean_half(&r_set, m);
1634                    let (_, large_abs) = mean_half(&ra_set, m);
1635                    // Up = larger half of the signed ranks, Down = smaller half,
1636                    // Mixed = larger half of the absolute ranks.
1637                    [large, small, large_abs]
1638                })
1639                .collect()
1640        }
1641    }
1642}
1643
1644/// Mean of the smaller and larger halves of `x` (`.meanHalf`). `n` is the
1645/// 1-based split point `floor((len + 1) / 2)`; for odd lengths the median is
1646/// counted in both halves, matching limma. Returns `(small_half, large_half)`.
1647fn mean_half(x: &[f64], n: usize) -> (f64, f64) {
1648    let l = x.len();
1649    let mut a = x.to_vec();
1650    a.sort_by(|p, q| p.partial_cmp(q).unwrap());
1651    let small = a[..n].iter().sum::<f64>() / n as f64;
1652    let large = if l % 2 == 0 {
1653        a[n..].iter().sum::<f64>() / (l - n) as f64
1654    } else {
1655        a[(n - 1)..].iter().sum::<f64>() / (l - n + 1) as f64
1656    };
1657    (small, large)
1658}
1659
1660/// Rank gene sets from a [`romer`] result and keep the top `n` (`topRomer`).
1661///
1662/// Mirrors `topRomer`'s ordering: by the chosen alternative's p-value, then the
1663/// mixed p-value (for up/down) or `min(Up, Down)` (for mixed), then descending
1664/// set size. Ties keep input order, matching R's stable `order`.
1665pub fn top_romer(rows: &[RomerRow], n: usize, alternative: RomerAlternative) -> Vec<RomerRow> {
1666    let mut idx: Vec<usize> = (0..rows.len()).collect();
1667    let key = |r: &RomerRow| match alternative {
1668        RomerAlternative::Up => r.p_up,
1669        RomerAlternative::Down => r.p_down,
1670        RomerAlternative::Mixed => r.p_mixed,
1671    };
1672    idx.sort_by(|&a, &b| {
1673        let primary = key(&rows[a]).partial_cmp(&key(&rows[b])).unwrap();
1674        let secondary = match alternative {
1675            RomerAlternative::Mixed => rows[a]
1676                .p_up
1677                .min(rows[a].p_down)
1678                .partial_cmp(&rows[b].p_up.min(rows[b].p_down))
1679                .unwrap(),
1680            _ => rows[a].p_mixed.partial_cmp(&rows[b].p_mixed).unwrap(),
1681        };
1682        primary
1683            .then(secondary)
1684            .then(rows[b].n_genes.cmp(&rows[a].n_genes))
1685    });
1686    idx.into_iter()
1687        .take(n.min(rows.len()))
1688        .map(|i| rows[i].clone())
1689        .collect()
1690}
1691
1692/// Map a list of gene sets (each a list of identifiers) to 0-based indices into
1693/// `identifiers` (`ids2indices`). With `remove_empty`, sets that match nothing
1694/// are dropped.
1695pub fn ids2indices(
1696    gene_sets: &[Vec<String>],
1697    identifiers: &[String],
1698    remove_empty: bool,
1699) -> Vec<Vec<usize>> {
1700    let mut out = Vec::with_capacity(gene_sets.len());
1701    for set in gene_sets {
1702        let want: HashSet<&str> = set.iter().map(|s| s.as_str()).collect();
1703        let idx: Vec<usize> = identifiers
1704            .iter()
1705            .enumerate()
1706            .filter_map(|(i, id)| want.contains(id.as_str()).then_some(i))
1707            .collect();
1708        if remove_empty && idx.is_empty() {
1709            continue;
1710        }
1711        out.push(idx);
1712    }
1713    out
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718    use super::*;
1719
1720    // 20 genes, mixed signs, with a tie (stats[2] == stats[10] == 1.8).
1721    fn fixture() -> Vec<f64> {
1722        vec![
1723            2.1, -0.5, 1.8, 0.3, -1.2, 2.5, -0.1, 1.1, -2.2, 0.7, 1.8, -0.9, 0.4, -1.5, 2.0, -0.3,
1724            1.3, -0.8, 0.6, -2.1,
1725        ]
1726    }
1727
1728    // 1-based R index c(1,3,6,8,15,17) -> 0-based.
1729    fn up_set() -> Vec<usize> {
1730        vec![0, 2, 5, 7, 14, 16]
1731    }
1732
1733    #[test]
1734    fn rank_average_handles_ties() {
1735        // Two tied 1.8 values share rank (their averaged position).
1736        let r = rank_average(&fixture());
1737        assert_eq!(r[2], r[10]);
1738    }
1739
1740    #[test]
1741    fn gene_set_test_matches_r() {
1742        let stats = fixture();
1743        let idx = up_set();
1744        let cases = [
1745            (Alternative::Up, 0.000645718763498011),
1746            (Alternative::Down, 0.999517239270778),
1747            (Alternative::Either, 0.00129143752699602),
1748            (Alternative::Mixed, 0.0143292516670446),
1749        ];
1750        for (alt, want) in cases {
1751            let got = gene_set_test(&idx, &stats, alt);
1752            assert!(
1753                (got - want).abs() < 1e-9,
1754                "gene_set_test({alt:?}): got {got}, want {want}"
1755            );
1756        }
1757        // wilcoxGST == geneSetTest mixed.
1758        let w = wilcox_gst(&idx, &stats);
1759        assert!((w - 0.0143292516670446).abs() < 1e-9);
1760    }
1761
1762    #[test]
1763    fn rank_sum_test_matches_r() {
1764        let stats = fixture();
1765        let idx = up_set();
1766
1767        // correlation = 0.1, df = 10.
1768        let (less, greater) = rank_sum_test_with_correlation(&idx, &stats, 0.1, 10.0);
1769        assert!((less - 0.991665460749303).abs() < 1e-9);
1770        assert!((greater - 0.0094257162710415).abs() < 1e-9);
1771
1772        // correlation = 0, df = Inf (normal).
1773        let (less, greater) = rank_sum_test_with_correlation(&idx, &stats, 0.0, f64::INFINITY);
1774        assert!((less - 0.999517239270778).abs() < 1e-9);
1775        assert!((greater - 0.000645718763498011).abs() < 1e-9);
1776
1777        // A different set, correlation = 0.25, df = 18. R index c(2,5,9,14,20).
1778        let idx2 = [1, 4, 8, 13, 19];
1779        let (less, greater) = rank_sum_test_with_correlation(&idx2, &stats, 0.25, 18.0);
1780        assert!((less - 0.0152351621428473).abs() < 1e-9);
1781        assert!((greater - 0.986720588745113).abs() < 1e-9);
1782    }
1783
1784    fn three_sets() -> Vec<Vec<usize>> {
1785        vec![
1786            vec![0, 2, 5, 7, 14, 16], // set1 = c(1,3,6,8,15,17)
1787            vec![1, 4, 8, 13, 19],    // set2 = c(2,5,9,14,20)
1788            vec![3, 6, 9, 12],        // set3 = c(4,7,10,13)
1789        ]
1790    }
1791
1792    #[test]
1793    fn camera_pr_parametric_matches_r() {
1794        let stat = fixture();
1795        let sets = three_sets();
1796        let rows = camera_pr(&stat, &sets, 0.01, false, true);
1797        // R output order (sorted by p-value): set1, set2, set3.
1798        let want = [
1799            (
1800                0usize,
1801                6usize,
1802                Direction::Up,
1803                0.000305279883783743,
1804                0.000489203611923099,
1805            ),
1806            (
1807                1,
1808                5,
1809                Direction::Down,
1810                0.000326135741282066,
1811                0.000489203611923099,
1812            ),
1813            (2, 4, Direction::Up, 0.91114902618042, 0.91114902618042),
1814        ];
1815        assert_eq!(rows.len(), want.len());
1816        for (r, (set, ng, dir, p, fdr)) in rows.iter().zip(want) {
1817            assert_eq!(r.set, set);
1818            assert_eq!(r.n_genes, ng);
1819            assert_eq!(r.direction, dir);
1820            assert!(
1821                (r.p_value - p).abs() < 1e-9,
1822                "p: got {}, want {p}",
1823                r.p_value
1824            );
1825            assert!((r.fdr - fdr).abs() < 1e-9, "fdr: got {}, want {fdr}", r.fdr);
1826        }
1827    }
1828
1829    #[test]
1830    fn camera_pr_use_ranks_matches_r() {
1831        let stat = fixture();
1832        let sets = three_sets();
1833        let rows = camera_pr(&stat, &sets, 0.01, true, true);
1834        let want = [
1835            (
1836                0usize,
1837                Direction::Up,
1838                0.00153858324497317,
1839                0.00385566266453651,
1840            ),
1841            (1, Direction::Down, 0.00257044177635767, 0.00385566266453651),
1842            (2, Direction::Up, 0.962711741316641, 0.962711741316641),
1843        ];
1844        for (r, (set, dir, p, fdr)) in rows.iter().zip(want) {
1845            assert_eq!(r.set, set);
1846            assert_eq!(r.direction, dir);
1847            assert!(
1848                (r.p_value - p).abs() < 1e-9,
1849                "p: got {}, want {p}",
1850                r.p_value
1851            );
1852            assert!((r.fdr - fdr).abs() < 1e-9, "fdr: got {}, want {fdr}", r.fdr);
1853        }
1854    }
1855
1856    // 12 genes x 6 samples; design = model.matrix(~group), group=A,A,A,B,B,B.
1857    fn camera_exprs() -> Array2<f64> {
1858        Array2::from_shape_vec(
1859            (12, 6),
1860            vec![
1861                4.871, 4.629, 4.697, 5.807, 4.798, 5.195, //
1862                6.356, 6.349, 6.764, 4.125, 3.125, 4.752, //
1863                4.298, 4.659, 4.508, 5.936, 4.075, 7.367, //
1864                8.896, 9.420, 8.915, 9.165, 9.466, 8.598, //
1865                6.563, 6.610, 6.813, 6.123, 6.155, 7.309, //
1866                4.443, 4.283, 3.851, 5.435, 5.304, 5.784, //
1867                7.247, 7.184, 7.620, 6.533, 7.878, 6.820, //
1868                7.456, 7.644, 8.368, 9.096, 7.422, 10.245, //
1869                7.229, 6.945, 6.986, 8.178, 7.445, 10.159, //
1870                5.378, 5.177, 4.919, 7.692, 6.023, 7.432, //
1871                8.748, 9.133, 9.280, 9.431, 10.394, 11.954, //
1872                6.697, 7.010, 6.719, 4.293, 3.114, 5.796, //
1873            ],
1874        )
1875        .unwrap()
1876    }
1877
1878    fn camera_design() -> Array2<f64> {
1879        Array2::from_shape_vec(
1880            (6, 2),
1881            vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
1882        )
1883        .unwrap()
1884    }
1885
1886    fn camera_sets() -> Vec<Vec<usize>> {
1887        vec![
1888            vec![0, 1, 2, 3],    // set1 = c(1,2,3,4)
1889            vec![4, 5, 6, 7, 8], // set2 = c(5,6,7,8,9)
1890            vec![9, 10, 11],     // set3 = c(10,11,12)
1891        ]
1892    }
1893
1894    #[test]
1895    fn inter_gene_correlation_matches_r() {
1896        let (vif, cor) = inter_gene_correlation(&camera_exprs(), &camera_design());
1897        assert!((vif - 3.54050719052897).abs() < 1e-9, "vif: {vif}");
1898        assert!((cor - 0.230955199138998).abs() < 1e-9, "cor: {cor}");
1899    }
1900
1901    #[test]
1902    fn camera_parametric_matches_r() {
1903        let rows = camera(
1904            &camera_exprs(),
1905            &camera_design(),
1906            1,
1907            &camera_sets(),
1908            0.01,
1909            false,
1910            true,
1911        )
1912        .unwrap();
1913        // R output order (sorted by p-value): set1, set2, set3.
1914        let want = [
1915            (
1916                0usize,
1917                4usize,
1918                Direction::Down,
1919                0.42121952380793,
1920                0.753808705609041,
1921            ),
1922            (1, 5, Direction::Up, 0.502539137072694, 0.753808705609041),
1923            (2, 3, Direction::Up, 0.916115986180527, 0.916115986180527),
1924        ];
1925        assert_eq!(rows.len(), want.len());
1926        for (r, (set, ng, dir, p, fdr)) in rows.iter().zip(want) {
1927            assert_eq!(r.set, set);
1928            assert_eq!(r.n_genes, ng);
1929            assert_eq!(r.direction, dir);
1930            assert!(
1931                (r.p_value - p).abs() < 1e-7,
1932                "p: got {}, want {p}",
1933                r.p_value
1934            );
1935            assert!((r.fdr - fdr).abs() < 1e-7, "fdr: got {}, want {fdr}", r.fdr);
1936        }
1937    }
1938
1939    #[test]
1940    fn camera_use_ranks_matches_r() {
1941        let rows = camera(
1942            &camera_exprs(),
1943            &camera_design(),
1944            1,
1945            &camera_sets(),
1946            0.01,
1947            true,
1948            true,
1949        )
1950        .unwrap();
1951        // R output order (sorted by p-value): set1, set3, set2.
1952        let want = [
1953            (
1954                0usize,
1955                Direction::Down,
1956                0.354526685759271,
1957                0.693805951243274,
1958            ),
1959            (2, Direction::Up, 0.462537300828849, 0.693805951243274),
1960            (1, Direction::Up, 0.872315053291437, 0.872315053291437),
1961        ];
1962        assert_eq!(rows.len(), want.len());
1963        for (r, (set, dir, p, fdr)) in rows.iter().zip(want) {
1964            assert_eq!(r.set, set);
1965            assert_eq!(r.direction, dir);
1966            assert!(
1967                (r.p_value - p).abs() < 1e-7,
1968                "p: got {}, want {p}",
1969                r.p_value
1970            );
1971            assert!((r.fdr - fdr).abs() < 1e-7, "fdr: got {}, want {fdr}", r.fdr);
1972        }
1973    }
1974
1975    #[test]
1976    fn fry_matches_r() {
1977        let rows = fry(
1978            &camera_exprs(),
1979            &camera_design(),
1980            1,
1981            &camera_sets(),
1982            FrySort::Directional,
1983        )
1984        .unwrap();
1985        // R output order (directional sort): set2, set1, set3.
1986        let want = [
1987            (
1988                1usize,
1989                5usize,
1990                Direction::Up,
1991                0.124433966893834,
1992                0.373301900681503,
1993                0.0667070665318511,
1994                0.0667070665318511,
1995            ),
1996            (
1997                0,
1998                4,
1999                Direction::Down,
2000                0.44028222758847,
2001                0.45071371571786,
2002                0.00113516116620128,
2003                0.00170274174930192,
2004            ),
2005            (
2006                2,
2007                3,
2008                Direction::Up,
2009                0.45071371571786,
2010                0.45071371571786,
2011                0.000139022937183932,
2012                0.000417068811551796,
2013            ),
2014        ];
2015        assert_eq!(rows.len(), want.len());
2016        for (r, (set, ng, dir, p, fdr, pm, fdrm)) in rows.iter().zip(want) {
2017            assert_eq!(r.set, set);
2018            assert_eq!(r.n_genes, ng);
2019            assert_eq!(r.direction, dir);
2020            assert!(
2021                (r.p_value - p).abs() < 1e-6,
2022                "p: got {}, want {p}",
2023                r.p_value
2024            );
2025            assert!((r.fdr - fdr).abs() < 1e-6, "fdr: got {}, want {fdr}", r.fdr);
2026            assert!(
2027                (r.p_value_mixed - pm).abs() < 1e-6,
2028                "pm: got {}, want {pm}",
2029                r.p_value_mixed
2030            );
2031            assert!(
2032                (r.fdr_mixed - fdrm).abs() < 1e-6,
2033                "fdrm: got {}, want {fdrm}",
2034                r.fdr_mixed
2035            );
2036        }
2037    }
2038
2039    #[test]
2040    fn ids2indices_maps_and_drops_empty() {
2041        let ids: Vec<String> = ["a", "b", "c", "d", "e"]
2042            .iter()
2043            .map(|s| s.to_string())
2044            .collect();
2045        let sets = vec![
2046            vec!["b".to_string(), "d".to_string()],
2047            vec!["x".to_string()],
2048            vec!["a".to_string(), "e".to_string(), "c".to_string()],
2049        ];
2050        let with_empty = ids2indices(&sets, &ids, false);
2051        assert_eq!(with_empty, vec![vec![1, 3], vec![], vec![0, 2, 4]]);
2052        let without = ids2indices(&sets, &ids, true);
2053        assert_eq!(without, vec![vec![1, 3], vec![0, 2, 4]]);
2054    }
2055
2056    #[test]
2057    #[allow(clippy::excessive_precision)]
2058    fn roast_matches_r() {
2059        // y = matrix(rnorm(50*6), 50, 6) after set.seed(2024). RRng is bit-exact
2060        // to R's rnorm, so regenerating here yields the same matrix R's reference
2061        // used (column-major fill, matching R's matrix() storage order).
2062        let g = 50usize;
2063        let n = 6usize;
2064        let y_data = RRng::new(2024).rnorm(g * n);
2065        let y = Array2::from_shape_vec((n, g), y_data)
2066            .unwrap()
2067            .t()
2068            .to_owned();
2069
2070        // design = cbind(Intercept=1, Group=c(0,0,0,1,1,1)); contrast=2 -> coef=1.
2071        let design = Array2::from_shape_vec(
2072            (2, n),
2073            vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
2074        )
2075        .unwrap()
2076        .t()
2077        .to_owned();
2078
2079        let check = |tag: &str,
2080                     seed: i32,
2081                     index: &[usize],
2082                     nrot: usize,
2083                     want_active: [f64; 4],
2084                     want_p: [f64; 4]| {
2085            let mut rng = RRng::new(seed);
2086            let out = roast(&y, &design, 1, index, nrot, &mut rng).unwrap();
2087            assert_eq!(out.n_genes_in_set, index.len());
2088            for i in 0..4 {
2089                assert!(
2090                    (out.active_prop[i] - want_active[i]).abs() < 1e-12,
2091                    "{tag} active[{i}]: got {}, want {}",
2092                    out.active_prop[i],
2093                    want_active[i]
2094                );
2095                // Counts are integers; if they match R the p-value is bit-exact.
2096                // A loose tolerance here would still flag any single-count drift
2097                // (which moves a p-value by ~1/nrot, far above this threshold).
2098                assert!(
2099                    (out.p_value[i] - want_p[i]).abs() < 1e-12,
2100                    "{tag} p[{i}]: got {}, want {}",
2101                    out.p_value[i],
2102                    want_p[i]
2103                );
2104            }
2105        };
2106
2107        let idx_a: Vec<usize> = (0..10).collect();
2108        check(
2109            "A",
2110            99,
2111            &idx_a,
2112            1999,
2113            [0.1, 0.0, 0.1, 0.1],
2114            [0.47211802950737686, 0.52813203300825207, 0.944, 0.344],
2115        );
2116
2117        let idx_b: Vec<usize> = (10..35).collect();
2118        check(
2119            "B",
2120            7,
2121            &idx_b,
2122            1999,
2123            [0.0, 0.12, 0.12, 0.12],
2124            [0.90547636909227303, 0.094773693423355843, 0.1895, 0.384],
2125        );
2126
2127        check(
2128            "C",
2129            123,
2130            &idx_a,
2131            999,
2132            [0.1, 0.0, 0.1, 0.1],
2133            [0.47423711855927964, 0.52626313156578286, 0.948, 0.364],
2134        );
2135    }
2136
2137    #[test]
2138    #[allow(clippy::excessive_precision)]
2139    fn mroast_matches_r() {
2140        // Same y/design fixture as roast_matches_r (set.seed(2024)); reference
2141        // from scratch/mroast_ref.R with seed 314, nrot 1999, midp = TRUE.
2142        let g = 50usize;
2143        let n = 6usize;
2144        let y = Array2::from_shape_vec((n, g), RRng::new(2024).rnorm(g * n))
2145            .unwrap()
2146            .t()
2147            .to_owned();
2148        let design = Array2::from_shape_vec(
2149            (2, n),
2150            vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
2151        )
2152        .unwrap()
2153        .t()
2154        .to_owned();
2155
2156        // index = list(S1=1:10, S2=11:35, S3=c(5:15,40:50), S4=20:24) (0-based).
2157        let index: Vec<Vec<usize>> = vec![
2158            (0..10).collect(),
2159            (10..35).collect(),
2160            (4..15).chain(39..50).collect(),
2161            (19..24).collect(),
2162        ];
2163
2164        // sort = "none": validate every per-set number in input order.
2165        let mut rng = RRng::new(314);
2166        let tab = mroast(
2167            &y,
2168            &design,
2169            1,
2170            &index,
2171            1999,
2172            true,
2173            FrySort::NoSort,
2174            &mut rng,
2175        )
2176        .unwrap();
2177        let want_ngenes = [10usize, 25, 22, 5];
2178        let want_propdown = [0.1, 0.0, 0.090909090909090912, 0.0];
2179        let want_propup = [0.0, 0.12, 0.090909090909090912, 0.4];
2180        let want_dir = [
2181            Direction::Down,
2182            Direction::Up,
2183            Direction::Down,
2184            Direction::Up,
2185        ];
2186        let want_p = [0.943, 0.1875, 0.6645, 0.124];
2187        let want_fdr = [0.943, 0.3745, 0.8856666666666666, 0.3745];
2188        let want_pm = [0.348, 0.403, 0.5915, 0.1265];
2189        let want_fdrm = [0.537, 0.537, 0.5915, 0.505];
2190        for i in 0..4 {
2191            assert_eq!(tab[i].set, i, "row {i} set");
2192            assert_eq!(tab[i].n_genes, want_ngenes[i], "row {i} ngenes");
2193            assert_eq!(tab[i].direction, want_dir[i], "row {i} direction");
2194            assert!(
2195                (tab[i].prop_down - want_propdown[i]).abs() < 1e-12,
2196                "row {i} propdown: got {}, want {}",
2197                tab[i].prop_down,
2198                want_propdown[i]
2199            );
2200            assert!(
2201                (tab[i].prop_up - want_propup[i]).abs() < 1e-12,
2202                "row {i} propup: got {}, want {}",
2203                tab[i].prop_up,
2204                want_propup[i]
2205            );
2206            assert!(
2207                (tab[i].p_value - want_p[i]).abs() < 1e-12,
2208                "row {i} pvalue: got {}, want {}",
2209                tab[i].p_value,
2210                want_p[i]
2211            );
2212            assert!(
2213                (tab[i].fdr - want_fdr[i]).abs() < 1e-12,
2214                "row {i} fdr: got {}, want {}",
2215                tab[i].fdr,
2216                want_fdr[i]
2217            );
2218            assert!(
2219                (tab[i].p_value_mixed - want_pm[i]).abs() < 1e-12,
2220                "row {i} pvalue_mixed: got {}, want {}",
2221                tab[i].p_value_mixed,
2222                want_pm[i]
2223            );
2224            assert!(
2225                (tab[i].fdr_mixed - want_fdrm[i]).abs() < 1e-12,
2226                "row {i} fdr_mixed: got {}, want {}",
2227                tab[i].fdr_mixed,
2228                want_fdrm[i]
2229            );
2230        }
2231
2232        // sort = "directional": rows ordered S4,S2,S3,S1.
2233        let mut rng = RRng::new(314);
2234        let td = mroast(
2235            &y,
2236            &design,
2237            1,
2238            &index,
2239            1999,
2240            true,
2241            FrySort::Directional,
2242            &mut rng,
2243        )
2244        .unwrap();
2245        let order_d: Vec<usize> = td.iter().map(|r| r.set).collect();
2246        assert_eq!(order_d, vec![3, 1, 2, 0], "directional order");
2247
2248        // sort = "mixed": rows ordered S4,S1,S2,S3.
2249        let mut rng = RRng::new(314);
2250        let tm = mroast(&y, &design, 1, &index, 1999, true, FrySort::Mixed, &mut rng).unwrap();
2251        let order_m: Vec<usize> = tm.iter().map(|r| r.set).collect();
2252        assert_eq!(order_m, vec![3, 0, 1, 2], "mixed order");
2253    }
2254
2255    #[test]
2256    #[allow(clippy::excessive_precision)]
2257    fn romer_matches_r() {
2258        // Same y/design fixture as roast_matches_r (set.seed(2024)); reference
2259        // from scratch/romer_ref.R with seed 271, nrot 999, shrink.resid = TRUE.
2260        let g = 50usize;
2261        let n = 6usize;
2262        let y = Array2::from_shape_vec((n, g), RRng::new(2024).rnorm(g * n))
2263            .unwrap()
2264            .t()
2265            .to_owned();
2266        let design = Array2::from_shape_vec(
2267            (2, n),
2268            vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
2269        )
2270        .unwrap()
2271        .t()
2272        .to_owned();
2273
2274        // index = list(S1=1:10, S2=11:35, S3=c(5:15,40:50), S4=20:24) (0-based).
2275        let index: Vec<Vec<usize>> = vec![
2276            (0..10).collect(),
2277            (10..35).collect(),
2278            (4..15).chain(39..50).collect(),
2279            (19..24).collect(),
2280        ];
2281
2282        let mut rng = RRng::new(271);
2283        let tab = romer(
2284            &y,
2285            &design,
2286            1,
2287            &index,
2288            RomerStatistic::Mean,
2289            999,
2290            true,
2291            &mut rng,
2292        )
2293        .unwrap();
2294        let want_ngenes = [10usize, 25, 22, 5];
2295        let want_up = [0.563, 0.476, 0.987, 0.492];
2296        let want_down = [0.441, 0.527, 0.016, 0.526];
2297        let want_mixed = [0.231, 0.341, 0.632, 0.139];
2298        for i in 0..4 {
2299            assert_eq!(tab[i].set, i, "row {i} set");
2300            assert_eq!(tab[i].n_genes, want_ngenes[i], "row {i} ngenes");
2301            // Counts are integers; a bit-exact p-value implies the counts match.
2302            assert!(
2303                (tab[i].p_up - want_up[i]).abs() < 1e-12,
2304                "row {i} up: got {}, want {}",
2305                tab[i].p_up,
2306                want_up[i]
2307            );
2308            assert!(
2309                (tab[i].p_down - want_down[i]).abs() < 1e-12,
2310                "row {i} down: got {}, want {}",
2311                tab[i].p_down,
2312                want_down[i]
2313            );
2314            assert!(
2315                (tab[i].p_mixed - want_mixed[i]).abs() < 1e-12,
2316                "row {i} mixed: got {}, want {}",
2317                tab[i].p_mixed,
2318                want_mixed[i]
2319            );
2320        }
2321
2322        // topRomer orderings (S1..S4 -> 0..3).
2323        let order = |a: RomerAlternative| -> Vec<usize> {
2324            top_romer(&tab, 4, a).iter().map(|r| r.set).collect()
2325        };
2326        assert_eq!(order(RomerAlternative::Up), vec![1, 3, 0, 2], "top up");
2327        assert_eq!(order(RomerAlternative::Down), vec![2, 0, 3, 1], "top down");
2328        assert_eq!(
2329            order(RomerAlternative::Mixed),
2330            vec![3, 0, 1, 2],
2331            "top mixed"
2332        );
2333
2334        // Non-default set.statistic options on the same fixture/seed; reference
2335        // integer rotation counts from scratch/romer_stats_ref.R. The p-value is
2336        // (count + 1) / (nrot + 1), so matching p implies matching counts.
2337        let check_stat =
2338            |stat: RomerStatistic, up: [i64; 4], down: [i64; 4], mixed: [i64; 4], tag: &str| {
2339                let mut rng = RRng::new(271);
2340                let t = romer(&y, &design, 1, &index, stat, 999, true, &mut rng).unwrap();
2341                let p = |c: i64| (c as f64 + 1.0) / 1000.0;
2342                for i in 0..4 {
2343                    assert!((t[i].p_up - p(up[i])).abs() < 1e-12, "{tag} row {i} up");
2344                    assert!(
2345                        (t[i].p_down - p(down[i])).abs() < 1e-12,
2346                        "{tag} row {i} down"
2347                    );
2348                    assert!(
2349                        (t[i].p_mixed - p(mixed[i])).abs() < 1e-12,
2350                        "{tag} row {i} mixed"
2351                    );
2352                }
2353            };
2354        check_stat(
2355            RomerStatistic::FloorMean,
2356            [426, 505, 925, 452],
2357            [477, 367, 121, 169],
2358            [900, 366, 666, 201],
2359            "floormean",
2360        );
2361        check_stat(
2362            RomerStatistic::Mean50,
2363            [479, 351, 935, 329],
2364            [361, 396, 109, 165],
2365            [690, 246, 707, 169],
2366            "mean50",
2367        );
2368    }
2369
2370    #[test]
2371    #[allow(clippy::excessive_precision)]
2372    fn contrast_as_coef_matches_r() {
2373        // design = cbind(A = 1, B = c(0,0,1,1,0,0), C = c(0,0,0,0,1,1)), 6 x 3.
2374        let design = Array2::from_shape_vec(
2375            (3, 6),
2376            vec![
2377                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, // A
2378                0.0, 0.0, 1.0, 1.0, 0.0, 0.0, // B
2379                0.0, 0.0, 0.0, 0.0, 1.0, 1.0, // C
2380            ],
2381        )
2382        .unwrap()
2383        .t()
2384        .to_owned();
2385
2386        // Compare reformed design (column-major), coef and rank against limma.
2387        let check = |out: &ContrastAsCoef, rank: usize, coef: &[usize], cm: &[f64]| {
2388            assert_eq!(out.rank, rank, "rank");
2389            assert_eq!(out.coef.as_slice(), coef, "coef");
2390            let n = out.design.nrows();
2391            let p = out.design.ncols();
2392            let mut flat = Vec::with_capacity(n * p);
2393            for j in 0..p {
2394                for i in 0..n {
2395                    flat.push(out.design[[i, j]]);
2396                }
2397            }
2398            assert_eq!(flat.len(), cm.len(), "design length");
2399            for (idx, (&a, &b)) in flat.iter().zip(cm).enumerate() {
2400                assert!((a - b).abs() < 1e-12, "design[{idx}]: {a} vs {b}");
2401            }
2402        };
2403
2404        // Single contrast B - C.
2405        let v1 = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, -1.0]).unwrap();
2406        check(
2407            &contrast_as_coef(&design, &v1, true).unwrap(),
2408            1,
2409            &[0],
2410            &[
2411                0.0,
2412                0.0,
2413                0.49999999999999994,
2414                0.49999999999999994,
2415                -0.49999999999999994,
2416                -0.49999999999999994,
2417                -0.70710678118654746,
2418                -0.70710678118654746,
2419                -0.20710678118654746,
2420                -0.20710678118654746,
2421                -0.20710678118654754,
2422                -0.20710678118654754,
2423                0.70710678118654746,
2424                0.70710678118654746,
2425                1.2071067811865475,
2426                1.2071067811865475,
2427                1.2071067811865475,
2428                1.2071067811865475,
2429            ],
2430        );
2431        check(
2432            &contrast_as_coef(&design, &v1, false).unwrap(),
2433            1,
2434            &[2],
2435            &[
2436                -0.70710678118654746,
2437                -0.70710678118654746,
2438                -0.20710678118654746,
2439                -0.20710678118654746,
2440                -0.20710678118654754,
2441                -0.20710678118654754,
2442                0.70710678118654746,
2443                0.70710678118654746,
2444                1.2071067811865475,
2445                1.2071067811865475,
2446                1.2071067811865475,
2447                1.2071067811865475,
2448                0.0,
2449                0.0,
2450                0.49999999999999994,
2451                0.49999999999999994,
2452                -0.49999999999999994,
2453                -0.49999999999999994,
2454            ],
2455        );
2456
2457        // Two-column full-rank contrast: (B-C) and (B+C-2A).
2458        let m2 = Array2::from_shape_vec((3, 2), vec![0.0, -2.0, 1.0, 1.0, -1.0, 1.0]).unwrap();
2459        check(
2460            &contrast_as_coef(&design, &m2, true).unwrap(),
2461            2,
2462            &[0, 1],
2463            &[
2464                0.0,
2465                0.0,
2466                0.49999999999999994,
2467                0.49999999999999994,
2468                -0.49999999999999994,
2469                -0.49999999999999994,
2470                -0.33333333333333337,
2471                -0.33333333333333337,
2472                -0.16666666666666663,
2473                -0.16666666666666663,
2474                -0.16666666666666669,
2475                -0.16666666666666669,
2476                0.57735026918962573,
2477                0.57735026918962573,
2478                1.1547005383792515,
2479                1.1547005383792515,
2480                1.1547005383792515,
2481                1.1547005383792515,
2482            ],
2483        );
2484        check(
2485            &contrast_as_coef(&design, &m2, false).unwrap(),
2486            2,
2487            &[1, 2],
2488            &[
2489                0.57735026918962573,
2490                0.57735026918962573,
2491                1.1547005383792515,
2492                1.1547005383792515,
2493                1.1547005383792515,
2494                1.1547005383792515,
2495                0.0,
2496                0.0,
2497                0.49999999999999994,
2498                0.49999999999999994,
2499                -0.49999999999999994,
2500                -0.49999999999999994,
2501                -0.33333333333333337,
2502                -0.33333333333333337,
2503                -0.16666666666666663,
2504                -0.16666666666666663,
2505                -0.16666666666666669,
2506                -0.16666666666666669,
2507            ],
2508        );
2509    }
2510}