pub(crate) mod processing;
pub(crate) mod statistics;
pub(crate) mod utils;
pub use processing::get_select_info_obs;
pub use processing::get_select_info_vars;
pub use processing::FlavorType;
pub use processing::HVGParams;
use std::collections::HashMap;
use anndata::backend::ScalarType;
use anndata::data::DynCsrMatrix;
use anndata::data::{DynArray, DynCscMatrix, SelectInfoElem};
use anndata::{data::Shape, ArrayData, HasShape};
use anyhow::{anyhow, bail};
use nalgebra_sparse::{CscMatrix, CsrMatrix};
use ndarray::{Array2, ArrayD, Ix2};
use num_traits::{NumCast, Zero};
use single_utilities::traits::NumericOps;
use utils::select_info_elem_to_indices;
pub enum FeatureSelection {
HighlyVariableCol(String),
HighlyVariable(usize),
Randomized(usize),
VarianceThreshold(f64),
None,
}
pub enum ComputationMode {
Chunked(usize),
Whole,
}
impl Clone for ComputationMode {
fn clone(&self) -> Self {
match self {
Self::Chunked(arg0) => Self::Chunked(*arg0),
Self::Whole => Self::Whole,
}
}
}
pub enum FlexValue {
Absolute(f32),
Relative(f32),
None,
}
impl Clone for FlexValue {
fn clone(&self) -> Self {
match self {
Self::Absolute(arg0) => Self::Absolute(*arg0),
Self::Relative(arg0) => Self::Relative(*arg0),
Self::None => Self::None,
}
}
}
impl FlexValue {
pub fn is_absolute(&self) -> bool {
match self {
Self::Absolute(_) => true,
Self::Relative(_) => false,
Self::None => false,
}
}
pub fn is_relative(&self) -> bool {
match self {
Self::Absolute(_) => false,
Self::Relative(_) => true,
Self::None => false,
}
}
pub fn is_none(&self) -> bool {
match self {
Self::Absolute(_) => false,
Self::Relative(_) => false,
Self::None => true,
}
}
pub fn is_some(&self) -> bool {
!self.is_none()
}
}
#[macro_export]
macro_rules! match_dyn_csr_matrix {
($csr:expr, $fun:ident, $($arg:expr),*) => {
match $csr {
DynCsrMatrix::I8(d) => $fun(d, $($arg),*),
DynCsrMatrix::I16(d) => $fun(d, $($arg),*),
DynCsrMatrix::I32(d) => $fun(d, $($arg),*),
DynCsrMatrix::I64(_d) => panic!("I64 CSR matrices are not supported for this operation"),
DynCsrMatrix::U8(d) => $fun(d, $($arg),*),
DynCsrMatrix::U16(d) => $fun(d, $($arg),*),
DynCsrMatrix::U32(d) => $fun(d, $($arg),*),
DynCsrMatrix::U64(_d) => panic!("U64 CSR matrices are not supported for this operation"),
DynCsrMatrix::F32(d) => $fun(d, $($arg),*),
DynCsrMatrix::F64(d) => $fun(d, $($arg),*),
DynCsrMatrix::Bool(_) => panic!("Boolean CSR matrices are not supported for this operation"),
DynCsrMatrix::String(_) => panic!("String CSR matrices are not supported for this operation"),
}
};
}
#[macro_export]
macro_rules! match_dyn_csc_matrix {
($csc:expr, $fun:ident, $($arg:expr),*) => {
match $csc {
DynCscMatrix::I8(d) => $fun(d, $($arg),*),
DynCscMatrix::I16(d) => $fun(d, $($arg),*),
DynCscMatrix::I32(d) => $fun(d, $($arg),*),
DynCscMatrix::I64(_d) => panic!("I64 CSC matrices are not supported for this operation"),
DynCscMatrix::U8(d) => $fun(d, $($arg),*),
DynCscMatrix::U16(d) => $fun(d, $($arg),*),
DynCscMatrix::U32(d) => $fun(d, $($arg),*),
DynCscMatrix::U64(_d) => panic!("U64 CSC matrices are not supported for this operation"),
DynCscMatrix::F32(d) => $fun(d, $($arg),*),
DynCscMatrix::F64(d) => $fun(d, $($arg),*),
DynCscMatrix::Bool(_) => panic!("Boolean CSC matrices are not supported for this operation"),
DynCscMatrix::String(_) => panic!("String CSC matrices are not supported for this operation"),
}
};
}
#[macro_export]
macro_rules! match_array_data_apply_function {
($data:expr, $fun:ident) => {
match $data {
anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
match dyn_csr_matrix {
anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun(),
anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun(),
_ => bail!("This operation is only supported on numeric types!")
}
},
anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
match dyn_csc_matrix {
anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun(),
anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun(),
_ => bail!("This operation is only supported on numeric types!")
}
},
_ => bail!("This operation is currently only supported for CSC and CSR matrices.")
}
};
($data:expr, $fun:ident, $($arg:expr),+) => {
match $data {
anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
match dyn_csr_matrix {
anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun($($arg),*),
_ => bail!("This operation is only supported on numeric types!")
}
},
anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
match dyn_csc_matrix {
anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun($($arg),*),
anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun($($arg),*),
_ => bail!("This operation is only supported on numeric types!")
}
},
_ => bail!("This operation is currently only supported for CSC and CSR matrices.")
}
};
}
#[macro_export]
macro_rules! match_array_data_apply_function_with_generics {
($data:expr, $fun:ident, [$($types:ty),+]) => {
match $data {
anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
match dyn_csr_matrix {
anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun::<$($types),+>(),
_ => bail!("This operation is only supported on numeric types!")
}
},
anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
match dyn_csc_matrix {
anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun::<$($types),+>(),
anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun::<$($types),+>(),
_ => bail!("This operation is only supported on numeric types!")
}
},
_ => bail!("This operation is currently only supported for CSC and CSR matrices.")
}
};
($data:expr, $fun:ident, [$($types:ty),+], $($arg:expr),+) => {
match $data {
anndata::ArrayData::CsrMatrix(dyn_csr_matrix) => {
match dyn_csr_matrix {
anndata::data::DynCsrMatrix::I8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::I16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::I32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::I64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::U8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::U16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::U32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::U64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::F32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCsrMatrix::F64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
_ => bail!("This operation is only supported on numeric types!")
}
},
anndata::ArrayData::CscMatrix(dyn_csc_matrix) => {
match dyn_csc_matrix {
anndata::data::DynCscMatrix::I8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::I16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::I32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::I64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::U8(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::U16(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::U32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::U64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::F32(matrix) => matrix.$fun::<$($types),+>($($arg),+),
anndata::data::DynCscMatrix::F64(matrix) => matrix.$fun::<$($types),+>($($arg),+),
_ => bail!("This operation is only supported on numeric types!")
}
},
_ => bail!("This operation is currently only supported for CSC and CSR matrices.")
}
};
}
pub fn convert_to_array_f64(arr_data: &ArrayData) -> anyhow::Result<Array2<f64>> {
let shape = arr_data.shape();
match arr_data {
ArrayData::Array(array) => convert_to_array_f64_array(array),
ArrayData::CsrMatrix(csr) => match_dyn_csr_matrix!(csr, convert_to_array_f64_csr, shape),
ArrayData::CsrNonCanonical(_csc) => todo!(),
ArrayData::CscMatrix(csc) => match_dyn_csc_matrix!(csc, convert_to_array_f64_csc, shape),
ArrayData::DataFrame(_) => todo!(),
}
}
fn convert_to_array_f64_array(darray: &DynArray) -> anyhow::Result<Array2<f64>> {
match darray {
DynArray::I8(arr) => convert_arrayd_to_array2_f64(arr),
DynArray::I16(arr) => convert_arrayd_to_array2_f64(arr),
DynArray::I32(arr) => convert_arrayd_to_array2_f64(arr),
DynArray::I64(_) => todo!(),
DynArray::U8(arr) => convert_arrayd_to_array2_f64(arr),
DynArray::U16(arr) => convert_arrayd_to_array2_f64(arr),
DynArray::U32(arr) => convert_arrayd_to_array2_f64(arr),
DynArray::U64(_) => todo!(),
DynArray::F32(arr) => convert_arrayd_to_array2_f64(arr),
DynArray::F64(array) => convert_arrayd_to_array2_f64(array),
DynArray::Bool(_) => todo!(),
DynArray::String(_) => todo!(),
}
}
fn convert_arrayd_to_array2_f64<T: NumericOps>(arrayd: &ArrayD<T>) -> anyhow::Result<Array2<f64>> {
let shape = arrayd.shape();
match shape.len() {
1 => Err(anyhow!("The ArrayD must have at least two dimensions!")),
2 => Ok(arrayd
.mapv(|x| NumCast::from(x).unwrap_or_else(f64::zero))
.into_dimensionality::<Ix2>()?),
_ => {
let rows = shape[0];
let cols = shape[1..].iter().product();
let flat_data: Vec<f64> = arrayd
.iter()
.map(|&x| NumCast::from(x).unwrap_or_else(f64::zero))
.collect();
let data = Array2::from_shape_vec((rows, cols), flat_data)?;
Ok(data)
}
}
}
fn convert_to_array_f64_csc<T: NumericOps>(
csc: &CscMatrix<T>,
shape: Shape,
) -> anyhow::Result<Array2<f64>> {
let mut dense = Array2::<f64>::zeros((shape[0], shape[1]));
for (col, vec) in csc.col_iter().enumerate() {
for (&row, val) in vec.row_indices().iter().zip(csc.values()) {
dense[[row, col]] = NumCast::from(*val).unwrap();
}
}
Ok(dense)
}
fn convert_to_array_f64_csr<T: NumericOps>(
csr: &CsrMatrix<T>,
shape: Shape,
) -> anyhow::Result<Array2<f64>> {
let mut dense = Array2::<f64>::zeros((shape[0], shape[1]));
for (row, vec) in csr.row_iter().enumerate() {
for (&col, val) in vec.col_indices().iter().zip(csr.values()) {
dense[[row, col]] = NumCast::from(*val).unwrap();
}
}
Ok(dense)
}
fn convert_to_array_f64_csr_selected<T: NumericOps>(
csr: &CsrMatrix<T>,
shape: Shape,
row_selection: &SelectInfoElem,
col_selection: &SelectInfoElem,
) -> anyhow::Result<Array2<f64>> {
let row_indices = select_info_elem_to_indices(row_selection, shape[0])?;
let col_indices = select_info_elem_to_indices(col_selection, shape[1])?;
let mut dense = Array2::<f64>::zeros((row_indices.len(), col_indices.len()));
let col_map: HashMap<usize, usize> = col_indices
.iter()
.enumerate()
.map(|(i, &col)| (col, i))
.collect();
for (out_row, &row) in row_indices.iter().enumerate() {
if row < csr.nrows() {
let row_start = csr.row_offsets()[row];
let row_end = csr.row_offsets()[row + 1];
for (&col, &value) in csr.col_indices()[row_start..row_end]
.iter()
.zip(csr.values()[row_start..row_end].iter())
{
if let Some(&out_col) = col_map.get(&col) {
dense[[out_row, out_col]] = NumCast::from(value)
.ok_or_else(|| anyhow!("Failed to convert value to f64"))?;
}
}
}
}
Ok(dense)
}
fn convert_to_array_f64_csc_selected<T: NumericOps>(
csc: &CscMatrix<T>,
shape: Shape,
row_selection: &SelectInfoElem,
col_selection: &SelectInfoElem,
) -> anyhow::Result<Array2<f64>> {
let row_indices = select_info_elem_to_indices(row_selection, shape[0])?;
let col_indices = select_info_elem_to_indices(col_selection, shape[1])?;
let mut dense = Array2::<f64>::zeros((row_indices.len(), col_indices.len()));
let row_map: HashMap<usize, usize> = row_indices
.iter()
.enumerate()
.map(|(i, &row)| (row, i))
.collect();
for (out_col, &col) in col_indices.iter().enumerate() {
if col < csc.ncols() {
let col_start = csc.col_offsets()[col];
let col_end = csc.col_offsets()[col + 1];
for (&row, &value) in csc.row_indices()[col_start..col_end]
.iter()
.zip(csc.values()[col_start..col_end].iter())
{
if let Some(&out_row) = row_map.get(&row) {
dense[[out_row, out_col]] = NumCast::from(value)
.ok_or_else(|| anyhow!("Failed to convert value to f64"))?;
}
}
}
}
Ok(dense)
}
pub fn convert_to_array_f64_selected(
data: &ArrayData,
shape: Shape,
row_selection: &SelectInfoElem,
col_selection: &SelectInfoElem,
) -> anyhow::Result<Array2<f64>> {
match data {
ArrayData::CscMatrix(csc) => match_dyn_csc_matrix!(
csc,
convert_to_array_f64_csc_selected,
shape,
row_selection,
col_selection
),
ArrayData::CsrMatrix(csr) => match_dyn_csr_matrix!(
csr,
convert_to_array_f64_csr_selected,
shape,
row_selection,
col_selection
),
_ => anyhow::bail!("Unsupported data type for conversion to Array2<f64>"),
}
}
pub fn need_conversion_target_float_type(scalar_type: &ScalarType) -> anyhow::Result<bool> {
match scalar_type {
ScalarType::I8 => Ok(true),
ScalarType::I16 => Ok(true),
ScalarType::I32 => Ok(true),
ScalarType::I64 => Ok(true),
ScalarType::U8 => Ok(true),
ScalarType::U16 => Ok(true),
ScalarType::U32 => Ok(true),
ScalarType::U64 => Ok(true),
ScalarType::F32 => Ok(false),
ScalarType::F64 => Ok(false),
ScalarType::Bool => {
bail!("Cannot use a Scalar of type <Bool> in the normalization procedure.")
}
ScalarType::String => {
bail!("Cannot use a Scalar of type <String> in the normalization procedure.")
}
}
}
#[derive(Default, Debug, Clone, Copy)]
pub enum Precision {
#[default]
Single,
Double,
}