transcriptomic_rs/
normalization.rs1use arrow::{
8 array::{Array, Float64Array},
9 datatypes::{DataType, Field, Schema},
10 record_batch::RecordBatch,
11};
12
13use crate::{Error, ExpressionMatrix, Result};
14
15pub struct Normalize;
21
22impl Normalize {
23 pub fn log2(matrix: &ExpressionMatrix) -> Result<ExpressionMatrix> {
33 let mut columns: Vec<std::sync::Arc<dyn Array>> = Vec::with_capacity(matrix.samples.len());
34
35 for col_idx in 0..matrix.values.num_columns() {
36 let col = matrix.values.column(col_idx);
37 let array = col
38 .as_any()
39 .downcast_ref::<Float64Array>()
40 .ok_or_else(|| Error::Normalization("Expected Float64Array".to_string()))?;
41
42 let transformed: Vec<Option<f64>> = (0..array.len())
43 .map(|i| {
44 if array.is_null(i) {
45 None
46 } else {
47 let x = array.value(i);
48 Some((x + 1.0).log2())
49 }
50 })
51 .collect();
52
53 columns.push(std::sync::Arc::new(Float64Array::from(transformed)));
54 }
55
56 let schema = Schema::new(
57 matrix
58 .samples
59 .iter()
60 .map(|s| Field::new(s.clone(), DataType::Float64, true))
61 .collect::<Vec<_>>(),
62 );
63
64 let batch = RecordBatch::try_new(std::sync::Arc::new(schema), columns)?;
65
66 Ok(ExpressionMatrix {
67 genes: matrix.genes.clone(),
68 samples: matrix.samples.clone(),
69 values: batch,
70 })
71 }
72
73 pub fn quantile(matrix: &ExpressionMatrix) -> Result<ExpressionMatrix> {
91 let num_genes = matrix.genes.len();
92 let num_samples = matrix.samples.len();
93
94 if num_genes == 0 || num_samples == 0 {
95 return Ok(matrix.clone());
96 }
97
98 let mut sample_values: Vec<Vec<Option<f64>>> = Vec::with_capacity(num_samples);
100 for col_idx in 0..num_samples {
101 let col = matrix.values.column(col_idx);
102 let array = col
103 .as_any()
104 .downcast_ref::<Float64Array>()
105 .ok_or_else(|| Error::Normalization("Expected Float64Array".to_string()))?;
106
107 let values: Vec<Option<f64>> = (0..num_genes)
108 .map(|i| {
109 if array.is_null(i) {
110 None
111 } else {
112 Some(array.value(i))
113 }
114 })
115 .collect();
116 sample_values.push(values);
117 }
118
119 let mut target_distribution: Vec<f64> = Vec::with_capacity(num_genes);
122
123 let mut sorted_per_sample: Vec<Vec<f64>> = Vec::with_capacity(num_samples);
125 for values in &sample_values {
126 let mut sorted: Vec<f64> = values.iter().flatten().copied().collect();
127 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
128 sorted_per_sample.push(sorted);
129 }
130
131 let max_sorted_len = sorted_per_sample.iter().map(Vec::len).max().unwrap_or(0);
133
134 for rank in 0..max_sorted_len {
136 let mut sum = 0.0;
137 let mut count = 0;
138 for sorted in &sorted_per_sample {
139 if let Some(&val) = sorted.get(rank) {
140 sum += val;
141 count += 1;
142 }
143 }
144 if count > 0 {
145 target_distribution.push(sum / f64::from(count));
146 }
147 }
148
149 let mut normalized_columns: Vec<std::sync::Arc<dyn Array>> =
151 Vec::with_capacity(num_samples);
152
153 for (sample_idx, values) in sample_values.iter().enumerate() {
154 let sorted = &sorted_per_sample[sample_idx];
155
156 let normalized: Vec<Option<f64>> = values
157 .iter()
158 .map(|&opt_val| {
159 if let Some(val) = opt_val {
160 if let Ok(pos) = sorted.binary_search_by(|probe| {
162 probe.partial_cmp(&val).unwrap_or(std::cmp::Ordering::Equal)
163 }) {
164 let mut start = pos;
166 let mut end = pos;
167 while start > 0 && sorted.get(start - 1) == Some(&val) {
168 start -= 1;
169 }
170 while end + 1 < sorted.len() && sorted.get(end + 1) == Some(&val) {
171 end += 1;
172 }
173 let rank = usize::midpoint(start, end);
174 target_distribution.get(rank).copied()
175 } else {
176 None
178 }
179 } else {
180 None
181 }
182 })
183 .collect();
184
185 normalized_columns.push(std::sync::Arc::new(Float64Array::from(normalized)));
186 }
187
188 let schema = Schema::new(
189 matrix
190 .samples
191 .iter()
192 .map(|s| Field::new(s.clone(), DataType::Float64, true))
193 .collect::<Vec<_>>(),
194 );
195
196 let batch = RecordBatch::try_new(std::sync::Arc::new(schema), normalized_columns)?;
197
198 Ok(ExpressionMatrix {
199 genes: matrix.genes.clone(),
200 samples: matrix.samples.clone(),
201 values: batch,
202 })
203 }
204
205 pub fn z_score_per_gene(matrix: &ExpressionMatrix) -> Result<ExpressionMatrix> {
215 let num_genes = matrix.genes.len();
216 let num_samples = matrix.samples.len();
217
218 if num_genes == 0 || num_samples == 0 {
219 return Ok(matrix.clone());
220 }
221
222 let mut gene_values: Vec<Vec<Option<f64>>> =
224 vec![Vec::with_capacity(num_samples); num_genes];
225
226 for col_idx in 0..num_samples {
227 let col = matrix.values.column(col_idx);
228 let array = col
229 .as_any()
230 .downcast_ref::<Float64Array>()
231 .ok_or_else(|| Error::Normalization("Expected Float64Array".to_string()))?;
232
233 for (gene_idx, opt_val) in (0..num_genes)
234 .map(|i| {
235 if array.is_null(i) {
236 None
237 } else {
238 Some(array.value(i))
239 }
240 })
241 .enumerate()
242 {
243 gene_values[gene_idx].push(opt_val);
244 }
245 }
246
247 let mut z_score_columns: Vec<Vec<Option<f64>>> =
249 vec![Vec::with_capacity(num_genes); num_samples];
250
251 for gene_row in &gene_values {
252 let non_null_values: Vec<f64> = gene_row.iter().flatten().copied().collect();
253
254 if non_null_values.len() < 2 {
255 for (col_idx, &orig) in gene_row.iter().enumerate() {
257 z_score_columns[col_idx].push(orig);
258 }
259 continue;
260 }
261
262 #[allow(clippy::cast_precision_loss)]
263 let n = non_null_values.len() as f64;
264 let mean = non_null_values.iter().sum::<f64>() / n;
265 let variance = non_null_values
266 .iter()
267 .map(|&x| (x - mean).powi(2))
268 .sum::<f64>()
269 / n;
270 let std = variance.sqrt();
271
272 if std < f64::EPSILON {
273 for (col_idx, &orig) in gene_row.iter().enumerate() {
275 z_score_columns[col_idx].push(orig);
276 }
277 } else {
278 for (col_idx, &opt_val) in gene_row.iter().enumerate() {
279 let z = opt_val.map(|v| (v - mean) / std);
280 z_score_columns[col_idx].push(z);
281 }
282 }
283 }
284
285 let mut columns: Vec<std::sync::Arc<dyn Array>> = Vec::with_capacity(num_samples);
287 for col_values in z_score_columns {
288 columns.push(std::sync::Arc::new(Float64Array::from(col_values)));
289 }
290
291 let schema = Schema::new(
292 matrix
293 .samples
294 .iter()
295 .map(|s| Field::new(s.clone(), DataType::Float64, true))
296 .collect::<Vec<_>>(),
297 );
298
299 let batch = RecordBatch::try_new(std::sync::Arc::new(schema), columns)?;
300
301 Ok(ExpressionMatrix {
302 genes: matrix.genes.clone(),
303 samples: matrix.samples.clone(),
304 values: batch,
305 })
306 }
307}