1use cyanea_core::{CyaneaError, Result, Summarizable};
9
10#[derive(Debug, Clone)]
12pub struct ExpressionMatrix {
13 data: Vec<f64>,
14 n_features: usize,
15 n_samples: usize,
16 feature_names: Vec<String>,
17 sample_names: Vec<String>,
18}
19
20impl ExpressionMatrix {
21 pub fn new(
25 data: Vec<Vec<f64>>,
26 feature_names: Vec<String>,
27 sample_names: Vec<String>,
28 ) -> Result<Self> {
29 let n_features = data.len();
30 let n_samples = sample_names.len();
31
32 if feature_names.len() != n_features {
33 return Err(CyaneaError::InvalidInput(format!(
34 "feature_names length ({}) does not match row count ({n_features})",
35 feature_names.len()
36 )));
37 }
38
39 let mut flat = Vec::with_capacity(n_features * n_samples);
40 for (i, row) in data.iter().enumerate() {
41 if row.len() != n_samples {
42 return Err(CyaneaError::InvalidInput(format!(
43 "row {i} has {} columns, expected {n_samples}",
44 row.len()
45 )));
46 }
47 flat.extend_from_slice(row);
48 }
49
50 Ok(Self {
51 data: flat,
52 n_features,
53 n_samples,
54 feature_names,
55 sample_names,
56 })
57 }
58
59 pub fn zeros(
61 n_features: usize,
62 n_samples: usize,
63 feature_names: Vec<String>,
64 sample_names: Vec<String>,
65 ) -> Result<Self> {
66 if feature_names.len() != n_features {
67 return Err(CyaneaError::InvalidInput(format!(
68 "feature_names length ({}) does not match n_features ({n_features})",
69 feature_names.len()
70 )));
71 }
72 if sample_names.len() != n_samples {
73 return Err(CyaneaError::InvalidInput(format!(
74 "sample_names length ({}) does not match n_samples ({n_samples})",
75 sample_names.len()
76 )));
77 }
78 Ok(Self {
79 data: vec![0.0; n_features * n_samples],
80 n_features,
81 n_samples,
82 feature_names,
83 sample_names,
84 })
85 }
86
87 pub fn shape(&self) -> (usize, usize) {
89 (self.n_features, self.n_samples)
90 }
91
92 pub fn get(&self, feature_idx: usize, sample_idx: usize) -> Option<f64> {
94 if feature_idx < self.n_features && sample_idx < self.n_samples {
95 Some(self.data[feature_idx * self.n_samples + sample_idx])
96 } else {
97 None
98 }
99 }
100
101 pub fn set(&mut self, feature_idx: usize, sample_idx: usize, value: f64) -> Result<()> {
103 if feature_idx >= self.n_features || sample_idx >= self.n_samples {
104 return Err(CyaneaError::InvalidInput(format!(
105 "index ({feature_idx}, {sample_idx}) out of bounds for ({}, {})",
106 self.n_features, self.n_samples
107 )));
108 }
109 self.data[feature_idx * self.n_samples + sample_idx] = value;
110 Ok(())
111 }
112
113 pub fn row(&self, feature_idx: usize) -> Option<&[f64]> {
115 if feature_idx < self.n_features {
116 let start = feature_idx * self.n_samples;
117 Some(&self.data[start..start + self.n_samples])
118 } else {
119 None
120 }
121 }
122
123 pub fn column(&self, sample_idx: usize) -> Option<Vec<f64>> {
125 if sample_idx >= self.n_samples {
126 return None;
127 }
128 let col: Vec<f64> = (0..self.n_features)
129 .map(|r| self.data[r * self.n_samples + sample_idx])
130 .collect();
131 Some(col)
132 }
133
134 pub fn row_mean(&self, feature_idx: usize) -> Option<f64> {
136 let row = self.row(feature_idx)?;
137 if row.is_empty() {
138 return Some(0.0);
139 }
140 Some(row.iter().sum::<f64>() / row.len() as f64)
141 }
142
143 pub fn column_mean(&self, sample_idx: usize) -> Option<f64> {
145 let col = self.column(sample_idx)?;
146 if col.is_empty() {
147 return Some(0.0);
148 }
149 Some(col.iter().sum::<f64>() / col.len() as f64)
150 }
151
152 pub fn transpose(&self) -> ExpressionMatrix {
154 let mut transposed = vec![0.0; self.data.len()];
155 for r in 0..self.n_features {
156 for c in 0..self.n_samples {
157 transposed[c * self.n_features + r] = self.data[r * self.n_samples + c];
158 }
159 }
160 ExpressionMatrix {
161 data: transposed,
162 n_features: self.n_samples,
163 n_samples: self.n_features,
164 feature_names: self.sample_names.clone(),
165 sample_names: self.feature_names.clone(),
166 }
167 }
168
169 pub fn filter_features(&self, indices: &[usize]) -> Result<ExpressionMatrix> {
171 let mut data = Vec::with_capacity(indices.len() * self.n_samples);
172 let mut names = Vec::with_capacity(indices.len());
173
174 for &i in indices {
175 if i >= self.n_features {
176 return Err(CyaneaError::InvalidInput(format!(
177 "feature index {i} out of bounds (n_features={})",
178 self.n_features
179 )));
180 }
181 let start = i * self.n_samples;
182 data.extend_from_slice(&self.data[start..start + self.n_samples]);
183 names.push(self.feature_names[i].clone());
184 }
185
186 Ok(ExpressionMatrix {
187 data,
188 n_features: indices.len(),
189 n_samples: self.n_samples,
190 feature_names: names,
191 sample_names: self.sample_names.clone(),
192 })
193 }
194
195 pub fn filter_samples(&self, indices: &[usize]) -> Result<ExpressionMatrix> {
197 for &i in indices {
198 if i >= self.n_samples {
199 return Err(CyaneaError::InvalidInput(format!(
200 "sample index {i} out of bounds (n_samples={})",
201 self.n_samples
202 )));
203 }
204 }
205
206 let mut data = Vec::with_capacity(self.n_features * indices.len());
207 let mut names = Vec::with_capacity(indices.len());
208
209 for &i in indices {
210 names.push(self.sample_names[i].clone());
211 }
212
213 for r in 0..self.n_features {
214 for &c in indices {
215 data.push(self.data[r * self.n_samples + c]);
216 }
217 }
218
219 Ok(ExpressionMatrix {
220 data,
221 n_features: self.n_features,
222 n_samples: indices.len(),
223 feature_names: self.feature_names.clone(),
224 sample_names: names,
225 })
226 }
227
228 pub fn as_slice(&self) -> &[f64] {
230 &self.data
231 }
232
233 pub fn feature_names(&self) -> &[String] {
235 &self.feature_names
236 }
237
238 pub fn sample_names(&self) -> &[String] {
240 &self.sample_names
241 }
242
243 pub fn log_transform(&self, pseudocount: f64) -> ExpressionMatrix {
247 let data: Vec<f64> = self
248 .data
249 .iter()
250 .map(|&x| (x + pseudocount).log2())
251 .collect();
252 ExpressionMatrix {
253 data,
254 n_features: self.n_features,
255 n_samples: self.n_samples,
256 feature_names: self.feature_names.clone(),
257 sample_names: self.sample_names.clone(),
258 }
259 }
260}
261
262impl Summarizable for ExpressionMatrix {
263 fn summary(&self) -> String {
264 format!(
265 "ExpressionMatrix: {} features \u{00d7} {} samples",
266 self.n_features, self.n_samples
267 )
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 fn sample_matrix() -> ExpressionMatrix {
276 ExpressionMatrix::new(
277 vec![
278 vec![1.0, 2.0, 3.0],
279 vec![4.0, 5.0, 6.0],
280 ],
281 vec!["gene1".into(), "gene2".into()],
282 vec!["s1".into(), "s2".into(), "s3".into()],
283 )
284 .unwrap()
285 }
286
287 #[test]
288 fn test_construction() {
289 let m = sample_matrix();
290 assert_eq!(m.shape(), (2, 3));
291 }
292
293 #[test]
294 fn test_dimension_mismatch() {
295 let result = ExpressionMatrix::new(
296 vec![vec![1.0, 2.0]],
297 vec!["gene1".into(), "gene2".into()], vec!["s1".into(), "s2".into()],
299 );
300 assert!(result.is_err());
301 }
302
303 #[test]
304 fn test_row_length_mismatch() {
305 let result = ExpressionMatrix::new(
306 vec![vec![1.0, 2.0], vec![3.0]], vec!["gene1".into(), "gene2".into()],
308 vec!["s1".into(), "s2".into()],
309 );
310 assert!(result.is_err());
311 }
312
313 #[test]
314 fn test_zeros() {
315 let m = ExpressionMatrix::zeros(
316 2,
317 3,
318 vec!["a".into(), "b".into()],
319 vec!["x".into(), "y".into(), "z".into()],
320 )
321 .unwrap();
322 assert_eq!(m.get(0, 0), Some(0.0));
323 assert_eq!(m.get(1, 2), Some(0.0));
324 }
325
326 #[test]
327 fn test_get_set() {
328 let mut m = sample_matrix();
329 assert_eq!(m.get(0, 0), Some(1.0));
330 assert_eq!(m.get(1, 2), Some(6.0));
331 assert_eq!(m.get(2, 0), None);
332
333 m.set(0, 0, 99.0).unwrap();
334 assert_eq!(m.get(0, 0), Some(99.0));
335 assert!(m.set(5, 0, 1.0).is_err());
336 }
337
338 #[test]
339 fn test_row() {
340 let m = sample_matrix();
341 assert_eq!(m.row(0), Some(&[1.0, 2.0, 3.0][..]));
342 assert_eq!(m.row(1), Some(&[4.0, 5.0, 6.0][..]));
343 assert_eq!(m.row(2), None);
344 }
345
346 #[test]
347 fn test_column() {
348 let m = sample_matrix();
349 assert_eq!(m.column(0), Some(vec![1.0, 4.0]));
350 assert_eq!(m.column(2), Some(vec![3.0, 6.0]));
351 assert_eq!(m.column(3), None);
352 }
353
354 #[test]
355 fn test_row_mean() {
356 let m = sample_matrix();
357 assert_eq!(m.row_mean(0), Some(2.0)); assert_eq!(m.row_mean(1), Some(5.0)); }
360
361 #[test]
362 fn test_column_mean() {
363 let m = sample_matrix();
364 assert_eq!(m.column_mean(0), Some(2.5)); assert_eq!(m.column_mean(1), Some(3.5)); }
367
368 #[test]
369 fn test_transpose() {
370 let m = sample_matrix();
371 let t = m.transpose();
372 assert_eq!(t.shape(), (3, 2));
373 assert_eq!(t.get(0, 0), Some(1.0));
374 assert_eq!(t.get(0, 1), Some(4.0));
375 assert_eq!(t.get(2, 1), Some(6.0));
376 }
377
378 #[test]
379 fn test_filter_features() {
380 let m = sample_matrix();
381 let filtered = m.filter_features(&[1]).unwrap();
382 assert_eq!(filtered.shape(), (1, 3));
383 assert_eq!(filtered.get(0, 0), Some(4.0));
384
385 assert!(m.filter_features(&[5]).is_err());
386 }
387
388 #[test]
389 fn test_filter_samples() {
390 let m = sample_matrix();
391 let filtered = m.filter_samples(&[0, 2]).unwrap();
392 assert_eq!(filtered.shape(), (2, 2));
393 assert_eq!(filtered.get(0, 0), Some(1.0));
394 assert_eq!(filtered.get(0, 1), Some(3.0));
395 assert_eq!(filtered.get(1, 0), Some(4.0));
396
397 assert!(m.filter_samples(&[5]).is_err());
398 }
399
400 #[test]
401 fn test_log_transform() {
402 let m = sample_matrix();
403 let logged = m.log_transform(1.0);
404 assert!((logged.get(0, 0).unwrap() - 1.0).abs() < 1e-10);
406 assert!((logged.get(1, 0).unwrap() - 5.0_f64.log2()).abs() < 1e-10);
408 }
409
410 #[test]
411 fn test_summary() {
412 let m = sample_matrix();
413 assert_eq!(m.summary(), "ExpressionMatrix: 2 features \u{00d7} 3 samples");
414 }
415
416 #[test]
417 fn test_empty_matrix() {
418 let m = ExpressionMatrix::new(
419 vec![],
420 vec![],
421 vec!["s1".into()],
422 )
423 .unwrap();
424 assert_eq!(m.shape(), (0, 1));
425 }
426
427 #[test]
428 fn test_as_slice() {
429 let m = sample_matrix();
430 assert_eq!(m.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
431 }
432
433 #[test]
434 fn test_feature_names() {
435 let m = sample_matrix();
436 assert_eq!(m.feature_names(), &["gene1", "gene2"]);
437 }
438
439 #[test]
440 fn test_sample_names() {
441 let m = sample_matrix();
442 assert_eq!(m.sample_names(), &["s1", "s2", "s3"]);
443 }
444}