Skip to main content

cyanea_omics/
single_cell.rs

1//! AnnData-like container for single-cell omics data.
2//!
3//! Provides an in-memory representation inspired by the Python AnnData format,
4//! the standard data structure in the scverse ecosystem (scanpy, scvi-tools).
5//!
6//! # Structure
7//!
8//! - `X` — primary data matrix (cells × genes), dense or sparse
9//! - `obs` — per-cell metadata (string key-value pairs)
10//! - `var` — per-gene metadata
11//! - `obsm` / `varm` — multi-dimensional annotations (e.g. PCA, UMAP embeddings)
12//! - `layers` — alternative data matrices (e.g. raw counts, normalized)
13//!
14//! # Example
15//!
16//! ```
17//! use cyanea_omics::single_cell::{AnnData, MatrixData};
18//!
19//! let x = MatrixData::Dense(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
20//! let adata = AnnData::new(
21//!     x,
22//!     vec!["cell_1".into(), "cell_2".into()],
23//!     vec!["gene_a".into(), "gene_b".into()],
24//! ).unwrap();
25//! assert_eq!(adata.n_obs(), 2);
26//! assert_eq!(adata.n_vars(), 2);
27//! ```
28
29use std::collections::HashMap;
30
31use cyanea_core::{CyaneaError, Result, Summarizable};
32
33use crate::sparse::SparseMatrix;
34
35/// The primary data matrix, either dense or sparse.
36#[derive(Debug, Clone)]
37pub enum MatrixData {
38    /// Dense row-major matrix (n_obs × n_vars).
39    Dense(Vec<Vec<f64>>),
40    /// Sparse COO matrix.
41    Sparse(SparseMatrix),
42}
43
44impl MatrixData {
45    /// (n_obs, n_vars).
46    pub fn shape(&self) -> (usize, usize) {
47        match self {
48            MatrixData::Dense(rows) => {
49                let n_obs = rows.len();
50                let n_vars = rows.first().map_or(0, |r| r.len());
51                (n_obs, n_vars)
52            }
53            MatrixData::Sparse(s) => s.shape(),
54        }
55    }
56
57    /// Get a value at (obs_idx, var_idx).
58    pub fn get(&self, obs: usize, var: usize) -> f64 {
59        match self {
60            MatrixData::Dense(rows) => {
61                rows.get(obs).and_then(|r| r.get(var)).copied().unwrap_or(0.0)
62            }
63            MatrixData::Sparse(s) => s.get(obs, var),
64        }
65    }
66
67    /// Set a value at (obs_idx, var_idx).
68    ///
69    /// For dense matrices, sets directly. For sparse matrices, inserts a triplet
70    /// (does not deduplicate — the last-inserted value wins on `get()`).
71    pub fn set(&mut self, obs: usize, var: usize, val: f64) {
72        match self {
73            MatrixData::Dense(rows) => {
74                if let Some(row) = rows.get_mut(obs) {
75                    if let Some(cell) = row.get_mut(var) {
76                        *cell = val;
77                    }
78                }
79            }
80            MatrixData::Sparse(s) => {
81                let _ = s.insert(obs, var, val);
82            }
83        }
84    }
85
86    /// Sum of values in each column.
87    pub fn column_sums(&self) -> Vec<f64> {
88        match self {
89            MatrixData::Dense(rows) => {
90                let n_vars = rows.first().map_or(0, |r| r.len());
91                let mut sums = vec![0.0; n_vars];
92                for row in rows {
93                    for (j, &v) in row.iter().enumerate() {
94                        sums[j] += v;
95                    }
96                }
97                sums
98            }
99            MatrixData::Sparse(s) => s.column_sums(),
100        }
101    }
102
103    /// Mean value of each column.
104    pub fn column_means(&self) -> Vec<f64> {
105        match self {
106            MatrixData::Dense(rows) => {
107                let n_obs = rows.len();
108                if n_obs == 0 {
109                    return vec![];
110                }
111                let sums = self.column_sums();
112                let n = n_obs as f64;
113                sums.into_iter().map(|s| s / n).collect()
114            }
115            MatrixData::Sparse(s) => s.column_means(),
116        }
117    }
118
119    /// Sum of values in each row.
120    pub fn row_sums(&self) -> Vec<f64> {
121        match self {
122            MatrixData::Dense(rows) => rows.iter().map(|r| r.iter().sum()).collect(),
123            MatrixData::Sparse(s) => s.row_sums(),
124        }
125    }
126
127    /// Flatten to a row-major `Vec<f64>` for interop with cyanea-stats/ml.
128    pub fn to_flat_row_major(&self) -> Vec<f64> {
129        let (n_obs, n_vars) = self.shape();
130        match self {
131            MatrixData::Dense(rows) => {
132                let mut flat = Vec::with_capacity(n_obs * n_vars);
133                for row in rows {
134                    flat.extend_from_slice(row);
135                }
136                flat
137            }
138            MatrixData::Sparse(s) => {
139                let mut flat = vec![0.0; n_obs * n_vars];
140                for (r, c, v) in s.iter() {
141                    flat[r * n_vars + c] = v;
142                }
143                flat
144            }
145        }
146    }
147}
148
149/// A metadata column with typed data.
150///
151/// Supports string, numeric, and categorical columns as found in `.h5ad` files.
152#[derive(Debug, Clone, PartialEq)]
153pub enum ColumnData {
154    /// Free-text string values.
155    Strings(Vec<String>),
156    /// Numeric (f64) values.
157    Numeric(Vec<f64>),
158    /// Categorical data stored as integer codes indexing into a category list.
159    Categorical {
160        codes: Vec<i32>,
161        categories: Vec<String>,
162    },
163}
164
165impl ColumnData {
166    /// Number of elements in this column.
167    pub fn len(&self) -> usize {
168        match self {
169            ColumnData::Strings(v) => v.len(),
170            ColumnData::Numeric(v) => v.len(),
171            ColumnData::Categorical { codes, .. } => codes.len(),
172        }
173    }
174
175    /// Whether the column is empty.
176    pub fn is_empty(&self) -> bool {
177        self.len() == 0
178    }
179
180    /// Try to get as string slice. Returns `None` if not `Strings` variant.
181    pub fn as_strings(&self) -> Option<&Vec<String>> {
182        match self {
183            ColumnData::Strings(v) => Some(v),
184            _ => None,
185        }
186    }
187
188    /// Try to get as numeric slice. Returns `None` if not `Numeric` variant.
189    pub fn as_numeric(&self) -> Option<&Vec<f64>> {
190        match self {
191            ColumnData::Numeric(v) => Some(v),
192            _ => None,
193        }
194    }
195
196    /// Subset to the given indices.
197    fn subset(&self, indices: &[usize]) -> Self {
198        match self {
199            ColumnData::Strings(v) => {
200                ColumnData::Strings(indices.iter().map(|&i| v[i].clone()).collect())
201            }
202            ColumnData::Numeric(v) => {
203                ColumnData::Numeric(indices.iter().map(|&i| v[i]).collect())
204            }
205            ColumnData::Categorical { codes, categories } => ColumnData::Categorical {
206                codes: indices.iter().map(|&i| codes[i]).collect(),
207                categories: categories.clone(),
208            },
209        }
210    }
211}
212
213/// Per-cell or per-gene quality control metrics.
214#[derive(Debug, Clone)]
215pub struct QcMetrics {
216    /// Total counts per observation.
217    pub total_counts: Vec<f64>,
218    /// Number of non-zero features per observation.
219    pub n_features: Vec<usize>,
220}
221
222/// AnnData-like container for single-cell data.
223#[derive(Debug, Clone)]
224pub struct AnnData {
225    /// Primary data matrix (n_obs × n_vars).
226    x: MatrixData,
227    /// Observation (cell) names.
228    obs_names: Vec<String>,
229    /// Variable (gene) names.
230    var_names: Vec<String>,
231    /// Per-cell metadata.
232    obs: HashMap<String, ColumnData>,
233    /// Per-gene metadata.
234    var: HashMap<String, ColumnData>,
235    /// Multi-dimensional observation annotations (e.g. PCA embeddings).
236    obsm: HashMap<String, Vec<Vec<f64>>>,
237    /// Multi-dimensional variable annotations.
238    varm: HashMap<String, Vec<Vec<f64>>>,
239    /// Alternative data layers (same shape as X).
240    layers: HashMap<String, MatrixData>,
241    /// Pairwise observation annotations (e.g. kNN graphs, distance matrices).
242    obsp: HashMap<String, SparseMatrix>,
243    /// Unstructured metadata (free-form key-value pairs).
244    uns: HashMap<String, String>,
245}
246
247impl AnnData {
248    /// Create a new AnnData container.
249    ///
250    /// # Errors
251    ///
252    /// Returns an error if the matrix dimensions don't match the name vectors.
253    pub fn new(
254        x: MatrixData,
255        obs_names: Vec<String>,
256        var_names: Vec<String>,
257    ) -> Result<Self> {
258        let (n_obs, n_vars) = x.shape();
259        if obs_names.len() != n_obs {
260            return Err(CyaneaError::InvalidInput(format!(
261                "obs_names length ({}) does not match n_obs ({})",
262                obs_names.len(),
263                n_obs
264            )));
265        }
266        if var_names.len() != n_vars {
267            return Err(CyaneaError::InvalidInput(format!(
268                "var_names length ({}) does not match n_vars ({})",
269                var_names.len(),
270                n_vars
271            )));
272        }
273
274        Ok(Self {
275            x,
276            obs_names,
277            var_names,
278            obs: HashMap::new(),
279            var: HashMap::new(),
280            obsm: HashMap::new(),
281            varm: HashMap::new(),
282            layers: HashMap::new(),
283            obsp: HashMap::new(),
284            uns: HashMap::new(),
285        })
286    }
287
288    /// Number of observations (cells).
289    pub fn n_obs(&self) -> usize {
290        self.obs_names.len()
291    }
292
293    /// Number of variables (genes).
294    pub fn n_vars(&self) -> usize {
295        self.var_names.len()
296    }
297
298    /// Shape of the primary data matrix.
299    pub fn shape(&self) -> (usize, usize) {
300        self.x.shape()
301    }
302
303    /// Access the primary data matrix.
304    pub fn x(&self) -> &MatrixData {
305        &self.x
306    }
307
308    /// Observation names.
309    pub fn obs_names(&self) -> &[String] {
310        &self.obs_names
311    }
312
313    /// Variable names.
314    pub fn var_names(&self) -> &[String] {
315        &self.var_names
316    }
317
318    /// Add a per-cell string metadata column.
319    pub fn add_obs(&mut self, key: &str, values: Vec<String>) -> Result<()> {
320        self.add_obs_column(key, ColumnData::Strings(values))
321    }
322
323    /// Add a per-cell numeric metadata column.
324    pub fn add_obs_numeric(&mut self, key: &str, values: Vec<f64>) -> Result<()> {
325        self.add_obs_column(key, ColumnData::Numeric(values))
326    }
327
328    /// Add a per-cell metadata column of any type.
329    pub fn add_obs_column(&mut self, key: &str, data: ColumnData) -> Result<()> {
330        if data.len() != self.n_obs() {
331            return Err(CyaneaError::InvalidInput(format!(
332                "obs '{}' length ({}) does not match n_obs ({})",
333                key,
334                data.len(),
335                self.n_obs()
336            )));
337        }
338        self.obs.insert(key.to_string(), data);
339        Ok(())
340    }
341
342    /// Get per-cell metadata column as typed data.
343    pub fn get_obs(&self, key: &str) -> Option<&ColumnData> {
344        self.obs.get(key)
345    }
346
347    /// Get per-cell metadata column as strings (convenience for backward compat).
348    pub fn get_obs_strings(&self, key: &str) -> Option<&Vec<String>> {
349        self.obs.get(key).and_then(|c| c.as_strings())
350    }
351
352    /// All observation metadata columns.
353    pub fn obs_columns(&self) -> &HashMap<String, ColumnData> {
354        &self.obs
355    }
356
357    /// Add a per-gene string metadata column.
358    pub fn add_var(&mut self, key: &str, values: Vec<String>) -> Result<()> {
359        self.add_var_column(key, ColumnData::Strings(values))
360    }
361
362    /// Add a per-gene numeric metadata column.
363    pub fn add_var_numeric(&mut self, key: &str, values: Vec<f64>) -> Result<()> {
364        self.add_var_column(key, ColumnData::Numeric(values))
365    }
366
367    /// Add a per-gene metadata column of any type.
368    pub fn add_var_column(&mut self, key: &str, data: ColumnData) -> Result<()> {
369        if data.len() != self.n_vars() {
370            return Err(CyaneaError::InvalidInput(format!(
371                "var '{}' length ({}) does not match n_vars ({})",
372                key,
373                data.len(),
374                self.n_vars()
375            )));
376        }
377        self.var.insert(key.to_string(), data);
378        Ok(())
379    }
380
381    /// Get per-gene metadata column as typed data.
382    pub fn get_var(&self, key: &str) -> Option<&ColumnData> {
383        self.var.get(key)
384    }
385
386    /// Get per-gene metadata column as strings (convenience for backward compat).
387    pub fn get_var_strings(&self, key: &str) -> Option<&Vec<String>> {
388        self.var.get(key).and_then(|c| c.as_strings())
389    }
390
391    /// All variable metadata columns.
392    pub fn var_columns(&self) -> &HashMap<String, ColumnData> {
393        &self.var
394    }
395
396    /// Add a multi-dimensional observation annotation (e.g. PCA embedding).
397    pub fn add_obsm(&mut self, key: &str, data: Vec<Vec<f64>>) -> Result<()> {
398        if data.len() != self.n_obs() {
399            return Err(CyaneaError::InvalidInput(format!(
400                "obsm '{}' length ({}) does not match n_obs ({})",
401                key,
402                data.len(),
403                self.n_obs()
404            )));
405        }
406        self.obsm.insert(key.to_string(), data);
407        Ok(())
408    }
409
410    /// Get a multi-dimensional observation annotation.
411    pub fn get_obsm(&self, key: &str) -> Option<&Vec<Vec<f64>>> {
412        self.obsm.get(key)
413    }
414
415    /// Add a multi-dimensional variable annotation.
416    pub fn add_varm(&mut self, key: &str, data: Vec<Vec<f64>>) -> Result<()> {
417        if data.len() != self.n_vars() {
418            return Err(CyaneaError::InvalidInput(format!(
419                "varm '{}' length ({}) does not match n_vars ({})",
420                key,
421                data.len(),
422                self.n_vars()
423            )));
424        }
425        self.varm.insert(key.to_string(), data);
426        Ok(())
427    }
428
429    /// Get a multi-dimensional variable annotation.
430    pub fn get_varm(&self, key: &str) -> Option<&Vec<Vec<f64>>> {
431        self.varm.get(key)
432    }
433
434    /// Add an alternative data layer.
435    pub fn add_layer(&mut self, key: &str, layer: MatrixData) -> Result<()> {
436        let (n_obs, n_vars) = layer.shape();
437        if n_obs != self.n_obs() || n_vars != self.n_vars() {
438            return Err(CyaneaError::InvalidInput(format!(
439                "layer '{}' shape ({}, {}) does not match ({}, {})",
440                key,
441                n_obs,
442                n_vars,
443                self.n_obs(),
444                self.n_vars()
445            )));
446        }
447        self.layers.insert(key.to_string(), layer);
448        Ok(())
449    }
450
451    /// Get an alternative data layer.
452    pub fn get_layer(&self, key: &str) -> Option<&MatrixData> {
453        self.layers.get(key)
454    }
455
456    /// All observation multi-dimensional annotations.
457    pub fn obsm_keys(&self) -> &HashMap<String, Vec<Vec<f64>>> {
458        &self.obsm
459    }
460
461    /// All variable multi-dimensional annotations.
462    pub fn varm_keys(&self) -> &HashMap<String, Vec<Vec<f64>>> {
463        &self.varm
464    }
465
466    /// All alternative data layers.
467    pub fn layers_keys(&self) -> &HashMap<String, MatrixData> {
468        &self.layers
469    }
470
471    /// Mutable access to the primary data matrix.
472    pub fn x_mut(&mut self) -> &mut MatrixData {
473        &mut self.x
474    }
475
476    /// Replace the primary data matrix. The new matrix must have the same shape.
477    pub fn set_x(&mut self, new_x: MatrixData) -> Result<()> {
478        let (n_obs, n_vars) = new_x.shape();
479        if n_obs != self.n_obs() || n_vars != self.n_vars() {
480            return Err(CyaneaError::InvalidInput(format!(
481                "new X shape ({}, {}) does not match ({}, {})",
482                n_obs, n_vars, self.n_obs(), self.n_vars()
483            )));
484        }
485        self.x = new_x;
486        Ok(())
487    }
488
489    /// Subset to the given variable (gene) indices.
490    pub fn subset_vars(&self, indices: &[usize]) -> Result<AnnData> {
491        for &i in indices {
492            if i >= self.n_vars() {
493                return Err(CyaneaError::InvalidInput(format!(
494                    "var index {} out of bounds (n_vars={})",
495                    i, self.n_vars()
496                )));
497            }
498        }
499
500        let x = subset_matrix_cols(&self.x, indices, self.n_obs());
501        let var_names: Vec<String> = indices.iter().map(|&i| self.var_names[i].clone()).collect();
502
503        let mut adata = AnnData::new(x, self.obs_names.clone(), var_names)?;
504
505        // Copy obs metadata
506        adata.obs = self.obs.clone();
507        // Subset var metadata
508        for (key, col) in &self.var {
509            adata.var.insert(key.clone(), col.subset(indices));
510        }
511        // Copy obsm/varm
512        adata.obsm = self.obsm.clone();
513        // Subset layers
514        for (key, layer) in &self.layers {
515            let sub = subset_matrix_cols(layer, indices, self.n_obs());
516            adata.layers.insert(key.clone(), sub);
517        }
518        // Copy obsp and uns
519        adata.obsp = self.obsp.clone();
520        adata.uns = self.uns.clone();
521
522        Ok(adata)
523    }
524
525    /// Add a pairwise observation annotation (e.g. kNN graph).
526    ///
527    /// The matrix must be n_obs × n_obs.
528    pub fn add_obsp(&mut self, key: &str, matrix: SparseMatrix) -> Result<()> {
529        let (r, c) = matrix.shape();
530        if r != self.n_obs() || c != self.n_obs() {
531            return Err(CyaneaError::InvalidInput(format!(
532                "obsp '{}' shape ({}, {}) does not match n_obs ({})",
533                key, r, c, self.n_obs()
534            )));
535        }
536        self.obsp.insert(key.to_string(), matrix);
537        Ok(())
538    }
539
540    /// Get a pairwise observation annotation.
541    pub fn get_obsp(&self, key: &str) -> Option<&SparseMatrix> {
542        self.obsp.get(key)
543    }
544
545    /// Add unstructured metadata.
546    pub fn add_uns(&mut self, key: &str, value: String) {
547        self.uns.insert(key.to_string(), value);
548    }
549
550    /// Get unstructured metadata.
551    pub fn get_uns(&self, key: &str) -> Option<&str> {
552        self.uns.get(key).map(|s| s.as_str())
553    }
554
555    /// All pairwise observation annotation keys.
556    pub fn obsp_keys(&self) -> &HashMap<String, SparseMatrix> {
557        &self.obsp
558    }
559
560    /// All unstructured metadata.
561    pub fn uns_keys(&self) -> &HashMap<String, String> {
562        &self.uns
563    }
564
565    /// Mutable access to a layer.
566    pub fn get_layer_mut(&mut self, key: &str) -> Option<&mut MatrixData> {
567        self.layers.get_mut(key)
568    }
569
570    /// Subset to the given observation indices.
571    pub fn subset_obs(&self, indices: &[usize]) -> Result<AnnData> {
572        for &i in indices {
573            if i >= self.n_obs() {
574                return Err(CyaneaError::InvalidInput(format!(
575                    "obs index {} out of bounds (n_obs={})",
576                    i,
577                    self.n_obs()
578                )));
579            }
580        }
581
582        let x = subset_matrix_rows(&self.x, indices, self.n_vars());
583        let obs_names: Vec<String> = indices.iter().map(|&i| self.obs_names[i].clone()).collect();
584
585        let mut adata = AnnData::new(x, obs_names, self.var_names.clone())?;
586
587        // Subset obs metadata
588        for (key, col) in &self.obs {
589            adata.obs.insert(key.clone(), col.subset(indices));
590        }
591        // Copy var metadata
592        adata.var = self.var.clone();
593
594        // Subset obsm
595        for (key, data) in &self.obsm {
596            let sub: Vec<Vec<f64>> = indices.iter().map(|&i| data[i].clone()).collect();
597            adata.obsm.insert(key.clone(), sub);
598        }
599        adata.varm = self.varm.clone();
600        adata.uns = self.uns.clone();
601
602        Ok(adata)
603    }
604
605    /// Compute basic QC metrics.
606    pub fn qc_metrics(&self) -> QcMetrics {
607        let n = self.n_obs();
608        let p = self.n_vars();
609        let mut total_counts = vec![0.0; n];
610        let mut n_features = vec![0usize; n];
611
612        for i in 0..n {
613            for j in 0..p {
614                let v = self.x.get(i, j);
615                total_counts[i] += v;
616                if v > 0.0 {
617                    n_features[i] += 1;
618                }
619            }
620        }
621
622        QcMetrics {
623            total_counts,
624            n_features,
625        }
626    }
627}
628
629fn subset_matrix_cols(x: &MatrixData, col_indices: &[usize], n_obs: usize) -> MatrixData {
630    match x {
631        MatrixData::Dense(rows) => {
632            let sub: Vec<Vec<f64>> = rows
633                .iter()
634                .map(|row| col_indices.iter().map(|&j| row[j]).collect())
635                .collect();
636            MatrixData::Dense(sub)
637        }
638        MatrixData::Sparse(s) => {
639            let n_new_cols = col_indices.len();
640            let mut new_s = SparseMatrix::new(n_obs, n_new_cols);
641            // Build reverse map: old_col -> new_col
642            let mut col_map = HashMap::new();
643            for (new_j, &old_j) in col_indices.iter().enumerate() {
644                col_map.insert(old_j, new_j);
645            }
646            for (r, c, v) in s.iter() {
647                if let Some(&new_c) = col_map.get(&c) {
648                    let _ = new_s.insert(r, new_c, v);
649                }
650            }
651            MatrixData::Sparse(new_s)
652        }
653    }
654}
655
656fn subset_matrix_rows(x: &MatrixData, indices: &[usize], n_vars: usize) -> MatrixData {
657    match x {
658        MatrixData::Dense(rows) => {
659            let sub: Vec<Vec<f64>> = indices.iter().map(|&i| rows[i].clone()).collect();
660            MatrixData::Dense(sub)
661        }
662        MatrixData::Sparse(s) => {
663            let n_new = indices.len();
664            let mut new_s = SparseMatrix::new(n_new, n_vars);
665            for (new_row, &old_row) in indices.iter().enumerate() {
666                for j in 0..n_vars {
667                    let v = s.get(old_row, j);
668                    if v != 0.0 {
669                        let _ = new_s.insert(new_row, j, v);
670                    }
671                }
672            }
673            MatrixData::Sparse(new_s)
674        }
675    }
676}
677
678impl Summarizable for AnnData {
679    fn summary(&self) -> String {
680        format!(
681            "AnnData: {} obs \u{00d7} {} vars, {} layers, {} obsm, {} varm, {} obsp, {} uns",
682            self.n_obs(),
683            self.n_vars(),
684            self.layers.len(),
685            self.obsm.len(),
686            self.varm.len(),
687            self.obsp.len(),
688            self.uns.len(),
689        )
690    }
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696
697    fn sample_adata() -> AnnData {
698        let x = MatrixData::Dense(vec![
699            vec![1.0, 2.0, 0.0],
700            vec![3.0, 0.0, 4.0],
701            vec![0.0, 5.0, 6.0],
702        ]);
703        AnnData::new(
704            x,
705            vec!["cell_1".into(), "cell_2".into(), "cell_3".into()],
706            vec!["gene_a".into(), "gene_b".into(), "gene_c".into()],
707        )
708        .unwrap()
709    }
710
711    #[test]
712    fn basic_construction() {
713        let adata = sample_adata();
714        assert_eq!(adata.n_obs(), 3);
715        assert_eq!(adata.n_vars(), 3);
716        assert_eq!(adata.shape(), (3, 3));
717    }
718
719    #[test]
720    fn dimension_mismatch_error() {
721        let x = MatrixData::Dense(vec![vec![1.0, 2.0]]);
722        let result = AnnData::new(
723            x,
724            vec!["cell_1".into(), "cell_2".into()], // 2 names, 1 row
725            vec!["gene_a".into(), "gene_b".into()],
726        );
727        assert!(result.is_err());
728    }
729
730    #[test]
731    fn obs_metadata() {
732        let mut adata = sample_adata();
733        adata
734            .add_obs(
735                "cell_type",
736                vec!["T-cell".into(), "B-cell".into(), "NK".into()],
737            )
738            .unwrap();
739        let ct = adata.get_obs_strings("cell_type").unwrap();
740        assert_eq!(ct[0], "T-cell");
741        assert!(adata.get_obs("missing").is_none());
742    }
743
744    #[test]
745    fn obs_metadata_length_mismatch() {
746        let mut adata = sample_adata();
747        let result = adata.add_obs("bad", vec!["a".into()]);
748        assert!(result.is_err());
749    }
750
751    #[test]
752    fn var_metadata() {
753        let mut adata = sample_adata();
754        adata
755            .add_var(
756                "gene_type",
757                vec!["coding".into(), "coding".into(), "lncRNA".into()],
758            )
759            .unwrap();
760        let gt = adata.get_var_strings("gene_type").unwrap();
761        assert_eq!(gt[2], "lncRNA");
762    }
763
764    #[test]
765    fn obsm_embedding() {
766        let mut adata = sample_adata();
767        let pca = vec![vec![0.1, 0.2], vec![0.3, 0.4], vec![0.5, 0.6]];
768        adata.add_obsm("X_pca", pca).unwrap();
769        let emb = adata.get_obsm("X_pca").unwrap();
770        assert_eq!(emb.len(), 3);
771        assert_eq!(emb[0], vec![0.1, 0.2]);
772    }
773
774    #[test]
775    fn layers() {
776        let mut adata = sample_adata();
777        let raw = MatrixData::Dense(vec![
778            vec![10.0, 20.0, 0.0],
779            vec![30.0, 0.0, 40.0],
780            vec![0.0, 50.0, 60.0],
781        ]);
782        adata.add_layer("raw_counts", raw).unwrap();
783        let layer = adata.get_layer("raw_counts").unwrap();
784        assert_eq!(layer.get(0, 0), 10.0);
785    }
786
787    #[test]
788    fn layer_shape_mismatch() {
789        let mut adata = sample_adata();
790        let bad = MatrixData::Dense(vec![vec![1.0]]);
791        assert!(adata.add_layer("bad", bad).is_err());
792    }
793
794    #[test]
795    fn subset_obs() {
796        let mut adata = sample_adata();
797        adata
798            .add_obs(
799                "label",
800                vec!["a".into(), "b".into(), "c".into()],
801            )
802            .unwrap();
803        let sub = adata.subset_obs(&[0, 2]).unwrap();
804        assert_eq!(sub.n_obs(), 2);
805        assert_eq!(sub.n_vars(), 3);
806        assert_eq!(sub.obs_names(), &["cell_1", "cell_3"]);
807        let labels = sub.get_obs_strings("label").unwrap();
808        assert_eq!(labels, &["a", "c"]);
809    }
810
811    #[test]
812    fn qc_metrics() {
813        let adata = sample_adata();
814        let qc = adata.qc_metrics();
815        assert_eq!(qc.total_counts, vec![3.0, 7.0, 11.0]);
816        assert_eq!(qc.n_features, vec![2, 2, 2]);
817    }
818
819    #[test]
820    fn sparse_x() {
821        let s = SparseMatrix::from_triplets(
822            vec![0, 1],
823            vec![0, 1],
824            vec![5.0, 10.0],
825            2,
826            2,
827        )
828        .unwrap();
829        let x = MatrixData::Sparse(s);
830        let adata = AnnData::new(
831            x,
832            vec!["c1".into(), "c2".into()],
833            vec!["g1".into(), "g2".into()],
834        )
835        .unwrap();
836        assert_eq!(adata.x().get(0, 0), 5.0);
837        assert_eq!(adata.x().get(0, 1), 0.0);
838    }
839
840    #[test]
841    fn summary_format() {
842        let adata = sample_adata();
843        let s = adata.summary();
844        assert!(s.contains("3 obs"));
845        assert!(s.contains("3 vars"));
846        assert!(s.contains("0 obsp"));
847        assert!(s.contains("0 uns"));
848    }
849
850    #[test]
851    fn matrix_data_set_dense() {
852        let mut x = MatrixData::Dense(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
853        x.set(0, 1, 99.0);
854        assert_eq!(x.get(0, 1), 99.0);
855        assert_eq!(x.get(0, 0), 1.0);
856    }
857
858    #[test]
859    fn matrix_data_set_sparse() {
860        let s = SparseMatrix::new(2, 2);
861        let mut x = MatrixData::Sparse(s);
862        x.set(0, 1, 5.0);
863        assert_eq!(x.get(0, 1), 5.0);
864    }
865
866    #[test]
867    fn matrix_data_column_sums_dense() {
868        let x = MatrixData::Dense(vec![
869            vec![1.0, 2.0, 3.0],
870            vec![4.0, 5.0, 6.0],
871        ]);
872        assert_eq!(x.column_sums(), vec![5.0, 7.0, 9.0]);
873    }
874
875    #[test]
876    fn matrix_data_column_means_dense() {
877        let x = MatrixData::Dense(vec![
878            vec![2.0, 4.0],
879            vec![6.0, 8.0],
880        ]);
881        let means = x.column_means();
882        assert!((means[0] - 4.0).abs() < 1e-10);
883        assert!((means[1] - 6.0).abs() < 1e-10);
884    }
885
886    #[test]
887    fn matrix_data_row_sums_dense() {
888        let x = MatrixData::Dense(vec![
889            vec![1.0, 2.0, 3.0],
890            vec![4.0, 5.0, 6.0],
891        ]);
892        assert_eq!(x.row_sums(), vec![6.0, 15.0]);
893    }
894
895    #[test]
896    fn matrix_data_to_flat_row_major_dense() {
897        let x = MatrixData::Dense(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
898        assert_eq!(x.to_flat_row_major(), vec![1.0, 2.0, 3.0, 4.0]);
899    }
900
901    #[test]
902    fn matrix_data_to_flat_row_major_sparse() {
903        let s = SparseMatrix::from_triplets(
904            vec![0, 1],
905            vec![1, 0],
906            vec![2.0, 3.0],
907            2,
908            2,
909        )
910        .unwrap();
911        let x = MatrixData::Sparse(s);
912        assert_eq!(x.to_flat_row_major(), vec![0.0, 2.0, 3.0, 0.0]);
913    }
914
915    #[test]
916    fn x_mut_modify() {
917        let mut adata = sample_adata();
918        adata.x_mut().set(0, 0, 42.0);
919        assert_eq!(adata.x().get(0, 0), 42.0);
920    }
921
922    #[test]
923    fn set_x_valid() {
924        let mut adata = sample_adata();
925        let new_x = MatrixData::Dense(vec![
926            vec![10.0, 20.0, 30.0],
927            vec![40.0, 50.0, 60.0],
928            vec![70.0, 80.0, 90.0],
929        ]);
930        adata.set_x(new_x).unwrap();
931        assert_eq!(adata.x().get(0, 0), 10.0);
932    }
933
934    #[test]
935    fn set_x_shape_mismatch() {
936        let mut adata = sample_adata();
937        let bad = MatrixData::Dense(vec![vec![1.0]]);
938        assert!(adata.set_x(bad).is_err());
939    }
940
941    #[test]
942    fn subset_vars_basic() {
943        let mut adata = sample_adata();
944        adata.add_var("type", vec!["a".into(), "b".into(), "c".into()]).unwrap();
945        let sub = adata.subset_vars(&[0, 2]).unwrap();
946        assert_eq!(sub.n_vars(), 2);
947        assert_eq!(sub.n_obs(), 3);
948        assert_eq!(sub.var_names(), &["gene_a", "gene_c"]);
949        assert_eq!(sub.x().get(0, 0), 1.0);
950        assert_eq!(sub.x().get(0, 1), 0.0); // was gene_c at col 2
951        let types = sub.get_var_strings("type").unwrap();
952        assert_eq!(types, &["a", "c"]);
953    }
954
955    #[test]
956    fn subset_vars_out_of_bounds() {
957        let adata = sample_adata();
958        assert!(adata.subset_vars(&[0, 10]).is_err());
959    }
960
961    #[test]
962    fn obsp_add_get() {
963        let mut adata = sample_adata();
964        let mut m = SparseMatrix::new(3, 3);
965        m.insert(0, 1, 0.5).unwrap();
966        m.insert(1, 2, 0.3).unwrap();
967        adata.add_obsp("connectivities", m).unwrap();
968        let conn = adata.get_obsp("connectivities").unwrap();
969        assert_eq!(conn.get(0, 1), 0.5);
970        assert!(adata.get_obsp("missing").is_none());
971    }
972
973    #[test]
974    fn obsp_wrong_shape() {
975        let mut adata = sample_adata();
976        let m = SparseMatrix::new(2, 2); // 3x3 needed
977        assert!(adata.add_obsp("bad", m).is_err());
978    }
979
980    #[test]
981    fn uns_add_get() {
982        let mut adata = sample_adata();
983        adata.add_uns("method", "leiden".into());
984        assert_eq!(adata.get_uns("method"), Some("leiden"));
985        assert_eq!(adata.get_uns("missing"), None);
986    }
987
988    #[test]
989    fn get_layer_mut() {
990        let mut adata = sample_adata();
991        let raw = MatrixData::Dense(vec![
992            vec![10.0, 20.0, 0.0],
993            vec![30.0, 0.0, 40.0],
994            vec![0.0, 50.0, 60.0],
995        ]);
996        adata.add_layer("counts", raw).unwrap();
997        let layer = adata.get_layer_mut("counts").unwrap();
998        layer.set(0, 0, 99.0);
999        assert_eq!(adata.get_layer("counts").unwrap().get(0, 0), 99.0);
1000    }
1001}