Skip to main content

transcriptomic_rs/
normalization.rs

1//! Normalization methods for expression matrices
2//!
3//! All normalization methods are **explicit and composable**.
4//! They take a reference to an `ExpressionMatrix` and return a new
5//! `ExpressionMatrix` with transformed values. No hidden defaults are applied.
6
7use arrow::{
8    array::{Array, Float64Array},
9    datatypes::{DataType, Field, Schema},
10    record_batch::RecordBatch,
11};
12
13use crate::{Error, ExpressionMatrix, Result};
14
15/// Normalization methods
16///
17/// Each method transforms expression values in a specific way, returning a new
18/// `ExpressionMatrix`. Methods are pure functions: they do not modify the input
19/// matrix.
20pub struct Normalize;
21
22impl Normalize {
23    /// Log2 transformation: log2(x+1)
24    ///
25    /// Applies `log2(x + 1)` to all non-null expression values.
26    /// This transformation compresses the dynamic range and handles
27    /// zero values gracefully (log2(0+1) = 0).
28    ///
29    /// # Errors
30    ///
31    /// Returns an error if Arrow data construction fails.
32    pub fn log2(matrix: &ExpressionMatrix) -> Result<ExpressionMatrix> {
33        let mut columns: Vec<std::sync::Arc<dyn Array>> = Vec::with_capacity(matrix.samples.len());
34
35        for col_idx in 0..matrix.values.num_columns() {
36            let col = matrix.values.column(col_idx);
37            let array = col
38                .as_any()
39                .downcast_ref::<Float64Array>()
40                .ok_or_else(|| Error::Normalization("Expected Float64Array".to_string()))?;
41
42            let transformed: Vec<Option<f64>> = (0..array.len())
43                .map(|i| {
44                    if array.is_null(i) {
45                        None
46                    } else {
47                        let x = array.value(i);
48                        Some((x + 1.0).log2())
49                    }
50                })
51                .collect();
52
53            columns.push(std::sync::Arc::new(Float64Array::from(transformed)));
54        }
55
56        let schema = Schema::new(
57            matrix
58                .samples
59                .iter()
60                .map(|s| Field::new(s.clone(), DataType::Float64, true))
61                .collect::<Vec<_>>(),
62        );
63
64        let batch = RecordBatch::try_new(std::sync::Arc::new(schema), columns)?;
65
66        Ok(ExpressionMatrix {
67            genes: matrix.genes.clone(),
68            samples: matrix.samples.clone(),
69            values: batch,
70        })
71    }
72
73    /// Quantile normalization across samples
74    ///
75    /// Normalizes the distribution of expression values across all samples
76    /// to have the same distribution (the average quantiles across samples).
77    /// This ensures that differences in expression are due to biology, not
78    /// technical variation.
79    ///
80    /// # Algorithm
81    ///
82    /// 1. Sort values within each sample (column) and compute mean ranks
83    /// 2. Replace each value with the mean of values at that rank across
84    ///    samples
85    /// 3. Unsort to restore original gene order
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if Arrow data construction fails.
90    pub fn quantile(matrix: &ExpressionMatrix) -> Result<ExpressionMatrix> {
91        let num_genes = matrix.genes.len();
92        let num_samples = matrix.samples.len();
93
94        if num_genes == 0 || num_samples == 0 {
95            return Ok(matrix.clone());
96        }
97
98        // Collect all columns into Vec<Vec<Option<f64>>>
99        let mut sample_values: Vec<Vec<Option<f64>>> = Vec::with_capacity(num_samples);
100        for col_idx in 0..num_samples {
101            let col = matrix.values.column(col_idx);
102            let array = col
103                .as_any()
104                .downcast_ref::<Float64Array>()
105                .ok_or_else(|| Error::Normalization("Expected Float64Array".to_string()))?;
106
107            let values: Vec<Option<f64>> = (0..num_genes)
108                .map(|i| {
109                    if array.is_null(i) {
110                        None
111                    } else {
112                        Some(array.value(i))
113                    }
114                })
115                .collect();
116            sample_values.push(values);
117        }
118
119        // For each gene position, collect all non-null values across samples
120        // Sort them and compute target distribution (mean at each rank)
121        let mut target_distribution: Vec<f64> = Vec::with_capacity(num_genes);
122
123        // Create sorted non-null values per sample
124        let mut sorted_per_sample: Vec<Vec<f64>> = Vec::with_capacity(num_samples);
125        for values in &sample_values {
126            let mut sorted: Vec<f64> = values.iter().flatten().copied().collect();
127            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
128            sorted_per_sample.push(sorted);
129        }
130
131        // Compute max length of sorted arrays
132        let max_sorted_len = sorted_per_sample.iter().map(Vec::len).max().unwrap_or(0);
133
134        // For each rank, compute mean across samples that have a value at that rank
135        for rank in 0..max_sorted_len {
136            let mut sum = 0.0;
137            let mut count = 0;
138            for sorted in &sorted_per_sample {
139                if let Some(&val) = sorted.get(rank) {
140                    sum += val;
141                    count += 1;
142                }
143            }
144            if count > 0 {
145                target_distribution.push(sum / f64::from(count));
146            }
147        }
148
149        // For each sample, assign quantile-normalized values
150        let mut normalized_columns: Vec<std::sync::Arc<dyn Array>> =
151            Vec::with_capacity(num_samples);
152
153        for (sample_idx, values) in sample_values.iter().enumerate() {
154            let sorted = &sorted_per_sample[sample_idx];
155
156            let normalized: Vec<Option<f64>> = values
157                .iter()
158                .map(|&opt_val| {
159                    if let Some(val) = opt_val {
160                        // Find rank of this value in the sorted array
161                        if let Ok(pos) = sorted.binary_search_by(|probe| {
162                            probe.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
163                        }) {
164                            // Handle ties: find middle rank
165                            let mut start = pos;
166                            let mut end = pos;
167                            while start > 0 && sorted.get(start - 1) == Some(&val) {
168                                start -= 1;
169                            }
170                            while end + 1 < sorted.len() && sorted.get(end + 1) == Some(&val) {
171                                end += 1;
172                            }
173                            let rank = usize::midpoint(start, end);
174                            target_distribution.get(rank).copied()
175                        } else {
176                            // Should not happen if value is from the sorted list
177                            None
178                        }
179                    } else {
180                        None
181                    }
182                })
183                .collect();
184
185            normalized_columns.push(std::sync::Arc::new(Float64Array::from(normalized)));
186        }
187
188        let schema = Schema::new(
189            matrix
190                .samples
191                .iter()
192                .map(|s| Field::new(s.clone(), DataType::Float64, true))
193                .collect::<Vec<_>>(),
194        );
195
196        let batch = RecordBatch::try_new(std::sync::Arc::new(schema), normalized_columns)?;
197
198        Ok(ExpressionMatrix {
199            genes: matrix.genes.clone(),
200            samples: matrix.samples.clone(),
201            values: batch,
202        })
203    }
204
205    /// Z-score normalization per gene (row-wise)
206    ///
207    /// For each gene (row), computes: `(x - mean) / std`
208    /// where mean and std are calculated across all samples for that gene.
209    /// Genes with zero variance (std = 0) are left unchanged.
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if Arrow data construction fails.
214    pub fn z_score_per_gene(matrix: &ExpressionMatrix) -> Result<ExpressionMatrix> {
215        let num_genes = matrix.genes.len();
216        let num_samples = matrix.samples.len();
217
218        if num_genes == 0 || num_samples == 0 {
219            return Ok(matrix.clone());
220        }
221
222        // Collect values per gene (row-wise)
223        let mut gene_values: Vec<Vec<Option<f64>>> =
224            vec![Vec::with_capacity(num_samples); num_genes];
225
226        for col_idx in 0..num_samples {
227            let col = matrix.values.column(col_idx);
228            let array = col
229                .as_any()
230                .downcast_ref::<Float64Array>()
231                .ok_or_else(|| Error::Normalization("Expected Float64Array".to_string()))?;
232
233            for (gene_idx, opt_val) in (0..num_genes)
234                .map(|i| {
235                    if array.is_null(i) {
236                        None
237                    } else {
238                        Some(array.value(i))
239                    }
240                })
241                .enumerate()
242            {
243                gene_values[gene_idx].push(opt_val);
244            }
245        }
246
247        // Compute z-scores per gene
248        let mut z_score_columns: Vec<Vec<Option<f64>>> =
249            vec![Vec::with_capacity(num_genes); num_samples];
250
251        for gene_row in &gene_values {
252            let non_null_values: Vec<f64> = gene_row.iter().flatten().copied().collect();
253
254            if non_null_values.len() < 2 {
255                // Not enough values for z-score, keep original
256                for (col_idx, &orig) in gene_row.iter().enumerate() {
257                    z_score_columns[col_idx].push(orig);
258                }
259                continue;
260            }
261
262            #[allow(clippy::cast_precision_loss)]
263            let n = non_null_values.len() as f64;
264            let mean = non_null_values.iter().sum::<f64>() / n;
265            let variance = non_null_values
266                .iter()
267                .map(|&x| (x - mean).powi(2))
268                .sum::<f64>()
269                / n;
270            let std = variance.sqrt();
271
272            if std < f64::EPSILON {
273                // Zero variance, keep original values
274                for (col_idx, &orig) in gene_row.iter().enumerate() {
275                    z_score_columns[col_idx].push(orig);
276                }
277            } else {
278                for (col_idx, &opt_val) in gene_row.iter().enumerate() {
279                    let z = opt_val.map(|v| (v - mean) / std);
280                    z_score_columns[col_idx].push(z);
281                }
282            }
283        }
284
285        // Build Arrow arrays
286        let mut columns: Vec<std::sync::Arc<dyn Array>> = Vec::with_capacity(num_samples);
287        for col_values in z_score_columns {
288            columns.push(std::sync::Arc::new(Float64Array::from(col_values)));
289        }
290
291        let schema = Schema::new(
292            matrix
293                .samples
294                .iter()
295                .map(|s| Field::new(s.clone(), DataType::Float64, true))
296                .collect::<Vec<_>>(),
297        );
298
299        let batch = RecordBatch::try_new(std::sync::Arc::new(schema), columns)?;
300
301        Ok(ExpressionMatrix {
302            genes: matrix.genes.clone(),
303            samples: matrix.samples.clone(),
304            values: batch,
305        })
306    }
307}