pub mod options;
pub use options::{AtomLevel, ChainLevel, ProteinLevel, ResidueLevel, SASAProcessor};
use utils::consts::ANGLE_INCREMENT;
pub mod structures;
pub mod utils;
pub use crate::options::*;
pub use crate::structures::atomic::*;
use crate::utils::ARCH;
use structures::spatial_grid::SpatialGrid;
use rayon::prelude::*;
#[cfg(feature = "serde_json")]
pub use utils::io::sasa_result_to_json;
pub use utils::io::sasa_result_to_protein_object;
#[cfg(feature = "quick-xml")]
pub use utils::io::sasa_result_to_xml;
struct SpherePointsSoA {
x: Vec<f32>,
y: Vec<f32>,
z: Vec<f32>,
}
impl SpherePointsSoA {
fn len(&self) -> usize {
self.x.len()
}
}
fn generate_sphere_points(n_points: usize) -> SpherePointsSoA {
let mut x = Vec::with_capacity(n_points);
let mut y = Vec::with_capacity(n_points);
let mut z = Vec::with_capacity(n_points);
let inv_n_points = 1.0 / n_points as f32;
for i in 0..n_points {
let i_f32 = i as f32;
let t = i_f32 * inv_n_points;
let inclination = (1.0 - 2.0 * t).acos();
let azimuth = ANGLE_INCREMENT * i_f32;
let (sin_azimuth, cos_azimuth) = azimuth.sin_cos();
let sin_inclination = inclination.sin();
x.push(sin_inclination * cos_azimuth);
y.push(sin_inclination * sin_azimuth);
z.push(inclination.cos());
}
SpherePointsSoA { x, y, z }
}
#[inline(never)]
pub fn precompute_neighbors(
atoms: &[Atom],
active_indices: &[usize],
probe_radius: f32,
max_radii: f32,
) -> Vec<Vec<NeighborData>> {
let cell_size = probe_radius + max_radii;
let max_search_radius = max_radii + max_radii + 2.0 * probe_radius;
let grid = SpatialGrid::new(atoms, active_indices, cell_size, max_search_radius);
grid.build_all_neighbor_lists(atoms, active_indices, probe_radius, max_radii)
}
struct AtomSasaKernel<'a> {
atom_index: usize,
atoms: &'a [Atom],
neighbors: &'a [NeighborData],
sphere_points: &'a SpherePointsSoA,
probe_radius: f32,
}
impl<'a> pulp::WithSimd for AtomSasaKernel<'a> {
type Output = f32;
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let atom = &self.atoms[self.atom_index];
let center_pos = atom.position;
let r = atom.radius + self.probe_radius;
let r2 = r * r;
let (sx_chunks, sx_rem) = S::as_simd_f32s(&self.sphere_points.x);
let (sy_chunks, sy_rem) = S::as_simd_f32s(&self.sphere_points.y);
let (sz_chunks, sz_rem) = S::as_simd_f32s(&self.sphere_points.z);
let mut accessible_points = 0.0;
let zero = simd.splat_f32s(0.0);
let true_mask = simd.equal_f32s(zero, zero);
for i in 0..sx_chunks.len() {
let sx = sx_chunks[i];
let sy = sy_chunks[i];
let sz = sz_chunks[i];
let mut chunk_mask = simd.less_than_f32s(simd.splat_f32s(1.0), simd.splat_f32s(0.0));
for neighbor in self.neighbors {
if self.atoms[neighbor.idx as usize].id == atom.id {
continue;
}
let neighbor_pos = self.atoms[neighbor.idx as usize].position;
let vx_scalar = center_pos[0] - neighbor_pos[0];
let vy_scalar = center_pos[1] - neighbor_pos[1];
let vz_scalar = center_pos[2] - neighbor_pos[2];
let v_mag_sq =
vx_scalar * vx_scalar + vy_scalar * vy_scalar + vz_scalar * vz_scalar;
let t = neighbor.threshold_squared;
let limit_scalar = (t - v_mag_sq - r2) / (2.0 * r);
let vx = simd.splat_f32s(vx_scalar);
let vy = simd.splat_f32s(vy_scalar);
let vz = simd.splat_f32s(vz_scalar);
let limit = simd.splat_f32s(limit_scalar);
let dot =
simd.mul_add_f32s(sx, vx, simd.mul_add_f32s(sy, vy, simd.mul_f32s(sz, vz)));
let occ = simd.less_than_f32s(dot, limit);
chunk_mask = simd.or_m32s(chunk_mask, occ);
let not_occ = simd.xor_m32s(chunk_mask, true_mask);
if simd.first_true_m32s(not_occ) == S::F32_LANES {
break;
}
}
let not_occ = simd.xor_m32s(chunk_mask, true_mask);
let contribution =
simd.select_f32s(not_occ, simd.splat_f32s(1.0), simd.splat_f32s(0.0));
accessible_points += simd.reduce_sum_f32s(contribution);
}
let mut current_nb = 0;
for i in 0..sx_rem.len() {
let sx = sx_rem[i];
let sy = sy_rem[i];
let sz = sz_rem[i];
let mut occluded = false;
if current_nb < self.neighbors.len() {
let neighbor = &self.neighbors[current_nb];
if self.atoms[neighbor.idx as usize].id != atom.id {
if self.atoms[neighbor.idx as usize].id == atom.id {
continue;
}
let n_pos = self.atoms[neighbor.idx as usize].position;
let vx = center_pos[0] - n_pos[0];
let vy = center_pos[1] - n_pos[1];
let vz = center_pos[2] - n_pos[2];
let v_mag_sq = vx * vx + vy * vy + vz * vz;
let t = neighbor.threshold_squared;
let limit = (t - v_mag_sq - r2) / (2.0 * r);
let dot = sx * vx + sy * vy + sz * vz;
if dot <= limit {
occluded = true;
}
}
}
if !occluded {
for (idx, neighbor) in self.neighbors.iter().enumerate() {
if self.atoms[neighbor.idx as usize].id == atom.id {
continue;
}
let n_pos = self.atoms[neighbor.idx as usize].position;
let vx = center_pos[0] - n_pos[0];
let vy = center_pos[1] - n_pos[1];
let vz = center_pos[2] - n_pos[2];
let v_mag_sq = vx * vx + vy * vy + vz * vz;
let t = neighbor.threshold_squared;
let limit = (t - v_mag_sq - r2) / (2.0 * r);
let dot = sx * vx + sy * vy + sz * vz;
if dot <= limit {
occluded = true;
current_nb = idx; break;
}
}
}
if !occluded {
accessible_points += 1.0;
}
}
let surface_area = 4.0 * std::f32::consts::PI * r2;
let inv_n_points = 1.0 / (self.sphere_points.len() as f32);
surface_area * accessible_points * inv_n_points
}
}
pub fn calculate_sasa_internal(
atoms: &[Atom],
probe_radius: f32,
n_points: usize,
threads: isize,
) -> Vec<f32> {
let active_indices: Vec<usize> = (0..atoms.len()).collect();
let sphere_points = generate_sphere_points(n_points);
let max_radii = active_indices
.iter()
.map(|&i| atoms[i].radius)
.fold(0.0f32, f32::max);
let neighbor_lists = precompute_neighbors(atoms, &active_indices, probe_radius, max_radii);
let process_atom = |(list_idx, neighbors): (usize, &Vec<NeighborData>)| {
let orig_idx = active_indices[list_idx];
ARCH.dispatch(AtomSasaKernel {
atom_index: orig_idx,
atoms,
neighbors,
sphere_points: &sphere_points,
probe_radius,
})
};
let active_results: Vec<f32> = if threads == 1 {
neighbor_lists
.iter()
.enumerate()
.map(process_atom)
.collect()
} else {
neighbor_lists
.par_iter()
.enumerate()
.map(process_atom)
.collect()
};
let mut results = vec![0.0; atoms.len()];
for (list_idx, &orig_idx) in active_indices.iter().enumerate() {
results[orig_idx] = active_results[list_idx];
}
results
}