1use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
2use crate::utils::consts::{POLAR_AMINO_ACIDS, get_protor_radius, load_radii_from_file};
3use crate::utils::{serialize_chain_id, simd_sum};
4use crate::{Atom, calculate_sasa_internal};
5use nalgebra::Point3;
6use pdbtbx::PDB;
7use snafu::OptionExt;
8use snafu::prelude::*;
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
56pub struct SASAOptions<T> {
57 probe_radius: f32,
58 n_points: usize,
59 parallel: bool,
60 include_hydrogens: bool,
61 radii_config: Option<HashMap<String, HashMap<String, f32>>>,
62 _marker: PhantomData<T>,
63}
64
65pub struct AtomLevel;
67pub struct ResidueLevel;
68pub struct ChainLevel;
69pub struct ProteinLevel;
70
71pub type AtomsMappingResult = Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError>;
72
73fn get_radius(
75 residue_name: &str,
76 atom_name: &str,
77 radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
78) -> Option<f32> {
79 if let Some(config) = radii_config {
81 if let Some(radius) = config
82 .get(residue_name)
83 .and_then(|inner| inner.get(atom_name))
84 {
85 return Some(*radius);
86 }
87 }
88 get_protor_radius(residue_name, atom_name)
90}
91
92macro_rules! build_atom {
94 ($atoms:expr, $atom:expr, $element:expr, $residue_name:expr, $atom_name:expr, $parent_id:expr, $radii_config:expr) => {{
95 let radius = match get_radius($residue_name, $atom_name, $radii_config) {
96 Some(r) => r,
97 None => $element
98 .atomic_radius()
99 .van_der_waals
100 .context(VanDerWaalsMissingSnafu)? as f32,
101 };
102
103 $atoms.push(Atom {
104 position: Point3::new(
105 $atom.pos().0 as f32,
106 $atom.pos().1 as f32,
107 $atom.pos().2 as f32,
108 ),
109 radius,
110 id: $atom.serial_number(),
111 parent_id: $parent_id,
112 is_hydrogen: $element == &pdbtbx::Element::H,
113 });
114 }};
115}
116
117pub trait SASAProcessor {
119 type Output;
120
121 fn process_atoms(
122 atoms: &[Atom],
123 atom_sasa: &[f32],
124 pdb: &PDB,
125 parent_to_atoms: &HashMap<isize, Vec<usize>>,
126 ) -> Result<Self::Output, SASACalcError>;
127
128 fn build_atoms_and_mapping(
129 pdb: &PDB,
130 radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
131 ) -> AtomsMappingResult;
132}
133
134impl SASAProcessor for AtomLevel {
135 type Output = Vec<f32>;
136
137 fn process_atoms(
138 _atoms: &[Atom],
139 atom_sasa: &[f32],
140 _pdb: &PDB,
141 _parent_to_atoms: &HashMap<isize, Vec<usize>>,
142 ) -> Result<Self::Output, SASACalcError> {
143 Ok(atom_sasa.to_vec())
144 }
145
146 fn build_atoms_and_mapping(
147 pdb: &PDB,
148 radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
149 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
150 let mut atoms = vec![];
151 for residue in pdb.residues() {
152 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
153 for atom in residue.atoms() {
154 let element = atom.element().context(ElementMissingSnafu)?;
155 let atom_name = atom.name();
156 build_atom!(
157 atoms,
158 atom,
159 element,
160 residue_name,
161 atom_name,
162 None,
163 radii_config
164 );
165 }
166 }
167 Ok((atoms, HashMap::new()))
168 }
169}
170
171impl SASAProcessor for ResidueLevel {
172 type Output = Vec<ResidueResult>;
173
174 fn process_atoms(
175 _atoms: &[Atom],
176 atom_sasa: &[f32],
177 pdb: &PDB,
178 parent_to_atoms: &HashMap<isize, Vec<usize>>,
179 ) -> Result<Self::Output, SASACalcError> {
180 let mut residue_sasa = vec![];
181 for chain in pdb.chains() {
182 for residue in chain.residues() {
183 let residue_atom_index = parent_to_atoms
184 .get(&residue.serial_number())
185 .context(AtomMapToLevelElementFailedSnafu)?;
186 let residue_atoms: Vec<_> = residue_atom_index
187 .iter()
188 .map(|&index| atom_sasa[index])
189 .collect();
190 let sum = simd_sum(residue_atoms.as_slice());
191 let name = residue
192 .name()
193 .context(FailedToGetResidueNameSnafu)?
194 .to_string();
195 residue_sasa.push(ResidueResult {
196 serial_number: residue.serial_number(),
197 value: sum,
198 is_polar: POLAR_AMINO_ACIDS.contains(&name),
199 chain_id: chain.id().to_string(),
200 name,
201 })
202 }
203 }
204 Ok(residue_sasa)
205 }
206
207 fn build_atoms_and_mapping(
208 pdb: &PDB,
209 radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
210 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
211 let mut atoms = vec![];
212 let mut parent_to_atoms = HashMap::new();
213 let mut i = 0;
214 for residue in pdb.residues() {
215 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
216 let mut temp = vec![];
217 for atom in residue.atoms() {
218 let element = atom.element().context(ElementMissingSnafu)?;
219 let atom_name = atom.name();
220 build_atom!(
221 atoms,
222 atom,
223 element,
224 residue_name,
225 atom_name,
226 Some(residue.serial_number()),
227 radii_config
228 );
229 temp.push(i);
230 i += 1;
231 }
232 parent_to_atoms.insert(residue.serial_number(), temp);
233 }
234 Ok((atoms, parent_to_atoms))
235 }
236}
237
238impl SASAProcessor for ChainLevel {
239 type Output = Vec<ChainResult>;
240
241 fn process_atoms(
242 _atoms: &[Atom],
243 atom_sasa: &[f32],
244 pdb: &PDB,
245 parent_to_atoms: &HashMap<isize, Vec<usize>>,
246 ) -> Result<Self::Output, SASACalcError> {
247 let mut chain_sasa = vec![];
248 for chain in pdb.chains() {
249 let chain_id = serialize_chain_id(chain.id());
250 let chain_atom_index = parent_to_atoms
251 .get(&chain_id)
252 .context(AtomMapToLevelElementFailedSnafu)?;
253 let chain_atoms: Vec<_> = chain_atom_index
254 .iter()
255 .map(|&index| atom_sasa[index])
256 .collect();
257 let sum = simd_sum(chain_atoms.as_slice());
258 chain_sasa.push(ChainResult {
259 name: chain.id().to_string(),
260 value: sum,
261 })
262 }
263 Ok(chain_sasa)
264 }
265
266 fn build_atoms_and_mapping(
267 pdb: &PDB,
268 radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
269 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
270 let mut atoms = vec![];
271 let mut parent_to_atoms = HashMap::new();
272 let mut i = 0;
273 for chain in pdb.chains() {
274 let mut temp = vec![];
275 let chain_id = serialize_chain_id(chain.id());
276 for residue in chain.residues() {
277 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
278 for atom in residue.atoms() {
279 let element = atom.element().context(ElementMissingSnafu)?;
280 let atom_name = atom.name();
281 build_atom!(
282 atoms,
283 atom,
284 element,
285 residue_name,
286 atom_name,
287 Some(chain_id),
288 radii_config
289 );
290 temp.push(i);
291 i += 1
292 }
293 }
294 parent_to_atoms.insert(chain_id, temp);
295 }
296 Ok((atoms, parent_to_atoms))
297 }
298}
299
300impl SASAProcessor for ProteinLevel {
301 type Output = ProteinResult;
302
303 fn process_atoms(
304 _atoms: &[Atom],
305 atom_sasa: &[f32],
306 pdb: &PDB,
307 parent_to_atoms: &HashMap<isize, Vec<usize>>,
308 ) -> Result<Self::Output, SASACalcError> {
309 let mut polar_total: f32 = 0.0;
310 let mut non_polar_total: f32 = 0.0;
311 for residue in pdb.residues() {
312 let residue_atom_index = parent_to_atoms
313 .get(&residue.serial_number())
314 .context(AtomMapToLevelElementFailedSnafu)?;
315 let residue_atoms: Vec<_> = residue_atom_index
316 .iter()
317 .map(|&index| atom_sasa[index])
318 .collect();
319 let sum = simd_sum(residue_atoms.as_slice());
320 let name = residue
321 .name()
322 .context(FailedToGetResidueNameSnafu)?
323 .to_string();
324 if POLAR_AMINO_ACIDS.contains(&name) {
325 polar_total += sum
326 } else {
327 non_polar_total += sum
328 }
329 }
330 let global_sum = simd_sum(atom_sasa);
331 Ok(ProteinResult {
332 global_total: global_sum,
333 polar_total,
334 non_polar_total,
335 })
336 }
337
338 fn build_atoms_and_mapping(
339 pdb: &PDB,
340 radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
341 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
342 let mut atoms = vec![];
343 let mut parent_to_atoms = HashMap::new();
344 let mut i = 0;
345 for residue in pdb.residues() {
346 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
347 let mut temp = vec![];
348 for atom in residue.atoms() {
349 let element = atom.element().context(ElementMissingSnafu)?;
350 let atom_name = atom.name();
351 build_atom!(
352 atoms,
353 atom,
354 element,
355 residue_name,
356 atom_name,
357 Some(residue.serial_number()),
358 radii_config
359 );
360 temp.push(i);
361 i += 1;
362 }
363 parent_to_atoms.insert(residue.serial_number(), temp);
364 }
365 Ok((atoms, parent_to_atoms))
366 }
367}
368
369#[derive(Debug, Snafu)]
370pub enum SASACalcError {
371 #[snafu(display("Element missing for atom"))]
372 ElementMissing,
373
374 #[snafu(display("Van der Waals radius missing for element"))]
375 VanDerWaalsMissing,
376
377 #[snafu(display("Failed to map atoms back to level element"))]
378 AtomMapToLevelElementFailed,
379
380 #[snafu(display("Failed to get residue name"))]
381 FailedToGetResidueName,
382
383 #[snafu(display("Failed to load radii file: {source}"))]
384 RadiiFileLoad { source: std::io::Error },
385}
386
387impl<T> SASAOptions<T> {
388 pub fn new() -> SASAOptions<T> {
390 SASAOptions {
391 probe_radius: 1.4,
392 n_points: 100,
393 parallel: true,
394 include_hydrogens: false,
395 radii_config: None,
396 _marker: PhantomData,
397 }
398 }
399
400 pub fn with_probe_radius(mut self, radius: f32) -> Self {
402 self.probe_radius = radius;
403 self
404 }
405
406 pub fn with_n_points(mut self, points: usize) -> Self {
408 self.n_points = points;
409 self
410 }
411
412 pub fn with_parallel(mut self, parallel: bool) -> Self {
414 self.parallel = parallel;
415 self
416 }
417
418 pub fn with_include_hydrogens(mut self, include_hydrogens: bool) -> Self {
420 self.include_hydrogens = include_hydrogens;
421 self
422 }
423
424 pub fn with_radii_file(mut self, path: &str) -> Result<Self, std::io::Error> {
426 self.radii_config = Some(load_radii_from_file(path)?);
427 Ok(self)
428 }
429}
430
431impl SASAOptions<AtomLevel> {
433 pub fn atom_level() -> Self {
434 Self::new()
435 }
436}
437
438impl SASAOptions<ResidueLevel> {
439 pub fn residue_level() -> Self {
440 Self::new()
441 }
442}
443
444impl SASAOptions<ChainLevel> {
445 pub fn chain_level() -> Self {
446 Self::new()
447 }
448}
449
450impl SASAOptions<ProteinLevel> {
451 pub fn protein_level() -> Self {
452 Self::new()
453 }
454}
455
456impl<T> Default for SASAOptions<T> {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462impl<T: SASAProcessor> SASAOptions<T> {
463 pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
474 let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(pdb, self.radii_config.as_ref())?;
475 let atom_sasa = calculate_sasa_internal(
476 &atoms,
477 self.probe_radius,
478 self.n_points,
479 self.parallel,
480 self.include_hydrogens,
481 );
482 T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
483 }
484}