bio_forge/ops/
solvate.rs

1//! Constructs solvent boxes around solute structures and optionally neutralizes charge.
2//!
3//! The solvation pipeline packs waters on a configurable grid, recenters the solute, sets
4//! orthorhombic box vectors, and replaces selected waters with ions to reach a desired net
5//! charge. All randomization respects deterministic seeds for reproducibility.
6
7use crate::db;
8use crate::model::{
9    atom::Atom,
10    chain::Chain,
11    residue::Residue,
12    structure::Structure,
13    types::{Element, Point, ResidueCategory, StandardResidue},
14};
15use crate::ops::error::Error;
16use nalgebra::{Rotation3, Vector3};
17use rand::rngs::StdRng;
18use rand::seq::{IndexedRandom, SliceRandom};
19use rand::{Rng, SeedableRng};
20use std::collections::HashMap;
21
22/// Supported cation species for ionic replacement.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum Cation {
25    /// Sodium ion.
26    Na,
27    /// Potassium ion.
28    K,
29    /// Magnesium ion.
30    Mg,
31    /// Calcium ion.
32    Ca,
33    /// Lithium ion.
34    Li,
35    /// Zinc ion.
36    Zn,
37}
38
39/// Supported anion species for ionic replacement.
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41pub enum Anion {
42    /// Chloride ion.
43    Cl,
44    /// Bromide ion.
45    Br,
46    /// Iodide ion.
47    I,
48    /// Fluoride ion.
49    F,
50}
51
52/// Configuration parameters controlling solvent placement and ionization.
53#[derive(Debug, Clone)]
54pub struct SolvateConfig {
55    /// Margin (Å) added in every direction around the solute before packing solvent.
56    pub margin: f64,
57    /// Distance (Å) between candidate water grid points.
58    pub water_spacing: f64,
59    /// Minimum separation (Å) between new waters and existing heavy atoms.
60    pub vdw_cutoff: f64,
61    /// Whether to remove pre-existing solvent/ions before generating the new box.
62    pub remove_existing: bool,
63    /// Cation species available for ionic substitution.
64    pub cations: Vec<Cation>,
65    /// Anion species available for ionic substitution.
66    pub anions: Vec<Anion>,
67    /// Target total charge after solvating (solute + ions + water).
68    pub target_charge: i32,
69    /// Optional RNG seed for deterministic solvent orientation.
70    pub rng_seed: Option<u64>,
71}
72
73impl Default for SolvateConfig {
74    /// Produces a rectangular water box with 10 Å padding and physiological NaCl by default.
75    fn default() -> Self {
76        Self {
77            margin: 10.0,
78            water_spacing: 3.1,
79            vdw_cutoff: 2.4,
80            remove_existing: true,
81            cations: vec![Cation::Na],
82            anions: vec![Anion::Cl],
83            target_charge: 0,
84            rng_seed: None,
85        }
86    }
87}
88
89impl Cation {
90    /// Returns the elemental identity associated with the cation.
91    ///
92    /// # Returns
93    ///
94    /// Matching [`Element`] variant for the ion.
95    pub fn element(&self) -> Element {
96        match self {
97            Cation::Na => Element::Na,
98            Cation::K => Element::K,
99            Cation::Mg => Element::Mg,
100            Cation::Ca => Element::Ca,
101            Cation::Li => Element::Li,
102            Cation::Zn => Element::Zn,
103        }
104    }
105
106    /// Reports the integer charge for the cation.
107    ///
108    /// # Returns
109    ///
110    /// `1` for monovalent ions, `2` for divalent ones.
111    pub fn charge(&self) -> i32 {
112        match self {
113            Cation::Na | Cation::K | Cation::Li => 1,
114            Cation::Mg | Cation::Ca | Cation::Zn => 2,
115        }
116    }
117
118    /// Provides the residue name used when instantiating ion residues.
119    ///
120    /// # Returns
121    ///
122    /// Uppercase residue/atom name recognized by biomolecular formats.
123    pub fn name(&self) -> &'static str {
124        match self {
125            Cation::Na => "NA",
126            Cation::K => "K",
127            Cation::Mg => "MG",
128            Cation::Ca => "CA",
129            Cation::Li => "LI",
130            Cation::Zn => "ZN",
131        }
132    }
133}
134
135impl Anion {
136    /// Returns the elemental identity associated with the anion.
137    ///
138    /// # Returns
139    ///
140    /// Matching [`Element`] for the ion.
141    pub fn element(&self) -> Element {
142        match self {
143            Anion::Cl => Element::Cl,
144            Anion::Br => Element::Br,
145            Anion::I => Element::I,
146            Anion::F => Element::F,
147        }
148    }
149
150    /// Reports the integer charge for the anion.
151    ///
152    /// # Returns
153    ///
154    /// Always returns `-1` since only monovalent anions are supported.
155    pub fn charge(&self) -> i32 {
156        -1
157    }
158
159    /// Provides the residue name used when instantiating the anion residue.
160    ///
161    /// # Returns
162    ///
163    /// Uppercase residue code recognized by biomolecular formats.
164    pub fn name(&self) -> &'static str {
165        match self {
166            Anion::Cl => "CL",
167            Anion::Br => "BR",
168            Anion::I => "I",
169            Anion::F => "F",
170        }
171    }
172}
173
174/// Builds a solvent box, translates the solute to the padded origin, and inserts ions.
175///
176/// The function removes existing solvent when requested, computes an orthorhombic box from
177/// the solute bounds plus margins, packs waters on a regular grid while randomizing orientation,
178/// and finally replaces selected waters with ions to reach the target charge.
179///
180/// # Arguments
181///
182/// * `structure` - Mutable structure containing the solute atoms to surround with solvent.
183/// * `config` - Parameters controlling padding, spacing, ion species, and RNG seeding.
184///
185/// # Returns
186///
187/// `Ok(())` when solvent and ions are generated successfully.
188///
189/// # Errors
190///
191/// Returns [`Error::MissingInternalTemplate`] if the water template is absent,
192/// [`Error::BoxTooSmall`] when insufficient waters remain for ion swapping, or
193/// [`Error::IonizationFailed`] when the requested charge cannot be achieved.
194pub fn solvate_structure(structure: &mut Structure, config: &SolvateConfig) -> Result<(), Error> {
195    if config.remove_existing {
196        structure.retain_residues(|_chain_id, res| {
197            let is_water = res.standard_name == Some(StandardResidue::HOH);
198            let is_ion = res.category == ResidueCategory::Ion;
199            !is_water && !is_ion
200        });
201        structure.prune_empty_chains();
202    }
203
204    let solvent_chain_id = next_solvent_chain_id(structure);
205    let mut rng = build_rng(config);
206
207    let (min_bound, max_bound) = calculate_bounds(structure);
208    let size = max_bound - min_bound;
209
210    let box_dim = size
211        + Vector3::new(
212            config.margin * 2.0,
213            config.margin * 2.0,
214            config.margin * 2.0,
215        );
216
217    structure.box_vectors = Some([
218        [box_dim.x, 0.0, 0.0],
219        [0.0, box_dim.y, 0.0],
220        [0.0, 0.0, box_dim.z],
221    ]);
222
223    let target_origin = Vector3::new(config.margin, config.margin, config.margin);
224    let translation = target_origin - min_bound.coords;
225
226    translate_structure(structure, &translation);
227
228    let grid = SpatialGrid::new(structure, 4.0);
229
230    let mut solvent_chain = Chain::new(&solvent_chain_id);
231    let mut water_positions = Vec::new();
232
233    let water_tmpl = db::get_template("HOH").ok_or(Error::MissingInternalTemplate {
234        res_name: "HOH".to_string(),
235    })?;
236    let water_name = water_tmpl.name();
237    let water_standard = water_tmpl.standard_name();
238
239    let tmpl_o_pos = water_tmpl
240        .heavy_atoms()
241        .find(|(n, _, _)| *n == "O")
242        .map(|(_, _, p)| p)
243        .unwrap_or(Point::origin());
244
245    let mut z = config.water_spacing / 2.0;
246    let mut res_id_counter = 1;
247
248    while z < box_dim.z {
249        let mut y = config.water_spacing / 2.0;
250        while y < box_dim.y {
251            let mut x = config.water_spacing / 2.0;
252            while x < box_dim.x {
253                let candidate_pos = Point::new(x, y, z);
254
255                if !grid.has_clash(&candidate_pos, config.vdw_cutoff) {
256                    let rotation = Rotation3::from_axis_angle(
257                        &Vector3::y_axis(),
258                        rng.random_range(0.0..std::f64::consts::TAU),
259                    ) * Rotation3::from_axis_angle(
260                        &Vector3::x_axis(),
261                        rng.random_range(0.0..std::f64::consts::TAU),
262                    );
263
264                    let mut residue = Residue::new(
265                        res_id_counter,
266                        None,
267                        water_name,
268                        Some(water_standard),
269                        ResidueCategory::Standard,
270                    );
271
272                    let final_o_pos = candidate_pos;
273                    residue.add_atom(Atom::new("O", Element::O, final_o_pos));
274
275                    for (h_name, h_pos, _) in water_tmpl.hydrogens() {
276                        let local_vec = h_pos - tmpl_o_pos;
277                        let rotated_vec = rotation * local_vec;
278                        residue.add_atom(Atom::new(h_name, Element::H, final_o_pos + rotated_vec));
279                    }
280
281                    solvent_chain.add_residue(residue);
282                    water_positions.push(res_id_counter);
283                    res_id_counter += 1;
284                }
285
286                x += config.water_spacing;
287            }
288            y += config.water_spacing;
289        }
290        z += config.water_spacing;
291    }
292
293    replace_with_ions(
294        structure,
295        &mut solvent_chain,
296        &mut water_positions,
297        config,
298        &mut rng,
299    )?;
300
301    if !solvent_chain.is_empty() {
302        structure.add_chain(solvent_chain);
303    }
304
305    Ok(())
306}
307
308/// Computes axis-aligned bounding box for all atoms in the structure.
309///
310/// # Arguments
311///
312/// * `structure` - Structure whose atoms will be scanned.
313///
314/// # Returns
315///
316/// Tuple of `(min_point, max_point)` representing the bounding box.
317fn calculate_bounds(structure: &Structure) -> (Point, Point) {
318    let mut min = Point::new(f64::MAX, f64::MAX, f64::MAX);
319    let mut max = Point::new(f64::MIN, f64::MIN, f64::MIN);
320    let mut count = 0;
321
322    for atom in structure.iter_atoms() {
323        min.x = min.x.min(atom.pos.x);
324        min.y = min.y.min(atom.pos.y);
325        min.z = min.z.min(atom.pos.z);
326        max.x = max.x.max(atom.pos.x);
327        max.y = max.y.max(atom.pos.y);
328        max.z = max.z.max(atom.pos.z);
329        count += 1;
330    }
331
332    if count == 0 {
333        return (Point::origin(), Point::origin());
334    }
335
336    (min, max)
337}
338
339/// Translates every atom in the structure by the provided vector.
340///
341/// # Arguments
342///
343/// * `structure` - Structure to move.
344/// * `vec` - Translation vector applied to each atom.
345fn translate_structure(structure: &mut Structure, vec: &Vector3<f64>) {
346    for atom in structure.iter_atoms_mut() {
347        atom.translate_by(vec);
348    }
349}
350
351/// Estimates the current solute charge using template charges and known ions.
352///
353/// # Arguments
354///
355/// * `structure` - Structure whose charge should be measured.
356///
357/// # Returns
358///
359/// Integer charge accumulated from templates and residue labels.
360fn calculate_solute_charge(structure: &Structure) -> i32 {
361    let mut charge = 0;
362    for chain in structure.iter_chains() {
363        for residue in chain.iter_residues() {
364            if let Some(tmpl) = db::get_template(&residue.name) {
365                charge += tmpl.charge();
366            } else if residue.category == ResidueCategory::Ion {
367                match residue.name.as_str() {
368                    "NA" | "K" | "LI" => charge += 1,
369                    "MG" | "CA" | "ZN" => charge += 2,
370                    "CL" | "BR" | "I" | "F" => charge -= 1,
371                    _ => {}
372                }
373            }
374        }
375    }
376    charge
377}
378
379/// Replaces selected waters with ions to reach the requested total charge.
380///
381/// # Arguments
382///
383/// * `structure` - Current solute (used for charge estimation).
384/// * `solvent_chain` - Chain containing newly created solvent residues.
385/// * `water_indices` - Residue IDs that can be substituted with ions.
386/// * `config` - Solvation configuration specifying ion species and target charge.
387/// * `rng` - Random number generator for stochastic selection.
388///
389/// # Returns
390///
391/// `Ok(())` when the charge target is hit or ions are not requested.
392///
393/// # Errors
394///
395/// Returns [`Error::BoxTooSmall`] if no waters remain to swap or
396/// [`Error::IonizationFailed`] when charge neutrality cannot be achieved.
397fn replace_with_ions(
398    structure: &Structure,
399    solvent_chain: &mut Chain,
400    water_indices: &mut Vec<i32>,
401    config: &SolvateConfig,
402    rng: &mut impl Rng,
403) -> Result<(), Error> {
404    if config.cations.is_empty() && config.anions.is_empty() {
405        return Ok(());
406    }
407
408    let current_charge = calculate_solute_charge(structure);
409    let mut charge_diff = config.target_charge - current_charge;
410
411    water_indices.shuffle(rng);
412
413    let mut attempts = 0;
414    let max_attempts = water_indices.len();
415
416    while charge_diff != 0 && attempts < max_attempts {
417        if let Some(res_id) = water_indices.pop() {
418            let residue = solvent_chain.residue_mut(res_id, None).unwrap();
419            let pos = residue.atom("O").unwrap().pos;
420
421            if charge_diff < 0 {
422                if let Some(anion) = config.anions.choose(rng) {
423                    *residue = create_anion_residue(res_id, *anion, pos);
424                    charge_diff -= anion.charge();
425                } else {
426                    break;
427                }
428            } else if let Some(cation) = config.cations.choose(rng) {
429                *residue = create_cation_residue(res_id, *cation, pos);
430                charge_diff -= cation.charge();
431            } else {
432                break;
433            }
434        }
435        attempts += 1;
436    }
437
438    if charge_diff != 0 {
439        if water_indices.is_empty() {
440            return Err(Error::BoxTooSmall);
441        }
442
443        return Err(Error::IonizationFailed {
444            details: format!(
445                "Could not reach target charge {}. Remaining diff: {}. Check if proper ion types are provided.",
446                config.target_charge, charge_diff
447            ),
448        });
449    }
450
451    Ok(())
452}
453
454/// Creates a single-ion residue for the provided cation at a given position.
455///
456/// # Arguments
457///
458/// * `id` - Residue identifier to assign.
459/// * `cation` - Ion species to instantiate.
460/// * `pos` - Coordinates where the ion will be placed.
461///
462/// # Returns
463///
464/// A residue labeled as [`ResidueCategory::Ion`].
465fn create_cation_residue(id: i32, cation: Cation, pos: Point) -> Residue {
466    let mut res = Residue::new(id, None, cation.name(), None, ResidueCategory::Ion);
467    res.add_atom(Atom::new(cation.name(), cation.element(), pos));
468    res
469}
470
471/// Creates a single-ion residue for the provided anion at a given position.
472///
473/// # Arguments
474///
475/// * `id` - Residue identifier.
476/// * `anion` - Ion species to instantiate.
477/// * `pos` - Coordinates where the ion is placed.
478///
479/// # Returns
480///
481/// A residue labeled as [`ResidueCategory::Ion`].
482fn create_anion_residue(id: i32, anion: Anion, pos: Point) -> Residue {
483    let mut res = Residue::new(id, None, anion.name(), None, ResidueCategory::Ion);
484    res.add_atom(Atom::new(anion.name(), anion.element(), pos));
485    res
486}
487
488/// Builds a seeded or OS-random generator for solvent placement.
489///
490/// # Arguments
491///
492/// * `config` - Solvation configuration containing an optional seed.
493///
494/// # Returns
495///
496/// Deterministic RNG when a seed is given; otherwise an OS-random generator.
497fn build_rng(config: &SolvateConfig) -> StdRng {
498    if let Some(seed) = config.rng_seed {
499        StdRng::seed_from_u64(seed)
500    } else {
501        StdRng::from_os_rng()
502    }
503}
504
505/// Derives the next available solvent chain identifier (W, W1, W2, ...).
506///
507/// # Arguments
508///
509/// * `structure` - Structure used to check for existing chain IDs.
510///
511/// # Returns
512///
513/// Unique chain ID for newly inserted solvent.
514fn next_solvent_chain_id(structure: &Structure) -> String {
515    const BASE_ID: &str = "W";
516    if structure.chain(BASE_ID).is_none() {
517        return BASE_ID.to_string();
518    }
519
520    let mut index = 1;
521    loop {
522        let candidate = format!("{}{}", BASE_ID, index);
523        if structure.chain(&candidate).is_none() {
524            return candidate;
525        }
526        index += 1;
527    }
528}
529
530/// Sparse spatial hash used to reject candidate water positions via clash checks.
531struct SpatialGrid {
532    cell_size: f64,
533    cells: HashMap<(isize, isize, isize), Vec<Point>>,
534}
535
536impl SpatialGrid {
537    /// Builds the grid from existing heavy atoms in the structure.
538    ///
539    /// # Arguments
540    ///
541    /// * `structure` - Structure providing the atom positions.
542    /// * `cell_size` - Edge length (Å) for each spatial bin.
543    fn new(structure: &Structure, cell_size: f64) -> Self {
544        let mut cells: HashMap<(isize, isize, isize), Vec<Point>> = HashMap::new();
545
546        for atom in structure.iter_atoms() {
547            if atom.element == Element::H {
548                continue;
549            }
550
551            let idx = Self::get_index(atom.pos, cell_size);
552            cells.entry(idx).or_default().push(atom.pos);
553        }
554
555        Self { cell_size, cells }
556    }
557
558    /// Computes the cell index for a position.
559    ///
560    /// # Arguments
561    ///
562    /// * `pos` - Coordinate being binned.
563    /// * `size` - Cell size used when building the grid.
564    fn get_index(pos: Point, size: f64) -> (isize, isize, isize) {
565        (
566            (pos.x / size).floor() as isize,
567            (pos.y / size).floor() as isize,
568            (pos.z / size).floor() as isize,
569        )
570    }
571
572    /// Tests whether placing a water at the given point would violate the cutoff.
573    ///
574    /// # Arguments
575    ///
576    /// * `pos` - Candidate water position.
577    /// * `cutoff` - Minimum allowed heavy-atom distance.
578    ///
579    /// # Returns
580    ///
581    /// `true` when any stored heavy atom is closer than the cutoff.
582    fn has_clash(&self, pos: &Point, cutoff: f64) -> bool {
583        let center_idx = Self::get_index(*pos, self.cell_size);
584        let cutoff_sq = cutoff * cutoff;
585
586        for dx in -1..=1 {
587            for dy in -1..=1 {
588                for dz in -1..=1 {
589                    let idx = (center_idx.0 + dx, center_idx.1 + dy, center_idx.2 + dz);
590                    if let Some(atoms) = self.cells.get(&idx) {
591                        for atom_pos in atoms {
592                            if nalgebra::distance_squared(pos, atom_pos) < cutoff_sq {
593                                return true;
594                            }
595                        }
596                    }
597                }
598            }
599        }
600        false
601    }
602}
603
604#[cfg(test)]
605mod tests {
606    use super::*;
607    use crate::model::{
608        atom::Atom,
609        chain::Chain,
610        residue::Residue,
611        structure::Structure,
612        types::{Element, Point, ResidueCategory, StandardResidue},
613    };
614
615    #[test]
616    fn removes_existing_solvent_and_repositions_solute() {
617        let mut structure = Structure::new();
618
619        let mut chain_a = Chain::new("A");
620        let mut residue = Residue::new(
621            1,
622            None,
623            "ALA",
624            Some(StandardResidue::ALA),
625            ResidueCategory::Standard,
626        );
627        residue.add_atom(Atom::new("CA", Element::C, Point::new(1.0, 2.0, 3.0)));
628        residue.add_atom(Atom::new("CB", Element::C, Point::new(3.0, 4.0, 5.0)));
629        chain_a.add_residue(residue);
630        structure.add_chain(chain_a);
631
632        let mut solvent_chain = Chain::new("W");
633        let mut existing_water = Residue::new(
634            999,
635            None,
636            "HOH",
637            Some(StandardResidue::HOH),
638            ResidueCategory::Standard,
639        );
640        existing_water.add_atom(Atom::new("O", Element::O, Point::new(20.0, 20.0, 20.0)));
641        solvent_chain.add_residue(existing_water);
642        structure.add_chain(solvent_chain);
643
644        let mut ion_chain = Chain::new("I");
645        let mut ion = Residue::new(1000, None, "NA", None, ResidueCategory::Ion);
646        ion.add_atom(Atom::new("NA", Element::Na, Point::new(25.0, 25.0, 25.0)));
647        ion_chain.add_residue(ion);
648        structure.add_chain(ion_chain);
649
650        let config = SolvateConfig {
651            margin: 5.0,
652            water_spacing: 6.0,
653            vdw_cutoff: 1.5,
654            remove_existing: true,
655            cations: vec![],
656            anions: vec![],
657            target_charge: 0,
658            rng_seed: Some(42),
659        };
660
661        solvate_structure(&mut structure, &config).expect("solvation should succeed");
662
663        let solute_chain = structure.chain("A").expect("solute chain");
664        let mut min_coords = (f64::MAX, f64::MAX, f64::MAX);
665        for atom in solute_chain.iter_atoms() {
666            min_coords.0 = min_coords.0.min(atom.pos.x);
667            min_coords.1 = min_coords.1.min(atom.pos.y);
668            min_coords.2 = min_coords.2.min(atom.pos.z);
669        }
670
671        assert!((min_coords.0 - config.margin).abs() < 1e-6);
672        assert!((min_coords.1 - config.margin).abs() < 1e-6);
673        assert!((min_coords.2 - config.margin).abs() < 1e-6);
674
675        let box_vectors = structure.box_vectors.expect("box vectors");
676        assert!((box_vectors[0][0] - 12.0).abs() < 1e-6);
677        assert!((box_vectors[1][1] - 12.0).abs() < 1e-6);
678        assert!((box_vectors[2][2] - 12.0).abs() < 1e-6);
679
680        let has_legacy_ids = structure
681            .iter_chains()
682            .flat_map(|chain| chain.iter_residues())
683            .any(|res| res.id == 999 || res.id == 1000);
684        assert!(!has_legacy_ids);
685
686        let solvent_residues: Vec<_> = structure
687            .iter_chains()
688            .filter(|chain| chain.id.starts_with('W'))
689            .flat_map(|chain| chain.iter_residues())
690            .filter(|res| res.standard_name == Some(StandardResidue::HOH))
691            .collect();
692        assert!(!solvent_residues.is_empty());
693    }
694
695    #[test]
696    fn populates_expected_number_of_waters_for_uniform_grid() {
697        let mut structure = Structure::new();
698        let mut chain = Chain::new("A");
699        let mut residue = Residue::new(
700            1,
701            None,
702            "GLY",
703            Some(StandardResidue::GLY),
704            ResidueCategory::Standard,
705        );
706        residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
707        chain.add_residue(residue);
708        structure.add_chain(chain);
709
710        let config = SolvateConfig {
711            margin: 4.0,
712            water_spacing: 4.0,
713            vdw_cutoff: 1.0,
714            remove_existing: true,
715            cations: vec![],
716            anions: vec![],
717            target_charge: 0,
718            rng_seed: Some(7),
719        };
720
721        solvate_structure(&mut structure, &config).expect("solvation should succeed");
722
723        let water_count = structure
724            .iter_chains()
725            .flat_map(|chain| chain.iter_residues())
726            .filter(|res| res.standard_name == Some(StandardResidue::HOH))
727            .count();
728
729        assert_eq!(water_count, 8);
730    }
731
732    #[test]
733    fn replaces_waters_with_anions_to_match_target_charge() {
734        let lys_charge = db::get_template("LYS").expect("LYS template").charge();
735        assert!(
736            lys_charge > 0,
737            "Test expects positively charged LYS template"
738        );
739
740        let mut structure = Structure::new();
741        let mut chain = Chain::new("A");
742        let mut residue = Residue::new(
743            1,
744            None,
745            "LYS",
746            Some(StandardResidue::LYS),
747            ResidueCategory::Standard,
748        );
749        residue.add_atom(Atom::new("NZ", Element::N, Point::origin()));
750        chain.add_residue(residue);
751        structure.add_chain(chain);
752
753        let config = SolvateConfig {
754            margin: 4.0,
755            water_spacing: 4.0,
756            vdw_cutoff: 1.0,
757            remove_existing: true,
758            cations: vec![],
759            anions: vec![Anion::Cl],
760            target_charge: 0,
761            rng_seed: Some(17),
762        };
763
764        solvate_structure(&mut structure, &config).expect("solvation should succeed");
765
766        let ion_residues: Vec<_> = structure
767            .iter_chains()
768            .flat_map(|chain| chain.iter_residues())
769            .filter(|res| res.category == ResidueCategory::Ion)
770            .collect();
771
772        assert_eq!(ion_residues.len() as i32, lys_charge);
773        assert!(ion_residues.iter().all(|res| res.name == "CL"));
774    }
775
776    #[test]
777    fn returns_box_too_small_when_insufficient_waters_for_target_charge() {
778        let gly_charge = db::get_template("GLY").expect("GLY template").charge();
779        assert_eq!(gly_charge, 0, "GLY should be neutral for this test");
780
781        let mut structure = Structure::new();
782        let mut chain = Chain::new("A");
783        let mut residue = Residue::new(
784            1,
785            None,
786            "GLY",
787            Some(StandardResidue::GLY),
788            ResidueCategory::Standard,
789        );
790        residue.add_atom(Atom::new("CA", Element::C, Point::origin()));
791        chain.add_residue(residue);
792        structure.add_chain(chain);
793
794        let config = SolvateConfig {
795            margin: 2.0,
796            water_spacing: 7.0,
797            vdw_cutoff: 0.1,
798            remove_existing: true,
799            cations: vec![Cation::Na],
800            anions: vec![],
801            target_charge: 2,
802            rng_seed: Some(5),
803        };
804
805        let result = solvate_structure(&mut structure, &config);
806        assert!(matches!(result, Err(Error::BoxTooSmall)));
807    }
808}