use chematic_core::{AtomIdx, Molecule};
use crate::coords::{Coords3D, Point3};
use crate::dg::generate_coords;
fn bondi_vdw_radius(atomic_number: u8) -> f64 {
match atomic_number {
1 => 1.20, 6 => 1.70, 7 => 1.55, 8 => 1.52, 9 => 1.47, 15 => 1.80, 16 => 1.80, 17 => 1.75, 35 => 1.85, 53 => 1.98, _ => 1.70, }
}
pub fn sasa(mol: &Molecule, coords: &Coords3D) -> f64 {
shrake_rupley_sasa(mol, coords, 1.4, 100)
}
pub fn sasa_per_atom_default(mol: &Molecule, coords: &Coords3D) -> Vec<f64> {
sasa_per_atom(mol, coords, 1.4, 100)
}
pub fn sasa_from_dg(mol: &Molecule) -> Result<f64, String> {
let coords = generate_coords(mol);
Ok(sasa(mol, &coords))
}
pub fn sasa_per_atom_from_dg(mol: &Molecule) -> Result<Vec<f64>, String> {
let coords = generate_coords(mol);
Ok(sasa_per_atom_default(mol, &coords))
}
pub fn shrake_rupley_sasa(
mol: &Molecule,
coords: &Coords3D,
probe_radius: f64,
sphere_points: usize,
) -> f64 {
if mol.atom_count() == 0 {
return 0.0;
}
let mut total_sasa = 0.0;
for i in 0..mol.atom_count() {
let idx = AtomIdx(i as u32);
let atom_i = mol.atom(idx);
let vdw_i = bondi_vdw_radius(atom_i.element.atomic_number());
let radius_i = vdw_i + probe_radius;
let pos_i = coords.get(idx);
let exposed_count = count_exposed_points(
mol,
coords,
pos_i,
idx,
probe_radius,
sphere_points,
);
let sphere_area = 4.0 * std::f64::consts::PI * radius_i * radius_i;
let atom_sasa = (exposed_count as f64 / sphere_points as f64) * sphere_area;
total_sasa += atom_sasa;
}
total_sasa
}
pub fn sasa_per_atom(
mol: &Molecule,
coords: &Coords3D,
probe_radius: f64,
sphere_points: usize,
) -> Vec<f64> {
let mut sasa_values = vec![0.0; mol.atom_count()];
for i in 0..mol.atom_count() {
let idx = AtomIdx(i as u32);
let atom_i = mol.atom(idx);
let vdw_i = bondi_vdw_radius(atom_i.element.atomic_number());
let radius_i = vdw_i + probe_radius;
let pos_i = coords.get(idx);
let exposed_count = count_exposed_points(
mol,
coords,
pos_i,
idx,
probe_radius,
sphere_points,
);
let sphere_area = 4.0 * std::f64::consts::PI * radius_i * radius_i;
sasa_values[i] = (exposed_count as f64 / sphere_points as f64) * sphere_area;
}
sasa_values
}
fn count_exposed_points(
mol: &Molecule,
coords: &Coords3D,
atom_pos: Point3,
atom_idx: AtomIdx,
probe_radius: f64,
sphere_points: usize,
) -> usize {
let atom = mol.atom(atom_idx);
let atom_vdw = bondi_vdw_radius(atom.element.atomic_number());
let atom_radius = atom_vdw + probe_radius;
let sphere = generate_sphere_points(atom_pos, atom_radius, sphere_points);
let neighbors: Vec<(Point3, f64)> = (0..mol.atom_count())
.filter(|&j| j as u32 != atom_idx.0)
.map(|j| {
let nbr_idx = AtomIdx(j as u32);
let nbr = mol.atom(nbr_idx);
let nbr_vdw = bondi_vdw_radius(nbr.element.atomic_number());
let nbr_radius = nbr_vdw + probe_radius;
let nbr_pos = coords.get(nbr_idx);
(nbr_pos, nbr_radius)
})
.collect();
let mut exposed_count = 0;
for &point in &sphere {
let mut is_exposed = true;
for &(nbr_pos, nbr_radius) in &neighbors {
let dx = point.x - nbr_pos.x;
let dy = point.y - nbr_pos.y;
let dz = point.z - nbr_pos.z;
let dist = (dx * dx + dy * dy + dz * dz).sqrt();
if dist < nbr_radius - 1e-6 {
is_exposed = false;
break;
}
}
if is_exposed {
exposed_count += 1;
}
}
exposed_count
}
fn generate_sphere_points(center: Point3, radius: f64, num_points: usize) -> Vec<Point3> {
if num_points < 2 {
return vec![center]; }
let mut points = Vec::with_capacity(num_points);
let golden_angle = std::f64::consts::PI * (3.0 - 5_f64.sqrt());
for i in 0..num_points {
let y = 1.0 - (i as f64) / (num_points as f64 - 1.0) * 2.0;
let x_radius = (1.0 - y * y).sqrt();
let theta = golden_angle * (i as f64);
let x = theta.cos() * x_radius;
let z = theta.sin() * x_radius;
let point = Point3::new(
center.x + x * radius,
center.y + y * radius,
center.z + z * radius,
);
points.push(point);
}
points
}
#[derive(Clone, Debug)]
pub struct SasaDescriptor {
pub total: f64,
pub mean: f64,
pub std_dev: f64,
pub per_atom: Vec<f64>,
}
impl SasaDescriptor {
pub fn from_per_atom(per_atom: Vec<f64>) -> Self {
let total: f64 = per_atom.iter().sum();
let n = per_atom.len() as f64;
let mean = if n > 0.0 { total / n } else { 0.0 };
let variance = if n > 0.0 {
per_atom
.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>()
/ n
} else {
0.0
};
let std_dev = variance.sqrt();
SasaDescriptor {
total,
mean,
std_dev,
per_atom,
}
}
}
pub fn sasa_descriptor(mol: &Molecule, coords: &Coords3D) -> SasaDescriptor {
let per_atom = sasa_per_atom_default(mol, coords);
SasaDescriptor::from_per_atom(per_atom)
}
pub fn sasa_descriptor_from_dg(mol: &Molecule) -> Result<SasaDescriptor, String> {
let coords = generate_coords(mol);
Ok(sasa_descriptor(mol, &coords))
}
#[derive(Clone, Debug)]
pub struct PerElementSasa {
pub by_element: Vec<f64>,
}
impl PerElementSasa {
pub fn get(&self, atomic_number: u8) -> f64 {
if (atomic_number as usize) < self.by_element.len() {
self.by_element[atomic_number as usize]
} else {
0.0
}
}
}
pub fn sasa_per_element(mol: &Molecule, coords: &Coords3D) -> PerElementSasa {
let per_atom = sasa_per_atom_default(mol, coords);
let mut by_element = vec![0.0; 119];
for (i, &sasa_val) in per_atom.iter().enumerate() {
let idx = AtomIdx(i as u32);
let z = mol.atom(idx).element.atomic_number() as usize;
if z < by_element.len() {
by_element[z] += sasa_val;
}
}
PerElementSasa { by_element }
}
pub fn sasa_per_element_from_dg(mol: &Molecule) -> Result<PerElementSasa, String> {
let coords = generate_coords(mol);
Ok(sasa_per_element(mol, &coords))
}
pub fn calc_mol_sasa(mol: &Molecule, coords: &Coords3D) -> f64 {
sasa(mol, coords)
}
pub fn calc_mol_sasa_with_probe(mol: &Molecule, coords: &Coords3D, probe_radius: f64) -> f64 {
shrake_rupley_sasa(mol, coords, probe_radius, 100)
}
#[cfg(test)]
mod tests {
use super::*;
use chematic_smiles::parse;
#[test]
fn test_sasa_single_atom() {
let mol = parse("C").unwrap();
let coords = Coords3D::new_zeroed(1);
let sasa = shrake_rupley_sasa(&mol, &coords, 1.4, 100);
assert!(sasa > 0.0, "single atom should have positive SASA");
}
#[test]
fn test_sasa_multiple_atoms() {
let mol = parse("CC").unwrap();
let mut coords = Coords3D::new_zeroed(2);
coords.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords.set(AtomIdx(1), Point3::new(10.0, 0.0, 0.0));
let sasa = shrake_rupley_sasa(&mol, &coords, 1.4, 100);
assert!(sasa > 0.0, "multi-atom SASA should be positive");
}
#[test]
fn test_sasa_per_atom_sum() {
let mol = parse("CC").unwrap();
let mut coords = Coords3D::new_zeroed(2);
coords.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords.set(AtomIdx(1), Point3::new(10.0, 0.0, 0.0));
let per_atom = sasa_per_atom(&mol, &coords, 1.4, 100);
let sum: f64 = per_atom.iter().sum();
let total = shrake_rupley_sasa(&mol, &coords, 1.4, 100);
assert!((sum - total).abs() < 1e-6, "per-atom sum should match total");
}
#[test]
fn test_sasa_empty_molecule() {
let mol = parse("").unwrap_or_else(|_| {
chematic_core::MoleculeBuilder::new().build()
});
let coords = Coords3D::new_zeroed(0);
let sasa = shrake_rupley_sasa(&mol, &coords, 1.4, 100);
assert_eq!(sasa, 0.0, "empty molecule should have zero SASA");
}
#[test]
fn test_sasa_with_default_params() {
let mol = parse("C").unwrap();
let coords = Coords3D::new_zeroed(1);
let sasa_default = sasa(&mol, &coords);
let sasa_explicit = shrake_rupley_sasa(&mol, &coords, 1.4, 100);
assert!((sasa_default - sasa_explicit).abs() < 1e-6);
}
#[test]
fn test_sasa_per_atom_default() {
let mol = parse("CC").unwrap();
let mut coords = Coords3D::new_zeroed(2);
coords.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords.set(AtomIdx(1), Point3::new(10.0, 0.0, 0.0));
let per_atom = sasa_per_atom_default(&mol, &coords);
assert_eq!(per_atom.len(), 2);
assert!(per_atom[0] > 0.0 && per_atom[1] > 0.0);
}
#[test]
fn test_sasa_from_distance_geometry_methane() {
let mol = parse("C").unwrap();
let result = sasa_from_dg(&mol);
assert!(result.is_ok());
let sasa = result.unwrap();
assert!(sasa > 80.0 && sasa < 160.0, "methane SASA out of expected range: {}", sasa);
}
#[test]
fn test_sasa_from_dg_ethane() {
let mol = parse("CC").unwrap();
let result = sasa_from_dg(&mol);
assert!(result.is_ok());
let sasa = result.unwrap();
assert!(sasa > 100.0, "ethane SASA should be substantial: {}", sasa);
}
#[test]
fn test_sasa_per_atom_from_dg() {
let mol = parse("CCO").unwrap();
let result = sasa_per_atom_from_dg(&mol);
assert!(result.is_ok());
let per_atom = result.unwrap();
assert_eq!(per_atom.len(), 3);
for (i, &sasa) in per_atom.iter().enumerate() {
assert!(sasa > 0.0, "atom {} SASA should be positive", i);
}
}
#[test]
fn test_sasa_is_additive_separated_atoms() {
let mol = parse("CC").unwrap();
let mut coords = Coords3D::new_zeroed(2);
coords.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords.set(AtomIdx(1), Point3::new(100.0, 0.0, 0.0));
let total = sasa(&mol, &coords);
let per_atom = sasa_per_atom_default(&mol, &coords);
let sum: f64 = per_atom.iter().sum();
assert!((total - sum).abs() < 1.0, "far atoms should have additive SASA");
}
#[test]
fn test_sasa_occlusion_effect() {
let mol = parse("CC").unwrap();
let mut coords_far = Coords3D::new_zeroed(2);
coords_far.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords_far.set(AtomIdx(1), Point3::new(50.0, 0.0, 0.0));
let sasa_far = sasa(&mol, &coords_far);
let mut coords_close = Coords3D::new_zeroed(2);
coords_close.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords_close.set(AtomIdx(1), Point3::new(2.0, 0.0, 0.0));
let sasa_close = sasa(&mol, &coords_close);
assert!(sasa_close < sasa_far, "close atoms should have lower SASA due to occlusion");
}
#[test]
fn test_sasa_probe_radius_effect() {
let mol = parse("C").unwrap();
let coords = Coords3D::new_zeroed(1);
let sasa_small = shrake_rupley_sasa(&mol, &coords, 1.0, 100);
let sasa_large = shrake_rupley_sasa(&mol, &coords, 2.0, 100);
assert!(sasa_large > sasa_small, "larger probe radius should increase SASA");
}
#[test]
fn test_sasa_sphere_points_convergence() {
let mol = parse("CC").unwrap();
let mut coords = Coords3D::new_zeroed(2);
coords.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords.set(AtomIdx(1), Point3::new(1.5, 0.0, 0.0));
let sasa_100 = shrake_rupley_sasa(&mol, &coords, 1.4, 100);
let sasa_500 = shrake_rupley_sasa(&mol, &coords, 1.4, 500);
assert!((sasa_100 - sasa_500).abs() < sasa_100 * 0.2,
"different sphere points should give similar results");
}
#[test]
fn test_sasa_benzene() {
let mol = parse("c1ccccc1").unwrap();
let result = sasa_from_dg(&mol);
assert!(result.is_ok());
let sasa = result.unwrap();
assert!(sasa > 150.0, "benzene SASA should be substantial: {}", sasa);
}
#[test]
fn test_sasa_finite_nonzero() {
let mol = parse("C1CCCCC1").unwrap(); let result = sasa_from_dg(&mol);
assert!(result.is_ok());
let sasa = result.unwrap();
assert!(sasa.is_finite(), "SASA should be finite");
assert!(sasa > 0.0, "SASA should be positive");
assert!(sasa < 1e6, "SASA should not be unreasonably large");
}
#[test]
fn test_sasa_per_atom_polar_molecule() {
let mol = parse("CCO").unwrap();
let coords = generate_coords(&mol);
let per_atom = sasa_per_atom_default(&mol, &coords);
assert_eq!(per_atom.len(), 3);
for &sasa in &per_atom {
assert!(sasa > 0.0 && sasa.is_finite());
}
}
#[test]
fn test_sasa_descriptor_basic() {
let mol = parse("C").unwrap();
let coords = Coords3D::new_zeroed(1);
let desc = sasa_descriptor(&mol, &coords);
assert!(desc.total > 0.0);
assert_eq!(desc.per_atom.len(), 1);
assert!((desc.mean - desc.total).abs() < 1e-6); }
#[test]
fn test_sasa_descriptor_multiple_atoms() {
let mol = parse("CC").unwrap();
let mut coords = Coords3D::new_zeroed(2);
coords.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords.set(AtomIdx(1), Point3::new(10.0, 0.0, 0.0));
let desc = sasa_descriptor(&mol, &coords);
assert_eq!(desc.per_atom.len(), 2);
let sum: f64 = desc.per_atom.iter().sum();
assert!((desc.total - sum).abs() < 1e-6);
let expected_mean = desc.total / 2.0;
assert!((desc.mean - expected_mean).abs() < 1e-6);
}
#[test]
fn test_sasa_descriptor_std_dev() {
let mol = parse("CC").unwrap();
let mut coords = Coords3D::new_zeroed(2);
coords.set(AtomIdx(0), Point3::new(0.0, 0.0, 0.0));
coords.set(AtomIdx(1), Point3::new(2.0, 0.0, 0.0));
let desc = sasa_descriptor(&mol, &coords);
assert!(desc.std_dev >= 0.0);
}
#[test]
fn test_sasa_descriptor_from_dg() {
let mol = parse("CCO").unwrap();
let result = sasa_descriptor_from_dg(&mol);
assert!(result.is_ok());
let desc = result.unwrap();
assert_eq!(desc.per_atom.len(), 3);
assert!(desc.total > 0.0);
assert!(desc.mean > 0.0);
assert!(desc.std_dev >= 0.0);
}
#[test]
fn test_sasa_per_element_single_element() {
let mol = parse("C").unwrap(); let coords = Coords3D::new_zeroed(1);
let per_elem = sasa_per_element(&mol, &coords);
let carbon_sasa = per_elem.get(6); assert!(carbon_sasa > 0.0);
assert_eq!(per_elem.get(1), 0.0); assert_eq!(per_elem.get(8), 0.0); }
#[test]
fn test_sasa_per_element_mixed() {
let mol = parse("CCO").unwrap();
let coords = generate_coords(&mol);
let per_elem = sasa_per_element(&mol, &coords);
let carbon_sasa = per_elem.get(6);
let oxygen_sasa = per_elem.get(8);
assert!(carbon_sasa > 0.0);
assert!(oxygen_sasa > 0.0);
let total: f64 = per_elem.by_element.iter().sum();
let expected = sasa(&mol, &coords);
assert!((total - expected).abs() < 1e-6);
}
#[test]
fn test_sasa_per_element_from_dg() {
let mol = parse("C1CCCCC1").unwrap(); let result = sasa_per_element_from_dg(&mol);
assert!(result.is_ok());
let per_elem = result.unwrap();
let carbon_sasa = per_elem.get(6);
assert!(carbon_sasa > 0.0);
}
#[test]
fn test_sasa_rdkit_compatible() {
let mol = parse("C").unwrap();
let coords = Coords3D::new_zeroed(1);
let sasa_std = sasa(&mol, &coords);
let sasa_rdkit = calc_mol_sasa(&mol, &coords);
assert!((sasa_std - sasa_rdkit).abs() < 1e-10);
}
#[test]
fn test_sasa_rdkit_probe_radius() {
let mol = parse("CC").unwrap();
let coords = generate_coords(&mol);
let sasa_default = calc_mol_sasa(&mol, &coords);
let sasa_custom = calc_mol_sasa_with_probe(&mol, &coords, 1.4);
assert!((sasa_default - sasa_custom).abs() < 1e-6);
let sasa_large_probe = calc_mol_sasa_with_probe(&mol, &coords, 2.0);
assert!((sasa_large_probe - sasa_default).abs() > 0.1);
}
#[test]
fn test_sasa_descriptor_consistency() {
let mol = parse("CCO").unwrap();
let coords = generate_coords(&mol);
let desc = sasa_descriptor(&mol, &coords);
let total_sasa = sasa(&mol, &coords);
assert!((desc.total - total_sasa).abs() < 1e-6);
}
#[test]
fn test_sasa_descriptor_nonzero_atoms() {
let mol = parse("c1ccccc1").unwrap(); let coords = generate_coords(&mol);
let desc = sasa_descriptor(&mol, &coords);
for (i, &sasa_val) in desc.per_atom.iter().enumerate() {
assert!(sasa_val > 0.0, "atom {} should have positive SASA", i);
}
}
}