1use crate::db;
8use crate::model::{
9 atom::Atom,
10 chain::Chain,
11 grid::Grid,
12 residue::Residue,
13 structure::Structure,
14 types::{Element, Point, ResidueCategory, StandardResidue},
15};
16use crate::ops::error::Error;
17use crate::utils::parallel::*;
18use nalgebra::{Rotation3, Vector3};
19use rand::rngs::StdRng;
20use rand::seq::{IndexedRandom, SliceRandom};
21use rand::{Rng, SeedableRng};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum Cation {
26 Na,
28 K,
30 Mg,
32 Ca,
34 Li,
36 Zn,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub enum Anion {
43 Cl,
45 Br,
47 I,
49 F,
51}
52
53#[derive(Debug, Clone)]
55pub struct SolvateConfig {
56 pub margin: f64,
58 pub water_spacing: f64,
60 pub vdw_cutoff: f64,
62 pub remove_existing: bool,
64 pub cations: Vec<Cation>,
66 pub anions: Vec<Anion>,
68 pub target_charge: i32,
70 pub rng_seed: Option<u64>,
72}
73
74impl Default for SolvateConfig {
75 fn default() -> Self {
77 Self {
78 margin: 10.0,
79 water_spacing: 3.1,
80 vdw_cutoff: 2.4,
81 remove_existing: true,
82 cations: vec![Cation::Na],
83 anions: vec![Anion::Cl],
84 target_charge: 0,
85 rng_seed: None,
86 }
87 }
88}
89
90impl Cation {
91 pub fn element(&self) -> Element {
97 match self {
98 Cation::Na => Element::Na,
99 Cation::K => Element::K,
100 Cation::Mg => Element::Mg,
101 Cation::Ca => Element::Ca,
102 Cation::Li => Element::Li,
103 Cation::Zn => Element::Zn,
104 }
105 }
106
107 pub fn charge(&self) -> i32 {
113 match self {
114 Cation::Na | Cation::K | Cation::Li => 1,
115 Cation::Mg | Cation::Ca | Cation::Zn => 2,
116 }
117 }
118
119 pub fn name(&self) -> &'static str {
125 match self {
126 Cation::Na => "NA",
127 Cation::K => "K",
128 Cation::Mg => "MG",
129 Cation::Ca => "CA",
130 Cation::Li => "LI",
131 Cation::Zn => "ZN",
132 }
133 }
134}
135
136impl Anion {
137 pub fn element(&self) -> Element {
143 match self {
144 Anion::Cl => Element::Cl,
145 Anion::Br => Element::Br,
146 Anion::I => Element::I,
147 Anion::F => Element::F,
148 }
149 }
150
151 pub fn charge(&self) -> i32 {
157 -1
158 }
159
160 pub fn name(&self) -> &'static str {
166 match self {
167 Anion::Cl => "CL",
168 Anion::Br => "BR",
169 Anion::I => "I",
170 Anion::F => "F",
171 }
172 }
173}
174
175pub fn solvate_structure(structure: &mut Structure, config: &SolvateConfig) -> Result<(), Error> {
196 if config.remove_existing {
197 structure.retain_residues(|_chain_id, res| {
198 let is_water = res.standard_name == Some(StandardResidue::HOH);
199 let is_ion = res.category == ResidueCategory::Ion;
200 !is_water && !is_ion
201 });
202 structure.prune_empty_chains();
203 }
204
205 let solvent_chain_id = next_solvent_chain_id(structure);
206 let mut rng = build_rng(config);
207
208 let (min_bound, max_bound) = calculate_bounds(structure);
209 let size = max_bound - min_bound;
210
211 let box_dim = size
212 + Vector3::new(
213 config.margin * 2.0,
214 config.margin * 2.0,
215 config.margin * 2.0,
216 );
217
218 structure.box_vectors = Some([
219 [box_dim.x, 0.0, 0.0],
220 [0.0, box_dim.y, 0.0],
221 [0.0, 0.0, box_dim.z],
222 ]);
223
224 let target_origin = Vector3::new(config.margin, config.margin, config.margin);
225 let translation = target_origin - min_bound.coords;
226
227 translate_structure(structure, &translation);
228
229 let heavy_atoms: Vec<_> = structure
230 .par_atoms()
231 .filter(|a| a.element != Element::H)
232 .map(|a| (a.pos, ()))
233 .collect();
234 let grid = Grid::new(heavy_atoms, 4.0);
235
236 let mut solvent_chain = Chain::new(&solvent_chain_id);
237
238 let water_tmpl = db::get_template("HOH").ok_or(Error::MissingInternalTemplate {
239 res_name: "HOH".to_string(),
240 })?;
241 let water_name = water_tmpl.name();
242 let water_standard = water_tmpl.standard_name();
243
244 let tmpl_o_pos = water_tmpl
245 .heavy_atoms()
246 .find(|(n, _, _)| *n == "O")
247 .map(|(_, _, p)| p)
248 .unwrap_or(Point::origin());
249
250 let z_steps = (0..((box_dim.z / config.water_spacing).ceil() as usize)).collect::<Vec<_>>();
251 let base_seed = config.rng_seed.unwrap_or_else(rand::random);
252
253 let new_waters: Vec<Residue> = z_steps
254 .into_par_iter()
255 .enumerate()
256 .map(|(i, z_idx)| {
257 let mut local_rng = StdRng::seed_from_u64(base_seed.wrapping_add(i as u64));
258 let mut local_waters = Vec::new();
259 let z = (z_idx as f64 * config.water_spacing) + (config.water_spacing / 2.0);
260
261 if z >= box_dim.z {
262 return local_waters;
263 }
264
265 let mut y = config.water_spacing / 2.0;
266 while y < box_dim.y {
267 let mut x = config.water_spacing / 2.0;
268 while x < box_dim.x {
269 let candidate_pos = Point::new(x, y, z);
270
271 if grid
272 .neighbors(&candidate_pos, config.vdw_cutoff)
273 .exact()
274 .next()
275 .is_none()
276 {
277 let rotation = Rotation3::from_axis_angle(
278 &Vector3::y_axis(),
279 local_rng.random_range(0.0..std::f64::consts::TAU),
280 ) * Rotation3::from_axis_angle(
281 &Vector3::x_axis(),
282 local_rng.random_range(0.0..std::f64::consts::TAU),
283 );
284
285 let mut residue = Residue::new(
286 0,
287 None,
288 water_name,
289 Some(water_standard),
290 ResidueCategory::Standard,
291 );
292
293 let final_o_pos = candidate_pos;
294 residue.add_atom(Atom::new("O", Element::O, final_o_pos));
295
296 for (h_name, h_pos, _) in water_tmpl.hydrogens() {
297 let local_vec = h_pos - tmpl_o_pos;
298 let rotated_vec = rotation * local_vec;
299 residue.add_atom(Atom::new(
300 h_name,
301 Element::H,
302 final_o_pos + rotated_vec,
303 ));
304 }
305
306 local_waters.push(residue);
307 }
308 x += config.water_spacing;
309 }
310 y += config.water_spacing;
311 }
312 local_waters
313 })
314 .flatten()
315 .collect();
316
317 let mut water_positions = Vec::with_capacity(new_waters.len());
318 solvent_chain.reserve(new_waters.len());
319 let mut res_id_counter = 1;
320
321 for mut residue in new_waters {
322 residue.id = res_id_counter;
323 solvent_chain.add_residue(residue);
324 water_positions.push(res_id_counter);
325 res_id_counter += 1;
326 }
327
328 replace_with_ions(
329 structure,
330 &mut solvent_chain,
331 &mut water_positions,
332 config,
333 &mut rng,
334 )?;
335
336 if !solvent_chain.is_empty() {
337 structure.add_chain(solvent_chain);
338 }
339
340 Ok(())
341}
342
343fn calculate_bounds(structure: &Structure) -> (Point, Point) {
353 let mut min = Point::new(f64::MAX, f64::MAX, f64::MAX);
354 let mut max = Point::new(f64::MIN, f64::MIN, f64::MIN);
355 let mut count = 0;
356
357 for atom in structure.iter_atoms() {
358 min.x = min.x.min(atom.pos.x);
359 min.y = min.y.min(atom.pos.y);
360 min.z = min.z.min(atom.pos.z);
361 max.x = max.x.max(atom.pos.x);
362 max.y = max.y.max(atom.pos.y);
363 max.z = max.z.max(atom.pos.z);
364 count += 1;
365 }
366
367 if count == 0 {
368 return (Point::origin(), Point::origin());
369 }
370
371 (min, max)
372}
373
374fn translate_structure(structure: &mut Structure, vec: &Vector3<f64>) {
381 for atom in structure.iter_atoms_mut() {
382 atom.translate_by(vec);
383 }
384}
385
386fn calculate_solute_charge(structure: &Structure) -> i32 {
396 let mut charge = 0;
397 for chain in structure.iter_chains() {
398 for residue in chain.iter_residues() {
399 if let Some(tmpl) = db::get_template(&residue.name) {
400 charge += tmpl.charge();
401 } else if residue.category == ResidueCategory::Ion {
402 match residue.name.as_str() {
403 "NA" | "K" | "LI" => charge += 1,
404 "MG" | "CA" | "ZN" => charge += 2,
405 "CL" | "BR" | "I" | "F" => charge -= 1,
406 _ => {}
407 }
408 }
409 }
410 }
411 charge
412}
413
414fn replace_with_ions(
433 structure: &Structure,
434 solvent_chain: &mut Chain,
435 water_indices: &mut Vec<i32>,
436 config: &SolvateConfig,
437 rng: &mut impl Rng,
438) -> Result<(), Error> {
439 if config.cations.is_empty() && config.anions.is_empty() {
440 return Ok(());
441 }
442
443 let current_charge = calculate_solute_charge(structure);
444 let mut charge_diff = config.target_charge - current_charge;
445
446 water_indices.shuffle(rng);
447
448 let mut attempts = 0;
449 let max_attempts = water_indices.len();
450
451 while charge_diff != 0 && attempts < max_attempts {
452 if let Some(res_id) = water_indices.pop() {
453 let residue = solvent_chain.residue_mut(res_id, None).unwrap();
454 let pos = residue.atom("O").unwrap().pos;
455
456 if charge_diff < 0 {
457 if let Some(anion) = config.anions.choose(rng) {
458 *residue = create_anion_residue(res_id, *anion, pos);
459 charge_diff -= anion.charge();
460 } else {
461 break;
462 }
463 } else if let Some(cation) = config.cations.choose(rng) {
464 *residue = create_cation_residue(res_id, *cation, pos);
465 charge_diff -= cation.charge();
466 } else {
467 break;
468 }
469 }
470 attempts += 1;
471 }
472
473 if charge_diff != 0 {
474 if water_indices.is_empty() {
475 return Err(Error::BoxTooSmall);
476 }
477
478 return Err(Error::IonizationFailed {
479 details: format!(
480 "Could not reach target charge {}. Remaining diff: {}. Check if proper ion types are provided.",
481 config.target_charge, charge_diff
482 ),
483 });
484 }
485
486 Ok(())
487}
488
489fn create_cation_residue(id: i32, cation: Cation, pos: Point) -> Residue {
501 let mut res = Residue::new(id, None, cation.name(), None, ResidueCategory::Ion);
502 res.add_atom(Atom::new(cation.name(), cation.element(), pos));
503 res
504}
505
506fn create_anion_residue(id: i32, anion: Anion, pos: Point) -> Residue {
518 let mut res = Residue::new(id, None, anion.name(), None, ResidueCategory::Ion);
519 res.add_atom(Atom::new(anion.name(), anion.element(), pos));
520 res
521}
522
523fn build_rng(config: &SolvateConfig) -> StdRng {
533 if let Some(seed) = config.rng_seed {
534 StdRng::seed_from_u64(seed)
535 } else {
536 StdRng::from_os_rng()
537 }
538}
539
540fn next_solvent_chain_id(structure: &Structure) -> String {
550 const BASE_ID: &str = "W";
551 if structure.chain(BASE_ID).is_none() {
552 return BASE_ID.to_string();
553 }
554
555 let mut index = 1;
556 loop {
557 let candidate = format!("{}{}", BASE_ID, index);
558 if structure.chain(&candidate).is_none() {
559 return candidate;
560 }
561 index += 1;
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use crate::model::{
569 atom::Atom,
570 chain::Chain,
571 residue::Residue,
572 structure::Structure,
573 types::{Element, Point, ResidueCategory, StandardResidue},
574 };
575
576 #[test]
577 fn removes_existing_solvent_and_repositions_solute() {
578 let mut structure = Structure::new();
579
580 let mut chain_a = Chain::new("A");
581 let mut residue = Residue::new(
582 1,
583 None,
584 "ALA",
585 Some(StandardResidue::ALA),
586 ResidueCategory::Standard,
587 );
588 residue.add_atom(Atom::new("CA", Element::C, Point::new(1.0, 2.0, 3.0)));
589 residue.add_atom(Atom::new("CB", Element::C, Point::new(3.0, 4.0, 5.0)));
590 chain_a.add_residue(residue);
591 structure.add_chain(chain_a);
592
593 let mut solvent_chain = Chain::new("W");
594 let mut existing_water = Residue::new(
595 999,
596 None,
597 "HOH",
598 Some(StandardResidue::HOH),
599 ResidueCategory::Standard,
600 );
601 existing_water.add_atom(Atom::new("O", Element::O, Point::new(20.0, 20.0, 20.0)));
602 solvent_chain.add_residue(existing_water);
603 structure.add_chain(solvent_chain);
604
605 let mut ion_chain = Chain::new("I");
606 let mut ion = Residue::new(1000, None, "NA", None, ResidueCategory::Ion);
607 ion.add_atom(Atom::new("NA", Element::Na, Point::new(25.0, 25.0, 25.0)));
608 ion_chain.add_residue(ion);
609 structure.add_chain(ion_chain);
610
611 let config = SolvateConfig {
612 margin: 5.0,
613 water_spacing: 6.0,
614 vdw_cutoff: 1.5,
615 remove_existing: true,
616 cations: vec![],
617 anions: vec![],
618 target_charge: 0,
619 rng_seed: Some(42),
620 };
621
622 solvate_structure(&mut structure, &config).expect("solvation should succeed");
623
624 let solute_chain = structure.chain("A").expect("solute chain");
625 let mut min_coords = (f64::MAX, f64::MAX, f64::MAX);
626 for atom in solute_chain.iter_atoms() {
627 min_coords.0 = min_coords.0.min(atom.pos.x);
628 min_coords.1 = min_coords.1.min(atom.pos.y);
629 min_coords.2 = min_coords.2.min(atom.pos.z);
630 }
631
632 assert!((min_coords.0 - config.margin).abs() < 1e-6);
633 assert!((min_coords.1 - config.margin).abs() < 1e-6);
634 assert!((min_coords.2 - config.margin).abs() < 1e-6);
635
636 let box_vectors = structure.box_vectors.expect("box vectors");
637 assert!((box_vectors[0][0] - 12.0).abs() < 1e-6);
638 assert!((box_vectors[1][1] - 12.0).abs() < 1e-6);
639 assert!((box_vectors[2][2] - 12.0).abs() < 1e-6);
640
641 let has_legacy_ids = structure
642 .iter_chains()
643 .flat_map(|chain| chain.iter_residues())
644 .any(|res| res.id == 999 || res.id == 1000);
645 assert!(!has_legacy_ids);
646
647 let solvent_residues: Vec<_> = structure
648 .iter_chains()
649 .filter(|chain| chain.id.starts_with('W'))
650 .flat_map(|chain| chain.iter_residues())
651 .filter(|res| res.standard_name == Some(StandardResidue::HOH))
652 .collect();
653 assert!(!solvent_residues.is_empty());
654 }
655
656 #[test]
657 fn populates_expected_number_of_waters_for_uniform_grid() {
658 let mut structure = Structure::new();
659 let mut chain = Chain::new("A");
660 let mut residue = Residue::new(
661 1,
662 None,
663 "GLY",
664 Some(StandardResidue::GLY),
665 ResidueCategory::Standard,
666 );
667 residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
668 chain.add_residue(residue);
669 structure.add_chain(chain);
670
671 let config = SolvateConfig {
672 margin: 4.0,
673 water_spacing: 4.0,
674 vdw_cutoff: 1.0,
675 remove_existing: true,
676 cations: vec![],
677 anions: vec![],
678 target_charge: 0,
679 rng_seed: Some(7),
680 };
681
682 solvate_structure(&mut structure, &config).expect("solvation should succeed");
683
684 let water_count = structure
685 .iter_chains()
686 .flat_map(|chain| chain.iter_residues())
687 .filter(|res| res.standard_name == Some(StandardResidue::HOH))
688 .count();
689
690 assert_eq!(water_count, 8);
691 }
692
693 #[test]
694 fn replaces_waters_with_anions_to_match_target_charge() {
695 let lys_charge = db::get_template("LYS").expect("LYS template").charge();
696 assert!(
697 lys_charge > 0,
698 "Test expects positively charged LYS template"
699 );
700
701 let mut structure = Structure::new();
702 let mut chain = Chain::new("A");
703 let mut residue = Residue::new(
704 1,
705 None,
706 "LYS",
707 Some(StandardResidue::LYS),
708 ResidueCategory::Standard,
709 );
710 residue.add_atom(Atom::new("NZ", Element::N, Point::origin()));
711 chain.add_residue(residue);
712 structure.add_chain(chain);
713
714 let config = SolvateConfig {
715 margin: 4.0,
716 water_spacing: 4.0,
717 vdw_cutoff: 1.0,
718 remove_existing: true,
719 cations: vec![],
720 anions: vec![Anion::Cl],
721 target_charge: 0,
722 rng_seed: Some(17),
723 };
724
725 solvate_structure(&mut structure, &config).expect("solvation should succeed");
726
727 let ion_residues: Vec<_> = structure
728 .iter_chains()
729 .flat_map(|chain| chain.iter_residues())
730 .filter(|res| res.category == ResidueCategory::Ion)
731 .collect();
732
733 assert_eq!(ion_residues.len() as i32, lys_charge);
734 assert!(ion_residues.iter().all(|res| res.name == "CL"));
735 }
736
737 #[test]
738 fn returns_box_too_small_when_insufficient_waters_for_target_charge() {
739 let gly_charge = db::get_template("GLY").expect("GLY template").charge();
740 assert_eq!(gly_charge, 0, "GLY should be neutral for this test");
741
742 let mut structure = Structure::new();
743 let mut chain = Chain::new("A");
744 let mut residue = Residue::new(
745 1,
746 None,
747 "GLY",
748 Some(StandardResidue::GLY),
749 ResidueCategory::Standard,
750 );
751 residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
752 chain.add_residue(residue);
753 structure.add_chain(chain);
754
755 let config = SolvateConfig {
756 margin: 2.0,
757 water_spacing: 7.0,
758 vdw_cutoff: 0.1,
759 remove_existing: true,
760 cations: vec![Cation::Na],
761 anions: vec![],
762 target_charge: 2,
763 rng_seed: Some(5),
764 };
765
766 let result = solvate_structure(&mut structure, &config);
767 assert!(matches!(result, Err(Error::BoxTooSmall)));
768 }
769}