use crate::core::error::{RedicatError, Result};
use crate::core::sparse::SparseOps;
use anndata::data::array::dataframe::DataFrameIndex;
use anndata::{
data::*,
traits::{AnnDataOp, AxisArraysOp},
AnnData, Backend,
};
use anndata_hdf5::H5;
use log::{debug, info, warn};
use nalgebra_sparse::CsrMatrix;
use polars::prelude::*;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tempfile::TempDir;
#[derive(Debug, Clone)]
pub struct MemoryBudget {
pub limit_bytes: u64,
}
impl MemoryBudget {
pub fn new(limit_bytes: u64) -> Self {
Self { limit_bytes }
}
pub fn default_budget() -> Self {
Self::new(128 * 1024 * 1024 * 1024)
}
}
#[derive(Debug, Clone)]
pub struct AnnDataContainer {
pub obs: DataFrame,
pub var: DataFrame,
pub x: Option<CsrMatrix<f64>>,
pub layers: HashMap<String, CsrMatrix<u32>>,
pub n_obs: usize,
pub n_vars: usize,
pub var_names: Vec<String>,
pub obs_names: Vec<String>,
spilled_layers: HashMap<String, PathBuf>,
spill_dir: Option<Arc<TempDir>>,
memory_budget: Option<MemoryBudget>,
}
impl AnnDataContainer {
pub fn new(n_obs: usize, n_vars: usize) -> Self {
let obs_names: Vec<String> = (0..n_obs).map(|i| format!("cell_{}", i)).collect();
let var_names: Vec<String> = (0..n_vars).map(|i| format!("gene_{}", i)).collect();
let obs = DataFrame::new(vec![
Series::new("obs_names".into(), obs_names.clone()).into()
])
.unwrap();
let var = DataFrame::new(vec![
Series::new("var_names".into(), var_names.clone()).into()
])
.unwrap();
Self {
obs,
var,
x: None,
layers: HashMap::new(),
n_obs,
n_vars,
var_names,
obs_names,
spilled_layers: HashMap::new(),
spill_dir: None,
memory_budget: None,
}
}
pub fn set_memory_budget(&mut self, budget: MemoryBudget) {
info!("Memory budget set to {} bytes", budget.limit_bytes);
self.memory_budget = Some(budget);
}
pub fn resident_layer_bytes(&self) -> usize {
self.layers
.values()
.map(|m| SparseOps::estimate_csr_bytes(m))
.sum()
}
pub fn spill_layer(&mut self, layer_name: &str) -> Result<bool> {
let matrix = match self.layers.remove(layer_name) {
Some(m) => m,
None => return Ok(false),
};
if self.spill_dir.is_none() {
self.spill_dir = Some(Arc::new(
TempDir::new()
.map_err(|e| RedicatError::Io(e))?,
));
}
let dir = self.spill_dir.as_ref().unwrap();
let path = dir.path().join(format!("{}.csr", layer_name));
SparseOps::spill_to_file(&matrix, &path)?;
let bytes = SparseOps::estimate_csr_bytes(&matrix);
info!("Spilled layer '{}' ({} KB) to {:?}", layer_name, bytes / 1024, path);
self.spilled_layers.insert(layer_name.to_string(), path);
Ok(true)
}
pub fn load_layer(&mut self, layer_name: &str) -> Result<()> {
if self.layers.contains_key(layer_name) {
return Ok(());
}
let path = match self.spilled_layers.remove(layer_name) {
Some(p) => p,
None => return Err(RedicatError::DataProcessing(
format!("Layer '{}' is neither resident nor spilled", layer_name),
)),
};
let matrix = SparseOps::load_from_file(&path)?;
info!("Loaded spilled layer '{}' ({} KB) from {:?}",
layer_name, SparseOps::estimate_csr_bytes(&matrix) / 1024, path);
self.layers.insert(layer_name.to_string(), matrix);
let _ = std::fs::remove_file(&path);
Ok(())
}
pub fn auto_spill_if_needed(&mut self, keep: &[&str]) -> Result<()> {
let budget = match &self.memory_budget {
Some(b) => b.limit_bytes,
None => return Ok(()),
};
loop {
let resident = self.resident_layer_bytes() as u64;
if resident <= budget {
break;
}
let candidate = self
.layers
.iter()
.filter(|(name, _)| !keep.iter().any(|&k| k == name.as_str()))
.max_by_key(|(_, m)| SparseOps::estimate_csr_bytes(m))
.map(|(name, _)| name.clone());
match candidate {
Some(name) => {
self.spill_layer(&name)?;
}
None => {
warn!(
"Memory budget exceeded ({} > {}) but no spillable layers remain",
resident, budget
);
break;
}
}
}
Ok(())
}
pub fn compute_layer_row_sums(&self, layer_name: &str) -> Option<Vec<u32>> {
match self.layers.get(layer_name) {
Some(matrix) => {
debug!("Computing row sums for layer: {}", layer_name);
Some(SparseOps::compute_row_sums(matrix))
}
None => {
warn!(
"Layer '{}' not found. Available layers: {:?}",
layer_name,
self.layers.keys().collect::<Vec<_>>()
);
None
}
}
}
pub fn compute_layer_col_sums(&self, layer_name: &str) -> Option<Vec<u32>> {
match self.layers.get(layer_name) {
Some(matrix) => {
debug!("Computing column sums for layer: {}", layer_name);
Some(SparseOps::compute_col_sums(matrix))
}
None => {
warn!(
"Layer '{}' not found. Available layers: {:?}",
layer_name,
self.layers.keys().collect::<Vec<_>>()
);
None
}
}
}
pub fn compute_total_coverage(&self) -> Vec<u32> {
let layer_names = if self.layers.contains_key("A0") {
vec!["A0", "T0", "G0", "C0", "A1", "T1", "G1", "C1"]
} else {
vec!["A1", "T1", "G1", "C1"]
};
debug!("Computing total coverage using layers: {:?}", layer_names);
let matrices: Vec<&CsrMatrix<u32>> = layer_names
.iter()
.filter_map(|&name| self.layers.get(name))
.collect();
if matrices.is_empty() {
warn!("No base matrices found for coverage calculation");
return vec![0; self.n_obs];
}
let total_matrix =
matrices
.into_iter()
.fold(None, |acc: Option<CsrMatrix<u32>>, matrix| match acc {
None => Some(matrix.clone()),
Some(existing) => SparseOps::add_matrices(&existing, matrix)
.map_err(|e| warn!("Failed to add matrices: {}", e))
.ok()
.or(Some(existing)),
});
match total_matrix {
Some(matrix) => SparseOps::compute_row_sums(&matrix),
None => {
warn!("Failed to compute total coverage matrix");
vec![0; self.n_obs]
}
}
}
pub fn validate_dimensions(&self) -> Result<()> {
if self.obs_names.len() != self.n_obs {
return Err(RedicatError::DimensionMismatch {
expected: format!("obs_names length = {}", self.n_obs),
actual: format!("obs_names length = {}", self.obs_names.len()),
});
}
if self.var_names.len() != self.n_vars {
return Err(RedicatError::DimensionMismatch {
expected: format!("var_names length = {}", self.n_vars),
actual: format!("var_names length = {}", self.var_names.len()),
});
}
if !self.obs.is_empty() && self.obs.height() != self.n_obs {
warn!(
"obs DataFrame height ({}) doesn't match n_obs ({}), will use obs_names",
self.obs.height(),
self.n_obs
);
}
if !self.var.is_empty() && self.var.height() != self.n_vars {
warn!(
"var DataFrame height ({}) doesn't match n_vars ({}), will use var_names",
self.var.height(),
self.n_vars
);
}
if let Some(ref x_matrix) = self.x {
if x_matrix.nrows() != self.n_obs || x_matrix.ncols() != self.n_vars {
return Err(RedicatError::DimensionMismatch {
expected: format!("X matrix {}×{}", self.n_obs, self.n_vars),
actual: format!("X matrix {}×{}", x_matrix.nrows(), x_matrix.ncols()),
});
}
}
for (layer_name, matrix) in &self.layers {
if matrix.nrows() != self.n_obs || matrix.ncols() != self.n_vars {
return Err(RedicatError::DimensionMismatch {
expected: format!("Layer '{}' {}×{}", layer_name, self.n_obs, self.n_vars),
actual: format!(
"Layer '{}' {}×{}",
layer_name,
matrix.nrows(),
matrix.ncols()
),
});
}
}
Ok(())
}
pub fn fix_dataframe_dimensions(&mut self) -> Result<()> {
if self.obs.is_empty() || self.obs.height() != self.n_obs {
info!(
"Reconstructing obs DataFrame with {} observations",
self.n_obs
);
self.obs = DataFrame::new(vec![Series::new(
"obs_names".into(),
self.obs_names.clone(),
)
.into()])?;
}
if self.var.is_empty() || self.var.height() != self.n_vars {
info!(
"Reconstructing var DataFrame with {} variables",
self.n_vars
);
self.var = DataFrame::new(vec![Series::new(
"var_names".into(),
self.var_names.clone(),
)
.into()])?;
}
Ok(())
}
pub fn get_memory_stats(&self) -> HashMap<String, usize> {
let mut stats = HashMap::new();
stats.insert("obs_bytes".to_string(), estimate_dataframe_size(&self.obs));
stats.insert("var_bytes".to_string(), estimate_dataframe_size(&self.var));
if let Some(ref x_matrix) = self.x {
stats.insert(
"x_bytes".to_string(),
estimate_csr_matrix_size_f64(x_matrix),
);
}
let mut total_layer_bytes = 0;
for (layer_name, matrix) in &self.layers {
let size = estimate_csr_matrix_size_u32(matrix);
stats.insert(format!("layer_{}_bytes", layer_name), size);
total_layer_bytes += size;
}
stats.insert("total_layer_bytes".to_string(), total_layer_bytes);
stats
}
pub fn optimize_memory(&mut self) {
let mut layers_to_remove = Vec::new();
for (layer_name, matrix) in &self.layers {
if matrix.nnz() == 0 {
layers_to_remove.push(layer_name.clone());
}
}
for layer_name in layers_to_remove {
info!("Removing empty layer: {}", layer_name);
self.layers.remove(&layer_name);
}
}
}
pub fn write_anndata_h5ad(adata: &mut AnnDataContainer, path: &str) -> Result<()> {
info!("Writing AnnData to: {}", path);
adata.fix_dataframe_dimensions()?;
adata.validate_dimensions()?;
let stats = adata.get_memory_stats();
info!(
"Memory usage: obs={} KB, var={} KB, layers={} KB",
stats.get("obs_bytes").unwrap_or(&0) / 1024,
stats.get("var_bytes").unwrap_or(&0) / 1024,
stats.get("total_layer_bytes").unwrap_or(&0) / 1024
);
let h5_adata = AnnData::<H5>::new(Path::new(path))?;
let obs_index: DataFrameIndex = adata.obs_names.iter().cloned().collect();
let var_index: DataFrameIndex = adata.var_names.iter().cloned().collect();
h5_adata.set_obs_names(obs_index)?;
h5_adata.set_var_names(var_index)?;
if let Some(ref x_matrix) = adata.x {
let x_f32 = convert_f64_to_f32_csr(x_matrix)?;
h5_adata.set_x(x_f32)?;
info!(
" - Written X matrix: {}×{} with {} non-zeros",
x_matrix.nrows(),
x_matrix.ncols(),
x_matrix.nnz()
);
} else {
let zero_matrix = CsrMatrix::<f32>::zeros(adata.n_obs, adata.n_vars);
h5_adata.set_x(zero_matrix)?;
info!(
" - Written empty X matrix: {}×{}",
adata.n_obs, adata.n_vars
);
}
let priority_layers = ["ref", "alt", "others", "coverage"];
let mut written_layers = 0;
for layer_name in &priority_layers {
if let Some(layer_matrix) = adata.layers.get(*layer_name) {
info!(
" - Writing layer: {} ({}×{}, {} non-zeros)",
layer_name,
layer_matrix.nrows(),
layer_matrix.ncols(),
layer_matrix.nnz()
);
let f32_matrix = convert_u32_to_f32_csr(layer_matrix)?;
h5_adata.layers().add(layer_name, f32_matrix)?;
written_layers += 1;
}
}
if !adata.obs.is_empty() {
h5_adata.set_obs(adata.obs.clone())?;
info!(
" - Written obs annotations: {} rows, {} columns",
adata.obs.height(),
adata.obs.width()
);
}
if !adata.var.is_empty() {
h5_adata.set_var(adata.var.clone())?;
info!(
" - Written var annotations: {} rows, {} columns",
adata.var.height(),
adata.var.width()
);
}
h5_adata.set_n_obs(adata.n_obs)?;
h5_adata.set_n_vars(adata.n_vars)?;
info!(
"Successfully wrote AnnData with shape: {} × {}, {} layers",
adata.n_obs, adata.n_vars, written_layers
);
Ok(())
}
pub fn read_anndata_h5ad(path: &str) -> Result<AnnDataContainer> {
info!("Reading H5AD file: {}", path);
if !std::path::Path::new(path).exists() {
return Err(RedicatError::FileNotFound(format!(
"File not found: {}",
path
)));
}
let adata =
AnnData::<H5>::open(H5::open(path).map_err(|e| {
RedicatError::DataProcessing(format!("Failed to open H5 file: {:?}", e))
})?)?;
let n_obs = adata.n_obs();
let n_vars = adata.n_vars();
info!("AnnData shape: {} obs × {} vars", n_obs, n_vars);
if n_obs == 0 || n_vars == 0 {
return Err(RedicatError::EmptyData(format!(
"Empty AnnData: {} obs × {} vars",
n_obs, n_vars
)));
}
let obs_names = read_names(&adata.obs_names())?;
let var_names = read_names(&adata.var_names())?;
if obs_names.len() != n_obs {
return Err(RedicatError::DimensionMismatch {
expected: format!("obs_names length = {}", n_obs),
actual: format!("obs_names length = {}", obs_names.len()),
});
}
if var_names.len() != n_vars {
return Err(RedicatError::DimensionMismatch {
expected: format!("var_names length = {}", n_vars),
actual: format!("var_names length = {}", var_names.len()),
});
}
let obs = read_obs_dataframe(&adata, &obs_names, n_obs)?;
let var = read_var_dataframe(&adata, &var_names, n_vars)?;
let x = read_x_matrix(&adata)?;
let layers = read_layers_as_u32(&adata)?;
info!(
"Successfully loaded AnnData with {} layers: {:?}",
layers.len(),
layers.keys().collect::<Vec<_>>()
);
let mut container = AnnDataContainer {
obs,
var,
x,
layers,
n_obs,
n_vars,
var_names,
obs_names,
spilled_layers: HashMap::new(),
spill_dir: None,
memory_budget: None,
};
container.fix_dataframe_dimensions()?;
container.validate_dimensions()?;
Ok(container)
}
fn read_names(index: &DataFrameIndex) -> Result<Vec<String>> {
Ok(index.clone().into_vec())
}
fn read_obs_dataframe(
adata: &AnnData<H5>,
obs_names: &[String],
n_obs: usize,
) -> Result<DataFrame> {
match adata.read_obs() {
Ok(obs_df) => {
debug!(
"Read obs DataFrame: {} rows, {} columns",
obs_df.height(),
obs_df.width()
);
if obs_df.height() == n_obs {
Ok(obs_df)
} else {
warn!(
"obs DataFrame height ({}) doesn't match n_obs ({}), creating from names",
obs_df.height(),
n_obs
);
DataFrame::new(vec![
Series::new("obs_names".into(), obs_names.to_vec()).into()
])
.map_err(|e| {
RedicatError::DataProcessing(format!("Failed to create obs DataFrame: {}", e))
})
}
}
Err(e) => {
warn!("Failed to read obs DataFrame: {:?}, creating from names", e);
DataFrame::new(vec![
Series::new("obs_names".into(), obs_names.to_vec()).into()
])
.map_err(|e| {
RedicatError::DataProcessing(format!("Failed to create obs DataFrame: {}", e))
})
}
}
}
fn read_var_dataframe(
adata: &AnnData<H5>,
var_names: &[String],
n_vars: usize,
) -> Result<DataFrame> {
match adata.read_var() {
Ok(var_df) => {
debug!(
"Read var DataFrame: {} rows, {} columns",
var_df.height(),
var_df.width()
);
if var_df.height() == n_vars {
Ok(var_df)
} else {
warn!(
"var DataFrame height ({}) doesn't match n_vars ({}), creating from names",
var_df.height(),
n_vars
);
DataFrame::new(vec![
Series::new("var_names".into(), var_names.to_vec()).into()
])
.map_err(|e| {
RedicatError::DataProcessing(format!("Failed to create var DataFrame: {}", e))
})
}
}
Err(e) => {
warn!("Failed to read var DataFrame: {:?}, creating from names", e);
DataFrame::new(vec![
Series::new("var_names".into(), var_names.to_vec()).into()
])
.map_err(|e| {
RedicatError::DataProcessing(format!("Failed to create var DataFrame: {}", e))
})
}
}
}
fn read_x_matrix(adata: &AnnData<H5>) -> Result<Option<CsrMatrix<f64>>> {
let mut x_elem = match adata.x().extract() {
Some(elem) => elem,
None => {
debug!("No X matrix found");
return Ok(None);
}
};
let shape = x_elem.shape();
if shape.ndim() == 0 || shape.as_ref().contains(&0) {
debug!("Empty X matrix shape: {:?}", shape.as_ref());
return Ok(None);
}
match x_elem.data() {
Ok(array_data) => match convert_array_to_csr_f64(array_data) {
Ok(matrix) => {
info!(
"Read X matrix: {}×{} with {} non-zeros",
matrix.nrows(),
matrix.ncols(),
matrix.nnz()
);
Ok(Some(matrix))
}
Err(e) => {
warn!("Failed to convert X matrix: {}", e);
Ok(None)
}
},
Err(e) => {
warn!("Failed to extract X matrix data: {:?}", e);
Ok(None)
}
}
}
fn read_layers_as_u32(adata: &AnnData<H5>) -> Result<HashMap<String, CsrMatrix<u32>>> {
let mut layers: HashMap<String, CsrMatrix<u32>> = HashMap::new();
let layers_ref = adata.layers();
let common_layer_names = vec![
"A0", "T0", "G0", "C0", "A1", "T1", "G1", "C1", "ref", "alt", "others", "coverage",
];
info!("Attempting to load common layers: {:?}", common_layer_names);
for layer_name in common_layer_names {
match layers_ref.get_item::<ArrayData>(layer_name) {
Ok(Some(array_data)) => match convert_array_to_csr_u32(array_data) {
Ok(matrix) => {
info!(
" - Loaded layer '{}': {}×{} with {} non-zeros",
layer_name,
matrix.nrows(),
matrix.ncols(),
matrix.nnz()
);
layers.insert(layer_name.to_string(), matrix);
}
Err(e) => {
warn!(" - Failed to convert layer '{}': {}", layer_name, e);
}
},
Ok(None) => {
debug!(" - Layer '{}' not found (normal)", layer_name);
}
Err(_) => {
debug!(
" - Could not access layer '{}' (normal if it doesn't exist)",
layer_name
);
}
}
}
info!("Successfully loaded {} layers", layers.len());
Ok(layers)
}
fn convert_f64_to_f32_csr(matrix: &CsrMatrix<f64>) -> Result<CsrMatrix<f32>> {
let (row_offsets, col_indices, values) = matrix.csr_data();
let values_f32: Vec<f32> = values.iter().map(|&x| x as f32).collect();
CsrMatrix::try_from_csr_data(
matrix.nrows(),
matrix.ncols(),
row_offsets.to_vec(),
col_indices.to_vec(),
values_f32,
)
.map_err(|e| RedicatError::DataProcessing(format!("Failed to convert f64 to f32: {:?}", e)))
}
fn convert_f32_to_f64_csr(matrix: &CsrMatrix<f32>) -> Result<CsrMatrix<f64>> {
let (row_offsets, col_indices, values) = matrix.csr_data();
let values_f64: Vec<f64> = values.iter().map(|&x| x as f64).collect();
CsrMatrix::try_from_csr_data(
matrix.nrows(),
matrix.ncols(),
row_offsets.to_vec(),
col_indices.to_vec(),
values_f64,
)
.map_err(|e| RedicatError::DataProcessing(format!("Failed to convert f32 to f64: {:?}", e)))
}
fn convert_u32_to_f32_csr(matrix: &CsrMatrix<u32>) -> Result<CsrMatrix<f32>> {
let (row_offsets, col_indices, values) = matrix.csr_data();
let values_f32: Vec<f32> = values.iter().map(|&x| x as f32).collect();
CsrMatrix::try_from_csr_data(
matrix.nrows(),
matrix.ncols(),
row_offsets.to_vec(),
col_indices.to_vec(),
values_f32,
)
.map_err(|e| RedicatError::DataProcessing(format!("Failed to convert u32 to f32: {:?}", e)))
}
fn convert_u32_to_f64_csr(matrix: &CsrMatrix<u32>) -> Result<CsrMatrix<f64>> {
let (row_offsets, col_indices, values) = matrix.csr_data();
let values_f64: Vec<f64> = values.iter().map(|&x| x as f64).collect();
CsrMatrix::try_from_csr_data(
matrix.nrows(),
matrix.ncols(),
row_offsets.to_vec(),
col_indices.to_vec(),
values_f64,
)
.map_err(|e| RedicatError::DataProcessing(format!("Failed to convert u32 to f64: {:?}", e)))
}
fn convert_array_to_csr_f64(array_data: ArrayData) -> Result<CsrMatrix<f64>> {
if let Ok(matrix) = CsrMatrix::<f64>::try_from(array_data.clone()) {
return Ok(matrix);
}
if let Ok(matrix_f32) = CsrMatrix::<f32>::try_from(array_data.clone()) {
return convert_f32_to_f64_csr(&matrix_f32);
}
if let Ok(matrix_u32) = CsrMatrix::<u32>::try_from(array_data.clone()) {
return convert_u32_to_f64_csr(&matrix_u32);
}
Err(RedicatError::DataProcessing(format!(
"Unsupported array data type for X matrix: {:?}",
array_data.data_type()
)))
}
fn convert_array_to_csr_u32(array_data: ArrayData) -> Result<CsrMatrix<u32>> {
if let Ok(matrix) = CsrMatrix::<u32>::try_from(array_data.clone()) {
return Ok(matrix);
}
if let Ok(matrix_f32) = CsrMatrix::<f32>::try_from(array_data.clone()) {
let (row_offsets, col_indices, values) = matrix_f32.csr_data();
let values_u32: Vec<u32> = values.iter().map(|&x| x as u32).collect();
return CsrMatrix::try_from_csr_data(
matrix_f32.nrows(),
matrix_f32.ncols(),
row_offsets.to_vec(),
col_indices.to_vec(),
values_u32,
)
.map_err(|e| {
RedicatError::DataProcessing(format!("Failed to convert f32 to u32: {:?}", e))
});
}
if let Ok(matrix_f64) = CsrMatrix::<f64>::try_from(array_data.clone()) {
let (row_offsets, col_indices, values) = matrix_f64.csr_data();
let values_u32: Vec<u32> = values.iter().map(|&x| x as u32).collect();
return CsrMatrix::try_from_csr_data(
matrix_f64.nrows(),
matrix_f64.ncols(),
row_offsets.to_vec(),
col_indices.to_vec(),
values_u32,
)
.map_err(|e| {
RedicatError::DataProcessing(format!("Failed to convert f64 to u32: {:?}", e))
});
}
Err(RedicatError::DataProcessing(format!(
"Unsupported array data type for layer: {:?}",
array_data.data_type()
)))
}
fn estimate_dataframe_size(df: &DataFrame) -> usize {
df.get_columns()
.iter()
.map(|column| column.as_materialized_series().estimated_size())
.sum()
}
fn estimate_csr_matrix_size_f64(matrix: &CsrMatrix<f64>) -> usize {
let (row_offsets, col_indices, values) = matrix.csr_data();
std::mem::size_of_val(row_offsets)
+ std::mem::size_of_val(col_indices)
+ std::mem::size_of_val(values)
}
fn estimate_csr_matrix_size_u32(matrix: &CsrMatrix<u32>) -> usize {
let (row_offsets, col_indices, values) = matrix.csr_data();
std::mem::size_of_val(row_offsets)
+ std::mem::size_of_val(col_indices)
+ std::mem::size_of_val(values)
}
pub fn estimate_anndata_memory_usage(adata: &AnnDataContainer) -> usize {
let mut total = 0;
total += estimate_dataframe_size(&adata.obs);
total += estimate_dataframe_size(&adata.var);
if let Some(ref x) = adata.x {
total += estimate_csr_matrix_size_f64(x);
}
for matrix in adata.layers.values() {
total += estimate_csr_matrix_size_u32(matrix);
}
total
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::sparse::SparseOps;
fn csr_u32(nrows: usize, ncols: usize, triplets: &[(usize, usize, u32)]) -> CsrMatrix<u32> {
SparseOps::from_triplets_u32(nrows, ncols, triplets.to_vec()).unwrap()
}
#[test]
fn test_new_creates_correct_dimensions() {
let adata = AnnDataContainer::new(10, 20);
assert_eq!(adata.n_obs, 10);
assert_eq!(adata.n_vars, 20);
assert_eq!(adata.obs_names.len(), 10);
assert_eq!(adata.var_names.len(), 20);
assert!(adata.x.is_none());
assert!(adata.layers.is_empty());
assert_eq!(adata.obs.height(), 10);
assert_eq!(adata.var.height(), 20);
}
#[test]
fn test_new_zero_dimensions() {
let adata = AnnDataContainer::new(0, 0);
assert_eq!(adata.n_obs, 0);
assert_eq!(adata.n_vars, 0);
assert!(adata.obs_names.is_empty());
assert!(adata.var_names.is_empty());
}
#[test]
fn test_validate_dimensions_valid() {
let adata = AnnDataContainer::new(5, 10);
assert!(adata.validate_dimensions().is_ok());
}
#[test]
fn test_validate_dimensions_obs_names_mismatch() {
let mut adata = AnnDataContainer::new(5, 10);
adata.obs_names.push("extra".to_string());
assert!(adata.validate_dimensions().is_err());
}
#[test]
fn test_validate_dimensions_var_names_mismatch() {
let mut adata = AnnDataContainer::new(5, 10);
adata.var_names.pop();
assert!(adata.validate_dimensions().is_err());
}
#[test]
fn test_validate_dimensions_x_matrix_wrong_shape() {
let mut adata = AnnDataContainer::new(3, 4);
adata.x = Some(CsrMatrix::<f64>::zeros(3, 5)); assert!(adata.validate_dimensions().is_err());
}
#[test]
fn test_validate_dimensions_layer_wrong_shape() {
let mut adata = AnnDataContainer::new(3, 4);
adata.layers.insert("bad".into(), CsrMatrix::<u32>::zeros(2, 4));
assert!(adata.validate_dimensions().is_err());
}
#[test]
fn test_fix_dataframe_dimensions_rebuilds_when_empty() {
let mut adata = AnnDataContainer::new(3, 2);
adata.obs = DataFrame::default();
adata.var = DataFrame::default();
adata.fix_dataframe_dimensions().unwrap();
assert_eq!(adata.obs.height(), 3);
assert_eq!(adata.var.height(), 2);
}
#[test]
fn test_fix_dataframe_dimensions_keeps_correct() {
let mut adata = AnnDataContainer::new(2, 3);
let obs_before = adata.obs.clone();
adata.fix_dataframe_dimensions().unwrap();
assert_eq!(adata.obs.height(), obs_before.height());
}
#[test]
fn test_compute_layer_row_sums_existing_layer() {
let mut adata = AnnDataContainer::new(3, 2);
adata.layers.insert("A1".into(), csr_u32(3, 2, &[
(0, 0, 1), (0, 1, 2), (1, 0, 3), (2, 1, 4),
]));
let sums = adata.compute_layer_row_sums("A1").unwrap();
assert_eq!(sums, vec![3, 3, 4]);
}
#[test]
fn test_compute_layer_row_sums_missing_layer() {
let adata = AnnDataContainer::new(2, 2);
assert!(adata.compute_layer_row_sums("nonexistent").is_none());
}
#[test]
fn test_compute_layer_col_sums_existing_layer() {
let mut adata = AnnDataContainer::new(3, 2);
adata.layers.insert("G1".into(), csr_u32(3, 2, &[
(0, 0, 10), (1, 0, 20), (2, 1, 5),
]));
let sums = adata.compute_layer_col_sums("G1").unwrap();
assert_eq!(sums, vec![30, 5]);
}
#[test]
fn test_compute_layer_col_sums_missing_layer() {
let adata = AnnDataContainer::new(2, 2);
assert!(adata.compute_layer_col_sums("nonexistent").is_none());
}
#[test]
fn test_compute_total_coverage_stranded() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("A0".into(), csr_u32(2, 2, &[(0, 0, 1)]));
adata.layers.insert("A1".into(), csr_u32(2, 2, &[(0, 0, 2)]));
adata.layers.insert("T0".into(), csr_u32(2, 2, &[(1, 1, 3)]));
adata.layers.insert("T1".into(), csr_u32(2, 2, &[(1, 1, 4)]));
let cov = adata.compute_total_coverage();
assert_eq!(cov, vec![3, 7]);
}
#[test]
fn test_compute_total_coverage_unstranded() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("A1".into(), csr_u32(2, 2, &[(0, 0, 5)]));
adata.layers.insert("T1".into(), csr_u32(2, 2, &[(1, 1, 3)]));
adata.layers.insert("G1".into(), CsrMatrix::<u32>::zeros(2, 2));
adata.layers.insert("C1".into(), CsrMatrix::<u32>::zeros(2, 2));
let cov = adata.compute_total_coverage();
assert_eq!(cov, vec![5, 3]);
}
#[test]
fn test_compute_total_coverage_no_layers() {
let adata = AnnDataContainer::new(3, 2);
let cov = adata.compute_total_coverage();
assert_eq!(cov, vec![0, 0, 0]);
}
#[test]
fn test_get_memory_stats_includes_all_components() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("ref".into(), csr_u32(2, 2, &[(0, 0, 1)]));
adata.x = Some(CsrMatrix::<f64>::zeros(2, 2));
let stats = adata.get_memory_stats();
assert!(stats.contains_key("obs_bytes"));
assert!(stats.contains_key("var_bytes"));
assert!(stats.contains_key("x_bytes"));
assert!(stats.contains_key("layer_ref_bytes"));
assert!(stats.contains_key("total_layer_bytes"));
}
#[test]
fn test_get_memory_stats_no_x() {
let adata = AnnDataContainer::new(2, 2);
let stats = adata.get_memory_stats();
assert!(!stats.contains_key("x_bytes"));
}
#[test]
fn test_optimize_memory_removes_empty_layers() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("nonempty".into(), csr_u32(2, 2, &[(0, 0, 1)]));
adata.layers.insert("empty".into(), CsrMatrix::<u32>::zeros(2, 2));
assert_eq!(adata.layers.len(), 2);
adata.optimize_memory();
assert_eq!(adata.layers.len(), 1);
assert!(adata.layers.contains_key("nonempty"));
}
#[test]
fn test_optimize_memory_keeps_all_nonempty() {
let mut adata = AnnDataContainer::new(2, 2);
adata.layers.insert("a".into(), csr_u32(2, 2, &[(0, 0, 1)]));
adata.layers.insert("b".into(), csr_u32(2, 2, &[(1, 1, 2)]));
adata.optimize_memory();
assert_eq!(adata.layers.len(), 2);
}
#[test]
fn test_estimate_anndata_memory_usage_nonzero() {
let mut adata = AnnDataContainer::new(5, 5);
adata.layers.insert("A1".into(), csr_u32(5, 5, &[(0, 0, 1), (1, 1, 2)]));
let usage = estimate_anndata_memory_usage(&adata);
assert!(usage > 0);
}
#[test]
fn test_write_creates_valid_file() {
use tempfile::tempdir;
let dir = tempdir().unwrap();
let path = dir.path().join("test.h5ad");
let mut adata = AnnDataContainer::new(2, 3);
adata.layers.insert("ref".into(), csr_u32(2, 3, &[(0, 0, 1), (1, 2, 5)]));
adata.layers.insert("alt".into(), csr_u32(2, 3, &[(0, 1, 3)]));
adata.layers.insert("others".into(), CsrMatrix::<u32>::zeros(2, 3));
adata.layers.insert("coverage".into(), csr_u32(2, 3, &[(0, 0, 10), (1, 2, 20)]));
write_anndata_h5ad(&mut adata, path.to_str().unwrap()).unwrap();
assert!(path.exists());
let loaded = read_anndata_h5ad(path.to_str().unwrap()).unwrap();
assert_eq!(loaded.n_obs, 2);
assert_eq!(loaded.n_vars, 3);
}
#[test]
fn test_read_nonexistent_file_returns_error() {
let result = read_anndata_h5ad("/tmp/definitely_not_exist_redicat_test.h5ad");
assert!(result.is_err());
}
#[test]
fn test_convert_f64_to_f32_roundtrip() {
let m = CsrMatrix::try_from_csr_data(
2, 2, vec![0, 1, 2], vec![0, 1], vec![1.5_f64, 2.5_f64],
).unwrap();
let f32_m = convert_f64_to_f32_csr(&m).unwrap();
assert_eq!(f32_m.nrows(), 2);
assert_eq!(f32_m.ncols(), 2);
assert!((f32_m.csr_data().2[0] - 1.5f32).abs() < 1e-6);
}
#[test]
fn test_convert_u32_to_f32_preserves_values() {
let m = csr_u32(2, 2, &[(0, 0, 100), (1, 1, 200)]);
let f32_m = convert_u32_to_f32_csr(&m).unwrap();
assert_eq!(f32_m.csr_data().2[0] as u32, 100);
assert_eq!(f32_m.csr_data().2[1] as u32, 200);
}
}