Skip to main content

rust_sasa/
options.rs

1// Copyright (c) 2024 Maxwell Campbell. Licensed under the MIT License.
2use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
3use crate::utils::consts::{POLAR_AMINO_ACIDS, load_radii_from_file};
4use crate::utils::{combine_hash, get_radius, serialize_chain_id, simd_sum};
5use crate::{Atom, calculate_sasa_internal};
6use fnv::FnvHashMap;
7use pdbtbx::PDB;
8use snafu::OptionExt;
9use snafu::prelude::*;
10use std::marker::PhantomData;
11
12/// Options for configuring SASA (Solvent Accessible Surface Area) calculations.
13///
14/// This struct provides configuration options for SASA calculations at different levels
15/// of granularity (atom, residue, chain, or protein level). The type parameter `T`
16/// determines the output type and processing behavior.
17///
18/// # Type Parameters
19///
20/// * `T` - The processing level, which must implement [`SASAProcessor`]. Available levels:
21///   - [`AtomLevel`] - Returns SASA values for individual atoms
22///   - [`ResidueLevel`] - Returns SASA values aggregated by residue
23///   - [`ChainLevel`] - Returns SASA values aggregated by chain
24///   - [`ProteinLevel`] - Returns SASA values aggregated for the entire protein
25///
26/// # Fields
27///
28/// * `probe_radius` - Radius of the solvent probe sphere in Angstroms (default: 1.4)
29/// * `n_points` - Number of points on the sphere surface for sampling (default: 100)
30/// * `threads` - Number of threads to use for parallel processing (default: -1 for all cores)
31/// * `include_hydrogens` - Whether to include hydrogen atoms in calculations (default: false)
32/// * `radii_config` - Optional custom radii configuration (default: uses embedded protor.config)
33/// * `allow_vdw_fallback` - Allow fallback to PDBTBX van der Waals radii when radius is not found in radii file (default: false)
34/// * `include_hetatms` - Whether to include HETATM records (e.g. non-standard amino acids) in calculations (default: false)
35///
36/// # Examples
37///
38/// ```rust
39/// use rust_sasa::options::{SASAOptions, ResidueLevel};
40/// use pdbtbx::PDB;
41///
42/// // Create options with default settings
43/// let options = SASAOptions::<ResidueLevel>::new();
44///
45/// // Customize the configuration
46/// let custom_options = SASAOptions::<ResidueLevel>::new()
47///     .with_probe_radius(1.2)
48///     .with_n_points(200)
49///     .with_threads(-1)
50///     .with_include_hydrogens(false)
51///     .with_allow_vdw_fallback(true)
52///     .with_include_hetatms(false);
53///
54/// // Process a PDB structure
55/// # let pdb = PDB::new();
56/// let result = custom_options.process(&pdb)?;
57/// # Ok::<(), Box<dyn std::error::Error>>(())
58/// ```
59#[derive(Debug, Clone)]
60pub struct SASAOptions<T> {
61    probe_radius: f32,
62    n_points: usize,
63    threads: isize,
64    include_hydrogens: bool,
65    radii_config: Option<FnvHashMap<String, FnvHashMap<String, f32>>>,
66    allow_vdw_fallback: bool,
67    include_hetatms: bool,
68    read_radii_from_occupancy: bool,
69    _marker: PhantomData<T>,
70}
71
72// Zero-sized marker types for each level
73pub struct AtomLevel;
74pub struct ResidueLevel;
75pub struct ChainLevel;
76pub struct ProteinLevel;
77
78pub type AtomsMappingResult = Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError>;
79
80/// Macro to reduce duplication in atom building logic
81macro_rules! build_atom {
82    ($atoms:expr, $atom:expr, $element:expr, $residue_name:expr, $atom_name:expr, $parent_id:expr, $radii_config:expr, $allow_vdw_fallback:expr, $read_radii_from_occupancy:expr, $id:expr) => {{
83        let radius = if $read_radii_from_occupancy {
84            $atom.occupancy() as f32
85        } else {
86            match get_radius($residue_name, $atom_name, $radii_config) {
87                Some(r) => r,
88                None => {
89                    if $allow_vdw_fallback {
90                        $element
91                            .atomic_radius()
92                            .van_der_waals
93                            .context(VanDerWaalsMissingSnafu)? as f32
94                    } else {
95                        return Err(SASACalcError::RadiusMissing {
96                            residue_name: $residue_name.to_string(),
97                            atom_name: $atom_name.to_string(),
98                            element: $element.to_string(),
99                        });
100                    }
101                }
102            }
103        };
104
105        $atoms.push(Atom {
106            position: [
107                $atom.pos().0 as f32,
108                $atom.pos().1 as f32,
109                $atom.pos().2 as f32,
110            ],
111            radius,
112            id: $id as usize,
113            parent_id: $parent_id,
114        });
115    }};
116}
117
118// Trait that defines the processing behavior for each level
119pub trait SASAProcessor {
120    type Output;
121
122    fn process_atoms(
123        atoms: &[Atom],
124        atom_sasa: &[f32],
125        pdb: &PDB,
126        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
127    ) -> Result<Self::Output, SASACalcError>;
128
129    fn build_atoms_and_mapping(
130        pdb: &PDB,
131        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
132        allow_vdw_fallback: bool,
133        include_hydrogens: bool,
134        include_hetatms: bool,
135        read_radii_from_occupancy: bool,
136    ) -> AtomsMappingResult;
137}
138
139impl SASAProcessor for AtomLevel {
140    type Output = Vec<f32>;
141
142    fn process_atoms(
143        _atoms: &[Atom],
144        atom_sasa: &[f32],
145        _pdb: &PDB,
146        _parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
147    ) -> Result<Self::Output, SASACalcError> {
148        Ok(atom_sasa.to_vec())
149    }
150
151    fn build_atoms_and_mapping(
152        pdb: &PDB,
153        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
154        allow_vdw_fallback: bool,
155        include_hydrogens: bool,
156        include_hetatms: bool,
157        read_radii_from_occupancy: bool,
158    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
159        let mut atoms = vec![];
160        for residue in pdb.residues() {
161            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
162            if let Some(conformer) = residue.conformers().next() {
163                for atom in conformer.atoms() {
164                    let element = atom.element().context(ElementMissingSnafu)?;
165                    let atom_name = atom.name();
166                    if element == &pdbtbx::Element::H && !include_hydrogens {
167                        continue;
168                    };
169                    if atom.hetero() && !include_hetatms {
170                        continue;
171                    }
172                    let conformer_alt = conformer.alternative_location().unwrap_or("");
173                    build_atom!(
174                        atoms,
175                        atom,
176                        element,
177                        residue_name,
178                        atom_name,
179                        None,
180                        radii_config,
181                        allow_vdw_fallback,
182                        read_radii_from_occupancy,
183                        combine_hash(&(conformer_alt, atom.serial_number()))
184                    );
185                }
186            }
187        }
188        Ok((atoms, FnvHashMap::default()))
189    }
190}
191
192impl SASAProcessor for ResidueLevel {
193    type Output = Vec<ResidueResult>;
194
195    fn process_atoms(
196        _atoms: &[Atom],
197        atom_sasa: &[f32],
198        pdb: &PDB,
199        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
200    ) -> Result<Self::Output, SASACalcError> {
201        let mut residue_sasa = vec![];
202        for chain in pdb.chains() {
203            for residue in chain.residues() {
204                let residue_key = combine_hash(&(
205                    chain.id(),
206                    residue.serial_number(),
207                    residue.insertion_code().unwrap_or_default(),
208                ));
209                let residue_atom_index = parent_to_atoms
210                    .get(&residue_key)
211                    .context(AtomMapToLevelElementFailedSnafu)?;
212                let residue_atoms: Vec<_> = residue_atom_index
213                    .iter()
214                    .map(|&index| atom_sasa[index])
215                    .collect();
216                let sum = simd_sum(residue_atoms.as_slice());
217                let name = residue
218                    .name()
219                    .context(FailedToGetResidueNameSnafu)?
220                    .to_string();
221                residue_sasa.push(ResidueResult {
222                    serial_number: residue.serial_number(),
223                    insertion_code: residue.insertion_code().unwrap_or_default().to_string(),
224                    value: sum,
225                    is_polar: POLAR_AMINO_ACIDS.contains(&name),
226                    chain_id: chain.id().to_string(),
227                    name,
228                })
229            }
230        }
231        Ok(residue_sasa)
232    }
233
234    fn build_atoms_and_mapping(
235        pdb: &PDB,
236        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
237        allow_vdw_fallback: bool,
238        include_hydrogens: bool,
239        include_hetatms: bool,
240        read_radii_from_occupancy: bool,
241    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
242        let mut atoms = vec![];
243        let mut parent_to_atoms = FnvHashMap::default();
244        let mut i = 0;
245        for chain in pdb.chains() {
246            let chain_id = chain.id();
247            for residue in chain.residues() {
248                let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
249                let residue_key = combine_hash(&(
250                    chain_id,
251                    residue.serial_number(),
252                    residue.insertion_code().unwrap_or_default(),
253                ));
254                let mut temp = vec![];
255                if let Some(conformer) = residue.conformers().next() {
256                    for atom in conformer.atoms() {
257                        let element = atom.element().context(ElementMissingSnafu)?;
258                        let atom_name = atom.name();
259                        if element == &pdbtbx::Element::H && !include_hydrogens {
260                            continue;
261                        };
262                        if atom.hetero() && !include_hetatms {
263                            continue;
264                        }
265                        let conformer_alt = conformer.alternative_location().unwrap_or("");
266                        build_atom!(
267                            atoms,
268                            atom,
269                            element,
270                            residue_name,
271                            atom_name,
272                            Some(residue.serial_number()),
273                            radii_config,
274                            allow_vdw_fallback,
275                            read_radii_from_occupancy,
276                            combine_hash(&(conformer_alt, atom.serial_number()))
277                        );
278                        temp.push(i);
279                        i += 1;
280                    }
281                    parent_to_atoms.insert(residue_key, temp);
282                }
283            }
284        }
285        Ok((atoms, parent_to_atoms))
286    }
287}
288
289impl SASAProcessor for ChainLevel {
290    type Output = Vec<ChainResult>;
291
292    fn process_atoms(
293        _atoms: &[Atom],
294        atom_sasa: &[f32],
295        pdb: &PDB,
296        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
297    ) -> Result<Self::Output, SASACalcError> {
298        let mut chain_sasa = vec![];
299        for chain in pdb.chains() {
300            let chain_id = serialize_chain_id(chain.id());
301            let chain_atom_index = parent_to_atoms
302                .get(&chain_id)
303                .context(AtomMapToLevelElementFailedSnafu)?;
304            let chain_atoms: Vec<_> = chain_atom_index
305                .iter()
306                .map(|&index| atom_sasa[index])
307                .collect();
308            let sum = simd_sum(chain_atoms.as_slice());
309            chain_sasa.push(ChainResult {
310                name: chain.id().to_string(),
311                value: sum,
312            })
313        }
314        Ok(chain_sasa)
315    }
316
317    fn build_atoms_and_mapping(
318        pdb: &PDB,
319        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
320        allow_vdw_fallback: bool,
321        include_hydrogens: bool,
322        include_hetatms: bool,
323        read_radii_from_occupancy: bool,
324    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
325        let mut atoms = vec![];
326        let mut parent_to_atoms = FnvHashMap::default();
327        let mut i = 0;
328        for chain in pdb.chains() {
329            let chain_id = serialize_chain_id(chain.id());
330            let mut temp = vec![];
331            for residue in chain.residues() {
332                let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
333                if let Some(conformer) = residue.conformers().next() {
334                    for atom in conformer.atoms() {
335                        let element = atom.element().context(ElementMissingSnafu)?;
336                        let atom_name = atom.name();
337                        let conformer_alt = conformer.alternative_location().unwrap_or("");
338                        if element == &pdbtbx::Element::H && !include_hydrogens {
339                            continue;
340                        };
341                        if atom.hetero() && !include_hetatms {
342                            continue;
343                        }
344                        build_atom!(
345                            atoms,
346                            atom,
347                            element,
348                            residue_name,
349                            atom_name,
350                            Some(chain_id),
351                            radii_config,
352                            allow_vdw_fallback,
353                            read_radii_from_occupancy,
354                            combine_hash(&(conformer_alt, atom.serial_number()))
355                        );
356                        temp.push(i);
357                        i += 1
358                    }
359                }
360            }
361            parent_to_atoms.insert(chain_id, temp);
362        }
363        Ok((atoms, parent_to_atoms))
364    }
365}
366
367impl SASAProcessor for ProteinLevel {
368    type Output = ProteinResult;
369
370    fn process_atoms(
371        _atoms: &[Atom],
372        atom_sasa: &[f32],
373        pdb: &PDB,
374        parent_to_atoms: &FnvHashMap<isize, Vec<usize>>,
375    ) -> Result<Self::Output, SASACalcError> {
376        let mut polar_total: f32 = 0.0;
377        let mut non_polar_total: f32 = 0.0;
378        for chain in pdb.chains() {
379            for residue in chain.residues() {
380                let residue_key = combine_hash(&(
381                    chain.id(),
382                    residue.serial_number(),
383                    residue.insertion_code().unwrap_or_default(),
384                ));
385                let residue_atom_index = parent_to_atoms
386                    .get(&residue_key)
387                    .context(AtomMapToLevelElementFailedSnafu)?;
388                let residue_atoms: Vec<_> = residue_atom_index
389                    .iter()
390                    .map(|&index| atom_sasa[index])
391                    .collect();
392                let sum = simd_sum(residue_atoms.as_slice());
393                let name = residue
394                    .name()
395                    .context(FailedToGetResidueNameSnafu)?
396                    .to_string();
397                if POLAR_AMINO_ACIDS.contains(&name) {
398                    polar_total += sum
399                } else {
400                    non_polar_total += sum
401                }
402            }
403        }
404        let global_sum = simd_sum(atom_sasa);
405        Ok(ProteinResult {
406            global_total: global_sum,
407            polar_total,
408            non_polar_total,
409        })
410    }
411
412    fn build_atoms_and_mapping(
413        pdb: &PDB,
414        radii_config: Option<&FnvHashMap<String, FnvHashMap<String, f32>>>,
415        allow_vdw_fallback: bool,
416        include_hydrogens: bool,
417        include_hetatms: bool,
418        read_radii_from_occupancy: bool,
419    ) -> Result<(Vec<Atom>, FnvHashMap<isize, Vec<usize>>), SASACalcError> {
420        let mut atoms = vec![];
421        let mut parent_to_atoms = FnvHashMap::default();
422        let mut i = 0;
423        for chain in pdb.chains() {
424            let chain_id = chain.id();
425            for residue in chain.residues() {
426                let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
427                let residue_key = combine_hash(&(
428                    chain_id,
429                    residue.serial_number(),
430                    residue.insertion_code().unwrap_or_default(),
431                ));
432                let mut temp = vec![];
433                if let Some(conformer) = residue.conformers().next() {
434                    for atom in conformer.atoms() {
435                        let element = atom.element().context(ElementMissingSnafu)?;
436                        let atom_name = atom.name();
437                        if element == &pdbtbx::Element::H && !include_hydrogens {
438                            continue;
439                        };
440                        if atom.hetero() && !include_hetatms {
441                            continue;
442                        }
443                        build_atom!(
444                            atoms,
445                            atom,
446                            element,
447                            residue_name,
448                            atom_name,
449                            Some(residue.serial_number()),
450                            radii_config,
451                            allow_vdw_fallback,
452                            read_radii_from_occupancy,
453                            combine_hash(&("", atom.serial_number()))
454                        );
455                        temp.push(i);
456                        i += 1;
457                    }
458                    parent_to_atoms.insert(residue_key, temp);
459                }
460            }
461        }
462        Ok((atoms, parent_to_atoms))
463    }
464}
465
466#[derive(Debug, Snafu)]
467pub enum SASACalcError {
468    #[snafu(display("Element missing for atom"))]
469    ElementMissing,
470
471    #[snafu(display("Van der Waals radius missing for element"))]
472    VanDerWaalsMissing,
473
474    #[snafu(display(
475        "Radius not found for residue '{}' atom '{}' of type '{}'. This error can can be ignored, if you are using the CLI pass --allow-vdw-fallback or use with_allow_vdw_fallback if you are using the API.",
476        residue_name,
477        atom_name,
478        element
479    ))]
480    RadiusMissing {
481        residue_name: String,
482        atom_name: String,
483        element: String,
484    },
485
486    #[snafu(display("Failed to map atoms back to level element"))]
487    AtomMapToLevelElementFailed,
488
489    #[snafu(display("Failed to get residue name"))]
490    FailedToGetResidueName,
491
492    #[snafu(display("Failed to load radii file: {source}"))]
493    RadiiFileLoad { source: std::io::Error },
494}
495
496impl<T> SASAOptions<T> {
497    /// Create a new SASAOptions with the specified level type
498    pub fn new() -> SASAOptions<T> {
499        SASAOptions {
500            probe_radius: 1.4,
501            n_points: 100,
502            threads: -1,
503            include_hydrogens: false,
504            radii_config: None,
505            allow_vdw_fallback: false,
506            include_hetatms: false,
507            read_radii_from_occupancy: false,
508            _marker: PhantomData,
509        }
510    }
511
512    /// Set the probe radius (default: 1.4 Angstroms)
513    pub fn with_probe_radius(mut self, radius: f32) -> Self {
514        self.probe_radius = radius;
515        self
516    }
517
518    /// Include or exclude HETATM records in protein.
519    pub fn with_include_hetatms(mut self, include_hetatms: bool) -> Self {
520        self.include_hetatms = include_hetatms;
521        self
522    }
523
524    /// Set the number of points on the sphere for sampling (default: 100)
525    pub fn with_n_points(mut self, points: usize) -> Self {
526        self.n_points = points;
527        self
528    }
529
530    /// Set whether radii should be read from input protein occupancy values. (default: false)
531    pub fn with_read_radii_from_occupancy(mut self, read_radii_from_occupancy: bool) -> Self {
532        self.read_radii_from_occupancy = read_radii_from_occupancy;
533        self
534    }
535
536    /// Configure the number of threads to use for parallel processing
537    ///   - `-1`: Use all available CPU cores (default)
538    ///   - `1`: Single-threaded execution (disables parallelism)
539    ///   - `> 1`: Use specified number of threads
540    pub fn with_threads(mut self, threads: isize) -> Self {
541        self.threads = threads;
542        self
543    }
544
545    /// Include or exclude hydrogen atoms in calculations (default: false)
546    pub fn with_include_hydrogens(mut self, include_hydrogens: bool) -> Self {
547        self.include_hydrogens = include_hydrogens;
548        self
549    }
550
551    /// Load custom radii configuration from a file (default: uses embedded protor.config)
552    pub fn with_radii_file(mut self, path: &str) -> Result<Self, std::io::Error> {
553        self.radii_config = Some(load_radii_from_file(path)?);
554        Ok(self)
555    }
556
557    /// Allow fallback to PDBTBX van der Waals radii when radius is not found in radii config file (default: false)
558    pub fn with_allow_vdw_fallback(mut self, allow: bool) -> Self {
559        self.allow_vdw_fallback = allow;
560        self
561    }
562}
563
564// Convenience constructors for each level
565impl SASAOptions<AtomLevel> {
566    pub fn atom_level() -> Self {
567        Self::new()
568    }
569}
570
571impl SASAOptions<ResidueLevel> {
572    pub fn residue_level() -> Self {
573        Self::new()
574    }
575}
576
577impl SASAOptions<ChainLevel> {
578    pub fn chain_level() -> Self {
579        Self::new()
580    }
581}
582
583impl SASAOptions<ProteinLevel> {
584    pub fn protein_level() -> Self {
585        Self::new()
586    }
587}
588
589impl<T> Default for SASAOptions<T> {
590    fn default() -> Self {
591        Self::new()
592    }
593}
594
595impl<T: SASAProcessor> SASAOptions<T> {
596    /// This function calculates the SASA for a given protein. The output type is determined by the level type parameter.
597    /// Probe radius and n_points can be customized, defaulting to 1.4 and 100 respectively.
598    /// If you want more fine-grained control you may want to use [calculate_sasa_internal] instead.
599    /// ## Example
600    /// ```
601    /// use pdbtbx::StrictnessLevel;
602    /// use rust_sasa::options::{SASAOptions, ResidueLevel};
603    /// let (mut pdb, _errors) = pdbtbx::open("./tests/data/pdbs/example.cif").unwrap();
604    /// let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
605    /// ```
606    pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
607        let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(
608            pdb,
609            self.radii_config.as_ref(),
610            self.allow_vdw_fallback,
611            self.include_hydrogens,
612            self.include_hetatms,
613            self.read_radii_from_occupancy,
614        )?;
615        let atom_sasa =
616            calculate_sasa_internal(&atoms, self.probe_radius, self.n_points, self.threads);
617        T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
618    }
619}