use ndarray::{Array2, Axis};
use polars::prelude::*;
use rayon::prelude::*;
use std::collections::HashMap;
pub fn calculate_weighted_ligands(
xy: &Array2<f64>,
lig_values: &Array2<f64>,
radius: f64,
scale_factor: f64,
) -> Array2<f64> {
calculate_weighted_ligands_with_cutoff(xy, lig_values, radius, scale_factor, None)
}
pub fn calculate_weighted_ligands_with_cutoff(
xy: &Array2<f64>,
lig_values: &Array2<f64>,
radius: f64,
scale_factor: f64,
max_neighbor_distance: Option<f64>,
) -> Array2<f64> {
let n_cells = xy.nrows();
let n_ligands = lig_values.ncols();
let inv_2r2 = -1.0 / (2.0 * radius * radius);
let d2_cut = max_neighbor_distance
.filter(|m| m.is_finite() && *m > 0.0)
.map(|m| m * m);
let mut result = Array2::zeros((n_cells, n_ligands));
if n_cells == 0 {
return result;
}
let n_inv = 1.0 / n_cells as f64;
result
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
let xi = xy[[i, 0]];
let yi = xy[[i, 1]];
for j in 0..n_cells {
let dx = xi - xy[[j, 0]];
let dy = yi - xy[[j, 1]];
let d2 = dx * dx + dy * dy;
if d2_cut.is_some_and(|c| d2 > c) {
continue;
}
let w = scale_factor * (d2 * inv_2r2).exp();
for k in 0..n_ligands {
row[k] += w * lig_values[[j, k]];
}
}
for k in 0..n_ligands {
row[k] *= n_inv;
}
});
result
}
pub fn calculate_weighted_ligands_grid(
xy: &Array2<f64>,
lig_values: &Array2<f64>,
radius: f64,
scale_factor: f64,
grid_factor: f64,
) -> Array2<f64> {
calculate_weighted_ligands_grid_with_cutoff(
xy,
lig_values,
radius,
scale_factor,
grid_factor,
None,
)
}
pub fn calculate_weighted_ligands_grid_with_cutoff(
xy: &Array2<f64>,
lig_values: &Array2<f64>,
radius: f64,
scale_factor: f64,
grid_factor: f64,
max_neighbor_distance: Option<f64>,
) -> Array2<f64> {
let n_cells = xy.nrows();
let n_ligands = lig_values.ncols();
if n_cells == 0 {
return Array2::zeros((0, n_ligands));
}
let grid_spacing = radius * grid_factor;
let inv_2r2 = -1.0 / (2.0 * radius * radius);
let d2_cut = max_neighbor_distance
.filter(|m| m.is_finite() && *m > 0.0)
.map(|m| m * m);
let mut x_min = f64::INFINITY;
let mut x_max = f64::NEG_INFINITY;
let mut y_min = f64::INFINITY;
let mut y_max = f64::NEG_INFINITY;
for i in 0..n_cells {
let x = xy[[i, 0]];
let y = xy[[i, 1]];
if x < x_min {
x_min = x;
}
if x > x_max {
x_max = x;
}
if y < y_min {
y_min = y;
}
if y > y_max {
y_max = y;
}
}
x_min -= grid_spacing;
y_min -= grid_spacing;
x_max += grid_spacing;
y_max += grid_spacing;
let nx = ((x_max - x_min) / grid_spacing).ceil() as usize + 1;
let ny = ((y_max - y_min) / grid_spacing).ceil() as usize + 1;
let n_anchors = nx * ny;
if n_anchors >= n_cells {
return calculate_weighted_ligands_with_cutoff(
xy,
lig_values,
radius,
scale_factor,
max_neighbor_distance,
);
}
let lig_flat = lig_values.as_slice().unwrap();
let mut anchor_vals = vec![0.0f64; n_anchors * n_ligands];
let n_inv = 1.0 / n_cells as f64;
anchor_vals
.par_chunks_mut(n_ligands)
.enumerate()
.for_each(|(a, row)| {
let gx = a % nx;
let gy = a / nx;
let ax = x_min + gx as f64 * grid_spacing;
let ay = y_min + gy as f64 * grid_spacing;
for j in 0..n_cells {
let dx = ax - xy[[j, 0]];
let dy = ay - xy[[j, 1]];
let d2 = dx * dx + dy * dy;
if d2_cut.is_some_and(|c| d2 > c) {
continue;
}
let w = scale_factor * (d2 * inv_2r2).exp();
let base = j * n_ligands;
for k in 0..n_ligands {
unsafe {
*row.get_unchecked_mut(k) += w * *lig_flat.get_unchecked(base + k);
}
}
}
for k in 0..n_ligands {
row[k] *= n_inv;
}
});
let mut result = Array2::zeros((n_cells, n_ligands));
let res_flat = result.as_slice_mut().unwrap();
res_flat
.par_chunks_mut(n_ligands)
.enumerate()
.for_each(|(i, row)| {
let gx_f = (xy[[i, 0]] - x_min) / grid_spacing;
let gy_f = (xy[[i, 1]] - y_min) / grid_spacing;
let gx0 = gx_f.floor() as usize;
let gy0 = gy_f.floor() as usize;
let gx1 = (gx0 + 1).min(nx - 1);
let gy1 = (gy0 + 1).min(ny - 1);
let fx = gx_f - gx0 as f64;
let fy = gy_f - gy0 as f64;
let w00 = (1.0 - fx) * (1.0 - fy);
let w10 = fx * (1.0 - fy);
let w01 = (1.0 - fx) * fy;
let w11 = fx * fy;
let a00 = (gy0 * nx + gx0) * n_ligands;
let a10 = (gy0 * nx + gx1) * n_ligands;
let a01 = (gy1 * nx + gx0) * n_ligands;
let a11 = (gy1 * nx + gx1) * n_ligands;
for k in 0..n_ligands {
unsafe {
*row.get_unchecked_mut(k) = w00 * *anchor_vals.get_unchecked(a00 + k)
+ w10 * *anchor_vals.get_unchecked(a10 + k)
+ w01 * *anchor_vals.get_unchecked(a01 + k)
+ w11 * *anchor_vals.get_unchecked(a11 + k);
}
}
});
result
}
pub fn compute_received_ligands(
xy: &Array2<f64>,
ligands_df: &DataFrame,
lr_info: &DataFrame,
scale_factor: f64,
) -> anyhow::Result<DataFrame> {
let radius_col = lr_info.column("radius")?.f64()?;
let ligand_col = lr_info.column("ligand")?.str()?;
let mut radius_to_ligands: HashMap<u64, Vec<String>> = HashMap::new();
for i in 0..lr_info.height() {
if let (Some(r), Some(l)) = (radius_col.get(i), ligand_col.get(i)) {
let r_bits = r.to_bits();
radius_to_ligands
.entry(r_bits)
.or_default()
.push(l.to_string());
}
}
let mut results_cols = Vec::new();
for (r_bits, ligands) in radius_to_ligands {
let radius = f64::from_bits(r_bits);
let mut valid_ligands = Vec::new();
let mut lig_indices = Vec::new();
for (idx, name) in ligands_df.get_column_names().iter().enumerate() {
if ligands.contains(&name.to_string()) {
valid_ligands.push(name.to_string());
lig_indices.push(idx);
}
}
if valid_ligands.is_empty() {
continue;
}
let sub_df = ligands_df.select(&valid_ligands)?;
let lig_values = sub_df.to_ndarray::<Float64Type>(IndexOrder::C)?;
let weighted = calculate_weighted_ligands(xy, &lig_values, radius, scale_factor);
for (i, name) in valid_ligands.into_iter().enumerate() {
let col_data: Vec<f64> = weighted.column(i).to_vec();
results_cols.push(Column::new(name.into(), col_data));
}
}
let result_df = DataFrame::new(results_cols)?;
let original_col_names = ligands_df.get_column_names();
let sorted_df = result_df.select(original_col_names.iter().map(|s| s.as_str()))?;
Ok(sorted_df)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
#[test]
fn single_cell_self_contribution() {
let xy = array![[0.0, 0.0]];
let lig = array![[1.0]];
let result = calculate_weighted_ligands(&xy, &lig, 1.0, 1.0);
assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
let r2 = calculate_weighted_ligands(&xy, &lig, 1.0, 2.5);
assert_abs_diff_eq!(r2[[0, 0]], 2.5, epsilon = 1e-10);
}
#[test]
fn two_cells_symmetry() {
let xy = array![[0.0, 0.0], [1.0, 0.0]];
let lig = array![[1.0], [1.0]];
let result = calculate_weighted_ligands(&xy, &lig, 1.0, 1.0);
assert_abs_diff_eq!(result[[0, 0]], result[[1, 0]], epsilon = 1e-10);
}
#[test]
fn gaussian_decay_with_distance() {
let d = 2.0;
let r = 1.0;
let n = 2.0_f64;
let scale = 1.0_f64;
let xy = array![[0.0, 0.0], [d, 0.0]];
let lig = array![[0.0], [1.0]];
let result = calculate_weighted_ligands(&xy, &lig, r, scale);
let w1 = (-(d * d) / (2.0 * r * r)).exp();
let expected = scale * w1 / n;
assert_abs_diff_eq!(result[[0, 0]], expected, epsilon = 1e-10);
}
#[test]
fn scale_factor_scales_output_linearly() {
let xy = array![[0.0, 0.0], [1.0, 0.0], [2.0, 1.0]];
let lig = array![[1.0], [2.0], [4.0]];
let r = 1.7_f64;
let base = calculate_weighted_ligands(&xy, &lig, r, 1.0);
let doubled = calculate_weighted_ligands(&xy, &lig, r, 2.0);
let pi = calculate_weighted_ligands(&xy, &lig, r, std::f64::consts::PI);
for i in 0..3 {
assert_abs_diff_eq!(doubled[[i, 0]], 2.0 * base[[i, 0]], epsilon = 1e-9);
assert_abs_diff_eq!(
pi[[i, 0]],
std::f64::consts::PI * base[[i, 0]],
epsilon = 1e-9
);
}
}
#[test]
fn large_radius_uniform_weights() {
let xy = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let lig = array![[3.0], [6.0], [9.0]];
let result = calculate_weighted_ligands(&xy, &lig, 1e6, 1.0);
let mean = 6.0; for i in 0..3 {
assert_abs_diff_eq!(result[[i, 0]], mean, epsilon = 0.01);
}
}
#[test]
fn small_radius_self_dominant() {
let xy = array![[0.0, 0.0], [100.0, 0.0]];
let lig = array![[5.0], [10.0]];
let result = calculate_weighted_ligands(&xy, &lig, 0.001, 1.0);
assert_abs_diff_eq!(result[[0, 0]], 2.5, epsilon = 1e-3);
assert_abs_diff_eq!(result[[1, 0]], 5.0, epsilon = 1e-3);
}
#[test]
fn multiple_ligands() {
let xy = array![[0.0, 0.0], [1.0, 0.0]];
let lig = array![[1.0, 2.0], [3.0, 4.0]];
let result = calculate_weighted_ligands(&xy, &lig, 1.0, 1.0);
assert_eq!(result.ncols(), 2);
assert_eq!(result.nrows(), 2);
}
#[test]
fn result_shape_matches_input() {
let xy = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0]];
let lig = Array2::from_shape_fn((4, 3), |(i, j)| (i + j) as f64);
let result = calculate_weighted_ligands(&xy, &lig, 1.0, 1.0);
assert_eq!(result.shape(), &[4, 3]);
}
#[test]
fn nonnegative_output_for_nonneg_input() {
let xy = array![[0.0, 0.0], [1.0, 1.0], [2.0, 0.5]];
let lig = array![[1.0], [2.0], [3.0]];
let result = calculate_weighted_ligands(&xy, &lig, 1.0, 1.0);
for &v in result.iter() {
assert!(
v >= 0.0,
"Output should be non-negative for non-negative input"
);
}
}
#[test]
fn zero_ligand_values_give_zero() {
let xy = array![[0.0, 0.0], [1.0, 0.0]];
let lig = array![[0.0], [0.0]];
let result = calculate_weighted_ligands(&xy, &lig, 1.0, 1.0);
assert_abs_diff_eq!(result[[0, 0]], 0.0, epsilon = 1e-15);
assert_abs_diff_eq!(result[[1, 0]], 0.0, epsilon = 1e-15);
}
#[test]
fn gaussian_formula_verification() {
let d = 1.5_f64;
let r = 2.0_f64;
let xy = array![[0.0, 0.0], [d, 0.0]];
let lig = array![[0.0], [1.0]];
let result = calculate_weighted_ligands(&xy, &lig, r, 1.0);
let w1 = (-d * d / (2.0 * r * r)).exp();
assert_abs_diff_eq!(result[[0, 0]], w1 / 2.0, epsilon = 1e-12);
}
#[test]
fn grid_of_cells() {
let xy = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let lig = array![[1.0], [1.0], [1.0], [1.0]];
let result = calculate_weighted_ligands(&xy, &lig, 1.0, 1.0);
let r0 = result[[0, 0]];
let r3 = result[[3, 0]];
assert_abs_diff_eq!(r0, r3, epsilon = 1e-10,);
let r1 = result[[1, 0]];
let r2 = result[[2, 0]];
assert_abs_diff_eq!(r1, r2, epsilon = 1e-10);
}
fn make_random_cells(n: usize, spread: f64) -> (Array2<f64>, Array2<f64>) {
let mut xy = Array2::zeros((n, 2));
let mut lig = Array2::zeros((n, 2));
for i in 0..n {
let t = i as f64 / n as f64;
xy[[i, 0]] = (t * 7.3 + 0.5).sin() * spread;
xy[[i, 1]] = (t * 11.1 + 1.3).cos() * spread;
lig[[i, 0]] = ((t * 3.7).sin() + 1.0).max(0.0);
lig[[i, 1]] = ((t * 5.1).cos() + 1.0).max(0.0);
}
(xy, lig)
}
#[test]
fn grid_approx_matches_exact_uniform_field() {
let xy = Array2::from_shape_fn((100, 2), |(i, d)| {
if d == 0 {
(i % 10) as f64 * 10.0
} else {
(i / 10) as f64 * 10.0
}
});
let lig = Array2::ones((100, 1));
let r = 50.0;
let exact = calculate_weighted_ligands(&xy, &lig, r, 1.0);
let grid = calculate_weighted_ligands_grid(&xy, &lig, r, 1.0, 0.3);
let max_err = exact
.iter()
.zip(grid.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max);
assert!(
max_err < 0.05,
"uniform field max error {:.4e} too large",
max_err
);
}
#[test]
fn grid_approx_accuracy_vs_exact() {
let (xy, lig) = make_random_cells(200, 500.0);
let r = 100.0;
let exact = calculate_weighted_ligands(&xy, &lig, r, 1.0);
let grid = calculate_weighted_ligands_grid(&xy, &lig, r, 1.0, 0.5);
let diffs: Vec<f64> = exact
.iter()
.zip(grid.iter())
.map(|(a, b)| (a - b).abs())
.collect();
let max_err = diffs.iter().cloned().fold(0.0f64, f64::max);
let mean_err = diffs.iter().sum::<f64>() / diffs.len() as f64;
assert!(max_err < 0.5, "max error {:.4e}", max_err);
assert!(mean_err < 0.01, "mean error {:.4e}", mean_err);
}
#[test]
fn grid_approx_tighter_factor_less_error() {
let (xy, lig) = make_random_cells(500, 200.0);
let r = 150.0;
let exact = calculate_weighted_ligands(&xy, &lig, r, 1.0);
let coarse = calculate_weighted_ligands_grid(&xy, &lig, r, 1.0, 0.8);
let fine = calculate_weighted_ligands_grid(&xy, &lig, r, 1.0, 0.3);
let err_coarse = exact
.iter()
.zip(coarse.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max);
let err_fine = exact
.iter()
.zip(fine.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f64, f64::max);
assert!(
err_coarse > 1e-12,
"coarse grid should differ from exact, got {:.4e}",
err_coarse
);
assert!(
err_fine <= err_coarse,
"finer grid should be more accurate: fine={:.4e} > coarse={:.4e}",
err_fine,
err_coarse
);
}
#[test]
fn grid_approx_preserves_shape_and_sign() {
let (xy, lig) = make_random_cells(50, 200.0);
let result = calculate_weighted_ligands_grid(&xy, &lig, 50.0, 1.0, 0.5);
assert_eq!(result.shape(), &[50, 2]);
for &v in result.iter() {
assert!(v >= 0.0, "grid approx produced negative value: {}", v);
}
}
#[test]
fn grid_approx_scale_factor_scales_linearly() {
let (xy, lig) = make_random_cells(80, 300.0);
let r = 80.0;
let r1 = calculate_weighted_ligands_grid(&xy, &lig, r, 1.0, 0.5);
let r3 = calculate_weighted_ligands_grid(&xy, &lig, r, 3.0, 0.5);
for (a, b) in r1.iter().zip(r3.iter()) {
assert_abs_diff_eq!(*b, 3.0 * *a, epsilon = 1e-8);
}
}
#[test]
fn grid_approx_symmetry() {
let xy = Array2::from_shape_fn((64, 2), |(i, d)| {
if d == 0 {
(i % 8) as f64 * 20.0
} else {
(i / 8) as f64 * 20.0
}
});
let lig = Array2::ones((64, 1));
let result = calculate_weighted_ligands_grid(&xy, &lig, 40.0, 1.0, 0.3);
assert_abs_diff_eq!(result[[0, 0]], result[[7, 0]], epsilon = 1e-6);
assert_abs_diff_eq!(result[[0, 0]], result[[56, 0]], epsilon = 1e-6);
assert_abs_diff_eq!(result[[0, 0]], result[[63, 0]], epsilon = 1e-6);
}
#[test]
fn grid_falls_back_to_exact_when_few_cells() {
let xy = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
let lig = array![[1.0], [2.0], [3.0]];
let exact = calculate_weighted_ligands(&xy, &lig, 0.5, 1.0);
let grid = calculate_weighted_ligands_grid(&xy, &lig, 0.5, 1.0, 0.3);
for (a, b) in exact.iter().zip(grid.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-14);
}
}
#[test]
fn grid_approx_multiple_ligands() {
let (xy, lig) = make_random_cells(150, 400.0);
let r = 80.0;
let exact = calculate_weighted_ligands(&xy, &lig, r, 1.0);
let grid = calculate_weighted_ligands_grid(&xy, &lig, r, 1.0, 0.3);
for col in 0..2 {
let max_err = (0..150)
.map(|i| (exact[[i, col]] - grid[[i, col]]).abs())
.fold(0.0f64, f64::max);
assert!(
max_err < 0.3,
"ligand column {} max error {:.4e} too large",
col,
max_err
);
}
}
}