use anndata::data::SelectInfoElem;
use log::{log, Level};
use ndarray::{ArrayView, Ix1};
use polars::prelude::DataType;
pub fn get_select_info_obs(
obs_mask: Option<ArrayView<'_, bool, Ix1>>,
) -> anyhow::Result<Vec<SelectInfoElem>> {
let obs_indices = obs_mask.map(|mask| {
mask.iter()
.enumerate()
.filter_map(|(i, &b)| if b { Some(i) } else { None })
.collect::<Vec<_>>()
});
let mut selection = vec![SelectInfoElem::full(), SelectInfoElem::full()];
if let Some(obs_idx) = obs_indices {
selection[0] = SelectInfoElem::Index(obs_idx);
}
Ok(selection)
}
pub fn get_select_info_vars(
vars_mask: Option<ArrayView<'_, bool, Ix1>>,
) -> anyhow::Result<Vec<SelectInfoElem>> {
let vars_indices = vars_mask.map(|mask| {
mask.iter()
.enumerate()
.filter_map(|(i, &b)| if b { Some(i) } else { None })
.collect::<Vec<_>>()
});
let mut selection = vec![SelectInfoElem::full(), SelectInfoElem::full()];
if let Some(vars_idx) = vars_indices {
log!(Level::Debug, "VarIDx length: {}", vars_idx.len());
selection[1] = SelectInfoElem::Index(vars_idx);
}
Ok(selection)
}
#[derive(Debug)]
pub enum FlavorType {
Seurat,
CellRanger,
SVR,
}
pub struct HVGParams {
pub min_mean: f64,
pub max_mean: f64,
pub min_dispersion: f64,
pub max_dispersion: f64,
pub n_bins: usize,
pub n_top_genes: Option<usize>,
pub flavor: FlavorType,
pub span: f64,
pub batch_key: Option<String>,
}
impl Default for HVGParams {
fn default() -> Self {
HVGParams {
min_mean: 0.0125,
max_mean: 3.0,
min_dispersion: 0.5,
max_dispersion: f64::INFINITY,
n_bins: 20,
n_top_genes: None,
flavor: FlavorType::Seurat,
span: 0.3,
batch_key: None,
}
}
}
pub fn standardize_log(x: f64, mu: f64, sigma: f64) -> f64 {
if sigma == 0.0 {
return 0.0;
}
(x - mu) / sigma
}
pub fn standardize_log_form_vec(vec: &[f64]) -> Vec<f64> {
let finite_values: Vec<f64> = vec.iter().filter(|x| x.is_finite()).copied().collect();
let n = finite_values.len() as f64;
if n <= 1.0 {
println!(
"DEBUG: Too few finite values ({}) for standardization, returning zeros",
n
);
return vec![0.0; vec.len()];
}
let mu: f64 = finite_values.iter().sum::<f64>() / n;
let sigma = (finite_values.iter().map(|x| (x - mu).powi(2)).sum::<f64>() / (n - 1.0)).sqrt();
let finite_count = vec.iter().filter(|x| x.is_finite()).count();
let nan_count = vec.iter().filter(|x| x.is_nan()).count();
let inf_count = vec.iter().filter(|x| x.is_infinite()).count();
println!(
"DEBUG Standardization: n={}, mu={:.6}, sigma={:.6}, finite={}, nan={}, inf={}",
vec.len(),
mu,
sigma,
finite_count,
nan_count,
inf_count
);
if !sigma.is_finite() || sigma == 0.0 {
println!("DEBUG: Sigma is problematic, returning zeros");
return vec![0.0; vec.len()];
}
vec.iter()
.map(|&x| {
if x.is_finite() {
standardize_log(x, mu, sigma)
} else {
0.0 }
})
.collect()
}
#[allow(dead_code)]
fn normalize_per_bin(
log_means: &[f64],
log_dispersions: &[f64],
n_bins: usize,
) -> anyhow::Result<Vec<f64>> {
let min_mean = log_means.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let max_mean = log_means.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let bin_width = (max_mean - min_mean) / n_bins as f64;
let mut norm_dispersions = vec![0.0; log_means.len()];
let mut bin_counts = vec![0; n_bins];
let mut bin_disp_means = vec![0.0; n_bins];
let mut bin_disp_stds = vec![0.0; n_bins];
for i in 0..log_means.len() {
let bin_idx = if log_means[i] == max_mean {
n_bins - 1
} else {
((log_means[i] - min_mean) / bin_width) as usize
};
bin_counts[bin_idx] += 1;
bin_disp_means[bin_idx] += log_dispersions[i];
}
for i in 0..n_bins {
if bin_counts[i] > 0 {
bin_disp_means[i] /= bin_counts[i] as f64;
}
}
for i in 0..log_means.len() {
let bin_idx = if log_means[i] == max_mean {
n_bins - 1
} else {
((log_means[i] - min_mean) / bin_width) as usize
};
if bin_counts[bin_idx] > 1 {
bin_disp_stds[bin_idx] += (log_dispersions[i] - bin_disp_means[bin_idx]).powi(2);
}
}
for i in 0..n_bins {
if bin_counts[i] > 1 {
bin_disp_stds[i] = (bin_disp_stds[i] / (bin_counts[i] - 1) as f64).sqrt();
}
}
for i in 0..log_means.len() {
let bin_idx = if log_means[i] == max_mean {
n_bins - 1
} else {
((log_means[i] - min_mean) / bin_width) as usize
};
if bin_disp_stds[bin_idx] > 0.0 {
norm_dispersions[i] =
(log_dispersions[i] - bin_disp_means[bin_idx]) / bin_disp_stds[bin_idx];
}
}
Ok(norm_dispersions)
}
pub fn _normalize_by_batch(
log_means: &[f64],
log_dispersions: &[f64],
batch_col: &polars::prelude::Column,
n_bins: usize,
) -> anyhow::Result<Vec<f64>> {
let mut unique_batches = Vec::new();
if let DataType::String = batch_col.dtype() {
unique_batches = batch_col.str()?.iter().flatten().collect::<Vec<&str>>();
}
let mut norm_dispersions = vec![0.0; log_means.len()];
for batch in unique_batches {
let batch_mask: Vec<bool> = batch_col
.str()?
.into_iter()
.map(|x| x == Some(batch))
.collect();
let batch_size = batch_mask.iter().filter(|&&x| x).count();
if batch_size > 0 {
let batch_norm = normalize_per_bin(
&log_means
.iter()
.zip(batch_mask.iter())
.filter(|(_, &mask)| mask)
.map(|(&x, _)| x)
.collect::<Vec<_>>(),
&log_dispersions
.iter()
.zip(batch_mask.iter())
.filter(|(_, &mask)| mask)
.map(|(&x, _)| x)
.collect::<Vec<_>>(),
n_bins,
)?;
let mut j = 0;
for i in 0..log_means.len() {
if batch_mask[i] {
norm_dispersions[i] = batch_norm[j];
j += 1;
}
}
}
}
Ok(norm_dispersions)
}
pub fn fit_svr(x: &[f64], y: &[f64]) -> anyhow::Result<(Vec<f64>, Vec<f64>)> {
let n = x.len();
if n == 0 {
return Ok((vec![], vec![]));
}
let x_min = x.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let x_max = x.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let x_norm: Vec<f64> = x.iter().map(|&xi| (xi - x_min) / (x_max - x_min)).collect();
let gamma = 1.0 / n as f64;
let mut kernel = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
let diff = x_norm[i] - x_norm[j];
kernel[j][i] = (-gamma * diff.powi(2)).exp();
}
}
let lambda = 1.0; let mut alpha = vec![0.0; n];
for _ in 0..100 {
for i in 0..n {
let mut sum = 0.0;
for j in 0..n {
if i != j {
sum += kernel[j][i] * alpha[j];
}
}
alpha[i] = (y[i] - sum) / (kernel[i][i] + lambda);
}
}
let mut y_pred = vec![0.0; n];
for i in 0..n {
#[allow(clippy::needless_range_loop)]
for j in 0..n {
y_pred[i] += alpha[j] * kernel[i][j];
}
}
let residuals: Vec<f64> = y
.iter()
.zip(y_pred.iter())
.map(|(&yi, &yp)| yi - yp)
.collect();
Ok((residuals, y_pred))
}
pub fn _get_mean_bins(
log_means: &[f64],
n_bins: usize,
) -> anyhow::Result<(Vec<usize>, Vec<usize>)> {
let mut sorted_means: Vec<(usize, f64)> = log_means
.iter()
.enumerate()
.map(|(i, &mean)| (i, mean))
.collect();
sorted_means.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let n_genes = log_means.len();
let genes_per_bin = n_genes / n_bins;
let remainder = n_genes % n_bins;
let mut bin_indices = vec![0; log_means.len()];
let mut mean_bins = vec![0; n_bins];
let mut current_gene_idx = 0;
(0..n_bins).for_each(|bin_idx| {
let genes_in_this_bin = if bin_idx < remainder {
genes_per_bin + 1
} else {
genes_per_bin
};
for _ in 0..genes_in_this_bin {
if current_gene_idx < sorted_means.len() {
let original_idx = sorted_means[current_gene_idx].0;
bin_indices[original_idx] = bin_idx;
current_gene_idx += 1;
}
}
mean_bins[bin_idx] = genes_in_this_bin;
});
Ok((mean_bins, bin_indices))
}
pub fn _calculate_dispersion_stats(
log_dispersions: &[f64],
bin_indices: &[usize],
mean_bins: &[usize],
) -> anyhow::Result<(Vec<f64>, Vec<f64>)> {
let n_bins = mean_bins.len();
let mut bin_means = vec![0.0; n_bins];
let mut bin_stds = vec![0.0; n_bins];
let mut bin_sums = vec![0.0; n_bins];
let mut bin_sum_squares = vec![0.0; n_bins];
for (i, &bin_idx) in bin_indices.iter().enumerate() {
let disp = log_dispersions[i];
if !disp.is_nan() {
bin_sums[bin_idx] += disp;
bin_sum_squares[bin_idx] += disp * disp;
}
}
for bin_idx in 0..n_bins {
let count = mean_bins[bin_idx] as f64;
if count > 0.0 {
bin_means[bin_idx] = bin_sums[bin_idx] / count;
if count > 1.0 {
let variance =
(bin_sum_squares[bin_idx] - bin_sums[bin_idx].powi(2) / count) / (count - 1.0);
let min_variance = 1e-12;
bin_stds[bin_idx] = (variance.max(min_variance)).sqrt();
} else {
bin_stds[bin_idx] = f64::NAN;
}
} else {
bin_means[bin_idx] = f64::NAN;
bin_stds[bin_idx] = f64::NAN;
}
}
Ok((bin_means, bin_stds))
}