Skip to main content

cyanea_omics/
expr.rs

1//! Dense expression matrix for bulk omics data.
2//!
3//! [`ExpressionMatrix`] stores a row-major dense matrix of `f64` values
4//! (n_features × n_samples) with associated feature and sample names.
5//! Typical use cases include RNA-seq gene expression, proteomics intensity
6//! values, and metabolomics abundances.
7
8use cyanea_core::{CyaneaError, Result, Summarizable};
9
10/// A dense, row-major expression matrix (features × samples).
11#[derive(Debug, Clone)]
12pub struct ExpressionMatrix {
13    data: Vec<f64>,
14    n_features: usize,
15    n_samples: usize,
16    feature_names: Vec<String>,
17    sample_names: Vec<String>,
18}
19
20impl ExpressionMatrix {
21    /// Create a matrix from row-major 2D data.
22    ///
23    /// Each inner `Vec` is one feature (row) with `n_samples` values.
24    pub fn new(
25        data: Vec<Vec<f64>>,
26        feature_names: Vec<String>,
27        sample_names: Vec<String>,
28    ) -> Result<Self> {
29        let n_features = data.len();
30        let n_samples = sample_names.len();
31
32        if feature_names.len() != n_features {
33            return Err(CyaneaError::InvalidInput(format!(
34                "feature_names length ({}) does not match row count ({n_features})",
35                feature_names.len()
36            )));
37        }
38
39        let mut flat = Vec::with_capacity(n_features * n_samples);
40        for (i, row) in data.iter().enumerate() {
41            if row.len() != n_samples {
42                return Err(CyaneaError::InvalidInput(format!(
43                    "row {i} has {} columns, expected {n_samples}",
44                    row.len()
45                )));
46            }
47            flat.extend_from_slice(row);
48        }
49
50        Ok(Self {
51            data: flat,
52            n_features,
53            n_samples,
54            feature_names,
55            sample_names,
56        })
57    }
58
59    /// Create a zero-filled matrix.
60    pub fn zeros(
61        n_features: usize,
62        n_samples: usize,
63        feature_names: Vec<String>,
64        sample_names: Vec<String>,
65    ) -> Result<Self> {
66        if feature_names.len() != n_features {
67            return Err(CyaneaError::InvalidInput(format!(
68                "feature_names length ({}) does not match n_features ({n_features})",
69                feature_names.len()
70            )));
71        }
72        if sample_names.len() != n_samples {
73            return Err(CyaneaError::InvalidInput(format!(
74                "sample_names length ({}) does not match n_samples ({n_samples})",
75                sample_names.len()
76            )));
77        }
78        Ok(Self {
79            data: vec![0.0; n_features * n_samples],
80            n_features,
81            n_samples,
82            feature_names,
83            sample_names,
84        })
85    }
86
87    /// (n_features, n_samples).
88    pub fn shape(&self) -> (usize, usize) {
89        (self.n_features, self.n_samples)
90    }
91
92    /// Get a single value by feature and sample index.
93    pub fn get(&self, feature_idx: usize, sample_idx: usize) -> Option<f64> {
94        if feature_idx < self.n_features && sample_idx < self.n_samples {
95            Some(self.data[feature_idx * self.n_samples + sample_idx])
96        } else {
97            None
98        }
99    }
100
101    /// Set a single value. Returns an error if indices are out of bounds.
102    pub fn set(&mut self, feature_idx: usize, sample_idx: usize, value: f64) -> Result<()> {
103        if feature_idx >= self.n_features || sample_idx >= self.n_samples {
104            return Err(CyaneaError::InvalidInput(format!(
105                "index ({feature_idx}, {sample_idx}) out of bounds for ({}, {})",
106                self.n_features, self.n_samples
107            )));
108        }
109        self.data[feature_idx * self.n_samples + sample_idx] = value;
110        Ok(())
111    }
112
113    /// A slice of one feature's expression across all samples.
114    pub fn row(&self, feature_idx: usize) -> Option<&[f64]> {
115        if feature_idx < self.n_features {
116            let start = feature_idx * self.n_samples;
117            Some(&self.data[start..start + self.n_samples])
118        } else {
119            None
120        }
121    }
122
123    /// All feature values for a single sample (column copy, since data is row-major).
124    pub fn column(&self, sample_idx: usize) -> Option<Vec<f64>> {
125        if sample_idx >= self.n_samples {
126            return None;
127        }
128        let col: Vec<f64> = (0..self.n_features)
129            .map(|r| self.data[r * self.n_samples + sample_idx])
130            .collect();
131        Some(col)
132    }
133
134    /// Mean expression of a feature across all samples.
135    pub fn row_mean(&self, feature_idx: usize) -> Option<f64> {
136        let row = self.row(feature_idx)?;
137        if row.is_empty() {
138            return Some(0.0);
139        }
140        Some(row.iter().sum::<f64>() / row.len() as f64)
141    }
142
143    /// Mean expression across all features for a sample.
144    pub fn column_mean(&self, sample_idx: usize) -> Option<f64> {
145        let col = self.column(sample_idx)?;
146        if col.is_empty() {
147            return Some(0.0);
148        }
149        Some(col.iter().sum::<f64>() / col.len() as f64)
150    }
151
152    /// Transpose the matrix, swapping features and samples.
153    pub fn transpose(&self) -> ExpressionMatrix {
154        let mut transposed = vec![0.0; self.data.len()];
155        for r in 0..self.n_features {
156            for c in 0..self.n_samples {
157                transposed[c * self.n_features + r] = self.data[r * self.n_samples + c];
158            }
159        }
160        ExpressionMatrix {
161            data: transposed,
162            n_features: self.n_samples,
163            n_samples: self.n_features,
164            feature_names: self.sample_names.clone(),
165            sample_names: self.feature_names.clone(),
166        }
167    }
168
169    /// Subset the matrix to the given feature (row) indices.
170    pub fn filter_features(&self, indices: &[usize]) -> Result<ExpressionMatrix> {
171        let mut data = Vec::with_capacity(indices.len() * self.n_samples);
172        let mut names = Vec::with_capacity(indices.len());
173
174        for &i in indices {
175            if i >= self.n_features {
176                return Err(CyaneaError::InvalidInput(format!(
177                    "feature index {i} out of bounds (n_features={})",
178                    self.n_features
179                )));
180            }
181            let start = i * self.n_samples;
182            data.extend_from_slice(&self.data[start..start + self.n_samples]);
183            names.push(self.feature_names[i].clone());
184        }
185
186        Ok(ExpressionMatrix {
187            data,
188            n_features: indices.len(),
189            n_samples: self.n_samples,
190            feature_names: names,
191            sample_names: self.sample_names.clone(),
192        })
193    }
194
195    /// Subset the matrix to the given sample (column) indices.
196    pub fn filter_samples(&self, indices: &[usize]) -> Result<ExpressionMatrix> {
197        for &i in indices {
198            if i >= self.n_samples {
199                return Err(CyaneaError::InvalidInput(format!(
200                    "sample index {i} out of bounds (n_samples={})",
201                    self.n_samples
202                )));
203            }
204        }
205
206        let mut data = Vec::with_capacity(self.n_features * indices.len());
207        let mut names = Vec::with_capacity(indices.len());
208
209        for &i in indices {
210            names.push(self.sample_names[i].clone());
211        }
212
213        for r in 0..self.n_features {
214            for &c in indices {
215                data.push(self.data[r * self.n_samples + c]);
216            }
217        }
218
219        Ok(ExpressionMatrix {
220            data,
221            n_features: self.n_features,
222            n_samples: indices.len(),
223            feature_names: self.feature_names.clone(),
224            sample_names: names,
225        })
226    }
227
228    /// The underlying flat data as a slice (row-major, n_features × n_samples).
229    pub fn as_slice(&self) -> &[f64] {
230        &self.data
231    }
232
233    /// Feature (gene/protein) names.
234    pub fn feature_names(&self) -> &[String] {
235        &self.feature_names
236    }
237
238    /// Sample names.
239    pub fn sample_names(&self) -> &[String] {
240        &self.sample_names
241    }
242
243    /// Log2-transform all values: `log2(x + pseudocount)`.
244    ///
245    /// Commonly used to normalize RNA-seq count data.
246    pub fn log_transform(&self, pseudocount: f64) -> ExpressionMatrix {
247        let data: Vec<f64> = self
248            .data
249            .iter()
250            .map(|&x| (x + pseudocount).log2())
251            .collect();
252        ExpressionMatrix {
253            data,
254            n_features: self.n_features,
255            n_samples: self.n_samples,
256            feature_names: self.feature_names.clone(),
257            sample_names: self.sample_names.clone(),
258        }
259    }
260}
261
262impl Summarizable for ExpressionMatrix {
263    fn summary(&self) -> String {
264        format!(
265            "ExpressionMatrix: {} features \u{00d7} {} samples",
266            self.n_features, self.n_samples
267        )
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    fn sample_matrix() -> ExpressionMatrix {
276        ExpressionMatrix::new(
277            vec![
278                vec![1.0, 2.0, 3.0],
279                vec![4.0, 5.0, 6.0],
280            ],
281            vec!["gene1".into(), "gene2".into()],
282            vec!["s1".into(), "s2".into(), "s3".into()],
283        )
284        .unwrap()
285    }
286
287    #[test]
288    fn test_construction() {
289        let m = sample_matrix();
290        assert_eq!(m.shape(), (2, 3));
291    }
292
293    #[test]
294    fn test_dimension_mismatch() {
295        let result = ExpressionMatrix::new(
296            vec![vec![1.0, 2.0]],
297            vec!["gene1".into(), "gene2".into()], // 2 names, 1 row
298            vec!["s1".into(), "s2".into()],
299        );
300        assert!(result.is_err());
301    }
302
303    #[test]
304    fn test_row_length_mismatch() {
305        let result = ExpressionMatrix::new(
306            vec![vec![1.0, 2.0], vec![3.0]], // second row too short
307            vec!["gene1".into(), "gene2".into()],
308            vec!["s1".into(), "s2".into()],
309        );
310        assert!(result.is_err());
311    }
312
313    #[test]
314    fn test_zeros() {
315        let m = ExpressionMatrix::zeros(
316            2,
317            3,
318            vec!["a".into(), "b".into()],
319            vec!["x".into(), "y".into(), "z".into()],
320        )
321        .unwrap();
322        assert_eq!(m.get(0, 0), Some(0.0));
323        assert_eq!(m.get(1, 2), Some(0.0));
324    }
325
326    #[test]
327    fn test_get_set() {
328        let mut m = sample_matrix();
329        assert_eq!(m.get(0, 0), Some(1.0));
330        assert_eq!(m.get(1, 2), Some(6.0));
331        assert_eq!(m.get(2, 0), None);
332
333        m.set(0, 0, 99.0).unwrap();
334        assert_eq!(m.get(0, 0), Some(99.0));
335        assert!(m.set(5, 0, 1.0).is_err());
336    }
337
338    #[test]
339    fn test_row() {
340        let m = sample_matrix();
341        assert_eq!(m.row(0), Some(&[1.0, 2.0, 3.0][..]));
342        assert_eq!(m.row(1), Some(&[4.0, 5.0, 6.0][..]));
343        assert_eq!(m.row(2), None);
344    }
345
346    #[test]
347    fn test_column() {
348        let m = sample_matrix();
349        assert_eq!(m.column(0), Some(vec![1.0, 4.0]));
350        assert_eq!(m.column(2), Some(vec![3.0, 6.0]));
351        assert_eq!(m.column(3), None);
352    }
353
354    #[test]
355    fn test_row_mean() {
356        let m = sample_matrix();
357        assert_eq!(m.row_mean(0), Some(2.0)); // (1+2+3)/3
358        assert_eq!(m.row_mean(1), Some(5.0)); // (4+5+6)/3
359    }
360
361    #[test]
362    fn test_column_mean() {
363        let m = sample_matrix();
364        assert_eq!(m.column_mean(0), Some(2.5)); // (1+4)/2
365        assert_eq!(m.column_mean(1), Some(3.5)); // (2+5)/2
366    }
367
368    #[test]
369    fn test_transpose() {
370        let m = sample_matrix();
371        let t = m.transpose();
372        assert_eq!(t.shape(), (3, 2));
373        assert_eq!(t.get(0, 0), Some(1.0));
374        assert_eq!(t.get(0, 1), Some(4.0));
375        assert_eq!(t.get(2, 1), Some(6.0));
376    }
377
378    #[test]
379    fn test_filter_features() {
380        let m = sample_matrix();
381        let filtered = m.filter_features(&[1]).unwrap();
382        assert_eq!(filtered.shape(), (1, 3));
383        assert_eq!(filtered.get(0, 0), Some(4.0));
384
385        assert!(m.filter_features(&[5]).is_err());
386    }
387
388    #[test]
389    fn test_filter_samples() {
390        let m = sample_matrix();
391        let filtered = m.filter_samples(&[0, 2]).unwrap();
392        assert_eq!(filtered.shape(), (2, 2));
393        assert_eq!(filtered.get(0, 0), Some(1.0));
394        assert_eq!(filtered.get(0, 1), Some(3.0));
395        assert_eq!(filtered.get(1, 0), Some(4.0));
396
397        assert!(m.filter_samples(&[5]).is_err());
398    }
399
400    #[test]
401    fn test_log_transform() {
402        let m = sample_matrix();
403        let logged = m.log_transform(1.0);
404        // log2(1.0 + 1.0) = 1.0
405        assert!((logged.get(0, 0).unwrap() - 1.0).abs() < 1e-10);
406        // log2(4.0 + 1.0) = log2(5) ≈ 2.3219
407        assert!((logged.get(1, 0).unwrap() - 5.0_f64.log2()).abs() < 1e-10);
408    }
409
410    #[test]
411    fn test_summary() {
412        let m = sample_matrix();
413        assert_eq!(m.summary(), "ExpressionMatrix: 2 features \u{00d7} 3 samples");
414    }
415
416    #[test]
417    fn test_empty_matrix() {
418        let m = ExpressionMatrix::new(
419            vec![],
420            vec![],
421            vec!["s1".into()],
422        )
423        .unwrap();
424        assert_eq!(m.shape(), (0, 1));
425    }
426
427    #[test]
428    fn test_as_slice() {
429        let m = sample_matrix();
430        assert_eq!(m.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
431    }
432
433    #[test]
434    fn test_feature_names() {
435        let m = sample_matrix();
436        assert_eq!(m.feature_names(), &["gene1", "gene2"]);
437    }
438
439    #[test]
440    fn test_sample_names() {
441        let m = sample_matrix();
442        assert_eq!(m.sample_names(), &["s1", "s2", "s3"]);
443    }
444}