rust_sasa/
options.rs

1use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
2use crate::utils::consts::{POLAR_AMINO_ACIDS, get_protor_radius, load_radii_from_file};
3use crate::utils::{serialize_chain_id, simd_sum};
4use crate::{Atom, calculate_sasa_internal};
5use nalgebra::Point3;
6use pdbtbx::PDB;
7use snafu::OptionExt;
8use snafu::prelude::*;
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12/// Options for configuring SASA (Solvent Accessible Surface Area) calculations.
13///
14/// This struct provides configuration options for SASA calculations at different levels
15/// of granularity (atom, residue, chain, or protein level). The type parameter `T`
16/// determines the output type and processing behavior.
17///
18/// # Type Parameters
19///
20/// * `T` - The processing level, which must implement [`SASAProcessor`]. Available levels:
21///   - [`AtomLevel`] - Returns SASA values for individual atoms
22///   - [`ResidueLevel`] - Returns SASA values aggregated by residue
23///   - [`ChainLevel`] - Returns SASA values aggregated by chain
24///   - [`ProteinLevel`] - Returns SASA values aggregated for the entire protein
25///
26/// # Fields
27///
28/// * `probe_radius` - Radius of the solvent probe sphere in Angstroms (default: 1.4)
29/// * `n_points` - Number of points on the sphere surface for sampling (default: 100)
30/// * `parallel` - Whether to use parallel processing (default: true)
31/// * `include_hydrogens` - Whether to include hydrogen atoms in calculations (default: false)
32/// * `radii_config` - Optional custom radii configuration (default: uses embedded protor.config)
33///
34/// # Examples
35///
36/// ```rust
37/// use rust_sasa::options::{SASAOptions, ResidueLevel};
38/// use pdbtbx::PDB;
39///
40/// // Create options with default settings
41/// let options = SASAOptions::<ResidueLevel>::new();
42///
43/// // Customize the configuration
44/// let custom_options = SASAOptions::<ResidueLevel>::new()
45///     .with_probe_radius(1.2)
46///     .with_n_points(200)
47///     .with_parallel(true)
48///     .with_include_hydrogens(false);
49///
50/// // Process a PDB structure
51/// # let pdb = PDB::new();
52/// let result = custom_options.process(&pdb)?;
53/// # Ok::<(), Box<dyn std::error::Error>>(())
54/// ```
55#[derive(Debug, Clone)]
56pub struct SASAOptions<T> {
57    probe_radius: f32,
58    n_points: usize,
59    parallel: bool,
60    include_hydrogens: bool,
61    radii_config: Option<HashMap<String, HashMap<String, f32>>>,
62    _marker: PhantomData<T>,
63}
64
65// Zero-sized marker types for each level
66pub struct AtomLevel;
67pub struct ResidueLevel;
68pub struct ChainLevel;
69pub struct ProteinLevel;
70
71pub type AtomsMappingResult = Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError>;
72
73/// Helper function to get atomic radius from custom config or default protor config
74fn get_radius(
75    residue_name: &str,
76    atom_name: &str,
77    radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
78) -> Option<f32> {
79    // Check custom config first
80    if let Some(config) = radii_config {
81        if let Some(radius) = config
82            .get(residue_name)
83            .and_then(|inner| inner.get(atom_name))
84        {
85            return Some(*radius);
86        }
87    }
88    // Fall back to default protor config
89    get_protor_radius(residue_name, atom_name)
90}
91
92/// Macro to reduce duplication in atom building logic
93macro_rules! build_atom {
94    ($atoms:expr, $atom:expr, $element:expr, $residue_name:expr, $atom_name:expr, $parent_id:expr, $radii_config:expr) => {{
95        let radius = match get_radius($residue_name, $atom_name, $radii_config) {
96            Some(r) => r,
97            None => $element
98                .atomic_radius()
99                .van_der_waals
100                .context(VanDerWaalsMissingSnafu)? as f32,
101        };
102
103        $atoms.push(Atom {
104            position: Point3::new(
105                $atom.pos().0 as f32,
106                $atom.pos().1 as f32,
107                $atom.pos().2 as f32,
108            ),
109            radius,
110            id: $atom.serial_number(),
111            parent_id: $parent_id,
112            is_hydrogen: $element == &pdbtbx::Element::H,
113        });
114    }};
115}
116
117// Trait that defines the processing behavior for each level
118pub trait SASAProcessor {
119    type Output;
120
121    fn process_atoms(
122        atoms: &[Atom],
123        atom_sasa: &[f32],
124        pdb: &PDB,
125        parent_to_atoms: &HashMap<isize, Vec<usize>>,
126    ) -> Result<Self::Output, SASACalcError>;
127
128    fn build_atoms_and_mapping(
129        pdb: &PDB,
130        radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
131    ) -> AtomsMappingResult;
132}
133
134impl SASAProcessor for AtomLevel {
135    type Output = Vec<f32>;
136
137    fn process_atoms(
138        _atoms: &[Atom],
139        atom_sasa: &[f32],
140        _pdb: &PDB,
141        _parent_to_atoms: &HashMap<isize, Vec<usize>>,
142    ) -> Result<Self::Output, SASACalcError> {
143        Ok(atom_sasa.to_vec())
144    }
145
146    fn build_atoms_and_mapping(
147        pdb: &PDB,
148        radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
149    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
150        let mut atoms = vec![];
151        for residue in pdb.residues() {
152            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
153            for atom in residue.atoms() {
154                let element = atom.element().context(ElementMissingSnafu)?;
155                let atom_name = atom.name();
156                build_atom!(
157                    atoms,
158                    atom,
159                    element,
160                    residue_name,
161                    atom_name,
162                    None,
163                    radii_config
164                );
165            }
166        }
167        Ok((atoms, HashMap::new()))
168    }
169}
170
171impl SASAProcessor for ResidueLevel {
172    type Output = Vec<ResidueResult>;
173
174    fn process_atoms(
175        _atoms: &[Atom],
176        atom_sasa: &[f32],
177        pdb: &PDB,
178        parent_to_atoms: &HashMap<isize, Vec<usize>>,
179    ) -> Result<Self::Output, SASACalcError> {
180        let mut residue_sasa = vec![];
181        for chain in pdb.chains() {
182            for residue in chain.residues() {
183                let residue_atom_index = parent_to_atoms
184                    .get(&residue.serial_number())
185                    .context(AtomMapToLevelElementFailedSnafu)?;
186                let residue_atoms: Vec<_> = residue_atom_index
187                    .iter()
188                    .map(|&index| atom_sasa[index])
189                    .collect();
190                let sum = simd_sum(residue_atoms.as_slice());
191                let name = residue
192                    .name()
193                    .context(FailedToGetResidueNameSnafu)?
194                    .to_string();
195                residue_sasa.push(ResidueResult {
196                    serial_number: residue.serial_number(),
197                    value: sum,
198                    is_polar: POLAR_AMINO_ACIDS.contains(&name),
199                    chain_id: chain.id().to_string(),
200                    name,
201                })
202            }
203        }
204        Ok(residue_sasa)
205    }
206
207    fn build_atoms_and_mapping(
208        pdb: &PDB,
209        radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
210    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
211        let mut atoms = vec![];
212        let mut parent_to_atoms = HashMap::new();
213        let mut i = 0;
214        for residue in pdb.residues() {
215            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
216            let mut temp = vec![];
217            for atom in residue.atoms() {
218                let element = atom.element().context(ElementMissingSnafu)?;
219                let atom_name = atom.name();
220                build_atom!(
221                    atoms,
222                    atom,
223                    element,
224                    residue_name,
225                    atom_name,
226                    Some(residue.serial_number()),
227                    radii_config
228                );
229                temp.push(i);
230                i += 1;
231            }
232            parent_to_atoms.insert(residue.serial_number(), temp);
233        }
234        Ok((atoms, parent_to_atoms))
235    }
236}
237
238impl SASAProcessor for ChainLevel {
239    type Output = Vec<ChainResult>;
240
241    fn process_atoms(
242        _atoms: &[Atom],
243        atom_sasa: &[f32],
244        pdb: &PDB,
245        parent_to_atoms: &HashMap<isize, Vec<usize>>,
246    ) -> Result<Self::Output, SASACalcError> {
247        let mut chain_sasa = vec![];
248        for chain in pdb.chains() {
249            let chain_id = serialize_chain_id(chain.id());
250            let chain_atom_index = parent_to_atoms
251                .get(&chain_id)
252                .context(AtomMapToLevelElementFailedSnafu)?;
253            let chain_atoms: Vec<_> = chain_atom_index
254                .iter()
255                .map(|&index| atom_sasa[index])
256                .collect();
257            let sum = simd_sum(chain_atoms.as_slice());
258            chain_sasa.push(ChainResult {
259                name: chain.id().to_string(),
260                value: sum,
261            })
262        }
263        Ok(chain_sasa)
264    }
265
266    fn build_atoms_and_mapping(
267        pdb: &PDB,
268        radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
269    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
270        let mut atoms = vec![];
271        let mut parent_to_atoms = HashMap::new();
272        let mut i = 0;
273        for chain in pdb.chains() {
274            let mut temp = vec![];
275            let chain_id = serialize_chain_id(chain.id());
276            for residue in chain.residues() {
277                let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
278                for atom in residue.atoms() {
279                    let element = atom.element().context(ElementMissingSnafu)?;
280                    let atom_name = atom.name();
281                    build_atom!(
282                        atoms,
283                        atom,
284                        element,
285                        residue_name,
286                        atom_name,
287                        Some(chain_id),
288                        radii_config
289                    );
290                    temp.push(i);
291                    i += 1
292                }
293            }
294            parent_to_atoms.insert(chain_id, temp);
295        }
296        Ok((atoms, parent_to_atoms))
297    }
298}
299
300impl SASAProcessor for ProteinLevel {
301    type Output = ProteinResult;
302
303    fn process_atoms(
304        _atoms: &[Atom],
305        atom_sasa: &[f32],
306        pdb: &PDB,
307        parent_to_atoms: &HashMap<isize, Vec<usize>>,
308    ) -> Result<Self::Output, SASACalcError> {
309        let mut polar_total: f32 = 0.0;
310        let mut non_polar_total: f32 = 0.0;
311        for residue in pdb.residues() {
312            let residue_atom_index = parent_to_atoms
313                .get(&residue.serial_number())
314                .context(AtomMapToLevelElementFailedSnafu)?;
315            let residue_atoms: Vec<_> = residue_atom_index
316                .iter()
317                .map(|&index| atom_sasa[index])
318                .collect();
319            let sum = simd_sum(residue_atoms.as_slice());
320            let name = residue
321                .name()
322                .context(FailedToGetResidueNameSnafu)?
323                .to_string();
324            if POLAR_AMINO_ACIDS.contains(&name) {
325                polar_total += sum
326            } else {
327                non_polar_total += sum
328            }
329        }
330        let global_sum = simd_sum(atom_sasa);
331        Ok(ProteinResult {
332            global_total: global_sum,
333            polar_total,
334            non_polar_total,
335        })
336    }
337
338    fn build_atoms_and_mapping(
339        pdb: &PDB,
340        radii_config: Option<&HashMap<String, HashMap<String, f32>>>,
341    ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
342        let mut atoms = vec![];
343        let mut parent_to_atoms = HashMap::new();
344        let mut i = 0;
345        for residue in pdb.residues() {
346            let residue_name = residue.name().context(FailedToGetResidueNameSnafu)?;
347            let mut temp = vec![];
348            for atom in residue.atoms() {
349                let element = atom.element().context(ElementMissingSnafu)?;
350                let atom_name = atom.name();
351                build_atom!(
352                    atoms,
353                    atom,
354                    element,
355                    residue_name,
356                    atom_name,
357                    Some(residue.serial_number()),
358                    radii_config
359                );
360                temp.push(i);
361                i += 1;
362            }
363            parent_to_atoms.insert(residue.serial_number(), temp);
364        }
365        Ok((atoms, parent_to_atoms))
366    }
367}
368
369#[derive(Debug, Snafu)]
370pub enum SASACalcError {
371    #[snafu(display("Element missing for atom"))]
372    ElementMissing,
373
374    #[snafu(display("Van der Waals radius missing for element"))]
375    VanDerWaalsMissing,
376
377    #[snafu(display("Failed to map atoms back to level element"))]
378    AtomMapToLevelElementFailed,
379
380    #[snafu(display("Failed to get residue name"))]
381    FailedToGetResidueName,
382
383    #[snafu(display("Failed to load radii file: {source}"))]
384    RadiiFileLoad { source: std::io::Error },
385}
386
387impl<T> SASAOptions<T> {
388    /// Create a new SASAOptions with the specified level type
389    pub fn new() -> SASAOptions<T> {
390        SASAOptions {
391            probe_radius: 1.4,
392            n_points: 100,
393            parallel: true,
394            include_hydrogens: false,
395            radii_config: None,
396            _marker: PhantomData,
397        }
398    }
399
400    /// Set the probe radius (default: 1.4 Angstroms)
401    pub fn with_probe_radius(mut self, radius: f32) -> Self {
402        self.probe_radius = radius;
403        self
404    }
405
406    /// Set the number of points on the sphere for sampling (default: 100)
407    pub fn with_n_points(mut self, points: usize) -> Self {
408        self.n_points = points;
409        self
410    }
411
412    /// Enable or disable parallel processing (default: true)
413    pub fn with_parallel(mut self, parallel: bool) -> Self {
414        self.parallel = parallel;
415        self
416    }
417
418    /// Include or exclude hydrogen atoms in calculations (default: false)
419    pub fn with_include_hydrogens(mut self, include_hydrogens: bool) -> Self {
420        self.include_hydrogens = include_hydrogens;
421        self
422    }
423
424    /// Load custom radii configuration from a file
425    pub fn with_radii_file(mut self, path: &str) -> Result<Self, std::io::Error> {
426        self.radii_config = Some(load_radii_from_file(path)?);
427        Ok(self)
428    }
429}
430
431// Convenience constructors for each level
432impl SASAOptions<AtomLevel> {
433    pub fn atom_level() -> Self {
434        Self::new()
435    }
436}
437
438impl SASAOptions<ResidueLevel> {
439    pub fn residue_level() -> Self {
440        Self::new()
441    }
442}
443
444impl SASAOptions<ChainLevel> {
445    pub fn chain_level() -> Self {
446        Self::new()
447    }
448}
449
450impl SASAOptions<ProteinLevel> {
451    pub fn protein_level() -> Self {
452        Self::new()
453    }
454}
455
456impl<T> Default for SASAOptions<T> {
457    fn default() -> Self {
458        Self::new()
459    }
460}
461
462impl<T: SASAProcessor> SASAOptions<T> {
463    /// This function calculates the SASA for a given protein. The output type is determined by the level type parameter.
464    /// Probe radius and n_points can be customized, defaulting to 1.4 and 100 respectively.
465    /// If you want more fine-grained control you may want to use [calculate_sasa_internal] instead.
466    /// ## Example
467    /// ```
468    /// use pdbtbx::StrictnessLevel;
469    /// use rust_sasa::options::{SASAOptions, ResidueLevel};
470    /// let (mut pdb, _errors) = pdbtbx::open("./pdbs/example.cif").unwrap();
471    /// let result = SASAOptions::<ResidueLevel>::new().process(&pdb);
472    /// ```
473    pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
474        let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(pdb, self.radii_config.as_ref())?;
475        let atom_sasa = calculate_sasa_internal(
476            &atoms,
477            self.probe_radius,
478            self.n_points,
479            self.parallel,
480            self.include_hydrogens,
481        );
482        T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
483    }
484}