use std::{
collections::HashMap,
fmt,
fs::File,
io::{BufRead, BufReader, Write}, path::Path,
};
use sprs::{
io::{read_matrix_market, write_matrix_market}, TriMat
};
#[derive(Debug)]
pub struct CountMatrix {
pub matrix: sprs::CsMat<i32>,
cbs: Vec<String>,
genes: Vec<String>,
}
impl CountMatrix {
pub fn new(matrix: sprs::CsMat<i32>, cbs: Vec<String>, genes: Vec<String>) -> CountMatrix {
CountMatrix { matrix, cbs, genes }
}
pub fn to_map(&self) -> HashMap<(String, String), i32> {
let mut h1: HashMap<(String, String), i32> = HashMap::new();
for (value, (i, j)) in self.matrix.iter() {
h1.insert((self.cbs[i].clone(), self.genes[j].clone()), *value);
}
h1
}
pub fn get_shape(&self) -> (usize, usize) {
self.matrix.shape()
}
pub fn from_disk(mtx_file: &Path, cbfile: &Path, genefile: &Path) -> Self {
let mat: TriMat<f32> =
read_matrix_market(mtx_file).unwrap_or_else(|e| panic!("cant load {:?}: {:?}", mtx_file, e));
println!("Convertting f32 -> i32");
let intdata: Vec<i32> = mat.data().iter().map(|x| x.round() as i32).collect();
let intmat: TriMat<i32> = TriMat::from_triplets(
mat.shape(),
mat.row_inds().to_vec(),
mat.col_inds().to_vec(),
intdata
);
println!("Done Convertting f32 -> i32");
let matrix: sprs::CsMat<i32> = intmat.to_csr();
let fh = File::open(cbfile).unwrap_or_else(|_| panic!("{:?} not found", cbfile));
let cbs: Vec<String> = BufReader::new(fh)
.lines()
.collect::<Result<_, _>>()
.unwrap();
let fh = File::open(genefile).unwrap_or_else(|_| panic!("{:?} not found", genefile));
let genes: Vec<String> = BufReader::new(fh)
.lines()
.collect::<Result<_, _>>()
.unwrap();
CountMatrix { matrix, cbs, genes }
}
pub fn from_folder(foldername: &Path) -> Self {
let mfile = foldername.join("gene.mtx");
let cbfile = foldername.join("gene.barcodes.txt");
let genefile = foldername.join("gene.genes.txt");
CountMatrix::from_disk(&mfile, &cbfile, &genefile)
}
pub fn write(&self, foldername: &Path) {
let mfile = foldername.join("gene.mtx");
let cbfile = foldername.join("gene.barcodes.txt");
let genefile = foldername.join("gene.genes.txt");
println!("Convertting i32 -> f32");
let mut floatdata: Vec<f32> = Vec::new();
let mut rows: Vec<usize> = Vec::new();
let mut cols: Vec<usize> = Vec::new();
for (&v, (r,c)) in self.matrix.iter() {
floatdata.push(v as f32);
rows.push(r);
cols.push(c);
}
let fmat: TriMat<f32> = TriMat::from_triplets(
self.matrix.shape(),
rows,
cols,
floatdata
);
println!("Done Convertting f32 -> i32");
write_matrix_market(&mfile, &fmat).unwrap_or_else(|_x| panic!("{:?} not found", mfile));
let mut fh_cb = File::create(&cbfile).unwrap_or_else(|_x| panic!("{:?} not found", cbfile));
let mut fh_gene = File::create(&genefile).unwrap_or_else(|_x| panic!("{:?} not found", genefile));
for cb in self.cbs.iter() {
fh_cb.write_all(format!("{}\n", cb).as_bytes()).unwrap();
}
for g in self.genes.iter() {
fh_gene.write_all(format!("{}\n", g).as_bytes()).unwrap();
}
}
}
impl PartialEq for CountMatrix {
fn eq(&self, other: &Self) -> bool {
let h1 = self.to_map();
let h2 = other.to_map();
h1 == h2
}
}
impl Eq for CountMatrix {}
impl fmt::Display for CountMatrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Shape: {:?}; nnz {}",
self.get_shape(),
self.matrix.nnz()
)
}
}
#[cfg(test)]
mod test {
use super::CountMatrix;
use crate::count2::countmap_to_matrix;
use bustools_core::{consistent_genes::{CB, GeneId, Genename}, hashmap};
use ndarray::arr2;
use std::collections::HashMap;
use tempfile::tempdir;
#[test]
fn test_countmatrix() {
let countmap: HashMap<(CB, GeneId), usize> = hashmap!(
(CB(0), GeneId(0)) => 10,
(CB(0), GeneId(1)) => 1,
(CB(1), GeneId(0)) => 0, (CB(1), GeneId(1)) => 5
);
let gene_vector = vec![Genename("geneA".to_string()), Genename("geneB".to_string())];
let cmat = countmap_to_matrix(&countmap, gene_vector);
let dense_mat = cmat.matrix.to_dense();
let expected = arr2(&[[10_i32, 1], [0, 5]]);
assert_eq!(dense_mat, expected);
assert_eq!(
cmat.cbs,
vec![
"AAAAAAAAAAAAAAAA".to_string(),
"AAAAAAAAAAAAAAAC".to_string()
]
);
}
#[test]
fn test_read_write() {
let countmap: HashMap<(CB, GeneId), usize> = hashmap!(
(CB(0), GeneId(0)) => 10,
(CB(0), GeneId(1)) => 1,
(CB(1), GeneId(0)) => 0, (CB(1), GeneId(1)) => 5
);
let gene_vector = vec![Genename("geneA".to_string()), Genename("geneB".to_string())];
let cmat = countmap_to_matrix(&countmap, gene_vector);
let dir = tempdir().unwrap();
let path = dir.path().join("bustools_test_read_write");
if !path.exists() {
std::fs::create_dir(&path).unwrap();
}
cmat.write(&path);
let cmat2 = CountMatrix::from_disk(
&path.join("gene.mtx"),
&path.join("gene.barcodes.txt"),
&path.join("gene.genes.txt"),
);
assert!(cmat == cmat2);
}
#[test]
fn test_countmatrix_equal() {
let countmap1: HashMap<(CB, GeneId), usize> = hashmap!(
(CB(0), GeneId(0)) => 10,
(CB(0), GeneId(1)) => 1,
(CB(1), GeneId(0)) => 0, (CB(1), GeneId(1)) => 5
);
let gene_vector = vec![Genename("geneA".to_string()), Genename("geneB".to_string())];
let cmat1 = countmap_to_matrix(&countmap1, gene_vector);
let countmap2: HashMap<(CB, GeneId), usize> = hashmap!(
(CB(0), GeneId(1)) => 10,
(CB(0), GeneId(0)) => 1,
(CB(1), GeneId(1)) => 0, (CB(1), GeneId(0)) => 5
);
let gene_vector = vec![Genename("geneB".to_string()), Genename("geneA".to_string())];
let cmat2 = countmap_to_matrix(&countmap2, gene_vector);
println!("{:?}", cmat1.to_map());
println!("{:?}", cmat2.to_map());
assert!(cmat1 == cmat2);
}
}