1use crate::structures::atomic::{ChainResult, ProteinResult, ResidueResult};
2use crate::utils::consts::POLAR_AMINO_ACIDS;
3use crate::utils::{serialize_chain_id, simd_sum};
4use crate::{Atom, calculate_sasa_internal};
5use nalgebra::Point3;
6use pdbtbx::PDB;
7use snafu::OptionExt;
8use snafu::prelude::*;
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
53pub struct SASAOptions<T> {
54 probe_radius: f32,
55 n_points: usize,
56 parallel: bool,
57 _marker: PhantomData<T>,
58}
59
60pub struct AtomLevel;
62pub struct ResidueLevel;
63pub struct ChainLevel;
64pub struct ProteinLevel;
65
66pub type AtomsMappingResult = Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError>;
67
68pub trait SASAProcessor {
70 type Output;
71
72 fn process_atoms(
73 atoms: &[Atom],
74 atom_sasa: &[f32],
75 pdb: &PDB,
76 parent_to_atoms: &HashMap<isize, Vec<usize>>,
77 ) -> Result<Self::Output, SASACalcError>;
78
79 fn build_atoms_and_mapping(pdb: &PDB) -> AtomsMappingResult;
80}
81
82impl SASAProcessor for AtomLevel {
83 type Output = Vec<f32>;
84
85 fn process_atoms(
86 _atoms: &[Atom],
87 atom_sasa: &[f32],
88 _pdb: &PDB,
89 _parent_to_atoms: &HashMap<isize, Vec<usize>>,
90 ) -> Result<Self::Output, SASACalcError> {
91 Ok(atom_sasa.to_vec())
92 }
93
94 fn build_atoms_and_mapping(
95 pdb: &PDB,
96 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
97 let mut atoms = vec![];
98 for atom in pdb.atoms() {
99 atoms.push(Atom {
100 position: Point3::new(
101 atom.pos().0 as f32,
102 atom.pos().1 as f32,
103 atom.pos().2 as f32,
104 ),
105 radius: atom
106 .element()
107 .context(ElementMissingSnafu)?
108 .atomic_radius()
109 .van_der_waals
110 .context(VanDerWaalsMissingSnafu)? as f32,
111 id: atom.serial_number(),
112 parent_id: None,
113 })
114 }
115 Ok((atoms, HashMap::new()))
116 }
117}
118
119impl SASAProcessor for ResidueLevel {
120 type Output = Vec<ResidueResult>;
121
122 fn process_atoms(
123 _atoms: &[Atom],
124 atom_sasa: &[f32],
125 pdb: &PDB,
126 parent_to_atoms: &HashMap<isize, Vec<usize>>,
127 ) -> Result<Self::Output, SASACalcError> {
128 let mut residue_sasa = vec![];
129 for chain in pdb.chains() {
130 for residue in chain.residues() {
131 let residue_atom_index = parent_to_atoms
132 .get(&residue.serial_number())
133 .context(AtomMapToLevelElementFailedSnafu)?;
134 let residue_atoms: Vec<_> = residue_atom_index
135 .iter()
136 .map(|&index| atom_sasa[index])
137 .collect();
138 let sum = simd_sum(residue_atoms.as_slice());
139 let name = residue
140 .name()
141 .context(FailedToGetResidueNameSnafu)?
142 .to_string();
143 residue_sasa.push(ResidueResult {
144 serial_number: residue.serial_number(),
145 value: sum,
146 is_polar: POLAR_AMINO_ACIDS.contains(&name),
147 chain_id: chain.id().to_string(),
148 name,
149 })
150 }
151 }
152 Ok(residue_sasa)
153 }
154
155 fn build_atoms_and_mapping(
156 pdb: &PDB,
157 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
158 let mut atoms = vec![];
159 let mut parent_to_atoms = HashMap::new();
160 let mut i = 0;
161 for residue in pdb.residues() {
162 let mut temp = vec![];
163 for atom in residue.atoms() {
164 atoms.push(Atom {
165 position: Point3::new(
166 atom.pos().0 as f32,
167 atom.pos().1 as f32,
168 atom.pos().2 as f32,
169 ),
170 radius: atom
171 .element()
172 .context(ElementMissingSnafu)?
173 .atomic_radius()
174 .van_der_waals
175 .context(VanDerWaalsMissingSnafu)? as f32,
176 id: atom.serial_number(),
177 parent_id: Some(residue.serial_number()),
178 });
179 temp.push(i);
180 i += 1;
181 }
182 parent_to_atoms.insert(residue.serial_number(), temp);
183 }
184 Ok((atoms, parent_to_atoms))
185 }
186}
187
188impl SASAProcessor for ChainLevel {
189 type Output = Vec<ChainResult>;
190
191 fn process_atoms(
192 _atoms: &[Atom],
193 atom_sasa: &[f32],
194 pdb: &PDB,
195 parent_to_atoms: &HashMap<isize, Vec<usize>>,
196 ) -> Result<Self::Output, SASACalcError> {
197 let mut chain_sasa = vec![];
198 for chain in pdb.chains() {
199 let chain_id = serialize_chain_id(chain.id());
200 let chain_atom_index = parent_to_atoms
201 .get(&chain_id)
202 .context(AtomMapToLevelElementFailedSnafu)?;
203 let chain_atoms: Vec<_> = chain_atom_index
204 .iter()
205 .map(|&index| atom_sasa[index])
206 .collect();
207 let sum = simd_sum(chain_atoms.as_slice());
208 chain_sasa.push(ChainResult {
209 name: chain.id().to_string(),
210 value: sum,
211 })
212 }
213 Ok(chain_sasa)
214 }
215
216 fn build_atoms_and_mapping(
217 pdb: &PDB,
218 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
219 let mut atoms = vec![];
220 let mut parent_to_atoms = HashMap::new();
221 let mut i = 0;
222 for chain in pdb.chains() {
223 let mut temp = vec![];
224 let chain_id = serialize_chain_id(chain.id());
225 for atom in chain.atoms() {
226 atoms.push(Atom {
227 position: Point3::new(
228 atom.pos().0 as f32,
229 atom.pos().1 as f32,
230 atom.pos().2 as f32,
231 ),
232 radius: atom
233 .element()
234 .context(ElementMissingSnafu)?
235 .atomic_radius()
236 .van_der_waals
237 .context(VanDerWaalsMissingSnafu)? as f32,
238 id: atom.serial_number(),
239 parent_id: Some(chain_id),
240 });
241 temp.push(i);
242 i += 1
243 }
244 parent_to_atoms.insert(chain_id, temp);
245 }
246 Ok((atoms, parent_to_atoms))
247 }
248}
249
250impl SASAProcessor for ProteinLevel {
251 type Output = ProteinResult;
252
253 fn process_atoms(
254 _atoms: &[Atom],
255 atom_sasa: &[f32],
256 pdb: &PDB,
257 parent_to_atoms: &HashMap<isize, Vec<usize>>,
258 ) -> Result<Self::Output, SASACalcError> {
259 let mut polar_total: f32 = 0.0;
260 let mut non_polar_total: f32 = 0.0;
261 for residue in pdb.residues() {
262 let residue_atom_index = parent_to_atoms
263 .get(&residue.serial_number())
264 .context(AtomMapToLevelElementFailedSnafu)?;
265 let residue_atoms: Vec<_> = residue_atom_index
266 .iter()
267 .map(|&index| atom_sasa[index])
268 .collect();
269 let sum = simd_sum(residue_atoms.as_slice());
270 let name = residue
271 .name()
272 .context(FailedToGetResidueNameSnafu)?
273 .to_string();
274 if POLAR_AMINO_ACIDS.contains(&name) {
275 polar_total += sum
276 } else {
277 non_polar_total += sum
278 }
279 }
280 let global_sum = simd_sum(atom_sasa);
281 Ok(ProteinResult {
282 global_total: global_sum,
283 polar_total,
284 non_polar_total,
285 })
286 }
287
288 fn build_atoms_and_mapping(
289 pdb: &PDB,
290 ) -> Result<(Vec<Atom>, HashMap<isize, Vec<usize>>), SASACalcError> {
291 let mut atoms = vec![];
292 let mut parent_to_atoms = HashMap::new();
293 let mut i = 0;
294 for residue in pdb.residues() {
295 let mut temp = vec![];
296 for atom in residue.atoms() {
297 atoms.push(Atom {
298 position: Point3::new(
299 atom.pos().0 as f32,
300 atom.pos().1 as f32,
301 atom.pos().2 as f32,
302 ),
303 radius: atom
304 .element()
305 .context(ElementMissingSnafu)?
306 .atomic_radius()
307 .van_der_waals
308 .context(VanDerWaalsMissingSnafu)? as f32,
309 id: atom.serial_number(),
310 parent_id: Some(residue.serial_number()),
311 });
312 temp.push(i);
313 i += 1;
314 }
315 parent_to_atoms.insert(residue.serial_number(), temp);
316 }
317 Ok((atoms, parent_to_atoms))
318 }
319}
320
321#[derive(Debug, Snafu)]
322pub enum SASACalcError {
323 #[snafu(display("Element missing for atom"))]
324 ElementMissing,
325
326 #[snafu(display("Van der Waals radius missing for element"))]
327 VanDerWaalsMissing,
328
329 #[snafu(display("Failed to map atoms back to level element"))]
330 AtomMapToLevelElementFailed,
331
332 #[snafu(display("Failed to get residue name"))]
333 FailedToGetResidueName,
334}
335
336impl Default for SASAOptions<ResidueLevel> {
337 fn default() -> Self {
338 Self {
339 probe_radius: 1.4, n_points: 100, parallel: true, _marker: PhantomData,
343 }
344 }
345}
346
347impl<T> SASAOptions<T> {
348 pub fn new() -> SASAOptions<T> {
350 SASAOptions {
351 probe_radius: 1.4,
352 n_points: 100,
353 parallel: false,
354 _marker: PhantomData,
355 }
356 }
357
358 pub fn with_probe_radius(mut self, radius: f32) -> Self {
360 self.probe_radius = radius;
361 self
362 }
363
364 pub fn with_n_points(mut self, points: usize) -> Self {
366 self.n_points = points;
367 self
368 }
369
370 pub fn with_parallel(mut self, parallel: bool) -> Self {
372 self.parallel = parallel;
373 self
374 }
375}
376
377impl SASAOptions<AtomLevel> {
379 pub fn atom_level() -> Self {
380 Self::new()
381 }
382}
383
384impl SASAOptions<ResidueLevel> {
385 pub fn residue_level() -> Self {
386 Self::new()
387 }
388}
389
390impl SASAOptions<ChainLevel> {
391 pub fn chain_level() -> Self {
392 Self::new()
393 }
394}
395
396impl SASAOptions<ProteinLevel> {
397 pub fn protein_level() -> Self {
398 Self::new()
399 }
400}
401
402impl<T: SASAProcessor> SASAOptions<T> {
403 pub fn process(&self, pdb: &PDB) -> Result<T::Output, SASACalcError> {
414 let (atoms, parent_to_atoms) = T::build_atoms_and_mapping(pdb)?;
415 let atom_sasa =
416 calculate_sasa_internal(&atoms, self.probe_radius, self.n_points, self.parallel);
417 T::process_atoms(&atoms, &atom_sasa, pdb, &parent_to_atoms)
418 }
419}