use std::collections::HashMap;
use cyanea_core::{CyaneaError, Result};
use crate::single_cell::{AnnData, ColumnData, MatrixData};
use crate::sparse::SparseMatrix;
#[derive(Debug, Clone)]
pub struct DiffusionConfig {
pub n_components: usize,
pub alpha: f64,
}
impl Default for DiffusionConfig {
fn default() -> Self {
Self {
n_components: 15,
alpha: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct DiffusionResult {
pub eigenvalues: Vec<f64>,
pub components: Vec<Vec<f64>>,
}
pub fn diffusion_map(adata: &mut AnnData, config: &DiffusionConfig) -> Result<DiffusionResult> {
let conn = adata
.get_obsp("connectivities")
.ok_or_else(|| {
CyaneaError::InvalidInput(
"obsp['connectivities'] not found; run neighbors() first".into(),
)
})?
.clone();
let n = adata.n_obs();
let n_comps = config.n_components.min(n.saturating_sub(1)).max(1);
let w = conn.to_dense();
let d: Vec<f64> = w.iter().map(|row| row.iter().sum::<f64>().max(1e-15)).collect();
let mut w_prime = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
w_prime[i][j] = w[i][j] / (d[i].powf(config.alpha) * d[j].powf(config.alpha));
}
}
let d_prime: Vec<f64> = w_prime
.iter()
.map(|row| row.iter().sum::<f64>().max(1e-15))
.collect();
let mut transition = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
transition[i][j] = w_prime[i][j] / d_prime[i];
}
}
let mut eigenvalues = Vec::with_capacity(n_comps);
let mut eigenvectors: Vec<Vec<f64>> = Vec::with_capacity(n_comps);
let d_sqrt: Vec<f64> = d_prime.iter().map(|&di| di.sqrt()).collect();
let d_inv_sqrt: Vec<f64> = d_prime.iter().map(|&di| 1.0 / di.sqrt().max(1e-15)).collect();
let mut sym = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
sym[i][j] = d_sqrt[i] * transition[i][j] * d_inv_sqrt[j];
}
}
let mut deflated = sym.clone();
for comp in 0..n_comps + 1 {
let (eval, evec) = power_iteration(&deflated, n, 1000);
if comp == 0 {
deflate_matrix(&mut deflated, &evec, eval, n);
continue;
}
eigenvalues.push(eval);
let v: Vec<f64> = (0..n).map(|i| evec[i] * d_inv_sqrt[i]).collect();
eigenvectors.push(v);
if comp < n_comps {
deflate_matrix(&mut deflated, &evec, eval, n);
}
}
let mut pairs: Vec<(f64, Vec<f64>)> = eigenvalues
.into_iter()
.zip(eigenvectors.into_iter())
.collect();
pairs.sort_by(|a, b| {
b.0.abs()
.partial_cmp(&a.0.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let eigenvalues: Vec<f64> = pairs.iter().map(|(e, _)| *e).collect();
let eigenvectors: Vec<Vec<f64>> = pairs.into_iter().map(|(_, v)| v).collect();
let embedding: Vec<Vec<f64>> = (0..n)
.map(|i| eigenvectors.iter().map(|ev| ev[i]).collect())
.collect();
adata.add_obsm("X_diffmap", embedding.clone())?;
let evals_str: Vec<String> = eigenvalues.iter().map(|e| format!("{:.6}", e)).collect();
adata.add_uns("diffmap_evals", evals_str.join(","));
Ok(DiffusionResult {
eigenvalues,
components: embedding,
})
}
fn power_iteration(matrix: &[Vec<f64>], n: usize, max_iter: usize) -> (f64, Vec<f64>) {
let mut v = vec![1.0 / (n as f64).sqrt(); n];
let mut eigenvalue = 0.0;
for _ in 0..max_iter {
let mut w = vec![0.0; n];
for i in 0..n {
for j in 0..n {
w[i] += matrix[i][j] * v[j];
}
}
let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-15 {
break;
}
let new_eigenvalue = w
.iter()
.zip(v.iter())
.map(|(wi, vi)| wi * vi)
.sum::<f64>();
for x in &mut w {
*x /= norm;
}
let converged = (new_eigenvalue - eigenvalue).abs() < 1e-12;
eigenvalue = new_eigenvalue;
v = w;
if converged {
break;
}
}
(eigenvalue, v)
}
fn deflate_matrix(matrix: &mut [Vec<f64>], eigenvec: &[f64], eigenvalue: f64, n: usize) {
for i in 0..n {
for j in 0..n {
matrix[i][j] -= eigenvalue * eigenvec[i] * eigenvec[j];
}
}
}
#[derive(Debug, Clone)]
pub struct DptConfig {
pub root_cell: usize,
pub n_branchings: usize,
}
impl Default for DptConfig {
fn default() -> Self {
Self {
root_cell: 0,
n_branchings: 0,
}
}
}
pub fn dpt(adata: &mut AnnData, config: &DptConfig) -> Result<()> {
let diffmap = adata
.get_obsm("X_diffmap")
.ok_or_else(|| {
CyaneaError::InvalidInput(
"obsm['X_diffmap'] not found; run diffusion_map() first".into(),
)
})?
.clone();
let evals_str = adata
.get_uns("diffmap_evals")
.ok_or_else(|| CyaneaError::InvalidInput("diffmap_evals not found in uns".into()))?
.to_string();
let eigenvalues: Vec<f64> = evals_str
.split(',')
.filter_map(|s| s.trim().parse::<f64>().ok())
.collect();
let n = adata.n_obs();
if config.root_cell >= n {
return Err(CyaneaError::InvalidInput(format!(
"root_cell {} out of bounds (n_obs={})",
config.root_cell, n
)));
}
let n_comps = diffmap[0].len().min(eigenvalues.len());
let root = config.root_cell;
let weights: Vec<f64> = eigenvalues
.iter()
.take(n_comps)
.map(|&lam| {
let denom = (1.0 - lam).max(1e-10);
1.0 / (denom * denom)
})
.collect();
let mut pseudotime = vec![0.0; n];
for i in 0..n {
let mut dist_sq = 0.0;
for k in 0..n_comps {
let diff = diffmap[i][k] - diffmap[root][k];
dist_sq += weights[k] * diff * diff;
}
pseudotime[i] = dist_sq.sqrt();
}
let max_pt = pseudotime.iter().cloned().fold(0.0f64, f64::max);
if max_pt > 0.0 {
for pt in &mut pseudotime {
*pt /= max_pt;
}
}
adata.add_obs_column("dpt_pseudotime", ColumnData::Numeric(pseudotime))?;
Ok(())
}
#[derive(Debug, Clone)]
pub struct PagaResult {
pub connectivities: Vec<Vec<f64>>,
pub groups: Vec<String>,
pub cluster_sizes: Vec<usize>,
}
pub fn paga(adata: &AnnData, cluster_key: &str) -> Result<PagaResult> {
let conn = adata
.get_obsp("connectivities")
.ok_or_else(|| {
CyaneaError::InvalidInput(
"obsp['connectivities'] not found; run neighbors() first".into(),
)
})?;
let cluster_col = adata
.get_obs(cluster_key)
.ok_or_else(|| {
CyaneaError::InvalidInput(format!("obs['{}'] not found; run clustering first", cluster_key))
})?;
let labels: Vec<String> = match cluster_col {
ColumnData::Strings(v) => v.clone(),
ColumnData::Numeric(v) => v.iter().map(|x| x.to_string()).collect(),
ColumnData::Categorical { codes, categories } => {
codes.iter().map(|&c| {
categories.get(c as usize).cloned().unwrap_or_else(|| c.to_string())
}).collect()
}
};
let mut unique_labels: Vec<String> = labels.clone();
unique_labels.sort();
unique_labels.dedup();
let n_clusters = unique_labels.len();
let label_to_idx: HashMap<&str, usize> = unique_labels
.iter()
.enumerate()
.map(|(i, l)| (l.as_str(), i))
.collect();
let assignments: Vec<usize> = labels
.iter()
.map(|l| *label_to_idx.get(l.as_str()).unwrap())
.collect();
let cluster_sizes: Vec<usize> = (0..n_clusters)
.map(|c| assignments.iter().filter(|&&a| a == c).count())
.collect();
let mut inter_conn = vec![vec![0.0; n_clusters]; n_clusters];
for (i, j, w) in conn.iter() {
let ci = assignments[i];
let cj = assignments[j];
if ci != cj {
inter_conn[ci][cj] += w;
}
}
let n = adata.n_obs();
let total_weight: f64 = conn.iter().map(|(_, _, w)| w).sum();
let mut normalized = vec![vec![0.0; n_clusters]; n_clusters];
for ci in 0..n_clusters {
for cj in (ci + 1)..n_clusters {
let observed = inter_conn[ci][cj] + inter_conn[cj][ci];
let expected = if n > 0 {
2.0 * cluster_sizes[ci] as f64 * cluster_sizes[cj] as f64 * total_weight
/ (n as f64 * n as f64)
} else {
1.0
};
let ratio = if expected > 0.0 {
observed / expected
} else {
0.0
};
normalized[ci][cj] = ratio;
normalized[cj][ci] = ratio;
}
}
Ok(PagaResult {
connectivities: normalized,
groups: unique_labels,
cluster_sizes,
})
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum VelocityMode {
SteadyState,
}
#[derive(Debug, Clone)]
pub struct VelocityConfig {
pub mode: VelocityMode,
pub n_neighbors: usize,
pub min_counts: usize,
}
impl Default for VelocityConfig {
fn default() -> Self {
Self {
mode: VelocityMode::SteadyState,
n_neighbors: 30,
min_counts: 20,
}
}
}
pub fn rna_velocity(adata: &mut AnnData, config: &VelocityConfig) -> Result<()> {
let n_obs = adata.n_obs();
let n_vars = adata.n_vars();
let spliced = adata
.get_layer("spliced")
.ok_or_else(|| {
CyaneaError::InvalidInput("layers['spliced'] not found".into())
})?
.clone();
let unspliced = adata
.get_layer("unspliced")
.ok_or_else(|| {
CyaneaError::InvalidInput("layers['unspliced'] not found".into())
})?
.clone();
let mut gammas = vec![0.0; n_vars];
let mut velocity_data = vec![vec![0.0; n_vars]; n_obs];
for j in 0..n_vars {
let mut sum_us = 0.0;
let mut sum_ss = 0.0;
let mut total_counts = 0.0;
for i in 0..n_obs {
let s = spliced.get(i, j);
let u = unspliced.get(i, j);
sum_us += u * s;
sum_ss += s * s;
total_counts += s + u;
}
if total_counts < config.min_counts as f64 || sum_ss < 1e-10 {
gammas[j] = 0.0;
continue;
}
gammas[j] = sum_us / sum_ss;
for i in 0..n_obs {
let s = spliced.get(i, j);
let u = unspliced.get(i, j);
velocity_data[i][j] = u - gammas[j] * s;
}
}
adata.add_layer("velocity", MatrixData::Dense(velocity_data.clone()))?;
let mut vel_graph = SparseMatrix::new(n_obs, n_obs);
let k = config.n_neighbors.min(n_obs - 1);
for i in 0..n_obs {
let mut sims: Vec<(usize, f64)> = (0..n_obs)
.filter(|&j| j != i)
.map(|j| {
let cos_sim = cosine_sim(&velocity_data[i], &velocity_data[j]);
(j, cos_sim)
})
.collect();
sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for &(j, sim) in sims.iter().take(k) {
if sim > 0.0 {
let _ = vel_graph.insert(i, j, sim);
}
}
}
adata.add_obsp("velocity_graph", vel_graph)?;
Ok(())
}
fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
let denom = norm_a * norm_b;
if denom < 1e-15 {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_adata(n_obs: usize, n_vars: usize) -> AnnData {
let x = MatrixData::Dense(vec![vec![0.0; n_vars]; n_obs]);
let obs_names: Vec<String> = (0..n_obs).map(|i| format!("cell_{}", i)).collect();
let var_names: Vec<String> = (0..n_vars).map(|j| format!("gene_{}", j)).collect();
AnnData::new(x, obs_names, var_names).unwrap()
}
fn make_connectivity(n: usize, clusters: &[Vec<usize>]) -> SparseMatrix {
let mut conn = SparseMatrix::new(n, n);
for cluster in clusters {
for &i in cluster {
for &j in cluster {
if i != j {
let _ = conn.insert(i, j, 1.0);
}
}
}
}
for (ci, c1) in clusters.iter().enumerate() {
for c2 in clusters.iter().skip(ci + 1) {
if let (Some(&a), Some(&b)) = (c1.last(), c2.first()) {
let _ = conn.insert(a, b, 0.01);
let _ = conn.insert(b, a, 0.01);
}
}
}
conn
}
#[test]
fn diffusion_map_basic() {
let n = 10;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[vec![0, 1, 2, 3, 4], vec![5, 6, 7, 8, 9]]);
adata.add_obsp("connectivities", conn).unwrap();
let result = diffusion_map(&mut adata, &DiffusionConfig::default()).unwrap();
assert!(!result.eigenvalues.is_empty());
assert_eq!(result.components.len(), n);
assert!(adata.get_obsm("X_diffmap").is_some());
assert!(adata.get_uns("diffmap_evals").is_some());
}
#[test]
fn diffusion_map_eigenvalues_descending() {
let n = 8;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[vec![0, 1, 2, 3], vec![4, 5, 6, 7]]);
adata.add_obsp("connectivities", conn).unwrap();
let result = diffusion_map(
&mut adata,
&DiffusionConfig {
n_components: 3,
..Default::default()
},
)
.unwrap();
for i in 1..result.eigenvalues.len() {
assert!(
result.eigenvalues[i - 1].abs() >= result.eigenvalues[i].abs() - 1e-10,
"eigenvalues not descending by magnitude: |{}| < |{}|",
result.eigenvalues[i - 1],
result.eigenvalues[i]
);
}
}
#[test]
fn diffusion_map_missing_connectivities() {
let mut adata = make_adata(5, 2);
let result = diffusion_map(&mut adata, &DiffusionConfig::default());
assert!(result.is_err());
}
#[test]
fn diffusion_map_n_components_capped() {
let n = 5;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[(0..n).collect()]);
adata.add_obsp("connectivities", conn).unwrap();
let result = diffusion_map(
&mut adata,
&DiffusionConfig {
n_components: 100, ..Default::default()
},
)
.unwrap();
assert!(result.eigenvalues.len() <= n - 1);
}
#[test]
fn diffusion_map_anisotropic_vs_standard() {
let n = 6;
let mut adata1 = make_adata(n, 2);
let mut adata2 = make_adata(n, 2);
let conn = make_connectivity(n, &[vec![0, 1, 2], vec![3, 4, 5]]);
adata1.add_obsp("connectivities", conn.clone()).unwrap();
adata2.add_obsp("connectivities", conn).unwrap();
let r1 = diffusion_map(
&mut adata1,
&DiffusionConfig { alpha: 0.0, ..Default::default() },
).unwrap();
let r2 = diffusion_map(
&mut adata2,
&DiffusionConfig { alpha: 1.0, ..Default::default() },
).unwrap();
if !r1.eigenvalues.is_empty() && !r2.eigenvalues.is_empty() {
assert!(r1.eigenvalues[0].is_finite());
assert!(r2.eigenvalues[0].is_finite());
}
}
#[test]
fn dpt_basic() {
let n = 8;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[(0..n).collect()]);
adata.add_obsp("connectivities", conn).unwrap();
diffusion_map(&mut adata, &DiffusionConfig { n_components: 3, ..Default::default() }).unwrap();
dpt(&mut adata, &DptConfig { root_cell: 0, ..Default::default() }).unwrap();
let pt = adata.get_obs("dpt_pseudotime").unwrap().as_numeric().unwrap();
assert_eq!(pt.len(), n);
assert!((pt[0] - 0.0).abs() < 1e-10);
assert!(pt.iter().all(|&v| (0.0..=1.0).contains(&v)));
}
#[test]
fn dpt_root_at_end() {
let n = 6;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[(0..n).collect()]);
adata.add_obsp("connectivities", conn).unwrap();
diffusion_map(&mut adata, &DiffusionConfig { n_components: 3, ..Default::default() }).unwrap();
dpt(&mut adata, &DptConfig { root_cell: n - 1, ..Default::default() }).unwrap();
let pt = adata.get_obs("dpt_pseudotime").unwrap().as_numeric().unwrap();
assert!((pt[n - 1] - 0.0).abs() < 1e-10);
}
#[test]
fn dpt_root_out_of_bounds() {
let n = 4;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[(0..n).collect()]);
adata.add_obsp("connectivities", conn).unwrap();
diffusion_map(&mut adata, &DiffusionConfig::default()).unwrap();
let result = dpt(&mut adata, &DptConfig { root_cell: 100, ..Default::default() });
assert!(result.is_err());
}
#[test]
fn dpt_missing_diffmap() {
let mut adata = make_adata(4, 2);
let result = dpt(&mut adata, &DptConfig::default());
assert!(result.is_err());
}
#[test]
fn dpt_two_branches() {
let n = 5;
let mut adata = make_adata(n, 2);
let mut conn = SparseMatrix::new(n, n);
for i in 0..n - 1 {
let _ = conn.insert(i, i + 1, 1.0);
let _ = conn.insert(i + 1, i, 1.0);
}
adata.add_obsp("connectivities", conn).unwrap();
diffusion_map(&mut adata, &DiffusionConfig { n_components: 3, ..Default::default() }).unwrap();
dpt(&mut adata, &DptConfig { root_cell: 0, ..Default::default() }).unwrap();
let pt = adata.get_obs("dpt_pseudotime").unwrap().as_numeric().unwrap();
assert!(pt[0] < pt[2] || pt[0] < pt[4]);
}
#[test]
fn paga_two_clusters() {
let n = 10;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[vec![0, 1, 2, 3, 4], vec![5, 6, 7, 8, 9]]);
adata.add_obsp("connectivities", conn).unwrap();
let labels: Vec<String> = (0..n).map(|i| if i < 5 { "A".into() } else { "B".into() }).collect();
adata.add_obs("cluster", labels).unwrap();
let result = paga(&adata, "cluster").unwrap();
assert_eq!(result.groups.len(), 2);
assert_eq!(result.cluster_sizes.len(), 2);
assert_eq!(result.connectivities.len(), 2);
assert_eq!(result.connectivities[0][0], 0.0);
assert_eq!(result.connectivities[1][1], 0.0);
assert!(result.connectivities[0][1] > 0.0);
}
#[test]
fn paga_three_clusters() {
let n = 9;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[vec![0, 1, 2], vec![3, 4, 5], vec![6, 7, 8]]);
adata.add_obsp("connectivities", conn).unwrap();
let labels: Vec<String> = (0..n)
.map(|i| match i / 3 { 0 => "A", 1 => "B", _ => "C" }.into())
.collect();
adata.add_obs("cluster", labels).unwrap();
let result = paga(&adata, "cluster").unwrap();
assert_eq!(result.groups.len(), 3);
assert_eq!(result.connectivities.len(), 3);
}
#[test]
fn paga_missing_connectivities() {
let mut adata = make_adata(5, 2);
adata.add_obs("cluster", vec!["A"; 5].into_iter().map(String::from).collect()).unwrap();
let result = paga(&adata, "cluster");
assert!(result.is_err());
}
#[test]
fn paga_missing_cluster_key() {
let mut adata = make_adata(5, 2);
let conn = SparseMatrix::new(5, 5);
adata.add_obsp("connectivities", conn).unwrap();
let result = paga(&adata, "nonexistent");
assert!(result.is_err());
}
#[test]
fn paga_symmetric_connectivity() {
let n = 6;
let mut adata = make_adata(n, 2);
let conn = make_connectivity(n, &[vec![0, 1, 2], vec![3, 4, 5]]);
adata.add_obsp("connectivities", conn).unwrap();
let labels: Vec<String> = (0..n).map(|i| if i < 3 { "X".into() } else { "Y".into() }).collect();
adata.add_obs("cluster", labels).unwrap();
let result = paga(&adata, "cluster").unwrap();
assert!((result.connectivities[0][1] - result.connectivities[1][0]).abs() < 1e-10);
}
#[test]
fn velocity_steady_state() {
let n = 10;
let n_vars = 3;
let mut adata = make_adata(n, n_vars);
let mut spliced_data = vec![vec![0.0; n_vars]; n];
let mut unspliced_data = vec![vec![0.0; n_vars]; n];
for i in 0..n {
for j in 0..n_vars {
spliced_data[i][j] = (i + 1) as f64 * (j + 1) as f64;
unspliced_data[i][j] = spliced_data[i][j] * 0.5; }
}
adata.add_layer("spliced", MatrixData::Dense(spliced_data)).unwrap();
adata.add_layer("unspliced", MatrixData::Dense(unspliced_data)).unwrap();
rna_velocity(&mut adata, &VelocityConfig::default()).unwrap();
assert!(adata.get_layer("velocity").is_some());
assert!(adata.get_obsp("velocity_graph").is_some());
let vel = adata.get_layer("velocity").unwrap();
for i in 0..n {
for j in 0..n_vars {
assert!(
vel.get(i, j).abs() < 1e-10,
"velocity[{},{}] = {} (expected ~0)",
i, j, vel.get(i, j)
);
}
}
}
#[test]
fn velocity_positive_velocity() {
let n = 5;
let mut adata = make_adata(n, 2);
let spliced = vec![vec![1.0, 2.0]; n];
let unspliced = vec![vec![5.0, 10.0]; n]; adata.add_layer("spliced", MatrixData::Dense(spliced)).unwrap();
adata.add_layer("unspliced", MatrixData::Dense(unspliced)).unwrap();
rna_velocity(&mut adata, &VelocityConfig { min_counts: 0, ..Default::default() }).unwrap();
}
#[test]
fn velocity_missing_spliced() {
let mut adata = make_adata(5, 2);
adata.add_layer("unspliced", MatrixData::Dense(vec![vec![1.0, 2.0]; 5])).unwrap();
let result = rna_velocity(&mut adata, &VelocityConfig::default());
assert!(result.is_err());
}
#[test]
fn velocity_missing_unspliced() {
let mut adata = make_adata(5, 2);
adata.add_layer("spliced", MatrixData::Dense(vec![vec![1.0, 2.0]; 5])).unwrap();
let result = rna_velocity(&mut adata, &VelocityConfig::default());
assert!(result.is_err());
}
#[test]
fn velocity_low_count_genes_skipped() {
let n = 5;
let mut adata = make_adata(n, 2);
let spliced = vec![vec![10.0, 0.001]; n];
let unspliced = vec![vec![5.0, 0.001]; n];
adata.add_layer("spliced", MatrixData::Dense(spliced)).unwrap();
adata.add_layer("unspliced", MatrixData::Dense(unspliced)).unwrap();
rna_velocity(
&mut adata,
&VelocityConfig {
min_counts: 20,
..Default::default()
},
)
.unwrap();
let vel = adata.get_layer("velocity").unwrap();
for i in 0..n {
assert_eq!(vel.get(i, 1), 0.0);
}
}
#[test]
fn velocity_graph_has_entries() {
let n = 8;
let mut adata = make_adata(n, 3);
let mut spliced = vec![vec![0.0; 3]; n];
let mut unspliced = vec![vec![0.0; 3]; n];
for i in 0..n {
for j in 0..3 {
spliced[i][j] = (i * 3 + j + 1) as f64;
unspliced[i][j] = spliced[i][j] * 0.3 + (i as f64) * 0.1;
}
}
adata.add_layer("spliced", MatrixData::Dense(spliced)).unwrap();
adata.add_layer("unspliced", MatrixData::Dense(unspliced)).unwrap();
rna_velocity(&mut adata, &VelocityConfig { min_counts: 0, ..Default::default() }).unwrap();
let vg = adata.get_obsp("velocity_graph").unwrap();
assert!(vg.nnz() > 0, "velocity graph should have entries");
}
#[test]
fn diffmap_then_dpt_pipeline() {
let n = 8;
let mut adata = make_adata(n, 2);
let mut conn = SparseMatrix::new(n, n);
for i in 0..n - 1 {
let _ = conn.insert(i, i + 1, 1.0);
let _ = conn.insert(i + 1, i, 1.0);
}
adata.add_obsp("connectivities", conn).unwrap();
diffusion_map(&mut adata, &DiffusionConfig { n_components: 3, ..Default::default() }).unwrap();
dpt(&mut adata, &DptConfig { root_cell: 0, ..Default::default() }).unwrap();
let pt = adata.get_obs("dpt_pseudotime").unwrap().as_numeric().unwrap();
assert_eq!(pt.len(), n);
assert!(pt[0].abs() < 1e-10, "root pseudotime should be 0");
assert!(pt.iter().all(|&v| v >= 0.0 && v <= 1.0));
}
#[test]
fn cosine_sim_identical() {
let sim = cosine_sim(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]);
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn cosine_sim_orthogonal() {
let sim = cosine_sim(&[1.0, 0.0], &[0.0, 1.0]);
assert!(sim.abs() < 1e-10);
}
#[test]
fn cosine_sim_zero() {
let sim = cosine_sim(&[0.0, 0.0], &[1.0, 2.0]);
assert_eq!(sim, 0.0);
}
}