1use std::collections::HashMap;
30
31use cyanea_core::{CyaneaError, Result, Summarizable};
32
33use crate::sparse::SparseMatrix;
34
35#[derive(Debug, Clone)]
37pub enum MatrixData {
38 Dense(Vec<Vec<f64>>),
40 Sparse(SparseMatrix),
42}
43
44impl MatrixData {
45 pub fn shape(&self) -> (usize, usize) {
47 match self {
48 MatrixData::Dense(rows) => {
49 let n_obs = rows.len();
50 let n_vars = rows.first().map_or(0, |r| r.len());
51 (n_obs, n_vars)
52 }
53 MatrixData::Sparse(s) => s.shape(),
54 }
55 }
56
57 pub fn get(&self, obs: usize, var: usize) -> f64 {
59 match self {
60 MatrixData::Dense(rows) => {
61 rows.get(obs).and_then(|r| r.get(var)).copied().unwrap_or(0.0)
62 }
63 MatrixData::Sparse(s) => s.get(obs, var),
64 }
65 }
66
67 pub fn set(&mut self, obs: usize, var: usize, val: f64) {
72 match self {
73 MatrixData::Dense(rows) => {
74 if let Some(row) = rows.get_mut(obs) {
75 if let Some(cell) = row.get_mut(var) {
76 *cell = val;
77 }
78 }
79 }
80 MatrixData::Sparse(s) => {
81 let _ = s.insert(obs, var, val);
82 }
83 }
84 }
85
86 pub fn column_sums(&self) -> Vec<f64> {
88 match self {
89 MatrixData::Dense(rows) => {
90 let n_vars = rows.first().map_or(0, |r| r.len());
91 let mut sums = vec![0.0; n_vars];
92 for row in rows {
93 for (j, &v) in row.iter().enumerate() {
94 sums[j] += v;
95 }
96 }
97 sums
98 }
99 MatrixData::Sparse(s) => s.column_sums(),
100 }
101 }
102
103 pub fn column_means(&self) -> Vec<f64> {
105 match self {
106 MatrixData::Dense(rows) => {
107 let n_obs = rows.len();
108 if n_obs == 0 {
109 return vec![];
110 }
111 let sums = self.column_sums();
112 let n = n_obs as f64;
113 sums.into_iter().map(|s| s / n).collect()
114 }
115 MatrixData::Sparse(s) => s.column_means(),
116 }
117 }
118
119 pub fn row_sums(&self) -> Vec<f64> {
121 match self {
122 MatrixData::Dense(rows) => rows.iter().map(|r| r.iter().sum()).collect(),
123 MatrixData::Sparse(s) => s.row_sums(),
124 }
125 }
126
127 pub fn to_flat_row_major(&self) -> Vec<f64> {
129 let (n_obs, n_vars) = self.shape();
130 match self {
131 MatrixData::Dense(rows) => {
132 let mut flat = Vec::with_capacity(n_obs * n_vars);
133 for row in rows {
134 flat.extend_from_slice(row);
135 }
136 flat
137 }
138 MatrixData::Sparse(s) => {
139 let mut flat = vec![0.0; n_obs * n_vars];
140 for (r, c, v) in s.iter() {
141 flat[r * n_vars + c] = v;
142 }
143 flat
144 }
145 }
146 }
147}
148
149#[derive(Debug, Clone, PartialEq)]
153pub enum ColumnData {
154 Strings(Vec<String>),
156 Numeric(Vec<f64>),
158 Categorical {
160 codes: Vec<i32>,
161 categories: Vec<String>,
162 },
163}
164
165impl ColumnData {
166 pub fn len(&self) -> usize {
168 match self {
169 ColumnData::Strings(v) => v.len(),
170 ColumnData::Numeric(v) => v.len(),
171 ColumnData::Categorical { codes, .. } => codes.len(),
172 }
173 }
174
175 pub fn is_empty(&self) -> bool {
177 self.len() == 0
178 }
179
180 pub fn as_strings(&self) -> Option<&Vec<String>> {
182 match self {
183 ColumnData::Strings(v) => Some(v),
184 _ => None,
185 }
186 }
187
188 pub fn as_numeric(&self) -> Option<&Vec<f64>> {
190 match self {
191 ColumnData::Numeric(v) => Some(v),
192 _ => None,
193 }
194 }
195
196 fn subset(&self, indices: &[usize]) -> Self {
198 match self {
199 ColumnData::Strings(v) => {
200 ColumnData::Strings(indices.iter().map(|&i| v[i].clone()).collect())
201 }
202 ColumnData::Numeric(v) => {
203 ColumnData::Numeric(indices.iter().map(|&i| v[i]).collect())
204 }
205 ColumnData::Categorical { codes, categories } => ColumnData::Categorical {
206 codes: indices.iter().map(|&i| codes[i]).collect(),
207 categories: categories.clone(),
208 },
209 }
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct QcMetrics {
216 pub total_counts: Vec<f64>,
218 pub n_features: Vec<usize>,
220}
221
222#[derive(Debug, Clone)]
224pub struct AnnData {
225 x: MatrixData,
227 obs_names: Vec<String>,
229 var_names: Vec<String>,
231 obs: HashMap<String, ColumnData>,
233 var: HashMap<String, ColumnData>,
235 obsm: HashMap<String, Vec<Vec<f64>>>,
237 varm: HashMap<String, Vec<Vec<f64>>>,
239 layers: HashMap<String, MatrixData>,
241 obsp: HashMap<String, SparseMatrix>,
243 uns: HashMap<String, String>,
245}
246
247impl AnnData {
248 pub fn new(
254 x: MatrixData,
255 obs_names: Vec<String>,
256 var_names: Vec<String>,
257 ) -> Result<Self> {
258 let (n_obs, n_vars) = x.shape();
259 if obs_names.len() != n_obs {
260 return Err(CyaneaError::InvalidInput(format!(
261 "obs_names length ({}) does not match n_obs ({})",
262 obs_names.len(),
263 n_obs
264 )));
265 }
266 if var_names.len() != n_vars {
267 return Err(CyaneaError::InvalidInput(format!(
268 "var_names length ({}) does not match n_vars ({})",
269 var_names.len(),
270 n_vars
271 )));
272 }
273
274 Ok(Self {
275 x,
276 obs_names,
277 var_names,
278 obs: HashMap::new(),
279 var: HashMap::new(),
280 obsm: HashMap::new(),
281 varm: HashMap::new(),
282 layers: HashMap::new(),
283 obsp: HashMap::new(),
284 uns: HashMap::new(),
285 })
286 }
287
288 pub fn n_obs(&self) -> usize {
290 self.obs_names.len()
291 }
292
293 pub fn n_vars(&self) -> usize {
295 self.var_names.len()
296 }
297
298 pub fn shape(&self) -> (usize, usize) {
300 self.x.shape()
301 }
302
303 pub fn x(&self) -> &MatrixData {
305 &self.x
306 }
307
308 pub fn obs_names(&self) -> &[String] {
310 &self.obs_names
311 }
312
313 pub fn var_names(&self) -> &[String] {
315 &self.var_names
316 }
317
318 pub fn add_obs(&mut self, key: &str, values: Vec<String>) -> Result<()> {
320 self.add_obs_column(key, ColumnData::Strings(values))
321 }
322
323 pub fn add_obs_numeric(&mut self, key: &str, values: Vec<f64>) -> Result<()> {
325 self.add_obs_column(key, ColumnData::Numeric(values))
326 }
327
328 pub fn add_obs_column(&mut self, key: &str, data: ColumnData) -> Result<()> {
330 if data.len() != self.n_obs() {
331 return Err(CyaneaError::InvalidInput(format!(
332 "obs '{}' length ({}) does not match n_obs ({})",
333 key,
334 data.len(),
335 self.n_obs()
336 )));
337 }
338 self.obs.insert(key.to_string(), data);
339 Ok(())
340 }
341
342 pub fn get_obs(&self, key: &str) -> Option<&ColumnData> {
344 self.obs.get(key)
345 }
346
347 pub fn get_obs_strings(&self, key: &str) -> Option<&Vec<String>> {
349 self.obs.get(key).and_then(|c| c.as_strings())
350 }
351
352 pub fn obs_columns(&self) -> &HashMap<String, ColumnData> {
354 &self.obs
355 }
356
357 pub fn add_var(&mut self, key: &str, values: Vec<String>) -> Result<()> {
359 self.add_var_column(key, ColumnData::Strings(values))
360 }
361
362 pub fn add_var_numeric(&mut self, key: &str, values: Vec<f64>) -> Result<()> {
364 self.add_var_column(key, ColumnData::Numeric(values))
365 }
366
367 pub fn add_var_column(&mut self, key: &str, data: ColumnData) -> Result<()> {
369 if data.len() != self.n_vars() {
370 return Err(CyaneaError::InvalidInput(format!(
371 "var '{}' length ({}) does not match n_vars ({})",
372 key,
373 data.len(),
374 self.n_vars()
375 )));
376 }
377 self.var.insert(key.to_string(), data);
378 Ok(())
379 }
380
381 pub fn get_var(&self, key: &str) -> Option<&ColumnData> {
383 self.var.get(key)
384 }
385
386 pub fn get_var_strings(&self, key: &str) -> Option<&Vec<String>> {
388 self.var.get(key).and_then(|c| c.as_strings())
389 }
390
391 pub fn var_columns(&self) -> &HashMap<String, ColumnData> {
393 &self.var
394 }
395
396 pub fn add_obsm(&mut self, key: &str, data: Vec<Vec<f64>>) -> Result<()> {
398 if data.len() != self.n_obs() {
399 return Err(CyaneaError::InvalidInput(format!(
400 "obsm '{}' length ({}) does not match n_obs ({})",
401 key,
402 data.len(),
403 self.n_obs()
404 )));
405 }
406 self.obsm.insert(key.to_string(), data);
407 Ok(())
408 }
409
410 pub fn get_obsm(&self, key: &str) -> Option<&Vec<Vec<f64>>> {
412 self.obsm.get(key)
413 }
414
415 pub fn add_varm(&mut self, key: &str, data: Vec<Vec<f64>>) -> Result<()> {
417 if data.len() != self.n_vars() {
418 return Err(CyaneaError::InvalidInput(format!(
419 "varm '{}' length ({}) does not match n_vars ({})",
420 key,
421 data.len(),
422 self.n_vars()
423 )));
424 }
425 self.varm.insert(key.to_string(), data);
426 Ok(())
427 }
428
429 pub fn get_varm(&self, key: &str) -> Option<&Vec<Vec<f64>>> {
431 self.varm.get(key)
432 }
433
434 pub fn add_layer(&mut self, key: &str, layer: MatrixData) -> Result<()> {
436 let (n_obs, n_vars) = layer.shape();
437 if n_obs != self.n_obs() || n_vars != self.n_vars() {
438 return Err(CyaneaError::InvalidInput(format!(
439 "layer '{}' shape ({}, {}) does not match ({}, {})",
440 key,
441 n_obs,
442 n_vars,
443 self.n_obs(),
444 self.n_vars()
445 )));
446 }
447 self.layers.insert(key.to_string(), layer);
448 Ok(())
449 }
450
451 pub fn get_layer(&self, key: &str) -> Option<&MatrixData> {
453 self.layers.get(key)
454 }
455
456 pub fn obsm_keys(&self) -> &HashMap<String, Vec<Vec<f64>>> {
458 &self.obsm
459 }
460
461 pub fn varm_keys(&self) -> &HashMap<String, Vec<Vec<f64>>> {
463 &self.varm
464 }
465
466 pub fn layers_keys(&self) -> &HashMap<String, MatrixData> {
468 &self.layers
469 }
470
471 pub fn x_mut(&mut self) -> &mut MatrixData {
473 &mut self.x
474 }
475
476 pub fn set_x(&mut self, new_x: MatrixData) -> Result<()> {
478 let (n_obs, n_vars) = new_x.shape();
479 if n_obs != self.n_obs() || n_vars != self.n_vars() {
480 return Err(CyaneaError::InvalidInput(format!(
481 "new X shape ({}, {}) does not match ({}, {})",
482 n_obs, n_vars, self.n_obs(), self.n_vars()
483 )));
484 }
485 self.x = new_x;
486 Ok(())
487 }
488
489 pub fn subset_vars(&self, indices: &[usize]) -> Result<AnnData> {
491 for &i in indices {
492 if i >= self.n_vars() {
493 return Err(CyaneaError::InvalidInput(format!(
494 "var index {} out of bounds (n_vars={})",
495 i, self.n_vars()
496 )));
497 }
498 }
499
500 let x = subset_matrix_cols(&self.x, indices, self.n_obs());
501 let var_names: Vec<String> = indices.iter().map(|&i| self.var_names[i].clone()).collect();
502
503 let mut adata = AnnData::new(x, self.obs_names.clone(), var_names)?;
504
505 adata.obs = self.obs.clone();
507 for (key, col) in &self.var {
509 adata.var.insert(key.clone(), col.subset(indices));
510 }
511 adata.obsm = self.obsm.clone();
513 for (key, layer) in &self.layers {
515 let sub = subset_matrix_cols(layer, indices, self.n_obs());
516 adata.layers.insert(key.clone(), sub);
517 }
518 adata.obsp = self.obsp.clone();
520 adata.uns = self.uns.clone();
521
522 Ok(adata)
523 }
524
525 pub fn add_obsp(&mut self, key: &str, matrix: SparseMatrix) -> Result<()> {
529 let (r, c) = matrix.shape();
530 if r != self.n_obs() || c != self.n_obs() {
531 return Err(CyaneaError::InvalidInput(format!(
532 "obsp '{}' shape ({}, {}) does not match n_obs ({})",
533 key, r, c, self.n_obs()
534 )));
535 }
536 self.obsp.insert(key.to_string(), matrix);
537 Ok(())
538 }
539
540 pub fn get_obsp(&self, key: &str) -> Option<&SparseMatrix> {
542 self.obsp.get(key)
543 }
544
545 pub fn add_uns(&mut self, key: &str, value: String) {
547 self.uns.insert(key.to_string(), value);
548 }
549
550 pub fn get_uns(&self, key: &str) -> Option<&str> {
552 self.uns.get(key).map(|s| s.as_str())
553 }
554
555 pub fn obsp_keys(&self) -> &HashMap<String, SparseMatrix> {
557 &self.obsp
558 }
559
560 pub fn uns_keys(&self) -> &HashMap<String, String> {
562 &self.uns
563 }
564
565 pub fn get_layer_mut(&mut self, key: &str) -> Option<&mut MatrixData> {
567 self.layers.get_mut(key)
568 }
569
570 pub fn subset_obs(&self, indices: &[usize]) -> Result<AnnData> {
572 for &i in indices {
573 if i >= self.n_obs() {
574 return Err(CyaneaError::InvalidInput(format!(
575 "obs index {} out of bounds (n_obs={})",
576 i,
577 self.n_obs()
578 )));
579 }
580 }
581
582 let x = subset_matrix_rows(&self.x, indices, self.n_vars());
583 let obs_names: Vec<String> = indices.iter().map(|&i| self.obs_names[i].clone()).collect();
584
585 let mut adata = AnnData::new(x, obs_names, self.var_names.clone())?;
586
587 for (key, col) in &self.obs {
589 adata.obs.insert(key.clone(), col.subset(indices));
590 }
591 adata.var = self.var.clone();
593
594 for (key, data) in &self.obsm {
596 let sub: Vec<Vec<f64>> = indices.iter().map(|&i| data[i].clone()).collect();
597 adata.obsm.insert(key.clone(), sub);
598 }
599 adata.varm = self.varm.clone();
600 adata.uns = self.uns.clone();
601
602 Ok(adata)
603 }
604
605 pub fn qc_metrics(&self) -> QcMetrics {
607 let n = self.n_obs();
608 let p = self.n_vars();
609 let mut total_counts = vec![0.0; n];
610 let mut n_features = vec![0usize; n];
611
612 for i in 0..n {
613 for j in 0..p {
614 let v = self.x.get(i, j);
615 total_counts[i] += v;
616 if v > 0.0 {
617 n_features[i] += 1;
618 }
619 }
620 }
621
622 QcMetrics {
623 total_counts,
624 n_features,
625 }
626 }
627}
628
629fn subset_matrix_cols(x: &MatrixData, col_indices: &[usize], n_obs: usize) -> MatrixData {
630 match x {
631 MatrixData::Dense(rows) => {
632 let sub: Vec<Vec<f64>> = rows
633 .iter()
634 .map(|row| col_indices.iter().map(|&j| row[j]).collect())
635 .collect();
636 MatrixData::Dense(sub)
637 }
638 MatrixData::Sparse(s) => {
639 let n_new_cols = col_indices.len();
640 let mut new_s = SparseMatrix::new(n_obs, n_new_cols);
641 let mut col_map = HashMap::new();
643 for (new_j, &old_j) in col_indices.iter().enumerate() {
644 col_map.insert(old_j, new_j);
645 }
646 for (r, c, v) in s.iter() {
647 if let Some(&new_c) = col_map.get(&c) {
648 let _ = new_s.insert(r, new_c, v);
649 }
650 }
651 MatrixData::Sparse(new_s)
652 }
653 }
654}
655
656fn subset_matrix_rows(x: &MatrixData, indices: &[usize], n_vars: usize) -> MatrixData {
657 match x {
658 MatrixData::Dense(rows) => {
659 let sub: Vec<Vec<f64>> = indices.iter().map(|&i| rows[i].clone()).collect();
660 MatrixData::Dense(sub)
661 }
662 MatrixData::Sparse(s) => {
663 let n_new = indices.len();
664 let mut new_s = SparseMatrix::new(n_new, n_vars);
665 for (new_row, &old_row) in indices.iter().enumerate() {
666 for j in 0..n_vars {
667 let v = s.get(old_row, j);
668 if v != 0.0 {
669 let _ = new_s.insert(new_row, j, v);
670 }
671 }
672 }
673 MatrixData::Sparse(new_s)
674 }
675 }
676}
677
678impl Summarizable for AnnData {
679 fn summary(&self) -> String {
680 format!(
681 "AnnData: {} obs \u{00d7} {} vars, {} layers, {} obsm, {} varm, {} obsp, {} uns",
682 self.n_obs(),
683 self.n_vars(),
684 self.layers.len(),
685 self.obsm.len(),
686 self.varm.len(),
687 self.obsp.len(),
688 self.uns.len(),
689 )
690 }
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 fn sample_adata() -> AnnData {
698 let x = MatrixData::Dense(vec![
699 vec![1.0, 2.0, 0.0],
700 vec![3.0, 0.0, 4.0],
701 vec![0.0, 5.0, 6.0],
702 ]);
703 AnnData::new(
704 x,
705 vec!["cell_1".into(), "cell_2".into(), "cell_3".into()],
706 vec!["gene_a".into(), "gene_b".into(), "gene_c".into()],
707 )
708 .unwrap()
709 }
710
711 #[test]
712 fn basic_construction() {
713 let adata = sample_adata();
714 assert_eq!(adata.n_obs(), 3);
715 assert_eq!(adata.n_vars(), 3);
716 assert_eq!(adata.shape(), (3, 3));
717 }
718
719 #[test]
720 fn dimension_mismatch_error() {
721 let x = MatrixData::Dense(vec![vec![1.0, 2.0]]);
722 let result = AnnData::new(
723 x,
724 vec!["cell_1".into(), "cell_2".into()], vec!["gene_a".into(), "gene_b".into()],
726 );
727 assert!(result.is_err());
728 }
729
730 #[test]
731 fn obs_metadata() {
732 let mut adata = sample_adata();
733 adata
734 .add_obs(
735 "cell_type",
736 vec!["T-cell".into(), "B-cell".into(), "NK".into()],
737 )
738 .unwrap();
739 let ct = adata.get_obs_strings("cell_type").unwrap();
740 assert_eq!(ct[0], "T-cell");
741 assert!(adata.get_obs("missing").is_none());
742 }
743
744 #[test]
745 fn obs_metadata_length_mismatch() {
746 let mut adata = sample_adata();
747 let result = adata.add_obs("bad", vec!["a".into()]);
748 assert!(result.is_err());
749 }
750
751 #[test]
752 fn var_metadata() {
753 let mut adata = sample_adata();
754 adata
755 .add_var(
756 "gene_type",
757 vec!["coding".into(), "coding".into(), "lncRNA".into()],
758 )
759 .unwrap();
760 let gt = adata.get_var_strings("gene_type").unwrap();
761 assert_eq!(gt[2], "lncRNA");
762 }
763
764 #[test]
765 fn obsm_embedding() {
766 let mut adata = sample_adata();
767 let pca = vec![vec![0.1, 0.2], vec![0.3, 0.4], vec![0.5, 0.6]];
768 adata.add_obsm("X_pca", pca).unwrap();
769 let emb = adata.get_obsm("X_pca").unwrap();
770 assert_eq!(emb.len(), 3);
771 assert_eq!(emb[0], vec![0.1, 0.2]);
772 }
773
774 #[test]
775 fn layers() {
776 let mut adata = sample_adata();
777 let raw = MatrixData::Dense(vec![
778 vec![10.0, 20.0, 0.0],
779 vec![30.0, 0.0, 40.0],
780 vec![0.0, 50.0, 60.0],
781 ]);
782 adata.add_layer("raw_counts", raw).unwrap();
783 let layer = adata.get_layer("raw_counts").unwrap();
784 assert_eq!(layer.get(0, 0), 10.0);
785 }
786
787 #[test]
788 fn layer_shape_mismatch() {
789 let mut adata = sample_adata();
790 let bad = MatrixData::Dense(vec![vec![1.0]]);
791 assert!(adata.add_layer("bad", bad).is_err());
792 }
793
794 #[test]
795 fn subset_obs() {
796 let mut adata = sample_adata();
797 adata
798 .add_obs(
799 "label",
800 vec!["a".into(), "b".into(), "c".into()],
801 )
802 .unwrap();
803 let sub = adata.subset_obs(&[0, 2]).unwrap();
804 assert_eq!(sub.n_obs(), 2);
805 assert_eq!(sub.n_vars(), 3);
806 assert_eq!(sub.obs_names(), &["cell_1", "cell_3"]);
807 let labels = sub.get_obs_strings("label").unwrap();
808 assert_eq!(labels, &["a", "c"]);
809 }
810
811 #[test]
812 fn qc_metrics() {
813 let adata = sample_adata();
814 let qc = adata.qc_metrics();
815 assert_eq!(qc.total_counts, vec![3.0, 7.0, 11.0]);
816 assert_eq!(qc.n_features, vec![2, 2, 2]);
817 }
818
819 #[test]
820 fn sparse_x() {
821 let s = SparseMatrix::from_triplets(
822 vec![0, 1],
823 vec![0, 1],
824 vec![5.0, 10.0],
825 2,
826 2,
827 )
828 .unwrap();
829 let x = MatrixData::Sparse(s);
830 let adata = AnnData::new(
831 x,
832 vec!["c1".into(), "c2".into()],
833 vec!["g1".into(), "g2".into()],
834 )
835 .unwrap();
836 assert_eq!(adata.x().get(0, 0), 5.0);
837 assert_eq!(adata.x().get(0, 1), 0.0);
838 }
839
840 #[test]
841 fn summary_format() {
842 let adata = sample_adata();
843 let s = adata.summary();
844 assert!(s.contains("3 obs"));
845 assert!(s.contains("3 vars"));
846 assert!(s.contains("0 obsp"));
847 assert!(s.contains("0 uns"));
848 }
849
850 #[test]
851 fn matrix_data_set_dense() {
852 let mut x = MatrixData::Dense(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
853 x.set(0, 1, 99.0);
854 assert_eq!(x.get(0, 1), 99.0);
855 assert_eq!(x.get(0, 0), 1.0);
856 }
857
858 #[test]
859 fn matrix_data_set_sparse() {
860 let s = SparseMatrix::new(2, 2);
861 let mut x = MatrixData::Sparse(s);
862 x.set(0, 1, 5.0);
863 assert_eq!(x.get(0, 1), 5.0);
864 }
865
866 #[test]
867 fn matrix_data_column_sums_dense() {
868 let x = MatrixData::Dense(vec![
869 vec![1.0, 2.0, 3.0],
870 vec![4.0, 5.0, 6.0],
871 ]);
872 assert_eq!(x.column_sums(), vec![5.0, 7.0, 9.0]);
873 }
874
875 #[test]
876 fn matrix_data_column_means_dense() {
877 let x = MatrixData::Dense(vec![
878 vec![2.0, 4.0],
879 vec![6.0, 8.0],
880 ]);
881 let means = x.column_means();
882 assert!((means[0] - 4.0).abs() < 1e-10);
883 assert!((means[1] - 6.0).abs() < 1e-10);
884 }
885
886 #[test]
887 fn matrix_data_row_sums_dense() {
888 let x = MatrixData::Dense(vec![
889 vec![1.0, 2.0, 3.0],
890 vec![4.0, 5.0, 6.0],
891 ]);
892 assert_eq!(x.row_sums(), vec![6.0, 15.0]);
893 }
894
895 #[test]
896 fn matrix_data_to_flat_row_major_dense() {
897 let x = MatrixData::Dense(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
898 assert_eq!(x.to_flat_row_major(), vec![1.0, 2.0, 3.0, 4.0]);
899 }
900
901 #[test]
902 fn matrix_data_to_flat_row_major_sparse() {
903 let s = SparseMatrix::from_triplets(
904 vec![0, 1],
905 vec![1, 0],
906 vec![2.0, 3.0],
907 2,
908 2,
909 )
910 .unwrap();
911 let x = MatrixData::Sparse(s);
912 assert_eq!(x.to_flat_row_major(), vec![0.0, 2.0, 3.0, 0.0]);
913 }
914
915 #[test]
916 fn x_mut_modify() {
917 let mut adata = sample_adata();
918 adata.x_mut().set(0, 0, 42.0);
919 assert_eq!(adata.x().get(0, 0), 42.0);
920 }
921
922 #[test]
923 fn set_x_valid() {
924 let mut adata = sample_adata();
925 let new_x = MatrixData::Dense(vec![
926 vec![10.0, 20.0, 30.0],
927 vec![40.0, 50.0, 60.0],
928 vec![70.0, 80.0, 90.0],
929 ]);
930 adata.set_x(new_x).unwrap();
931 assert_eq!(adata.x().get(0, 0), 10.0);
932 }
933
934 #[test]
935 fn set_x_shape_mismatch() {
936 let mut adata = sample_adata();
937 let bad = MatrixData::Dense(vec![vec![1.0]]);
938 assert!(adata.set_x(bad).is_err());
939 }
940
941 #[test]
942 fn subset_vars_basic() {
943 let mut adata = sample_adata();
944 adata.add_var("type", vec!["a".into(), "b".into(), "c".into()]).unwrap();
945 let sub = adata.subset_vars(&[0, 2]).unwrap();
946 assert_eq!(sub.n_vars(), 2);
947 assert_eq!(sub.n_obs(), 3);
948 assert_eq!(sub.var_names(), &["gene_a", "gene_c"]);
949 assert_eq!(sub.x().get(0, 0), 1.0);
950 assert_eq!(sub.x().get(0, 1), 0.0); let types = sub.get_var_strings("type").unwrap();
952 assert_eq!(types, &["a", "c"]);
953 }
954
955 #[test]
956 fn subset_vars_out_of_bounds() {
957 let adata = sample_adata();
958 assert!(adata.subset_vars(&[0, 10]).is_err());
959 }
960
961 #[test]
962 fn obsp_add_get() {
963 let mut adata = sample_adata();
964 let mut m = SparseMatrix::new(3, 3);
965 m.insert(0, 1, 0.5).unwrap();
966 m.insert(1, 2, 0.3).unwrap();
967 adata.add_obsp("connectivities", m).unwrap();
968 let conn = adata.get_obsp("connectivities").unwrap();
969 assert_eq!(conn.get(0, 1), 0.5);
970 assert!(adata.get_obsp("missing").is_none());
971 }
972
973 #[test]
974 fn obsp_wrong_shape() {
975 let mut adata = sample_adata();
976 let m = SparseMatrix::new(2, 2); assert!(adata.add_obsp("bad", m).is_err());
978 }
979
980 #[test]
981 fn uns_add_get() {
982 let mut adata = sample_adata();
983 adata.add_uns("method", "leiden".into());
984 assert_eq!(adata.get_uns("method"), Some("leiden"));
985 assert_eq!(adata.get_uns("missing"), None);
986 }
987
988 #[test]
989 fn get_layer_mut() {
990 let mut adata = sample_adata();
991 let raw = MatrixData::Dense(vec![
992 vec![10.0, 20.0, 0.0],
993 vec![30.0, 0.0, 40.0],
994 vec![0.0, 50.0, 60.0],
995 ]);
996 adata.add_layer("counts", raw).unwrap();
997 let layer = adata.get_layer_mut("counts").unwrap();
998 layer.set(0, 0, 99.0);
999 assert_eq!(adata.get_layer("counts").unwrap().get(0, 0), 99.0);
1000 }
1001}