1use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
3use crate::utils::consts::{POLAR_AMINO_ACIDS, load_radii_from_file};
4use crate::utils::{combine_hash, get_radius, serialize_chain_id, simd_sum};
5use crate::{Atom, calculate_sasa_internal};
6use fnv::FnvHashMap;
7use pdbtbx::PDB;
8use snafu::OptionExt;
9use snafu::prelude::*;
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
60pub struct SASAOptions<T> {
61 probe_radius: f32,
62 n_points: usize,
63 threads: isize,
64 include_hydrogens: bool,
65 radii_config: Option<FnvHashMap<String, FnvHashMap<String, f32>>>,
66 allow_vdw_fallback: bool,
67 include_hetatms: bool,
68 read_radii_from_occupancy: bool,
69 _marker: PhantomData<T>,
70}
71
72pub struct AtomLevel;
74pub struct ResidueLevel;
75pub struct ChainLevel;
76pub struct ProteinLevel;
77
78pub type AtomsMappingResult = Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError>;
79
80macro_rules! build_atom {
82 ($atoms:expr, $atom:expr, $element:expr, $residue_name:expr, $atom_name:expr, $parent_id:expr, $radii_config:expr, $allow_vdw_fallback:expr, $read_radii_from_occupancy:expr, $id:expr) => {{
83 let radius = if $read_radii_from_occupancy {
84 $atom.occupancy() as f32
85 } else {
86 match get_radius($residue_name, $atom_name, $radii_config) {
87 Some(r) => r,
88 None => {
89 if $allow_vdw_fallback {
90 $element
91 .atomic_radius()
92 .van_der_waals
93 .context(VanDerWaalsMissingSnafu)? as f32
94 } else {
95 return Err(SASACalcError::RadiusMissing {
96 residue_name: $residue_name.to_string(),
97 atom_name: $atom_name.to_string(),
98 element: $element.to_string(),
99 });
100 }
101 }
102 }
103 };
104
105 $atoms.push(Atom {
106 position: [
107 $atom.pos().0 as f32,
108 $atom.pos().1 as f32,
109 $atom.pos().2 as f32,
110 ],
111 radius,
112 id: $id as usize,
113 parent_id: $parent_id,
114 });
115 }};
116}
117
118pub trait SASAProcessor {
120 type Output;
121
122 fn process_atoms(
123 atoms: &[Atom],
124 atom_sasa: &[f32],
125 pdb: &PDB,
126 parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
127 ) -> Result<Self::Output, SASACalcError>;
128
129 fn build_atoms_and_mapping(
130 pdb: &PDB,
131 radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
132 allow_vdw_fallback: bool,
133 include_hydrogens: bool,
134 include_hetatms: bool,
135 read_radii_from_occupancy: bool,
136 ) -> AtomsMappingResult;
137}
138
139impl SASAProcessor for AtomLevel {
140 type Output = Vec<f32>;
141
142 fn process_atoms(
143 _atoms: &[Atom],
144 atom_sasa: &[f32],
145 _pdb: &PDB,
146 _parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
147 ) -> Result<Self::Output, SASACalcError> {
148 Ok(atom_sasa.to_vec())
149 }
150
151 fn build_atoms_and_mapping(
152 pdb: &PDB,
153 radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
154 allow_vdw_fallback: bool,
155 include_hydrogens: bool,
156 include_hetatms: bool,
157 read_radii_from_occupancy: bool,
158 ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
159 let mut atoms = vec![];
160 for residue in pdb.residues() {
161 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
162 if let Some(conformer) = residue.conformers().next() {
163 for atom in conformer.atoms() {
164 let element = atom.element().context(ElementMissingSnafu)?;
165 let atom_name = atom.name();
166 if element == &pdbtbx::Element::H && !include_hydrogens {
167 continue;
168 };
169 if atom.hetero() && !include_hetatms {
170 continue;
171 }
172 let conformer_alt = conformer.alternative_location().unwrap_or("");
173 build_atom!(
174 atoms,
175 atom,
176 element,
177 residue_name,
178 atom_name,
179 None,
180 radii_config,
181 allow_vdw_fallback,
182 read_radii_from_occupancy,
183 combine_hash(conformer_alt, atom.serial_number())
184 );
185 }
186 }
187 }
188 Ok((atoms, FnvHashMap::default()))
189 }
190}
191
192impl SASAProcessor for ResidueLevel {
193 type Output = Vec<ResidueResult>;
194
195 fn process_atoms(
196 _atoms: &[Atom],
197 atom_sasa: &[f32],
198 pdb: &PDB,
199 parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
200 ) -> Result<Self::Output, SASACalcError> {
201 let mut residue_sasa = vec![];
202 for chain in pdb.chains() {
203 for residue in chain.residues() {
204 let residue_key = combine_hash(chain.id(), residue.serial_number());
205 let residue_atom_index = parent_to_atoms
206 .get(&residue_key)
207 .context(AtomMapToLevelElementFailedSnafu)?;
208 let residue_atoms: Vec<_> = residue_atom_index
209 .iter()
210 .map(|&index| atom_sasa[index])
211 .collect();
212 let sum = simd_sum(residue_atoms.as_slice());
213 let name = residue
214 .name()
215 .context(FailedToGetResidueNameSnafu)?
216 .to_string();
217 residue_sasa.push(ResidueResult {
218 serial_number: residue.serial_number(),
219 value: sum,
220 is_polar: POLAR_AMINO_ACIDS.contains(&name),
221 chain_id: chain.id().to_string(),
222 name,
223 })
224 }
225 }
226 Ok(residue_sasa)
227 }
228
229 fn build_atoms_and_mapping(
230 pdb: &PDB,
231 radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
232 allow_vdw_fallback: bool,
233 include_hydrogens: bool,
234 include_hetatms: bool,
235 read_radii_from_occupancy: bool,
236 ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
237 let mut atoms = vec![];
238 let mut parent_to_atoms = FnvHashMap::default();
239 let mut i = 0;
240 for chain in pdb.chains() {
241 let chain_id = chain.id();
242 for residue in chain.residues() {
243 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
244 let residue_key = combine_hash(chain_id, residue.serial_number());
245 let mut temp = vec![];
246 if let Some(conformer) = residue.conformers().next() {
247 for atom in conformer.atoms() {
248 let element = atom.element().context(ElementMissingSnafu)?;
249 let atom_name = atom.name();
250 if element == &pdbtbx::Element::H && !include_hydrogens {
251 continue;
252 };
253 if atom.hetero() && !include_hetatms {
254 continue;
255 }
256 let conformer_alt = conformer.alternative_location().unwrap_or("");
257 build_atom!(
258 atoms,
259 atom,
260 element,
261 residue_name,
262 atom_name,
263 Some(residue.serial_number()),
264 radii_config,
265 allow_vdw_fallback,
266 read_radii_from_occupancy,
267 combine_hash(conformer_alt, atom.serial_number())
268 );
269 temp.push(i);
270 i += 1;
271 }
272 parent_to_atoms.insert(residue_key, temp);
273 }
274 }
275 }
276 Ok((atoms, parent_to_atoms))
277 }
278}
279
280impl SASAProcessor for ChainLevel {
281 type Output = Vec<ChainResult>;
282
283 fn process_atoms(
284 _atoms: &[Atom],
285 atom_sasa: &[f32],
286 pdb: &PDB,
287 parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
288 ) -> Result<Self::Output, SASACalcError> {
289 let mut chain_sasa = vec![];
290 for chain in pdb.chains() {
291 let chain_id = serialize_chain_id(chain.id());
292 let chain_atom_index = parent_to_atoms
293 .get(&chain_id)
294 .context(AtomMapToLevelElementFailedSnafu)?;
295 let chain_atoms: Vec<_> = chain_atom_index
296 .iter()
297 .map(|&index| atom_sasa[index])
298 .collect();
299 let sum = simd_sum(chain_atoms.as_slice());
300 chain_sasa.push(ChainResult {
301 name: chain.id().to_string(),
302 value: sum,
303 })
304 }
305 Ok(chain_sasa)
306 }
307
308 fn build_atoms_and_mapping(
309 pdb: &PDB,
310 radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
311 allow_vdw_fallback: bool,
312 include_hydrogens: bool,
313 include_hetatms: bool,
314 read_radii_from_occupancy: bool,
315 ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
316 let mut atoms = vec![];
317 let mut parent_to_atoms = FnvHashMap::default();
318 let mut i = 0;
319 for chain in pdb.chains() {
320 let chain_id = serialize_chain_id(chain.id());
321 let mut temp = vec![];
322 for residue in chain.residues() {
323 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
324 if let Some(conformer) = residue.conformers().next() {
325 for atom in conformer.atoms() {
326 let element = atom.element().context(ElementMissingSnafu)?;
327 let atom_name = atom.name();
328 let conformer_alt = conformer.alternative_location().unwrap_or("");
329 if element == &pdbtbx::Element::H && !include_hydrogens {
330 continue;
331 };
332 if atom.hetero() && !include_hetatms {
333 continue;
334 }
335 build_atom!(
336 atoms,
337 atom,
338 element,
339 residue_name,
340 atom_name,
341 Some(chain_id),
342 radii_config,
343 allow_vdw_fallback,
344 read_radii_from_occupancy,
345 combine_hash(conformer_alt, atom.serial_number())
346 );
347 temp.push(i);
348 i += 1
349 }
350 }
351 }
352 parent_to_atoms.insert(chain_id, temp);
353 }
354 Ok((atoms, parent_to_atoms))
355 }
356}
357
358impl SASAProcessor for ProteinLevel {
359 type Output = ProteinResult;
360
361 fn process_atoms(
362 _atoms: &[Atom],
363 atom_sasa: &[f32],
364 pdb: &PDB,
365 parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
366 ) -> Result<Self::Output, SASACalcError> {
367 let mut polar_total: f32 = 0.0;
368 let mut non_polar_total: f32 = 0.0;
369 for chain in pdb.chains() {
370 for residue in chain.residues() {
371 let residue_key = combine_hash(chain.id(), residue.serial_number());
372 let residue_atom_index = parent_to_atoms
373 .get(&residue_key)
374 .context(AtomMapToLevelElementFailedSnafu)?;
375 let residue_atoms: Vec<_> = residue_atom_index
376 .iter()
377 .map(|&index| atom_sasa[index])
378 .collect();
379 let sum = simd_sum(residue_atoms.as_slice());
380 let name = residue
381 .name()
382 .context(FailedToGetResidueNameSnafu)?
383 .to_string();
384 if POLAR_AMINO_ACIDS.contains(&name) {
385 polar_total += sum
386 } else {
387 non_polar_total += sum
388 }
389 }
390 }
391 let global_sum = simd_sum(atom_sasa);
392 Ok(ProteinResult {
393 global_total: global_sum,
394 polar_total,
395 non_polar_total,
396 })
397 }
398
399 fn build_atoms_and_mapping(
400 pdb: &PDB,
401 radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
402 allow_vdw_fallback: bool,
403 include_hydrogens: bool,
404 include_hetatms: bool,
405 read_radii_from_occupancy: bool,
406 ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
407 let mut atoms = vec![];
408 let mut parent_to_atoms = FnvHashMap::default();
409 let mut i = 0;
410 for chain in pdb.chains() {
411 let chain_id = chain.id();
412 for residue in chain.residues() {
413 let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
414 let residue_key = combine_hash(chain_id, residue.serial_number());
415 let mut temp = vec![];
416 if let Some(conformer) = residue.conformers().next() {
417 for atom in conformer.atoms() {
418 let element = atom.element().context(ElementMissingSnafu)?;
419 let atom_name = atom.name();
420 if element == &pdbtbx::Element::H && !include_hydrogens {
421 continue;
422 };
423 if atom.hetero() && !include_hetatms {
424 continue;
425 }
426 build_atom!(
427 atoms,
428 atom,
429 element,
430 residue_name,
431 atom_name,
432 Some(residue.serial_number()),
433 radii_config,
434 allow_vdw_fallback,
435 read_radii_from_occupancy,
436 combine_hash("", atom.serial_number())
437 );
438 temp.push(i);
439 i += 1;
440 }
441 parent_to_atoms.insert(residue_key, temp);
442 }
443 }
444 }
445 Ok((atoms, parent_to_atoms))
446 }
447}
448
449#[derive(Debug, Snafu)]
450pub enum SASACalcError {
451 #[snafu(display("Element missing for atom"))]
452 ElementMissing,
453
454 #[snafu(display("Van der Waals radius missing for element"))]
455 VanDerWaalsMissing,
456
457 #[snafu(display(
458 "Radius not found for residue '{}' atom '{}' of type '{}'. This error can can be ignored, if you are using the CLI pass --allow-vdw-fallback or use with_allow_vdw_fallback if you are using the API.",
459 residue_name,
460 atom_name,
461 element
462 ))]
463 RadiusMissing {
464 residue_name: String,
465 atom_name: String,
466 element: String,
467 },
468
469 #[snafu(display("Failed to map atoms back to level element"))]
470 AtomMapToLevelElementFailed,
471
472 #[snafu(display("Failed to get residue name"))]
473 FailedToGetResidueName,
474
475 #[snafu(display("Failed to load radii file: {source}"))]
476 RadiiFileLoad { source: std::io::Error },
477}
478
479impl<T> SASAOptions<T> {
480 pub fn new() -> SASAOptions<T> {
482 SASAOptions {
483 probe_radius: 1.4,
484 n_points: 100,
485 threads: -1,
486 include_hydrogens: false,
487 radii_config: None,
488 allow_vdw_fallback: false,
489 include_hetatms: false,
490 read_radii_from_occupancy: false,
491 _marker: PhantomData,
492 }
493 }
494
495 pub fn with_probe_radius(mut self, radius: f32) -> Self {
497 self.probe_radius = radius;
498 self
499 }
500
501 pub fn with_include_hetatms(mut self, include_hetatms: bool) -> Self {
503 self.include_hetatms = include_hetatms;
504 self
505 }
506
507 pub fn with_n_points(mut self, points: usize) -> Self {
509 self.n_points = points;
510 self
511 }
512
513 pub fn with_read_radii_from_occupancy(mut self, read_radii_from_occupancy: bool) -> Self {
515 self.read_radii_from_occupancy = read_radii_from_occupancy;
516 self
517 }
518
519 pub fn with_threads(mut self, threads: isize) -> Self {
524 self.threads = threads;
525 self
526 }
527
528 pub fn with_include_hydrogens(mut self, include_hydrogens: bool) -> Self {
530 self.include_hydrogens = include_hydrogens;
531 self
532 }
533
534 pub fn with_radii_file(mut self, path: &str) -> Result<Self, std::io::Error> {
536 self.radii_config = Some(load_radii_from_file(path)?);
537 Ok(self)
538 }
539
540 pub fn with_allow_vdw_fallback(mut self, allow: bool) -> Self {
542 self.allow_vdw_fallback = allow;
543 self
544 }
545}
546
547impl SASAOptions<AtomLevel> {
549 pub fn atom_level() -> Self {
550 Self::new()
551 }
552}
553
554impl SASAOptions<ResidueLevel> {
555 pub fn residue_level() -> Self {
556 Self::new()
557 }
558}
559
560impl SASAOptions<ChainLevel> {
561 pub fn chain_level() -> Self {
562 Self::new()
563 }
564}
565
566impl SASAOptions<ProteinLevel> {
567 pub fn protein_level() -> Self {
568 Self::new()
569 }
570}
571
572impl<T> Default for SASAOptions<T> {
573 fn default() -> Self {
574 Self::new()
575 }
576}
577
578impl<T: SASAProcessor> SASAOptions<T> {
579 pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
590 let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(
591 pdb,
592 self.radii_config.as_ref(),
593 self.allow_vdw_fallback,
594 self.include_hydrogens,
595 self.include_hetatms,
596 self.read_radii_from_occupancy,
597 )?;
598 let atom_sasa =
599 calculate_sasa_internal(&atoms, self.probe_radius, self.n_points, self.threads);
600 T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
601 }
602}