1use cyanea_core::{CyaneaError, Result};
19
20use crate::correction;
21use crate::distribution::{Distribution, Normal};
22use crate::normalization;
23use crate::testing;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum DeMethod {
30 NegativeBinomial,
32 Wilcoxon,
34}
35
36#[derive(Debug, Clone)]
38pub struct DeGeneResult {
39 pub gene_index: usize,
41 pub log2_fold_change: f64,
43 pub base_mean: f64,
45 pub statistic: f64,
47 pub p_value: f64,
49 pub p_adjusted: f64,
51}
52
53#[derive(Debug, Clone)]
55pub struct DeResults {
56 pub genes: Vec<DeGeneResult>,
58 pub method: DeMethod,
60 pub n_genes: usize,
62 pub n_condition: usize,
64 pub n_control: usize,
66}
67
68#[derive(Debug, Clone)]
70pub struct VolcanoPoint {
71 pub gene_index: usize,
73 pub log2_fold_change: f64,
75 pub neg_log10_padj: f64,
77 pub significant: bool,
79}
80
81pub fn differential_expression(
92 counts: &[f64],
93 n_genes: usize,
94 n_samples: usize,
95 condition: &[bool],
96 method: DeMethod,
97) -> Result<DeResults> {
98 if n_genes == 0 || n_samples == 0 {
100 return Err(CyaneaError::InvalidInput(
101 "differential_expression: need at least 1 gene and 1 sample".into(),
102 ));
103 }
104 if counts.len() != n_genes * n_samples {
105 return Err(CyaneaError::InvalidInput(format!(
106 "differential_expression: counts length ({}) != n_genes ({}) * n_samples ({})",
107 counts.len(),
108 n_genes,
109 n_samples,
110 )));
111 }
112 if condition.len() != n_samples {
113 return Err(CyaneaError::InvalidInput(format!(
114 "differential_expression: condition length ({}) != n_samples ({})",
115 condition.len(),
116 n_samples,
117 )));
118 }
119
120 let n_cond = condition.iter().filter(|&&c| c).count();
121 let n_ctrl = n_samples - n_cond;
122 if n_cond < 2 || n_ctrl < 2 {
123 return Err(CyaneaError::InvalidInput(
124 "differential_expression: need at least 2 samples per group".into(),
125 ));
126 }
127
128 let cond_idx: Vec<usize> = (0..n_samples).filter(|&j| condition[j]).collect();
130 let ctrl_idx: Vec<usize> = (0..n_samples).filter(|&j| !condition[j]).collect();
131
132 let sf = normalization::size_factors(counts, n_genes, n_samples)?;
134 let normed = normalization::normalize_by_size_factors(counts, n_genes, n_samples, &sf)?;
135
136 let mut gene_results: Vec<DeGeneResult> = match method {
137 DeMethod::NegativeBinomial => nb_wald(&normed, n_genes, n_samples, &cond_idx, &ctrl_idx)?,
138 DeMethod::Wilcoxon => wilcoxon_de(&normed, n_genes, n_samples, &cond_idx, &ctrl_idx)?,
139 };
140
141 let raw_p: Vec<f64> = gene_results.iter().map(|g| g.p_value).collect();
143 let adj_p = correction::benjamini_hochberg(&raw_p)?;
144 for (g, &padj) in gene_results.iter_mut().zip(adj_p.iter()) {
145 g.p_adjusted = padj;
146 }
147
148 gene_results.sort_by(|a, b| a.p_value.total_cmp(&b.p_value));
150
151 Ok(DeResults {
152 genes: gene_results,
153 method,
154 n_genes,
155 n_condition: n_cond,
156 n_control: n_ctrl,
157 })
158}
159
160fn nb_wald(
163 normed: &[f64],
164 n_genes: usize,
165 n_samples: usize,
166 cond_idx: &[usize],
167 ctrl_idx: &[usize],
168) -> Result<Vec<DeGeneResult>> {
169 let normal = Normal::standard();
170 let pseudo = 0.5;
171
172 let mut results = Vec::with_capacity(n_genes);
173
174 for i in 0..n_genes {
175 let row = &normed[i * n_samples..(i + 1) * n_samples];
176
177 let mu_cond: f64 = cond_idx.iter().map(|&j| row[j]).sum::<f64>() / cond_idx.len() as f64;
179 let mu_ctrl: f64 = ctrl_idx.iter().map(|&j| row[j]).sum::<f64>() / ctrl_idx.len() as f64;
180 let base_mean: f64 = row.iter().sum::<f64>() / n_samples as f64;
181
182 let log2fc = ((mu_cond + pseudo) / (mu_ctrl + pseudo)).log2();
184
185 let overall_mean = base_mean;
187 let overall_var = if n_samples > 1 {
188 let ss: f64 = row.iter().map(|&x| (x - overall_mean).powi(2)).sum();
189 ss / (n_samples - 1) as f64
190 } else {
191 0.0
192 };
193
194 let alpha = if overall_mean > 0.0 {
196 ((overall_var - overall_mean) / (overall_mean * overall_mean)).clamp(1e-8, 1e8)
197 } else {
198 1e-8
199 };
200
201 let var_cond = mu_cond + alpha * mu_cond * mu_cond;
205 let var_ctrl = mu_ctrl + alpha * mu_ctrl * mu_ctrl;
206 let se_cond = (var_cond / cond_idx.len() as f64).sqrt();
207 let se_ctrl = (var_ctrl / ctrl_idx.len() as f64).sqrt();
208 let se_log2fc = ((se_cond / (mu_cond + pseudo)).powi(2)
210 + (se_ctrl / (mu_ctrl + pseudo)).powi(2))
211 .sqrt()
212 / 2.0_f64.ln();
213
214 let (z, p_value) = if se_log2fc > 1e-15 {
216 let z = log2fc / se_log2fc;
217 let p = 2.0 * (1.0 - normal.cdf(z.abs()));
218 (z, p.min(1.0))
219 } else {
220 (0.0, 1.0)
221 };
222
223 results.push(DeGeneResult {
224 gene_index: i,
225 log2_fold_change: log2fc,
226 base_mean,
227 statistic: z,
228 p_value,
229 p_adjusted: 1.0, });
231 }
232
233 Ok(results)
234}
235
236fn wilcoxon_de(
239 normed: &[f64],
240 n_genes: usize,
241 n_samples: usize,
242 cond_idx: &[usize],
243 ctrl_idx: &[usize],
244) -> Result<Vec<DeGeneResult>> {
245 let pseudo = 0.5;
246 let mut results = Vec::with_capacity(n_genes);
247
248 for i in 0..n_genes {
249 let row = &normed[i * n_samples..(i + 1) * n_samples];
250
251 let cond_vals: Vec<f64> = cond_idx.iter().map(|&j| row[j]).collect();
252 let ctrl_vals: Vec<f64> = ctrl_idx.iter().map(|&j| row[j]).collect();
253
254 let mu_cond = cond_vals.iter().sum::<f64>() / cond_vals.len() as f64;
255 let mu_ctrl = ctrl_vals.iter().sum::<f64>() / ctrl_vals.len() as f64;
256 let base_mean: f64 = row.iter().sum::<f64>() / n_samples as f64;
257 let log2fc = ((mu_cond + pseudo) / (mu_ctrl + pseudo)).log2();
258
259 let test_result = testing::mann_whitney_u(&cond_vals, &ctrl_vals)?;
260
261 results.push(DeGeneResult {
262 gene_index: i,
263 log2_fold_change: log2fc,
264 base_mean,
265 statistic: test_result.statistic,
266 p_value: test_result.p_value,
267 p_adjusted: 1.0,
268 });
269 }
270
271 Ok(results)
272}
273
274pub fn volcano_plot(
284 results: &DeResults,
285 padj_threshold: f64,
286 fc_threshold: f64,
287) -> Vec<VolcanoPoint> {
288 results
289 .genes
290 .iter()
291 .map(|g| {
292 let neg_log10 = if g.p_adjusted > 0.0 {
293 (-g.p_adjusted.log10()).min(300.0)
294 } else {
295 300.0
296 };
297 VolcanoPoint {
298 gene_index: g.gene_index,
299 log2_fold_change: g.log2_fold_change,
300 neg_log10_padj: neg_log10,
301 significant: g.p_adjusted < padj_threshold
302 && g.log2_fold_change.abs() > fc_threshold,
303 }
304 })
305 .collect()
306}
307
308#[cfg(test)]
311mod tests {
312 use super::*;
313
314 fn test_counts() -> (Vec<f64>, usize, usize, Vec<bool>) {
319 let n_genes = 5;
320 let n_samples = 6;
321 let condition = vec![false, false, false, true, true, true];
323
324 #[rustfmt::skip]
325 let counts = vec![
326 10.0, 12.0, 11.0, 200.0, 210.0, 190.0,
328 200.0, 190.0, 210.0, 10.0, 12.0, 11.0,
330 100.0, 105.0, 95.0, 98.0, 102.0, 100.0,
332 50.0, 52.0, 48.0, 49.0, 51.0, 50.0,
334 75.0, 78.0, 72.0, 74.0, 76.0, 75.0,
336 ];
337 (counts, n_genes, n_samples, condition)
338 }
339
340 #[test]
341 fn nb_detects_upregulated() {
342 let (counts, ng, ns, cond) = test_counts();
343 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
344 let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
345 assert!(gene0.log2_fold_change > 2.0, "log2fc={}", gene0.log2_fold_change);
346 assert!(gene0.p_adjusted < 0.05, "padj={}", gene0.p_adjusted);
347 }
348
349 #[test]
350 fn nb_detects_downregulated() {
351 let (counts, ng, ns, cond) = test_counts();
352 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
353 let gene1 = res.genes.iter().find(|g| g.gene_index == 1).unwrap();
354 assert!(gene1.log2_fold_change < -2.0, "log2fc={}", gene1.log2_fold_change);
355 assert!(gene1.p_adjusted < 0.05, "padj={}", gene1.p_adjusted);
356 }
357
358 #[test]
359 fn nb_unchanged_genes_high_p() {
360 let (counts, ng, ns, cond) = test_counts();
361 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
362 for idx in [2, 3, 4] {
363 let gene = res.genes.iter().find(|g| g.gene_index == idx).unwrap();
364 assert!(
365 gene.p_value > 0.05,
366 "gene {idx} should not be significant: p={}",
367 gene.p_value
368 );
369 }
370 }
371
372 #[test]
373 fn nb_log2fc_direction() {
374 let (counts, ng, ns, cond) = test_counts();
375 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
376 let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
377 let gene1 = res.genes.iter().find(|g| g.gene_index == 1).unwrap();
378 assert!(gene0.log2_fold_change > 0.0);
379 assert!(gene1.log2_fold_change < 0.0);
380 }
381
382 #[test]
383 fn nb_padj_ge_pvalue() {
384 let (counts, ng, ns, cond) = test_counts();
385 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
386 for g in &res.genes {
387 assert!(
388 g.p_adjusted >= g.p_value - 1e-15,
389 "gene {}: padj={} < p={}",
390 g.gene_index,
391 g.p_adjusted,
392 g.p_value
393 );
394 }
395 }
396
397 #[test]
398 fn nb_results_sorted() {
399 let (counts, ng, ns, cond) = test_counts();
400 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
401 for w in res.genes.windows(2) {
402 assert!(
403 w[0].p_value <= w[1].p_value + 1e-15,
404 "not sorted: {} > {}",
405 w[0].p_value,
406 w[1].p_value
407 );
408 }
409 }
410
411 #[test]
412 fn wilcoxon_detects_de() {
413 let (counts, ng, ns, cond) = test_counts();
414 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::Wilcoxon).unwrap();
415 let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
416 assert!(gene0.log2_fold_change > 2.0);
417 assert!(gene0.p_value < 0.1, "p={}", gene0.p_value);
418 }
419
420 #[test]
421 fn wilcoxon_matches_direct_mwu() {
422 let (counts, ng, ns, cond) = test_counts();
425 let sf = normalization::size_factors(&counts, ng, ns).unwrap();
426 let normed = normalization::normalize_by_size_factors(&counts, ng, ns, &sf).unwrap();
427
428 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::Wilcoxon).unwrap();
429
430 let cond_idx: Vec<usize> = (0..ns).filter(|&j| cond[j]).collect();
431 let ctrl_idx: Vec<usize> = (0..ns).filter(|&j| !cond[j]).collect();
432
433 for gene_res in &res.genes {
434 let i = gene_res.gene_index;
435 let row = &normed[i * ns..(i + 1) * ns];
436 let cond_vals: Vec<f64> = cond_idx.iter().map(|&j| row[j]).collect();
437 let ctrl_vals: Vec<f64> = ctrl_idx.iter().map(|&j| row[j]).collect();
438 let direct = testing::mann_whitney_u(&cond_vals, &ctrl_vals).unwrap();
439 assert!(
440 (gene_res.p_value - direct.p_value).abs() < 1e-10,
441 "gene {}: de_p={}, direct_p={}",
442 i,
443 gene_res.p_value,
444 direct.p_value
445 );
446 }
447 }
448
449 #[test]
450 fn dispersion_poisson_like() {
451 let counts = vec![
453 100.0, 101.0, 99.0, 100.0, 102.0, 98.0,
454 ];
455 let cond = vec![false, false, false, true, true, true];
456 let res = differential_expression(&counts, 1, 6, &cond, DeMethod::NegativeBinomial).unwrap();
457 assert!(res.genes[0].p_value > 0.5, "p={}", res.genes[0].p_value);
459 }
460
461 #[test]
462 fn dispersion_overdispersed() {
463 #[rustfmt::skip]
465 let counts = vec![
466 1.0, 50.0, 200.0, 500.0, 1000.0, 2000.0,
467 ];
468 let cond = vec![false, false, false, true, true, true];
469 let res = differential_expression(&counts, 1, 6, &cond, DeMethod::NegativeBinomial);
470 assert!(res.is_ok());
471 }
472
473 #[test]
474 fn volcano_thresholds() {
475 let (counts, ng, ns, cond) = test_counts();
476 let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
477 let points = volcano_plot(&res, 0.05, 1.0);
478
479 assert_eq!(points.len(), ng);
480 let sig_genes: Vec<usize> = points.iter().filter(|p| p.significant).map(|p| p.gene_index).collect();
482 assert!(sig_genes.contains(&0), "gene 0 should be significant");
483 assert!(sig_genes.contains(&1), "gene 1 should be significant");
484
485 for idx in [2, 3, 4] {
487 let pt = points.iter().find(|p| p.gene_index == idx).unwrap();
488 assert!(!pt.significant, "gene {idx} should not be significant");
489 }
490
491 for pt in &points {
493 assert!(pt.neg_log10_padj >= 0.0);
494 }
495 }
496
497 #[test]
498 fn error_dimension_mismatch() {
499 let cond = vec![false, true, false, true];
500 assert!(differential_expression(&[1.0, 2.0], 2, 4, &cond, DeMethod::NegativeBinomial).is_err());
501 }
502
503 #[test]
504 fn error_condition_length() {
505 let counts = vec![1.0; 8];
506 let cond = vec![false, true]; assert!(differential_expression(&counts, 2, 4, &cond, DeMethod::NegativeBinomial).is_err());
508 }
509
510 #[test]
511 fn error_too_few_per_group() {
512 let counts = vec![10.0, 20.0, 30.0, 40.0];
513 let cond = vec![false, true, true, true];
515 assert!(differential_expression(&counts, 1, 4, &cond, DeMethod::NegativeBinomial).is_err());
516 }
517
518 #[test]
519 fn error_single_group() {
520 let counts = vec![10.0, 20.0, 30.0, 40.0];
521 let cond = vec![true, true, true, true];
522 assert!(differential_expression(&counts, 1, 4, &cond, DeMethod::NegativeBinomial).is_err());
523 }
524
525 #[test]
526 fn volcano_clamps_neg_log10() {
527 let results = DeResults {
529 genes: vec![DeGeneResult {
530 gene_index: 0,
531 log2_fold_change: 5.0,
532 base_mean: 100.0,
533 statistic: 10.0,
534 p_value: 0.0,
535 p_adjusted: 0.0,
536 }],
537 method: DeMethod::NegativeBinomial,
538 n_genes: 1,
539 n_condition: 3,
540 n_control: 3,
541 };
542 let points = volcano_plot(&results, 0.05, 1.0);
543 assert_eq!(points[0].neg_log10_padj, 300.0);
544 assert!(points[0].significant);
545 }
546}