1use crate::db;
8use crate::model::{
9 atom::Atom,
10 chain::Chain,
11 residue::Residue,
12 structure::Structure,
13 types::{Element, Point, ResidueCategory, StandardResidue},
14};
15use crate::ops::error::Error;
16use nalgebra::{Rotation3, Vector3};
17use rand::rngs::StdRng;
18use rand::seq::{IndexedRandom, SliceRandom};
19use rand::{Rng, SeedableRng};
20use std::collections::HashMap;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum Cation {
25 Na,
27 K,
29 Mg,
31 Ca,
33 Li,
35 Zn,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub enum Anion {
42 Cl,
44 Br,
46 I,
48 F,
50}
51
52#[derive(Debug, Clone)]
54pub struct SolvateConfig {
55 pub margin: f64,
57 pub water_spacing: f64,
59 pub vdw_cutoff: f64,
61 pub remove_existing: bool,
63 pub cations: Vec<Cation>,
65 pub anions: Vec<Anion>,
67 pub target_charge: i32,
69 pub rng_seed: Option<u64>,
71}
72
73impl Default for SolvateConfig {
74 fn default() -> Self {
76 Self {
77 margin: 10.0,
78 water_spacing: 3.1,
79 vdw_cutoff: 2.4,
80 remove_existing: true,
81 cations: vec![Cation::Na],
82 anions: vec![Anion::Cl],
83 target_charge: 0,
84 rng_seed: None,
85 }
86 }
87}
88
89impl Cation {
90 pub fn element(&self) -> Element {
96 match self {
97 Cation::Na => Element::Na,
98 Cation::K => Element::K,
99 Cation::Mg => Element::Mg,
100 Cation::Ca => Element::Ca,
101 Cation::Li => Element::Li,
102 Cation::Zn => Element::Zn,
103 }
104 }
105
106 pub fn charge(&self) -> i32 {
112 match self {
113 Cation::Na | Cation::K | Cation::Li => 1,
114 Cation::Mg | Cation::Ca | Cation::Zn => 2,
115 }
116 }
117
118 pub fn name(&self) -> &'static str {
124 match self {
125 Cation::Na => "NA",
126 Cation::K => "K",
127 Cation::Mg => "MG",
128 Cation::Ca => "CA",
129 Cation::Li => "LI",
130 Cation::Zn => "ZN",
131 }
132 }
133}
134
135impl Anion {
136 pub fn element(&self) -> Element {
142 match self {
143 Anion::Cl => Element::Cl,
144 Anion::Br => Element::Br,
145 Anion::I => Element::I,
146 Anion::F => Element::F,
147 }
148 }
149
150 pub fn charge(&self) -> i32 {
156 -1
157 }
158
159 pub fn name(&self) -> &'static str {
165 match self {
166 Anion::Cl => "CL",
167 Anion::Br => "BR",
168 Anion::I => "I",
169 Anion::F => "F",
170 }
171 }
172}
173
174pub fn solvate_structure(structure: &mut Structure, config: &SolvateConfig) -> Result<(), Error> {
195 if config.remove_existing {
196 structure.retain_residues(|_chain_id, res| {
197 let is_water = res.standard_name == Some(StandardResidue::HOH);
198 let is_ion = res.category == ResidueCategory::Ion;
199 !is_water && !is_ion
200 });
201 structure.prune_empty_chains();
202 }
203
204 let solvent_chain_id = next_solvent_chain_id(structure);
205 let mut rng = build_rng(config);
206
207 let (min_bound, max_bound) = calculate_bounds(structure);
208 let size = max_bound - min_bound;
209
210 let box_dim = size
211 + Vector3::new(
212 config.margin * 2.0,
213 config.margin * 2.0,
214 config.margin * 2.0,
215 );
216
217 structure.box_vectors = Some([
218 [box_dim.x, 0.0, 0.0],
219 [0.0, box_dim.y, 0.0],
220 [0.0, 0.0, box_dim.z],
221 ]);
222
223 let target_origin = Vector3::new(config.margin, config.margin, config.margin);
224 let translation = target_origin - min_bound.coords;
225
226 translate_structure(structure, &translation);
227
228 let grid = SpatialGrid::new(structure, 4.0);
229
230 let mut solvent_chain = Chain::new(&solvent_chain_id);
231 let mut water_positions = Vec::new();
232
233 let water_tmpl = db::get_template("HOH").ok_or(Error::MissingInternalTemplate {
234 res_name: "HOH".to_string(),
235 })?;
236 let water_name = water_tmpl.name();
237 let water_standard = water_tmpl.standard_name();
238
239 let tmpl_o_pos = water_tmpl
240 .heavy_atoms()
241 .find(|(n, _, _)| *n == "O")
242 .map(|(_, _, p)| p)
243 .unwrap_or(Point::origin());
244
245 let mut z = config.water_spacing / 2.0;
246 let mut res_id_counter = 1;
247
248 while z < box_dim.z {
249 let mut y = config.water_spacing / 2.0;
250 while y < box_dim.y {
251 let mut x = config.water_spacing / 2.0;
252 while x < box_dim.x {
253 let candidate_pos = Point::new(x, y, z);
254
255 if !grid.has_clash(&candidate_pos, config.vdw_cutoff) {
256 let rotation = Rotation3::from_axis_angle(
257 &Vector3::y_axis(),
258 rng.random_range(0.0..std::f64::consts::TAU),
259 ) * Rotation3::from_axis_angle(
260 &Vector3::x_axis(),
261 rng.random_range(0.0..std::f64::consts::TAU),
262 );
263
264 let mut residue = Residue::new(
265 res_id_counter,
266 None,
267 water_name,
268 Some(water_standard),
269 ResidueCategory::Standard,
270 );
271
272 let final_o_pos = candidate_pos;
273 residue.add_atom(Atom::new("O", Element::O, final_o_pos));
274
275 for (h_name, h_pos, _) in water_tmpl.hydrogens() {
276 let local_vec = h_pos - tmpl_o_pos;
277 let rotated_vec = rotation * local_vec;
278 residue.add_atom(Atom::new(h_name, Element::H, final_o_pos + rotated_vec));
279 }
280
281 solvent_chain.add_residue(residue);
282 water_positions.push(res_id_counter);
283 res_id_counter += 1;
284 }
285
286 x += config.water_spacing;
287 }
288 y += config.water_spacing;
289 }
290 z += config.water_spacing;
291 }
292
293 replace_with_ions(
294 structure,
295 &mut solvent_chain,
296 &mut water_positions,
297 config,
298 &mut rng,
299 )?;
300
301 if !solvent_chain.is_empty() {
302 structure.add_chain(solvent_chain);
303 }
304
305 Ok(())
306}
307
308fn calculate_bounds(structure: &Structure) -> (Point, Point) {
318 let mut min = Point::new(f64::MAX, f64::MAX, f64::MAX);
319 let mut max = Point::new(f64::MIN, f64::MIN, f64::MIN);
320 let mut count = 0;
321
322 for atom in structure.iter_atoms() {
323 min.x = min.x.min(atom.pos.x);
324 min.y = min.y.min(atom.pos.y);
325 min.z = min.z.min(atom.pos.z);
326 max.x = max.x.max(atom.pos.x);
327 max.y = max.y.max(atom.pos.y);
328 max.z = max.z.max(atom.pos.z);
329 count += 1;
330 }
331
332 if count == 0 {
333 return (Point::origin(), Point::origin());
334 }
335
336 (min, max)
337}
338
339fn translate_structure(structure: &mut Structure, vec: &Vector3<f64>) {
346 for atom in structure.iter_atoms_mut() {
347 atom.translate_by(vec);
348 }
349}
350
351fn calculate_solute_charge(structure: &Structure) -> i32 {
361 let mut charge = 0;
362 for chain in structure.iter_chains() {
363 for residue in chain.iter_residues() {
364 if let Some(tmpl) = db::get_template(&residue.name) {
365 charge += tmpl.charge();
366 } else if residue.category == ResidueCategory::Ion {
367 match residue.name.as_str() {
368 "NA" | "K" | "LI" => charge += 1,
369 "MG" | "CA" | "ZN" => charge += 2,
370 "CL" | "BR" | "I" | "F" => charge -= 1,
371 _ => {}
372 }
373 }
374 }
375 }
376 charge
377}
378
379fn replace_with_ions(
398 structure: &Structure,
399 solvent_chain: &mut Chain,
400 water_indices: &mut Vec<i32>,
401 config: &SolvateConfig,
402 rng: &mut impl Rng,
403) -> Result<(), Error> {
404 if config.cations.is_empty() && config.anions.is_empty() {
405 return Ok(());
406 }
407
408 let current_charge = calculate_solute_charge(structure);
409 let mut charge_diff = config.target_charge - current_charge;
410
411 water_indices.shuffle(rng);
412
413 let mut attempts = 0;
414 let max_attempts = water_indices.len();
415
416 while charge_diff != 0 && attempts < max_attempts {
417 if let Some(res_id) = water_indices.pop() {
418 let residue = solvent_chain.residue_mut(res_id, None).unwrap();
419 let pos = residue.atom("O").unwrap().pos;
420
421 if charge_diff < 0 {
422 if let Some(anion) = config.anions.choose(rng) {
423 *residue = create_anion_residue(res_id, *anion, pos);
424 charge_diff -= anion.charge();
425 } else {
426 break;
427 }
428 } else if let Some(cation) = config.cations.choose(rng) {
429 *residue = create_cation_residue(res_id, *cation, pos);
430 charge_diff -= cation.charge();
431 } else {
432 break;
433 }
434 }
435 attempts += 1;
436 }
437
438 if charge_diff != 0 {
439 if water_indices.is_empty() {
440 return Err(Error::BoxTooSmall);
441 }
442
443 return Err(Error::IonizationFailed {
444 details: format!(
445 "Could not reach target charge {}. Remaining diff: {}. Check if proper ion types are provided.",
446 config.target_charge, charge_diff
447 ),
448 });
449 }
450
451 Ok(())
452}
453
454fn create_cation_residue(id: i32, cation: Cation, pos: Point) -> Residue {
466 let mut res = Residue::new(id, None, cation.name(), None, ResidueCategory::Ion);
467 res.add_atom(Atom::new(cation.name(), cation.element(), pos));
468 res
469}
470
471fn create_anion_residue(id: i32, anion: Anion, pos: Point) -> Residue {
483 let mut res = Residue::new(id, None, anion.name(), None, ResidueCategory::Ion);
484 res.add_atom(Atom::new(anion.name(), anion.element(), pos));
485 res
486}
487
488fn build_rng(config: &SolvateConfig) -> StdRng {
498 if let Some(seed) = config.rng_seed {
499 StdRng::seed_from_u64(seed)
500 } else {
501 StdRng::from_os_rng()
502 }
503}
504
505fn next_solvent_chain_id(structure: &Structure) -> String {
515 const BASE_ID: &str = "W";
516 if structure.chain(BASE_ID).is_none() {
517 return BASE_ID.to_string();
518 }
519
520 let mut index = 1;
521 loop {
522 let candidate = format!("{}{}", BASE_ID, index);
523 if structure.chain(&candidate).is_none() {
524 return candidate;
525 }
526 index += 1;
527 }
528}
529
530struct SpatialGrid {
532 cell_size: f64,
533 cells: HashMap<(isize, isize, isize), Vec<Point>>,
534}
535
536impl SpatialGrid {
537 fn new(structure: &Structure, cell_size: f64) -> Self {
544 let mut cells: HashMap<(isize, isize, isize), Vec<Point>> = HashMap::new();
545
546 for atom in structure.iter_atoms() {
547 if atom.element == Element::H {
548 continue;
549 }
550
551 let idx = Self::get_index(atom.pos, cell_size);
552 cells.entry(idx).or_default().push(atom.pos);
553 }
554
555 Self { cell_size, cells }
556 }
557
558 fn get_index(pos: Point, size: f64) -> (isize, isize, isize) {
565 (
566 (pos.x / size).floor() as isize,
567 (pos.y / size).floor() as isize,
568 (pos.z / size).floor() as isize,
569 )
570 }
571
572 fn has_clash(&self, pos: &Point, cutoff: f64) -> bool {
583 let center_idx = Self::get_index(*pos, self.cell_size);
584 let cutoff_sq = cutoff * cutoff;
585
586 for dx in -1..=1 {
587 for dy in -1..=1 {
588 for dz in -1..=1 {
589 let idx = (center_idx.0 + dx, center_idx.1 + dy, center_idx.2 + dz);
590 if let Some(atoms) = self.cells.get(&idx) {
591 for atom_pos in atoms {
592 if nalgebra::distance_squared(pos, atom_pos) < cutoff_sq {
593 return true;
594 }
595 }
596 }
597 }
598 }
599 }
600 false
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use crate::model::{
608 atom::Atom,
609 chain::Chain,
610 residue::Residue,
611 structure::Structure,
612 types::{Element, Point, ResidueCategory, StandardResidue},
613 };
614
615 #[test]
616 fn removes_existing_solvent_and_repositions_solute() {
617 let mut structure = Structure::new();
618
619 let mut chain_a = Chain::new("A");
620 let mut residue = Residue::new(
621 1,
622 None,
623 "ALA",
624 Some(StandardResidue::ALA),
625 ResidueCategory::Standard,
626 );
627 residue.add_atom(Atom::new("CA", Element::C, Point::new(1.0, 2.0, 3.0)));
628 residue.add_atom(Atom::new("CB", Element::C, Point::new(3.0, 4.0, 5.0)));
629 chain_a.add_residue(residue);
630 structure.add_chain(chain_a);
631
632 let mut solvent_chain = Chain::new("W");
633 let mut existing_water = Residue::new(
634 999,
635 None,
636 "HOH",
637 Some(StandardResidue::HOH),
638 ResidueCategory::Standard,
639 );
640 existing_water.add_atom(Atom::new("O", Element::O, Point::new(20.0, 20.0, 20.0)));
641 solvent_chain.add_residue(existing_water);
642 structure.add_chain(solvent_chain);
643
644 let mut ion_chain = Chain::new("I");
645 let mut ion = Residue::new(1000, None, "NA", None, ResidueCategory::Ion);
646 ion.add_atom(Atom::new("NA", Element::Na, Point::new(25.0, 25.0, 25.0)));
647 ion_chain.add_residue(ion);
648 structure.add_chain(ion_chain);
649
650 let config = SolvateConfig {
651 margin: 5.0,
652 water_spacing: 6.0,
653 vdw_cutoff: 1.5,
654 remove_existing: true,
655 cations: vec![],
656 anions: vec![],
657 target_charge: 0,
658 rng_seed: Some(42),
659 };
660
661 solvate_structure(&mut structure, &config).expect("solvation should succeed");
662
663 let solute_chain = structure.chain("A").expect("solute chain");
664 let mut min_coords = (f64::MAX, f64::MAX, f64::MAX);
665 for atom in solute_chain.iter_atoms() {
666 min_coords.0 = min_coords.0.min(atom.pos.x);
667 min_coords.1 = min_coords.1.min(atom.pos.y);
668 min_coords.2 = min_coords.2.min(atom.pos.z);
669 }
670
671 assert!((min_coords.0 - config.margin).abs() < 1e-6);
672 assert!((min_coords.1 - config.margin).abs() < 1e-6);
673 assert!((min_coords.2 - config.margin).abs() < 1e-6);
674
675 let box_vectors = structure.box_vectors.expect("box vectors");
676 assert!((box_vectors[0][0] - 12.0).abs() < 1e-6);
677 assert!((box_vectors[1][1] - 12.0).abs() < 1e-6);
678 assert!((box_vectors[2][2] - 12.0).abs() < 1e-6);
679
680 let has_legacy_ids = structure
681 .iter_chains()
682 .flat_map(|chain| chain.iter_residues())
683 .any(|res| res.id == 999 || res.id == 1000);
684 assert!(!has_legacy_ids);
685
686 let solvent_residues: Vec<_> = structure
687 .iter_chains()
688 .filter(|chain| chain.id.starts_with('W'))
689 .flat_map(|chain| chain.iter_residues())
690 .filter(|res| res.standard_name == Some(StandardResidue::HOH))
691 .collect();
692 assert!(!solvent_residues.is_empty());
693 }
694
695 #[test]
696 fn populates_expected_number_of_waters_for_uniform_grid() {
697 let mut structure = Structure::new();
698 let mut chain = Chain::new("A");
699 let mut residue = Residue::new(
700 1,
701 None,
702 "GLY",
703 Some(StandardResidue::GLY),
704 ResidueCategory::Standard,
705 );
706 residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
707 chain.add_residue(residue);
708 structure.add_chain(chain);
709
710 let config = SolvateConfig {
711 margin: 4.0,
712 water_spacing: 4.0,
713 vdw_cutoff: 1.0,
714 remove_existing: true,
715 cations: vec![],
716 anions: vec![],
717 target_charge: 0,
718 rng_seed: Some(7),
719 };
720
721 solvate_structure(&mut structure, &config).expect("solvation should succeed");
722
723 let water_count = structure
724 .iter_chains()
725 .flat_map(|chain| chain.iter_residues())
726 .filter(|res| res.standard_name == Some(StandardResidue::HOH))
727 .count();
728
729 assert_eq!(water_count, 8);
730 }
731
732 #[test]
733 fn replaces_waters_with_anions_to_match_target_charge() {
734 let lys_charge = db::get_template("LYS").expect("LYS template").charge();
735 assert!(
736 lys_charge > 0,
737 "Test expects positively charged LYS template"
738 );
739
740 let mut structure = Structure::new();
741 let mut chain = Chain::new("A");
742 let mut residue = Residue::new(
743 1,
744 None,
745 "LYS",
746 Some(StandardResidue::LYS),
747 ResidueCategory::Standard,
748 );
749 residue.add_atom(Atom::new("NZ", Element::N, Point::origin()));
750 chain.add_residue(residue);
751 structure.add_chain(chain);
752
753 let config = SolvateConfig {
754 margin: 4.0,
755 water_spacing: 4.0,
756 vdw_cutoff: 1.0,
757 remove_existing: true,
758 cations: vec![],
759 anions: vec![Anion::Cl],
760 target_charge: 0,
761 rng_seed: Some(17),
762 };
763
764 solvate_structure(&mut structure, &config).expect("solvation should succeed");
765
766 let ion_residues: Vec<_> = structure
767 .iter_chains()
768 .flat_map(|chain| chain.iter_residues())
769 .filter(|res| res.category == ResidueCategory::Ion)
770 .collect();
771
772 assert_eq!(ion_residues.len() as i32, lys_charge);
773 assert!(ion_residues.iter().all(|res| res.name == "CL"));
774 }
775
776 #[test]
777 fn returns_box_too_small_when_insufficient_waters_for_target_charge() {
778 let gly_charge = db::get_template("GLY").expect("GLY template").charge();
779 assert_eq!(gly_charge, 0, "GLY should be neutral for this test");
780
781 let mut structure = Structure::new();
782 let mut chain = Chain::new("A");
783 let mut residue = Residue::new(
784 1,
785 None,
786 "GLY",
787 Some(StandardResidue::GLY),
788 ResidueCategory::Standard,
789 );
790 residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
791 chain.add_residue(residue);
792 structure.add_chain(chain);
793
794 let config = SolvateConfig {
795 margin: 2.0,
796 water_spacing: 7.0,
797 vdw_cutoff: 0.1,
798 remove_existing: true,
799 cations: vec![Cation::Na],
800 anions: vec![],
801 target_charge: 2,
802 rng_seed: Some(5),
803 };
804
805 let result = solvate_structure(&mut structure, &config);
806 assert!(matches!(result, Err(Error::BoxTooSmall)));
807 }
808}