Skip to main content

cyanea_stats/
normalization.rs

1//! Count normalization for RNA-seq and related assays.
2//!
3//! All functions operate on row-major `&[f64]` slices with dimensions
4//! `(n_genes, n_samples)`, matching `cyanea_omics::ExpressionMatrix` layout.
5//!
6//! - [`cpm`] — Counts per million
7//! - [`tpm`] — Transcripts per million (requires gene lengths)
8//! - [`fpkm`] — Fragments per kilobase of transcript per million mapped reads
9//! - [`size_factors`] — DESeq2-style median-of-ratios normalization factors
10//! - [`normalize_by_size_factors`] — Divide counts by per-sample size factors
11
12use cyanea_core::{CyaneaError, Result};
13
14use crate::descriptive;
15
16// ── Helpers ──────────────────────────────────────────────────────────────────
17
18fn validate_matrix(counts: &[f64], n_genes: usize, n_samples: usize) -> Result<()> {
19    if n_genes == 0 || n_samples == 0 {
20        return Err(CyaneaError::InvalidInput(
21            "normalization: matrix must have at least 1 gene and 1 sample".into(),
22        ));
23    }
24    if counts.len() != n_genes * n_samples {
25        return Err(CyaneaError::InvalidInput(format!(
26            "normalization: counts length ({}) != n_genes ({}) * n_samples ({})",
27            counts.len(),
28            n_genes,
29            n_samples,
30        )));
31    }
32    Ok(())
33}
34
35fn column_sums(counts: &[f64], n_genes: usize, n_samples: usize) -> Vec<f64> {
36    let mut sums = vec![0.0; n_samples];
37    for i in 0..n_genes {
38        let row = &counts[i * n_samples..(i + 1) * n_samples];
39        for (j, &v) in row.iter().enumerate() {
40            sums[j] += v;
41        }
42    }
43    sums
44}
45
46// ── CPM ──────────────────────────────────────────────────────────────────────
47
48/// Counts per million: `CPM_ij = count_ij / library_size_j * 1e6`.
49pub fn cpm(counts: &[f64], n_genes: usize, n_samples: usize) -> Result<Vec<f64>> {
50    validate_matrix(counts, n_genes, n_samples)?;
51    let lib_sizes = column_sums(counts, n_genes, n_samples);
52    let mut out = vec![0.0; counts.len()];
53    for i in 0..n_genes {
54        for j in 0..n_samples {
55            let idx = i * n_samples + j;
56            out[idx] = if lib_sizes[j] > 0.0 {
57                counts[idx] / lib_sizes[j] * 1e6
58            } else {
59                0.0
60            };
61        }
62    }
63    Ok(out)
64}
65
66// ── TPM ──────────────────────────────────────────────────────────────────────
67
68/// Transcripts per million.
69///
70/// 1. Divide each count by gene length (→ reads per kilobase, RPK).
71/// 2. Sum RPK per sample, then scale each sample's RPK values to sum to 1M.
72///
73/// `gene_lengths` must have `n_genes` elements (in bases or kilobases — units
74/// are consistent because the per-kilobase step cancels in the ratio).
75pub fn tpm(
76    counts: &[f64],
77    n_genes: usize,
78    n_samples: usize,
79    gene_lengths: &[f64],
80) -> Result<Vec<f64>> {
81    validate_matrix(counts, n_genes, n_samples)?;
82    if gene_lengths.len() != n_genes {
83        return Err(CyaneaError::InvalidInput(format!(
84            "tpm: gene_lengths length ({}) != n_genes ({})",
85            gene_lengths.len(),
86            n_genes,
87        )));
88    }
89
90    // RPK: count / (length / 1000)
91    let mut rpk = vec![0.0; counts.len()];
92    for i in 0..n_genes {
93        let len_kb = gene_lengths[i] / 1000.0;
94        if len_kb <= 0.0 {
95            return Err(CyaneaError::InvalidInput(format!(
96                "tpm: gene_lengths[{i}] must be positive",
97            )));
98        }
99        for j in 0..n_samples {
100            rpk[i * n_samples + j] = counts[i * n_samples + j] / len_kb;
101        }
102    }
103
104    // Per-sample RPK sums
105    let rpk_sums = column_sums(&rpk, n_genes, n_samples);
106
107    // Scale to 1M
108    let mut out = vec![0.0; counts.len()];
109    for i in 0..n_genes {
110        for j in 0..n_samples {
111            let idx = i * n_samples + j;
112            out[idx] = if rpk_sums[j] > 0.0 {
113                rpk[idx] / rpk_sums[j] * 1e6
114            } else {
115                0.0
116            };
117        }
118    }
119    Ok(out)
120}
121
122// ── FPKM ─────────────────────────────────────────────────────────────────────
123
124/// Fragments per kilobase of transcript per million mapped reads.
125///
126/// `FPKM_ij = count_ij * 1e9 / (library_size_j * length_i)`
127pub fn fpkm(
128    counts: &[f64],
129    n_genes: usize,
130    n_samples: usize,
131    gene_lengths: &[f64],
132) -> Result<Vec<f64>> {
133    validate_matrix(counts, n_genes, n_samples)?;
134    if gene_lengths.len() != n_genes {
135        return Err(CyaneaError::InvalidInput(format!(
136            "fpkm: gene_lengths length ({}) != n_genes ({})",
137            gene_lengths.len(),
138            n_genes,
139        )));
140    }
141    let lib_sizes = column_sums(counts, n_genes, n_samples);
142    let mut out = vec![0.0; counts.len()];
143    for i in 0..n_genes {
144        if gene_lengths[i] <= 0.0 {
145            return Err(CyaneaError::InvalidInput(format!(
146                "fpkm: gene_lengths[{i}] must be positive",
147            )));
148        }
149        for j in 0..n_samples {
150            let idx = i * n_samples + j;
151            out[idx] = if lib_sizes[j] > 0.0 {
152                counts[idx] * 1e9 / (lib_sizes[j] * gene_lengths[i])
153            } else {
154                0.0
155            };
156        }
157    }
158    Ok(out)
159}
160
161// ── Size factors (median-of-ratios) ──────────────────────────────────────────
162
163/// DESeq2-style size factors via the median-of-ratios method (Anders & Huber 2010).
164///
165/// 1. Compute the geometric mean of each gene across samples.
166/// 2. For each sample, compute the ratio `count / geometric_mean` for every gene.
167/// 3. The size factor for a sample is the median of those ratios.
168///
169/// Genes with any zero count are excluded from the geometric mean calculation.
170/// Returns one size factor per sample.
171pub fn size_factors(counts: &[f64], n_genes: usize, n_samples: usize) -> Result<Vec<f64>> {
172    validate_matrix(counts, n_genes, n_samples)?;
173
174    // Compute geometric means, excluding genes with any zero.
175    let mut geo_means = Vec::with_capacity(n_genes);
176    let mut usable_genes = Vec::with_capacity(n_genes);
177
178    for i in 0..n_genes {
179        let row = &counts[i * n_samples..(i + 1) * n_samples];
180        if row.iter().any(|&v| v <= 0.0) {
181            continue;
182        }
183        let log_sum: f64 = row.iter().map(|v| v.ln()).sum();
184        let geo_mean = (log_sum / n_samples as f64).exp();
185        geo_means.push(geo_mean);
186        usable_genes.push(i);
187    }
188
189    if usable_genes.is_empty() {
190        return Err(CyaneaError::InvalidInput(
191            "size_factors: no genes with all non-zero counts".into(),
192        ));
193    }
194
195    // For each sample, compute median of ratios.
196    let mut factors = Vec::with_capacity(n_samples);
197    for j in 0..n_samples {
198        let ratios: Vec<f64> = usable_genes
199            .iter()
200            .zip(geo_means.iter())
201            .map(|(&gene_i, &gm)| counts[gene_i * n_samples + j] / gm)
202            .collect();
203        let med = descriptive::median(&ratios)?;
204        factors.push(med);
205    }
206
207    Ok(factors)
208}
209
210/// Divide each count by the corresponding sample's size factor.
211pub fn normalize_by_size_factors(
212    counts: &[f64],
213    n_genes: usize,
214    n_samples: usize,
215    factors: &[f64],
216) -> Result<Vec<f64>> {
217    validate_matrix(counts, n_genes, n_samples)?;
218    if factors.len() != n_samples {
219        return Err(CyaneaError::InvalidInput(format!(
220            "normalize_by_size_factors: factors length ({}) != n_samples ({})",
221            factors.len(),
222            n_samples,
223        )));
224    }
225    let mut out = vec![0.0; counts.len()];
226    for i in 0..n_genes {
227        for j in 0..n_samples {
228            let idx = i * n_samples + j;
229            out[idx] = counts[idx] / factors[j];
230        }
231    }
232    Ok(out)
233}
234
235// ── Tests ────────────────────────────────────────────────────────────────────
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    const TOL: f64 = 1e-6;
242
243    #[test]
244    fn cpm_column_sums_to_1m() {
245        // 2 genes, 3 samples
246        let counts = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0];
247        let result = cpm(&counts, 2, 3).unwrap();
248        // Column sums: [50, 70, 90]
249        for j in 0..3 {
250            let col_sum: f64 = (0..2).map(|i| result[i * 3 + j]).sum();
251            assert!((col_sum - 1e6).abs() < 1.0, "col {j} sum={col_sum}");
252        }
253    }
254
255    #[test]
256    fn tpm_column_sums_to_1m() {
257        let counts = [100.0, 200.0, 300.0, 400.0];
258        let lengths = [1000.0, 2000.0];
259        let result = tpm(&counts, 2, 2, &lengths).unwrap();
260        for j in 0..2 {
261            let col_sum: f64 = (0..2).map(|i| result[i * 2 + j]).sum();
262            assert!((col_sum - 1e6).abs() < 1.0, "col {j} sum={col_sum}");
263        }
264    }
265
266    #[test]
267    fn tpm_length_normalization() {
268        // Two genes with same count but different lengths:
269        // shorter gene should have higher TPM
270        let counts = [100.0, 100.0];
271        let lengths = [500.0, 2000.0];
272        let result = tpm(&counts, 2, 1, &lengths).unwrap();
273        assert!(result[0] > result[1], "shorter gene should have higher TPM");
274    }
275
276    #[test]
277    fn fpkm_known_values() {
278        // 1 gene, 1 sample: count=100, lib_size=100, length=1000
279        // FPKM = 100 * 1e9 / (100 * 1000) = 1_000_000
280        let counts = [100.0];
281        let lengths = [1000.0];
282        let result = fpkm(&counts, 1, 1, &lengths).unwrap();
283        assert!((result[0] - 1_000_000.0).abs() < TOL);
284    }
285
286    #[test]
287    fn fpkm_to_tpm_relationship() {
288        // TPM_i = FPKM_i / sum(FPKM_j) * 1e6
289        let counts = [100.0, 200.0, 50.0, 300.0];
290        let lengths = [1000.0, 2000.0];
291        let fpkm_vals = fpkm(&counts, 2, 2, &lengths).unwrap();
292        let tpm_vals = tpm(&counts, 2, 2, &lengths).unwrap();
293        for j in 0..2 {
294            let fpkm_sum: f64 = (0..2).map(|i| fpkm_vals[i * 2 + j]).sum();
295            for i in 0..2 {
296                let tpm_from_fpkm = fpkm_vals[i * 2 + j] / fpkm_sum * 1e6;
297                assert!(
298                    (tpm_from_fpkm - tpm_vals[i * 2 + j]).abs() < 1.0,
299                    "gene {i} sample {j}: tpm_from_fpkm={tpm_from_fpkm}, tpm={}", tpm_vals[i * 2 + j]
300                );
301            }
302        }
303    }
304
305    #[test]
306    fn size_factors_equal_libraries() {
307        // Two identical samples → size factors should be ~1
308        let counts = [10.0, 10.0, 20.0, 20.0, 30.0, 30.0];
309        let sf = size_factors(&counts, 3, 2).unwrap();
310        assert!((sf[0] - 1.0).abs() < TOL);
311        assert!((sf[1] - 1.0).abs() < TOL);
312    }
313
314    #[test]
315    fn size_factors_doubled_library() {
316        // Second sample has 2x counts → size factor ~2
317        let counts = [10.0, 20.0, 20.0, 40.0, 30.0, 60.0];
318        let sf = size_factors(&counts, 3, 2).unwrap();
319        let ratio = sf[1] / sf[0];
320        assert!((ratio - 2.0).abs() < TOL, "ratio={ratio}");
321    }
322
323    #[test]
324    fn size_factors_skip_zeros() {
325        // Gene 0 has a zero in sample 0 → excluded from calculation
326        let counts = [0.0, 10.0, 20.0, 20.0, 30.0, 30.0];
327        let sf = size_factors(&counts, 3, 2).unwrap();
328        // Without gene 0: genes 1,2 are equal → factors ~1
329        assert!((sf[0] - 1.0).abs() < TOL);
330        assert!((sf[1] - 1.0).abs() < TOL);
331    }
332
333    #[test]
334    fn normalize_roundtrip() {
335        let counts = [10.0, 20.0, 30.0, 60.0];
336        let sf = size_factors(&counts, 2, 2).unwrap();
337        let normed = normalize_by_size_factors(&counts, 2, 2, &sf).unwrap();
338        // After normalization, re-computing size factors should give ~1
339        let sf2 = size_factors(&normed, 2, 2).unwrap();
340        for &s in &sf2 {
341            assert!((s - 1.0).abs() < 1e-4, "s={s}");
342        }
343    }
344
345    #[test]
346    fn dimension_mismatch() {
347        assert!(cpm(&[1.0, 2.0], 3, 1).is_err());
348        assert!(tpm(&[1.0, 2.0], 2, 1, &[100.0]).is_err()); // lengths wrong
349        assert!(fpkm(&[1.0], 1, 1, &[100.0, 200.0]).is_err());
350        assert!(normalize_by_size_factors(&[1.0, 2.0], 1, 2, &[1.0]).is_err());
351    }
352
353    #[test]
354    fn empty_matrix() {
355        assert!(cpm(&[], 0, 0).is_err());
356    }
357}