Skip to main content

cyanea_stats/
diffexpr.rs

1//! Differential expression analysis for count data.
2//!
3//! Provides two methods for identifying differentially expressed genes between
4//! two conditions:
5//!
6//! - **Negative binomial Wald test** ([`DeMethod::NegativeBinomial`]) — models
7//!   count overdispersion (DESeq2-style pipeline with size-factor normalization,
8//!   method-of-moments dispersion, and Wald z-test).
9//! - **Wilcoxon rank-sum test** ([`DeMethod::Wilcoxon`]) — non-parametric
10//!   alternative using the existing [`crate::testing::mann_whitney_u`].
11//!
12//! Both methods apply Benjamini-Hochberg correction via
13//! [`crate::correction::benjamini_hochberg`].
14//!
15//! The [`volcano_plot`] function converts results into points suitable for
16//! plotting (log2 fold-change vs. −log10 adjusted p-value).
17
18use cyanea_core::{CyaneaError, Result};
19
20use crate::correction;
21use crate::distribution::{Distribution, Normal};
22use crate::normalization;
23use crate::testing;
24
25// ── Result types ─────────────────────────────────────────────────────────────
26
27/// Which statistical method to use for differential expression.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum DeMethod {
30    /// Negative binomial Wald test (DESeq2-style).
31    NegativeBinomial,
32    /// Wilcoxon rank-sum (Mann-Whitney U) test.
33    Wilcoxon,
34}
35
36/// Per-gene differential expression result.
37#[derive(Debug, Clone)]
38pub struct DeGeneResult {
39    /// Index of the gene in the input matrix.
40    pub gene_index: usize,
41    /// Log2 fold-change (condition / control).
42    pub log2_fold_change: f64,
43    /// Mean of normalized counts across all samples.
44    pub base_mean: f64,
45    /// Test statistic (Wald z for NB, U for Wilcoxon).
46    pub statistic: f64,
47    /// Raw p-value.
48    pub p_value: f64,
49    /// Benjamini-Hochberg adjusted p-value.
50    pub p_adjusted: f64,
51}
52
53/// Aggregate results from a differential expression analysis.
54#[derive(Debug, Clone)]
55pub struct DeResults {
56    /// Per-gene results, sorted by p_value ascending.
57    pub genes: Vec<DeGeneResult>,
58    /// Method used.
59    pub method: DeMethod,
60    /// Number of genes tested.
61    pub n_genes: usize,
62    /// Number of condition (treatment) samples.
63    pub n_condition: usize,
64    /// Number of control samples.
65    pub n_control: usize,
66}
67
68/// A point for a volcano plot.
69#[derive(Debug, Clone)]
70pub struct VolcanoPoint {
71    /// Gene index from the original matrix.
72    pub gene_index: usize,
73    /// Log2 fold-change.
74    pub log2_fold_change: f64,
75    /// −log10(adjusted p-value), clamped to 300.
76    pub neg_log10_padj: f64,
77    /// Whether this gene passes significance thresholds.
78    pub significant: bool,
79}
80
81// ── Main entry point ─────────────────────────────────────────────────────────
82
83/// Run differential expression analysis on a count matrix.
84///
85/// - `counts`: row-major `n_genes × n_samples` count matrix.
86/// - `condition`: boolean mask (`true` = treatment, `false` = control), one
87///   per sample.
88/// - `method`: which test to apply.
89///
90/// Returns [`DeResults`] with genes sorted by ascending p-value.
91pub fn differential_expression(
92    counts: &[f64],
93    n_genes: usize,
94    n_samples: usize,
95    condition: &[bool],
96    method: DeMethod,
97) -> Result<DeResults> {
98    // Validate dimensions
99    if n_genes == 0 || n_samples == 0 {
100        return Err(CyaneaError::InvalidInput(
101            "differential_expression: need at least 1 gene and 1 sample".into(),
102        ));
103    }
104    if counts.len() != n_genes * n_samples {
105        return Err(CyaneaError::InvalidInput(format!(
106            "differential_expression: counts length ({}) != n_genes ({}) * n_samples ({})",
107            counts.len(),
108            n_genes,
109            n_samples,
110        )));
111    }
112    if condition.len() != n_samples {
113        return Err(CyaneaError::InvalidInput(format!(
114            "differential_expression: condition length ({}) != n_samples ({})",
115            condition.len(),
116            n_samples,
117        )));
118    }
119
120    let n_cond = condition.iter().filter(|&&c| c).count();
121    let n_ctrl = n_samples - n_cond;
122    if n_cond < 2 || n_ctrl < 2 {
123        return Err(CyaneaError::InvalidInput(
124            "differential_expression: need at least 2 samples per group".into(),
125        ));
126    }
127
128    // Build sample index lists
129    let cond_idx: Vec<usize> = (0..n_samples).filter(|&j| condition[j]).collect();
130    let ctrl_idx: Vec<usize> = (0..n_samples).filter(|&j| !condition[j]).collect();
131
132    // Size-factor normalization
133    let sf = normalization::size_factors(counts, n_genes, n_samples)?;
134    let normed = normalization::normalize_by_size_factors(counts, n_genes, n_samples, &sf)?;
135
136    let mut gene_results: Vec<DeGeneResult> = match method {
137        DeMethod::NegativeBinomial => nb_wald(&normed, n_genes, n_samples, &cond_idx, &ctrl_idx)?,
138        DeMethod::Wilcoxon => wilcoxon_de(&normed, n_genes, n_samples, &cond_idx, &ctrl_idx)?,
139    };
140
141    // BH correction
142    let raw_p: Vec<f64> = gene_results.iter().map(|g| g.p_value).collect();
143    let adj_p = correction::benjamini_hochberg(&raw_p)?;
144    for (g, &padj) in gene_results.iter_mut().zip(adj_p.iter()) {
145        g.p_adjusted = padj;
146    }
147
148    // Sort by p-value ascending
149    gene_results.sort_by(|a, b| a.p_value.total_cmp(&b.p_value));
150
151    Ok(DeResults {
152        genes: gene_results,
153        method,
154        n_genes,
155        n_condition: n_cond,
156        n_control: n_ctrl,
157    })
158}
159
160// ── Negative binomial Wald test ──────────────────────────────────────────────
161
162fn nb_wald(
163    normed: &[f64],
164    n_genes: usize,
165    n_samples: usize,
166    cond_idx: &[usize],
167    ctrl_idx: &[usize],
168) -> Result<Vec<DeGeneResult>> {
169    let normal = Normal::standard();
170    let pseudo = 0.5;
171
172    let mut results = Vec::with_capacity(n_genes);
173
174    for i in 0..n_genes {
175        let row = &normed[i * n_samples..(i + 1) * n_samples];
176
177        // Group means
178        let mu_cond: f64 = cond_idx.iter().map(|&j| row[j]).sum::<f64>() / cond_idx.len() as f64;
179        let mu_ctrl: f64 = ctrl_idx.iter().map(|&j| row[j]).sum::<f64>() / ctrl_idx.len() as f64;
180        let base_mean: f64 = row.iter().sum::<f64>() / n_samples as f64;
181
182        // log2 fold-change with pseudocount
183        let log2fc = ((mu_cond + pseudo) / (mu_ctrl + pseudo)).log2();
184
185        // Overall mean and variance for dispersion estimation
186        let overall_mean = base_mean;
187        let overall_var = if n_samples > 1 {
188            let ss: f64 = row.iter().map(|&x| (x - overall_mean).powi(2)).sum();
189            ss / (n_samples - 1) as f64
190        } else {
191            0.0
192        };
193
194        // Method-of-moments dispersion: alpha = (var - mean) / mean^2
195        let alpha = if overall_mean > 0.0 {
196            ((overall_var - overall_mean) / (overall_mean * overall_mean)).clamp(1e-8, 1e8)
197        } else {
198            1e-8
199        };
200
201        // Standard error via delta method on NB variance
202        // Var(X) = mu + alpha * mu^2 for NB
203        // SE of group mean = sqrt(Var / n), SE of log2FC ≈ sqrt(SE_cond^2 + SE_ctrl^2) / ln(2)
204        let var_cond = mu_cond + alpha * mu_cond * mu_cond;
205        let var_ctrl = mu_ctrl + alpha * mu_ctrl * mu_ctrl;
206        let se_cond = (var_cond / cond_idx.len() as f64).sqrt();
207        let se_ctrl = (var_ctrl / ctrl_idx.len() as f64).sqrt();
208        // log2FC = log2((mu_c + pc) / (mu_t + pc)), SE on log2 scale via delta method
209        let se_log2fc = ((se_cond / (mu_cond + pseudo)).powi(2)
210            + (se_ctrl / (mu_ctrl + pseudo)).powi(2))
211        .sqrt()
212            / 2.0_f64.ln();
213
214        // Wald z-statistic and two-tailed p-value
215        let (z, p_value) = if se_log2fc > 1e-15 {
216            let z = log2fc / se_log2fc;
217            let p = 2.0 * (1.0 - normal.cdf(z.abs()));
218            (z, p.min(1.0))
219        } else {
220            (0.0, 1.0)
221        };
222
223        results.push(DeGeneResult {
224            gene_index: i,
225            log2_fold_change: log2fc,
226            base_mean,
227            statistic: z,
228            p_value,
229            p_adjusted: 1.0, // filled in later
230        });
231    }
232
233    Ok(results)
234}
235
236// ── Wilcoxon rank-sum DE ─────────────────────────────────────────────────────
237
238fn wilcoxon_de(
239    normed: &[f64],
240    n_genes: usize,
241    n_samples: usize,
242    cond_idx: &[usize],
243    ctrl_idx: &[usize],
244) -> Result<Vec<DeGeneResult>> {
245    let pseudo = 0.5;
246    let mut results = Vec::with_capacity(n_genes);
247
248    for i in 0..n_genes {
249        let row = &normed[i * n_samples..(i + 1) * n_samples];
250
251        let cond_vals: Vec<f64> = cond_idx.iter().map(|&j| row[j]).collect();
252        let ctrl_vals: Vec<f64> = ctrl_idx.iter().map(|&j| row[j]).collect();
253
254        let mu_cond = cond_vals.iter().sum::<f64>() / cond_vals.len() as f64;
255        let mu_ctrl = ctrl_vals.iter().sum::<f64>() / ctrl_vals.len() as f64;
256        let base_mean: f64 = row.iter().sum::<f64>() / n_samples as f64;
257        let log2fc = ((mu_cond + pseudo) / (mu_ctrl + pseudo)).log2();
258
259        let test_result = testing::mann_whitney_u(&cond_vals, &ctrl_vals)?;
260
261        results.push(DeGeneResult {
262            gene_index: i,
263            log2_fold_change: log2fc,
264            base_mean,
265            statistic: test_result.statistic,
266            p_value: test_result.p_value,
267            p_adjusted: 1.0,
268        });
269    }
270
271    Ok(results)
272}
273
274// ── Volcano plot ─────────────────────────────────────────────────────────────
275
276/// Convert DE results into volcano-plot points.
277///
278/// - `padj_threshold`: adjusted p-value cutoff (e.g. 0.05).
279/// - `fc_threshold`: absolute log2 fold-change cutoff (e.g. 1.0).
280///
281/// A gene is marked `significant` if `p_adjusted < padj_threshold` **and**
282/// `|log2_fold_change| > fc_threshold`.
283pub fn volcano_plot(
284    results: &DeResults,
285    padj_threshold: f64,
286    fc_threshold: f64,
287) -> Vec<VolcanoPoint> {
288    results
289        .genes
290        .iter()
291        .map(|g| {
292            let neg_log10 = if g.p_adjusted > 0.0 {
293                (-g.p_adjusted.log10()).min(300.0)
294            } else {
295                300.0
296            };
297            VolcanoPoint {
298                gene_index: g.gene_index,
299                log2_fold_change: g.log2_fold_change,
300                neg_log10_padj: neg_log10,
301                significant: g.p_adjusted < padj_threshold
302                    && g.log2_fold_change.abs() > fc_threshold,
303            }
304        })
305        .collect()
306}
307
308// ── Tests ────────────────────────────────────────────────────────────────────
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    /// Helper: build a simple count matrix with one strongly upregulated gene,
315    /// one downregulated gene, and several unchanged genes.
316    ///
317    /// Layout: 5 genes × 6 samples (3 ctrl + 3 cond)
318    fn test_counts() -> (Vec<f64>, usize, usize, Vec<bool>) {
319        let n_genes = 5;
320        let n_samples = 6;
321        // condition: first 3 control, last 3 treatment
322        let condition = vec![false, false, false, true, true, true];
323
324        #[rustfmt::skip]
325        let counts = vec![
326            // gene 0: upregulated in treatment (ctrl ~10, cond ~200)
327            10.0, 12.0, 11.0, 200.0, 210.0, 190.0,
328            // gene 1: downregulated (ctrl ~200, cond ~10)
329            200.0, 190.0, 210.0, 10.0, 12.0, 11.0,
330            // gene 2: unchanged (~100)
331            100.0, 105.0, 95.0, 98.0, 102.0, 100.0,
332            // gene 3: unchanged (~50)
333            50.0, 52.0, 48.0, 49.0, 51.0, 50.0,
334            // gene 4: unchanged (~75)
335            75.0, 78.0, 72.0, 74.0, 76.0, 75.0,
336        ];
337        (counts, n_genes, n_samples, condition)
338    }
339
340    #[test]
341    fn nb_detects_upregulated() {
342        let (counts, ng, ns, cond) = test_counts();
343        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
344        let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
345        assert!(gene0.log2_fold_change > 2.0, "log2fc={}", gene0.log2_fold_change);
346        assert!(gene0.p_adjusted < 0.05, "padj={}", gene0.p_adjusted);
347    }
348
349    #[test]
350    fn nb_detects_downregulated() {
351        let (counts, ng, ns, cond) = test_counts();
352        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
353        let gene1 = res.genes.iter().find(|g| g.gene_index == 1).unwrap();
354        assert!(gene1.log2_fold_change < -2.0, "log2fc={}", gene1.log2_fold_change);
355        assert!(gene1.p_adjusted < 0.05, "padj={}", gene1.p_adjusted);
356    }
357
358    #[test]
359    fn nb_unchanged_genes_high_p() {
360        let (counts, ng, ns, cond) = test_counts();
361        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
362        for idx in [2, 3, 4] {
363            let gene = res.genes.iter().find(|g| g.gene_index == idx).unwrap();
364            assert!(
365                gene.p_value > 0.05,
366                "gene {idx} should not be significant: p={}",
367                gene.p_value
368            );
369        }
370    }
371
372    #[test]
373    fn nb_log2fc_direction() {
374        let (counts, ng, ns, cond) = test_counts();
375        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
376        let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
377        let gene1 = res.genes.iter().find(|g| g.gene_index == 1).unwrap();
378        assert!(gene0.log2_fold_change > 0.0);
379        assert!(gene1.log2_fold_change < 0.0);
380    }
381
382    #[test]
383    fn nb_padj_ge_pvalue() {
384        let (counts, ng, ns, cond) = test_counts();
385        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
386        for g in &res.genes {
387            assert!(
388                g.p_adjusted >= g.p_value - 1e-15,
389                "gene {}: padj={} < p={}",
390                g.gene_index,
391                g.p_adjusted,
392                g.p_value
393            );
394        }
395    }
396
397    #[test]
398    fn nb_results_sorted() {
399        let (counts, ng, ns, cond) = test_counts();
400        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
401        for w in res.genes.windows(2) {
402            assert!(
403                w[0].p_value <= w[1].p_value + 1e-15,
404                "not sorted: {} > {}",
405                w[0].p_value,
406                w[1].p_value
407            );
408        }
409    }
410
411    #[test]
412    fn wilcoxon_detects_de() {
413        let (counts, ng, ns, cond) = test_counts();
414        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::Wilcoxon).unwrap();
415        let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
416        assert!(gene0.log2_fold_change > 2.0);
417        assert!(gene0.p_value < 0.1, "p={}", gene0.p_value);
418    }
419
420    #[test]
421    fn wilcoxon_matches_direct_mwu() {
422        // Verify that the Wilcoxon pathway produces the same p-value as
423        // calling mann_whitney_u directly on the same normalized data.
424        let (counts, ng, ns, cond) = test_counts();
425        let sf = normalization::size_factors(&counts, ng, ns).unwrap();
426        let normed = normalization::normalize_by_size_factors(&counts, ng, ns, &sf).unwrap();
427
428        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::Wilcoxon).unwrap();
429
430        let cond_idx: Vec<usize> = (0..ns).filter(|&j| cond[j]).collect();
431        let ctrl_idx: Vec<usize> = (0..ns).filter(|&j| !cond[j]).collect();
432
433        for gene_res in &res.genes {
434            let i = gene_res.gene_index;
435            let row = &normed[i * ns..(i + 1) * ns];
436            let cond_vals: Vec<f64> = cond_idx.iter().map(|&j| row[j]).collect();
437            let ctrl_vals: Vec<f64> = ctrl_idx.iter().map(|&j| row[j]).collect();
438            let direct = testing::mann_whitney_u(&cond_vals, &ctrl_vals).unwrap();
439            assert!(
440                (gene_res.p_value - direct.p_value).abs() < 1e-10,
441                "gene {}: de_p={}, direct_p={}",
442                i,
443                gene_res.p_value,
444                direct.p_value
445            );
446        }
447    }
448
449    #[test]
450    fn dispersion_poisson_like() {
451        // When data follows Poisson (variance ≈ mean), dispersion should be small
452        let counts = vec![
453            100.0, 101.0, 99.0, 100.0, 102.0, 98.0,
454        ];
455        let cond = vec![false, false, false, true, true, true];
456        let res = differential_expression(&counts, 1, 6, &cond, DeMethod::NegativeBinomial).unwrap();
457        // With nearly identical groups, p should be large
458        assert!(res.genes[0].p_value > 0.5, "p={}", res.genes[0].p_value);
459    }
460
461    #[test]
462    fn dispersion_overdispersed() {
463        // Highly variable data should still work
464        #[rustfmt::skip]
465        let counts = vec![
466            1.0, 50.0, 200.0, 500.0, 1000.0, 2000.0,
467        ];
468        let cond = vec![false, false, false, true, true, true];
469        let res = differential_expression(&counts, 1, 6, &cond, DeMethod::NegativeBinomial);
470        assert!(res.is_ok());
471    }
472
473    #[test]
474    fn volcano_thresholds() {
475        let (counts, ng, ns, cond) = test_counts();
476        let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
477        let points = volcano_plot(&res, 0.05, 1.0);
478
479        assert_eq!(points.len(), ng);
480        // Gene 0 and 1 should be significant (large FC, low padj)
481        let sig_genes: Vec<usize> = points.iter().filter(|p| p.significant).map(|p| p.gene_index).collect();
482        assert!(sig_genes.contains(&0), "gene 0 should be significant");
483        assert!(sig_genes.contains(&1), "gene 1 should be significant");
484
485        // Unchanged genes should not be significant
486        for idx in [2, 3, 4] {
487            let pt = points.iter().find(|p| p.gene_index == idx).unwrap();
488            assert!(!pt.significant, "gene {idx} should not be significant");
489        }
490
491        // neg_log10_padj should be non-negative
492        for pt in &points {
493            assert!(pt.neg_log10_padj >= 0.0);
494        }
495    }
496
497    #[test]
498    fn error_dimension_mismatch() {
499        let cond = vec![false, true, false, true];
500        assert!(differential_expression(&[1.0, 2.0], 2, 4, &cond, DeMethod::NegativeBinomial).is_err());
501    }
502
503    #[test]
504    fn error_condition_length() {
505        let counts = vec![1.0; 8];
506        let cond = vec![false, true]; // too short
507        assert!(differential_expression(&counts, 2, 4, &cond, DeMethod::NegativeBinomial).is_err());
508    }
509
510    #[test]
511    fn error_too_few_per_group() {
512        let counts = vec![10.0, 20.0, 30.0, 40.0];
513        // Only 1 control
514        let cond = vec![false, true, true, true];
515        assert!(differential_expression(&counts, 1, 4, &cond, DeMethod::NegativeBinomial).is_err());
516    }
517
518    #[test]
519    fn error_single_group() {
520        let counts = vec![10.0, 20.0, 30.0, 40.0];
521        let cond = vec![true, true, true, true];
522        assert!(differential_expression(&counts, 1, 4, &cond, DeMethod::NegativeBinomial).is_err());
523    }
524
525    #[test]
526    fn volcano_clamps_neg_log10() {
527        // Create a result with p_adjusted = 0 to test clamping
528        let results = DeResults {
529            genes: vec![DeGeneResult {
530                gene_index: 0,
531                log2_fold_change: 5.0,
532                base_mean: 100.0,
533                statistic: 10.0,
534                p_value: 0.0,
535                p_adjusted: 0.0,
536            }],
537            method: DeMethod::NegativeBinomial,
538            n_genes: 1,
539            n_condition: 3,
540            n_control: 3,
541        };
542        let points = volcano_plot(&results, 0.05, 1.0);
543        assert_eq!(points[0].neg_log10_padj, 300.0);
544        assert!(points[0].significant);
545    }
546}