use cyanea_core::{CyaneaError, Result};
use crate::correction;
use crate::distribution::{Distribution, Normal};
use crate::normalization;
use crate::testing;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeMethod {
NegativeBinomial,
Wilcoxon,
}
#[derive(Debug, Clone)]
pub struct DeGeneResult {
pub gene_index: usize,
pub log2_fold_change: f64,
pub base_mean: f64,
pub statistic: f64,
pub p_value: f64,
pub p_adjusted: f64,
}
#[derive(Debug, Clone)]
pub struct DeResults {
pub genes: Vec<DeGeneResult>,
pub method: DeMethod,
pub n_genes: usize,
pub n_condition: usize,
pub n_control: usize,
}
#[derive(Debug, Clone)]
pub struct VolcanoPoint {
pub gene_index: usize,
pub log2_fold_change: f64,
pub neg_log10_padj: f64,
pub significant: bool,
}
pub fn differential_expression(
counts: &[f64],
n_genes: usize,
n_samples: usize,
condition: &[bool],
method: DeMethod,
) -> Result<DeResults> {
if n_genes == 0 || n_samples == 0 {
return Err(CyaneaError::InvalidInput(
"differential_expression: need at least 1 gene and 1 sample".into(),
));
}
if counts.len() != n_genes * n_samples {
return Err(CyaneaError::InvalidInput(format!(
"differential_expression: counts length ({}) != n_genes ({}) * n_samples ({})",
counts.len(),
n_genes,
n_samples,
)));
}
if condition.len() != n_samples {
return Err(CyaneaError::InvalidInput(format!(
"differential_expression: condition length ({}) != n_samples ({})",
condition.len(),
n_samples,
)));
}
let n_cond = condition.iter().filter(|&&c| c).count();
let n_ctrl = n_samples - n_cond;
if n_cond < 2 || n_ctrl < 2 {
return Err(CyaneaError::InvalidInput(
"differential_expression: need at least 2 samples per group".into(),
));
}
let cond_idx: Vec<usize> = (0..n_samples).filter(|&j| condition[j]).collect();
let ctrl_idx: Vec<usize> = (0..n_samples).filter(|&j| !condition[j]).collect();
let sf = normalization::size_factors(counts, n_genes, n_samples)?;
let normed = normalization::normalize_by_size_factors(counts, n_genes, n_samples, &sf)?;
let mut gene_results: Vec<DeGeneResult> = match method {
DeMethod::NegativeBinomial => nb_wald(&normed, n_genes, n_samples, &cond_idx, &ctrl_idx)?,
DeMethod::Wilcoxon => wilcoxon_de(&normed, n_genes, n_samples, &cond_idx, &ctrl_idx)?,
};
let raw_p: Vec<f64> = gene_results.iter().map(|g| g.p_value).collect();
let adj_p = correction::benjamini_hochberg(&raw_p)?;
for (g, &padj) in gene_results.iter_mut().zip(adj_p.iter()) {
g.p_adjusted = padj;
}
gene_results.sort_by(|a, b| a.p_value.total_cmp(&b.p_value));
Ok(DeResults {
genes: gene_results,
method,
n_genes,
n_condition: n_cond,
n_control: n_ctrl,
})
}
fn nb_wald(
normed: &[f64],
n_genes: usize,
n_samples: usize,
cond_idx: &[usize],
ctrl_idx: &[usize],
) -> Result<Vec<DeGeneResult>> {
let normal = Normal::standard();
let pseudo = 0.5;
let mut results = Vec::with_capacity(n_genes);
for i in 0..n_genes {
let row = &normed[i * n_samples..(i + 1) * n_samples];
let mu_cond: f64 = cond_idx.iter().map(|&j| row[j]).sum::<f64>() / cond_idx.len() as f64;
let mu_ctrl: f64 = ctrl_idx.iter().map(|&j| row[j]).sum::<f64>() / ctrl_idx.len() as f64;
let base_mean: f64 = row.iter().sum::<f64>() / n_samples as f64;
let log2fc = ((mu_cond + pseudo) / (mu_ctrl + pseudo)).log2();
let overall_mean = base_mean;
let overall_var = if n_samples > 1 {
let ss: f64 = row.iter().map(|&x| (x - overall_mean).powi(2)).sum();
ss / (n_samples - 1) as f64
} else {
0.0
};
let alpha = if overall_mean > 0.0 {
((overall_var - overall_mean) / (overall_mean * overall_mean)).clamp(1e-8, 1e8)
} else {
1e-8
};
let var_cond = mu_cond + alpha * mu_cond * mu_cond;
let var_ctrl = mu_ctrl + alpha * mu_ctrl * mu_ctrl;
let se_cond = (var_cond / cond_idx.len() as f64).sqrt();
let se_ctrl = (var_ctrl / ctrl_idx.len() as f64).sqrt();
let se_log2fc = ((se_cond / (mu_cond + pseudo)).powi(2)
+ (se_ctrl / (mu_ctrl + pseudo)).powi(2))
.sqrt()
/ 2.0_f64.ln();
let (z, p_value) = if se_log2fc > 1e-15 {
let z = log2fc / se_log2fc;
let p = 2.0 * (1.0 - normal.cdf(z.abs()));
(z, p.min(1.0))
} else {
(0.0, 1.0)
};
results.push(DeGeneResult {
gene_index: i,
log2_fold_change: log2fc,
base_mean,
statistic: z,
p_value,
p_adjusted: 1.0, });
}
Ok(results)
}
fn wilcoxon_de(
normed: &[f64],
n_genes: usize,
n_samples: usize,
cond_idx: &[usize],
ctrl_idx: &[usize],
) -> Result<Vec<DeGeneResult>> {
let pseudo = 0.5;
let mut results = Vec::with_capacity(n_genes);
for i in 0..n_genes {
let row = &normed[i * n_samples..(i + 1) * n_samples];
let cond_vals: Vec<f64> = cond_idx.iter().map(|&j| row[j]).collect();
let ctrl_vals: Vec<f64> = ctrl_idx.iter().map(|&j| row[j]).collect();
let mu_cond = cond_vals.iter().sum::<f64>() / cond_vals.len() as f64;
let mu_ctrl = ctrl_vals.iter().sum::<f64>() / ctrl_vals.len() as f64;
let base_mean: f64 = row.iter().sum::<f64>() / n_samples as f64;
let log2fc = ((mu_cond + pseudo) / (mu_ctrl + pseudo)).log2();
let test_result = testing::mann_whitney_u(&cond_vals, &ctrl_vals)?;
results.push(DeGeneResult {
gene_index: i,
log2_fold_change: log2fc,
base_mean,
statistic: test_result.statistic,
p_value: test_result.p_value,
p_adjusted: 1.0,
});
}
Ok(results)
}
pub fn volcano_plot(
results: &DeResults,
padj_threshold: f64,
fc_threshold: f64,
) -> Vec<VolcanoPoint> {
results
.genes
.iter()
.map(|g| {
let neg_log10 = if g.p_adjusted > 0.0 {
(-g.p_adjusted.log10()).min(300.0)
} else {
300.0
};
VolcanoPoint {
gene_index: g.gene_index,
log2_fold_change: g.log2_fold_change,
neg_log10_padj: neg_log10,
significant: g.p_adjusted < padj_threshold
&& g.log2_fold_change.abs() > fc_threshold,
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_counts() -> (Vec<f64>, usize, usize, Vec<bool>) {
let n_genes = 5;
let n_samples = 6;
let condition = vec![false, false, false, true, true, true];
#[rustfmt::skip]
let counts = vec![
10.0, 12.0, 11.0, 200.0, 210.0, 190.0,
200.0, 190.0, 210.0, 10.0, 12.0, 11.0,
100.0, 105.0, 95.0, 98.0, 102.0, 100.0,
50.0, 52.0, 48.0, 49.0, 51.0, 50.0,
75.0, 78.0, 72.0, 74.0, 76.0, 75.0,
];
(counts, n_genes, n_samples, condition)
}
#[test]
fn nb_detects_upregulated() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
assert!(gene0.log2_fold_change > 2.0, "log2fc={}", gene0.log2_fold_change);
assert!(gene0.p_adjusted < 0.05, "padj={}", gene0.p_adjusted);
}
#[test]
fn nb_detects_downregulated() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
let gene1 = res.genes.iter().find(|g| g.gene_index == 1).unwrap();
assert!(gene1.log2_fold_change < -2.0, "log2fc={}", gene1.log2_fold_change);
assert!(gene1.p_adjusted < 0.05, "padj={}", gene1.p_adjusted);
}
#[test]
fn nb_unchanged_genes_high_p() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
for idx in [2, 3, 4] {
let gene = res.genes.iter().find(|g| g.gene_index == idx).unwrap();
assert!(
gene.p_value > 0.05,
"gene {idx} should not be significant: p={}",
gene.p_value
);
}
}
#[test]
fn nb_log2fc_direction() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
let gene1 = res.genes.iter().find(|g| g.gene_index == 1).unwrap();
assert!(gene0.log2_fold_change > 0.0);
assert!(gene1.log2_fold_change < 0.0);
}
#[test]
fn nb_padj_ge_pvalue() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
for g in &res.genes {
assert!(
g.p_adjusted >= g.p_value - 1e-15,
"gene {}: padj={} < p={}",
g.gene_index,
g.p_adjusted,
g.p_value
);
}
}
#[test]
fn nb_results_sorted() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
for w in res.genes.windows(2) {
assert!(
w[0].p_value <= w[1].p_value + 1e-15,
"not sorted: {} > {}",
w[0].p_value,
w[1].p_value
);
}
}
#[test]
fn wilcoxon_detects_de() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::Wilcoxon).unwrap();
let gene0 = res.genes.iter().find(|g| g.gene_index == 0).unwrap();
assert!(gene0.log2_fold_change > 2.0);
assert!(gene0.p_value < 0.1, "p={}", gene0.p_value);
}
#[test]
fn wilcoxon_matches_direct_mwu() {
let (counts, ng, ns, cond) = test_counts();
let sf = normalization::size_factors(&counts, ng, ns).unwrap();
let normed = normalization::normalize_by_size_factors(&counts, ng, ns, &sf).unwrap();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::Wilcoxon).unwrap();
let cond_idx: Vec<usize> = (0..ns).filter(|&j| cond[j]).collect();
let ctrl_idx: Vec<usize> = (0..ns).filter(|&j| !cond[j]).collect();
for gene_res in &res.genes {
let i = gene_res.gene_index;
let row = &normed[i * ns..(i + 1) * ns];
let cond_vals: Vec<f64> = cond_idx.iter().map(|&j| row[j]).collect();
let ctrl_vals: Vec<f64> = ctrl_idx.iter().map(|&j| row[j]).collect();
let direct = testing::mann_whitney_u(&cond_vals, &ctrl_vals).unwrap();
assert!(
(gene_res.p_value - direct.p_value).abs() < 1e-10,
"gene {}: de_p={}, direct_p={}",
i,
gene_res.p_value,
direct.p_value
);
}
}
#[test]
fn dispersion_poisson_like() {
let counts = vec![
100.0, 101.0, 99.0, 100.0, 102.0, 98.0,
];
let cond = vec![false, false, false, true, true, true];
let res = differential_expression(&counts, 1, 6, &cond, DeMethod::NegativeBinomial).unwrap();
assert!(res.genes[0].p_value > 0.5, "p={}", res.genes[0].p_value);
}
#[test]
fn dispersion_overdispersed() {
#[rustfmt::skip]
let counts = vec![
1.0, 50.0, 200.0, 500.0, 1000.0, 2000.0,
];
let cond = vec![false, false, false, true, true, true];
let res = differential_expression(&counts, 1, 6, &cond, DeMethod::NegativeBinomial);
assert!(res.is_ok());
}
#[test]
fn volcano_thresholds() {
let (counts, ng, ns, cond) = test_counts();
let res = differential_expression(&counts, ng, ns, &cond, DeMethod::NegativeBinomial).unwrap();
let points = volcano_plot(&res, 0.05, 1.0);
assert_eq!(points.len(), ng);
let sig_genes: Vec<usize> = points.iter().filter(|p| p.significant).map(|p| p.gene_index).collect();
assert!(sig_genes.contains(&0), "gene 0 should be significant");
assert!(sig_genes.contains(&1), "gene 1 should be significant");
for idx in [2, 3, 4] {
let pt = points.iter().find(|p| p.gene_index == idx).unwrap();
assert!(!pt.significant, "gene {idx} should not be significant");
}
for pt in &points {
assert!(pt.neg_log10_padj >= 0.0);
}
}
#[test]
fn error_dimension_mismatch() {
let cond = vec![false, true, false, true];
assert!(differential_expression(&[1.0, 2.0], 2, 4, &cond, DeMethod::NegativeBinomial).is_err());
}
#[test]
fn error_condition_length() {
let counts = vec![1.0; 8];
let cond = vec![false, true]; assert!(differential_expression(&counts, 2, 4, &cond, DeMethod::NegativeBinomial).is_err());
}
#[test]
fn error_too_few_per_group() {
let counts = vec![10.0, 20.0, 30.0, 40.0];
let cond = vec![false, true, true, true];
assert!(differential_expression(&counts, 1, 4, &cond, DeMethod::NegativeBinomial).is_err());
}
#[test]
fn error_single_group() {
let counts = vec![10.0, 20.0, 30.0, 40.0];
let cond = vec![true, true, true, true];
assert!(differential_expression(&counts, 1, 4, &cond, DeMethod::NegativeBinomial).is_err());
}
#[test]
fn volcano_clamps_neg_log10() {
let results = DeResults {
genes: vec![DeGeneResult {
gene_index: 0,
log2_fold_change: 5.0,
base_mean: 100.0,
statistic: 10.0,
p_value: 0.0,
p_adjusted: 0.0,
}],
method: DeMethod::NegativeBinomial,
n_genes: 1,
n_condition: 3,
n_control: 3,
};
let points = volcano_plot(&results, 0.05, 1.0);
assert_eq!(points[0].neg_log10_padj, 300.0);
assert!(points[0].significant);
}
}