af_anndata/
lib.rs

1use anndata::data::array::dataframe::DataFrameIndex;
2use anndata::{reader::MMReader, s, AnnData, AnnDataOp, ArrayData, ArrayElemOp};
3use anndata_hdf5::H5;
4use anyhow::{bail, Context};
5use polars::io::prelude::*;
6use polars::prelude::{CsvReadOptions, DataFrame, PolarsError, Series, SortMultipleOptions};
7use serde_json::Value;
8use std::path::{Path, PathBuf};
9use tracing::{error, info, trace, warn};
10
11/// Tags the actual type of matrix that has
12/// been populated in a `CSRMatPaack`
13enum PopulatedMatType {
14    I32,
15    I64,
16    F32,
17    F64,
18}
19
20/// Holds the different representable types of
21/// matrices. It is expected that only
22/// one of these should be populated.
23struct CSRMatPack {
24    pub mat_i32: nalgebra_sparse::CsrMatrix<i32>,
25    pub mat_i64: nalgebra_sparse::CsrMatrix<i64>,
26    pub mat_f32: nalgebra_sparse::CsrMatrix<f32>,
27    pub mat_f64: nalgebra_sparse::CsrMatrix<f64>,
28}
29
30impl CSRMatPack {
31    /// Create a new CSRMatPack with empty matrices for all supported types
32    pub fn new(nr: usize, ngenes: usize) -> Self {
33        Self {
34            mat_i32: nalgebra_sparse::CsrMatrix::<i32>::zeros(nr, ngenes),
35            mat_i64: nalgebra_sparse::CsrMatrix::<i64>::zeros(nr, ngenes),
36            mat_f32: nalgebra_sparse::CsrMatrix::<f32>::zeros(nr, ngenes),
37            mat_f64: nalgebra_sparse::CsrMatrix::<f64>::zeros(nr, ngenes),
38        }
39    }
40
41    pub fn ncells(&self) -> usize {
42        self.mat_f64.nrows()
43    }
44
45    pub fn ngenes(&self) -> usize {
46        self.mat_f64.ncols()
47    }
48
49    /// Returns a `PopulatedMatType` enum specifying the type of matrix in this
50    /// pack that is populated.  If more than one type is populated, it returns
51    /// an error.
52    pub fn populated_type(&self) -> anyhow::Result<Option<PopulatedMatType>> {
53        let mut t: Option<PopulatedMatType> = None;
54        let mut nset = 0_usize;
55        if self.mat_i32.nnz() > 0 {
56            nset += 1;
57            t = Some(PopulatedMatType::I32);
58        } else if self.mat_i64.nnz() > 0 {
59            nset += 1;
60            t = Some(PopulatedMatType::I64);
61        } else if self.mat_f32.nnz() > 0 {
62            nset += 1;
63            t = Some(PopulatedMatType::F32);
64        } else if self.mat_f64.nnz() > 0 {
65            nset += 1;
66            t = Some(PopulatedMatType::F64);
67        }
68        if nset > 1 {
69            bail!("The CSRMatPack has > 1 set matrix type. This should not happen");
70        }
71        if t.is_some() {
72            Ok(t)
73        } else {
74            Ok(None)
75        }
76    }
77}
78
79/// accumulates the content of `slice` into the apropriately typed member of `csr_accum` and
80/// returns `OK(CSRMatPack)` with the accumulated result in the appropriately typed member upon
81/// success. If an unsupported matrix type is provided in the `slice`, an error is raised.
82fn accumulate_layer(slice: &ArrayData, mut csr_accum: CSRMatPack) -> anyhow::Result<CSRMatPack> {
83    match slice {
84        anndata::data::array::ArrayData::CsrMatrix(a) => {
85            trace!("confirmed ArrayData slice is in CSR format");
86            match a {
87                anndata::data::array::DynCsrMatrix::I8(_l) => {
88                    bail!("I8 matrix type is not supported")
89                }
90                anndata::data::array::DynCsrMatrix::I16(_l) => {
91                    bail!("I16 matrix type is not supported")
92                }
93                anndata::data::array::DynCsrMatrix::I32(l) => {
94                    csr_accum.mat_i32 = csr_accum.mat_i32 + l;
95                }
96                anndata::data::array::DynCsrMatrix::I64(l) => {
97                    csr_accum.mat_i64 = csr_accum.mat_i64 + l;
98                }
99                anndata::data::array::DynCsrMatrix::U8(_l) => {
100                    bail!("U8 matrix type is not supported")
101                }
102                anndata::data::array::DynCsrMatrix::U16(_l) => {
103                    bail!("U16 matrix type is not supported")
104                }
105                anndata::data::array::DynCsrMatrix::U32(_l) => {
106                    bail!(
107                        "Addition is not supported for U32 CSR matrices because they do not satisfy NEG<Output=T>; storing result in an i64"
108                    );
109                }
110                anndata::data::array::DynCsrMatrix::U64(_l) => {
111                    bail!(
112                        "Addition is not supported for U32 CSR matrices because they do not satisfy NEG<Output=T>; storing result in an i64"
113                    );
114                }
115                anndata::data::array::DynCsrMatrix::F32(l) => {
116                    csr_accum.mat_f32 = csr_accum.mat_f32 + l;
117                }
118                anndata::data::array::DynCsrMatrix::F64(l) => {
119                    csr_accum.mat_f64 = csr_accum.mat_f64 + l;
120                }
121                anndata::data::array::DynCsrMatrix::Bool(_l) => {
122                    bail!("Bool matrix type is not supported")
123                }
124                anndata::data::array::DynCsrMatrix::String(_l) => {
125                    bail!("String matrix type is not supported")
126                }
127            };
128        }
129        _ => warn!("expected underlying CSR matrix; cannot populate X layer with sum counts!"),
130    }
131    Ok(csr_accum)
132}
133
134/// Sets the X layer of the `AnnData` object, as well as the `n_obs` and `n_vars` values.
135/// Any existing X layer will be deleted.  The contents of the new `X` layer will depend on
136/// which matrix type in the `csr_in` `CSRMatPack` is populated.
137fn set_x_layer<B: anndata::Backend>(b: &mut AnnData<B>, csr_in: CSRMatPack) -> anyhow::Result<()> {
138    // get rid of the old X
139    b.del_x().context("unable to delete X")?;
140    b.del_obs()?;
141    b.set_n_obs(csr_in.ncells())
142        .context("unable to set n_obs")?;
143    //b.del_var()?;
144    b.set_n_vars(csr_in.ngenes())
145        .context("unable to set n_vars")?;
146    // set the new X
147    match csr_in
148        .populated_type()
149        .context("error getting populated type of the X matrix")?
150    {
151        Some(PopulatedMatType::I32) => {
152            let csr_mat = anndata::data::array::DynCsrMatrix::I32(csr_in.mat_i32);
153            b.set_x(csr_mat).context("unable to set all 0s X")?;
154        }
155        Some(PopulatedMatType::I64) => {
156            let csr_mat = anndata::data::array::DynCsrMatrix::I64(csr_in.mat_i64);
157            b.set_x(csr_mat).context("unable to set all 0s X")?;
158        }
159        Some(PopulatedMatType::F32) => {
160            let csr_mat = anndata::data::array::DynCsrMatrix::F32(csr_in.mat_f32);
161            b.set_x(csr_mat).context("unable to set all 0s X")?;
162        }
163        Some(PopulatedMatType::F64) => {
164            let csr_mat = anndata::data::array::DynCsrMatrix::F64(csr_in.mat_f64);
165            b.set_x(csr_mat).context("unable to set all 0s X")?;
166        }
167        None => {
168            warn!(
169                "None of the underlying matrices for the layers had counts; setting the output to the trivial empty matrix (of type f64)"
170            );
171            let csr_mat = anndata::data::array::DynCsrMatrix::F64(
172                nalgebra_sparse::CsrMatrix::<f64>::zeros(csr_in.ncells(), csr_in.ngenes()),
173            );
174            b.set_x(csr_mat).context("unable to set all 0s X")?;
175        }
176    }
177    Ok(())
178}
179
180fn separate_usa_layers<B: anndata::Backend>(
181    b: &mut AnnData<B>,
182    row_df: &DataFrame,
183    col_df: &DataFrame,
184    var_df: Option<DataFrame>,
185) -> anyhow::Result<()> {
186    let mut sw = libsw::Sw::new();
187    sw.start()?;
188
189    let nr = b.n_obs();
190    let nc = b.n_vars();
191    // if USA mode then the number of genes is
192    // 1/3 of the number of features
193    let ngenes = nc / 3;
194
195    let mut csr_zero = CSRMatPack::new(nr, ngenes);
196
197    // Get the unspliced, spliced and ambiguous slices
198    let vars = col_df;
199
200    let slice1: ArrayData = b.get_x().slice(s![.., 0..ngenes])?.unwrap();
201    csr_zero = accumulate_layer(&slice1, csr_zero)?;
202
203    let var1 = vars.slice(0_i64, ngenes);
204    info!("getting slice took {:#?}", sw.elapsed());
205    sw.reset();
206    sw.start()?;
207
208    let slice2: ArrayData = b.get_x().slice(s![.., ngenes..2 * ngenes])?.unwrap();
209    csr_zero = accumulate_layer(&slice2, csr_zero)?;
210    //let var2 = vars.slice(ngenes as i64, ngenes);
211    info!("getting slice took {:#?}", sw.elapsed());
212    sw.reset();
213    sw.start()?;
214
215    let slice3: ArrayData = b.get_x().slice(s![.., 2 * ngenes..3 * ngenes])?.unwrap();
216    csr_zero = accumulate_layer(&slice3, csr_zero)?;
217    //let var3 = vars.slice(2_i64 * ngenes as i64, ngenes);
218    info!("getting slice took {:#?}", sw.elapsed());
219    sw.reset();
220    sw.start()?;
221
222    set_x_layer(b, csr_zero)?;
223
224    // populate with the gene id and gene symbol if we have it
225    // otherwise just set the gene name
226    if let Some(var_info) = var_df {
227        b.set_var(var_info)?;
228    } else {
229        let mut temp_var = var1.clone();
230        temp_var.set_column_names(["gene_id"])?;
231        b.set_var(temp_var)?;
232    }
233
234    let layers = vec![
235        ("spliced".to_owned(), slice1),
236        ("unspliced".to_owned(), slice2),
237        ("ambiguous".to_owned(), slice3),
238    ];
239    b.set_layers(layers)
240        .context("unable to set layers for AnnData object")?;
241    info!("setting layers took {:#?}", sw.elapsed());
242    b.set_obs(row_df.clone())?;
243
244    Ok(())
245}
246
247pub fn convert_csr_to_anndata<P: AsRef<Path>>(root_path: P, output_path: P) -> anyhow::Result<()> {
248    let root_path = root_path.as_ref();
249    let json_path = PathBuf::from(&root_path);
250
251    let mut gpl_path = json_path.clone();
252    gpl_path.push("generate_permit_list.json");
253
254    let mut collate_path = json_path.clone();
255    collate_path.push("collate.json");
256
257    let mut quant_path = json_path.clone();
258    quant_path.push("quant.json");
259
260    let mut map_log_path = json_path.clone();
261    map_log_path.push("simpleaf_map_info.json");
262
263    let alevin_path = root_path.join("alevin");
264    let mut p = PathBuf::from(&alevin_path);
265    p.push("quants_mat.mtx");
266
267    let mut colpath = PathBuf::from(&alevin_path);
268    colpath.push("quants_mat_cols.txt");
269
270    let mut rowpath = PathBuf::from(&alevin_path);
271    rowpath.push("quants_mat_rows.txt");
272
273    let mut gene_id_to_name_path = PathBuf::from(&root_path);
274    gene_id_to_name_path.push("gene_id_to_name.tsv");
275
276    if !p.is_file() {
277        anyhow::bail!(
278            "the count file was expected at {} but could not be found",
279            p.display()
280        );
281    }
282    if !colpath.is_file() {
283        anyhow::bail!(
284            "the column annotation file was expected at {} but could not be found",
285            colpath.display()
286        );
287    }
288    if !rowpath.is_file() {
289        anyhow::bail!(
290            "the row annotation file was expected at {} but could not be found",
291            rowpath.display()
292        );
293    }
294    if !gpl_path.is_file() {
295        anyhow::bail!(
296            "the generate_permit_list json file was expected at {} but could not be found",
297            gpl_path.display()
298        );
299    }
300    if !collate_path.is_file() {
301        anyhow::bail!(
302            "the collate json file was expected at {} but could not be found",
303            collate_path.display()
304        );
305    }
306    if !quant_path.is_file() {
307        anyhow::bail!(
308            "the quant json file was expected at {} but could not be found",
309            quant_path.display()
310        );
311    }
312
313    // see if we have a valid gene id to name file
314    let gene_id_to_name_path = gene_id_to_name_path
315        .is_file()
316        .then_some(gene_id_to_name_path);
317    // otherwise, wan the user
318    if gene_id_to_name_path.is_none() {
319        warn!(
320            "Could not find the `gene_id_to_name` file, so only gene IDs and not symbols will be present in `var`"
321        );
322    }
323
324    // read in the relevant JSON files
325    let qf = std::fs::File::open(&quant_path)?;
326    let quant_json: Value = serde_json::from_reader(qf)
327        .with_context(|| format!("could not parse {} as valid JSON.", quant_path.display()))?;
328
329    let cf = std::fs::File::open(&collate_path)?;
330    let collate_json: Value = serde_json::from_reader(cf)
331        .with_context(|| format!("could not parse {} as valid JSON.", collate_path.display()))?;
332
333    let gplf = std::fs::File::open(&gpl_path)?;
334    let gpl_json: Value = serde_json::from_reader(gplf)
335        .with_context(|| format!("could not parse {} as valid JSON.", gpl_path.display()))?;
336
337    let map_json: Value = if let Ok(mapf) = std::fs::File::open(&map_log_path) {
338        serde_json::from_reader(mapf)
339            .with_context(|| format!("could not parse {} as valid JSON.", gpl_path.display()))?
340    } else {
341        warn!("Could not find a simpleaf_map_info.json in the provided directory; please upgrade to the latest version of simpleaf when possible!");
342        serde_json::json!({
343            "mapper" : "file_not_found",
344            "num_mapped": 0,
345            "num_poisoned": 0,
346            "num_reads": 0,
347            "percent_mapped": 0.
348        })
349    };
350
351    let usa_mode = if let Some(Value::Bool(v)) = quant_json.get("usa_mode") {
352        *v
353    } else {
354        false
355    };
356
357    info!("USA mode : {}", usa_mode);
358
359    let mut sw = libsw::Sw::new();
360    sw.start()?;
361
362    let r = MMReader::from_path(&p)?;
363
364    let parse_opts = CsvParseOptions::default().with_separator(b'\t');
365    // read the gene ids
366    let mut col_df = match CsvReadOptions::default()
367        .with_has_header(false)
368        .with_parse_options(parse_opts.clone())
369        .with_raise_if_empty(true)
370        .try_into_reader_with_file_path(Some(colpath))?
371        .finish()
372    {
373        Ok(dframe) => dframe,
374        Err(PolarsError::NoData(estr)) => {
375            error!("error reading column labels : {:?};", estr);
376            bail!("failed to construct the column data frame.");
377        }
378        Err(err) => {
379            bail!(err);
380        }
381    };
382    col_df.set_column_names(["gene_id"])?;
383
384    // read the barcodes
385    let mut row_df = match CsvReadOptions::default()
386        .with_has_header(false)
387        .with_parse_options(parse_opts.clone())
388        .with_raise_if_empty(true)
389        .try_into_reader_with_file_path(Some(rowpath))?
390        .finish()
391    {
392        Ok(dframe) => dframe,
393        Err(PolarsError::NoData(estr)) => {
394            error!("error reading row labels : {:?};", estr);
395            error!(
396                "this likely indicates the row labels (barcode list) was empty --- please ensure the barcode list is properly matched to the chemistry being processed"
397            );
398            bail!("failed to construct the row data frame.");
399        }
400        Err(err) => {
401            bail!(err);
402        }
403    };
404
405    let nobs_cols = row_df.get_columns().len();
406    match nobs_cols {
407        1 => row_df.set_column_names(["barcodes"])?,
408        3 => row_df.set_column_names(["barcodes", "spot_x", "spot_y"])?,
409        x => {
410            error!(
411                "quants_mat_rows.txt file should have 1 (sc/sn-RNA) or 3 columns (spatial); the provided file has {}",
412                x
413            );
414            bail!(
415                "quants_mat_rows.txt file should have 1 (sc/sn-RNA) or 3 columns (spatial); the provided file has {}",
416                x
417            );
418        }
419    }
420
421    // read the gene_id_to_name file
422    let var_df = if let Some(id_to_name) = gene_id_to_name_path {
423        // if we had the gene id to name file name, then we want to read it in, but
424        // we also want to re-order the rows to match with the row-labels given in
425        // the input file.
426        let ordered_ids = col_df.column("gene_id")?.as_materialized_series().str()?;
427        // read through the gene names in order, and build a hash map from gene_id to
428        // rank (corresponding column) of the count matrix
429        let gene_rank_hash = std::collections::HashMap::<&str, u64>::from_iter(
430            ordered_ids.iter().enumerate().map(|(i, s)| {
431                (
432                    s.expect("should not be missing a gene id in the `quants_mat_cols` file"),
433                    i as u64,
434                )
435            }),
436        );
437
438        // read the gene id to name file
439        let mut vd = CsvReadOptions::default()
440            .with_has_header(false)
441            .with_parse_options(parse_opts)
442            .with_raise_if_empty(true)
443            .try_into_reader_with_file_path(Some(id_to_name))?
444            .finish()?;
445
446        // create a column that lists the rank for each gene
447        // *in the order it appears in the gene id to name list*.
448        // we will use this to re-order the rows of the gene id to name
449        // dataframe.
450        let rank_vec: Series = Series::from_iter(
451            vd.select_at_idx(0)
452                .expect("0th column of gene id to gene name DataFrame should exist")
453                .as_materialized_series()
454                .str()?
455                .iter()
456                .map(|s| {
457                    gene_rank_hash
458                        [s.expect("should not be a missing gene id in the `gene_id_to_name` file")]
459                }),
460        );
461        // add the rank column to the Data Frame
462        vd.with_column(rank_vec)?;
463        // set the column names
464        vd.set_column_names(["gene_id", "gene_symbol", "gene_rank"])?;
465        // reorder the rows to put the ranks in order, bringing the ids and names
466        // with them.
467        vd.sort_in_place(["gene_rank"], SortMultipleOptions::default())?;
468        // we no longer need the rank column
469        vd.drop_in_place("gene_rank")?;
470        Some(vd)
471    } else {
472        None
473    };
474
475    // make the AnnData object and populate it from the MMReader
476    let mut b = AnnData::<H5>::new(output_path.as_ref())?;
477    r.finish(&b)?;
478    info!("Reading MM into AnnData took {:#?}", sw.elapsed());
479
480    // read in the feature dump data
481    let mut feat_dump_path = PathBuf::from(&root_path);
482    feat_dump_path.push("featureDump.txt");
483    let feat_parse_options =
484        polars::io::csv::read::CsvParseOptions::default().with_separator(b'\t');
485    let mut feat_dump_frame = match polars_io::csv::read::CsvReadOptions::default()
486        .with_parse_options(feat_parse_options)
487        .with_has_header(true)
488        .with_raise_if_empty(true)
489        .try_into_reader_with_file_path(Some(feat_dump_path.clone()))
490        .context("could not create TSV file reader")?
491        .finish()
492    {
493        Ok(dframe) => dframe,
494        Err(PolarsError::NoData(estr)) => {
495            error!(
496                "error reading the feature file ({}): {:?};",
497                feat_dump_path.display(),
498                estr
499            );
500            error!(
501                "this likely indicates no barcodes were processed and written to the output --- please ensure the barcode list is properly matched to the chemistry being processed"
502            );
503            bail!("failed to construct the feature data frame.");
504        }
505        Err(err) => {
506            bail!(err);
507        }
508    };
509    // add the features to the row df
510    // skip the first column since it is `CB` (the cell barcode) and is
511    // redundant with the cell barcode we already have in this dataframe
512    // CB      CorrectedReads  MappedReads     DeduplicatedReads       MappingRate     DedupRate       MeanByMax       NumGenesExpressed       NumGenesOverMean
513    let col_rename = vec![
514        ("CorrectedReads", "corrected_reads"),
515        ("MappedReads", "mapped_reads"),
516        ("DeduplicatedReads", "deduplicated_reads"),
517        ("MappingRate", "mapping_rate"),
518        ("DedupRate", "dedup_rate"),
519        ("MeanByMax", "mean_by_max"),
520        ("NumGenesExpressed", "num_genes_expressed"),
521        ("NumGenesOverMean", "num_genes_over_mean"),
522    ];
523    for (old_name, new_name) in col_rename {
524        feat_dump_frame.rename(old_name, new_name.into())?;
525    }
526    let row_df = row_df.hstack(&feat_dump_frame.take_columns()[1..])?;
527
528    // read in the quant JSON file
529    let gpl_json_str = serde_json::to_string(&gpl_json).context(
530        "could not convert generate_permit_list.json to string succesfully to place in uns data.",
531    )?;
532    let collate_json_str = serde_json::to_string(&collate_json)
533        .context("could not convert collate.json to string succesfully to place in uns data.")?;
534    let quant_json_str = serde_json::to_string(&quant_json)
535        .context("could not convert quant.json to string succesfully to place in uns data.")?;
536    let map_log_json_str = serde_json::to_string(&map_json).context(
537        "could not convert simpleaf_map_info.json to string succesfully to place in uns data.",
538    )?;
539
540    // set unstructured metadata
541    let uns: Vec<(String, anndata::Data)> = vec![
542        ("gpl_info".to_owned(), anndata::Data::from(gpl_json_str)),
543        (
544            "collate_info".to_owned(),
545            anndata::Data::from(collate_json_str),
546        ),
547        ("quant_info".to_owned(), anndata::Data::from(quant_json_str)),
548        (
549            "simpleaf_map_info".to_owned(),
550            anndata::Data::from(map_log_json_str),
551        ),
552    ];
553    b.set_uns(uns).context("failed to set \"uns\" data")?;
554
555    let ngenes;
556    if usa_mode {
557        ngenes = b.n_vars() / 3;
558        separate_usa_layers(&mut b, &row_df, &col_df, var_df)?;
559    } else {
560        ngenes = b.n_vars();
561        if let Some(var_info) = var_df {
562            b.set_var(var_info)?;
563        } else {
564            b.set_var(col_df.clone())?;
565        }
566        b.set_obs(row_df.clone())?;
567    }
568
569    let gene_ids: Vec<String> = col_df
570        .column("gene_id")?
571        .str()
572        .unwrap()
573        .iter()
574        .flatten()
575        .take(ngenes)
576        .map(|s| s.to_string())
577        .collect();
578    let var_index = DataFrameIndex::from(gene_ids);
579    b.set_var_names(var_index)?;
580
581    let barcodes: Vec<String> = row_df
582        .column("barcodes")?
583        .str()
584        .unwrap()
585        .iter()
586        .flatten()
587        .map(|s| s.to_string())
588        .collect();
589    let obs_index = DataFrameIndex::from(barcodes);
590    b.set_obs_names(obs_index)?;
591    Ok(())
592}