#![allow(clippy::excessive_precision)]
use std::{f32::consts::TAU, fs, io, path::Path, time::Instant};
use bincode::{Decode, Encode};
use bio_files::gromacs;
use lin_alg::{
f32::{Mat3 as Mat3F32, Quaternion, Vec3},
f64::{Quaternion as QuaternionF64, Vec3 as Vec3F64},
};
use rand::{Rng, distr::Uniform, rngs::ThreadRng};
use rand_distr::{Distribution, Normal};
use crate::{
AtomDynamics, ComputationDevice, MdState, MolDynamics, NATIVE_TO_KCAL,
barostat::SimBox,
partial_charge_inference::{files::load_from_bytes_bincode, save},
sa_surface,
solvent::WaterMol,
thermostat::{GAS_CONST_R, KB_A2_PS2_PER_K_PER_AMU},
};
const WATER_DENSITY: f32 = 0.997;
const MASS_WATER: f32 = 18.015_28;
const N_A: f32 = 6.022_140_76e23;
const WATER_MOLS_PER_VOL: f32 = WATER_DENSITY * N_A / (MASS_WATER * 1.0e24);
const MIN_NONWATER_DIST: f32 = 1.7;
const MIN_NONWATER_DIST_SQ: f32 = MIN_NONWATER_DIST * MIN_NONWATER_DIST;
const MIN_WATER_O_O_DIST: f32 = 1.7;
const MIN_WATER_O_O_DIST_SQ: f32 = MIN_WATER_O_O_DIST * MIN_WATER_O_O_DIST;
const PBC_MIN_WATER_O_O_DIST: f32 = 2.8;
const PBC_MIN_WATER_O_O_DIST_SQ: f32 = PBC_MIN_WATER_O_O_DIST * PBC_MIN_WATER_O_O_DIST;
const NUM_EQUILIBRATION_STEPS: usize = 200;
const DT_EQUILIBRATION: f32 = 0.0005;
pub const WATER_TEMPLATE_60A: &[u8] =
include_bytes!("../../param_data/water_60A.water_init_template");
#[derive(Encode, Decode)]
pub struct WaterInitTemplate {
o_posits: Vec<Vec3>,
h0_posits: Vec<Vec3>,
h1_posits: Vec<Vec3>,
o_velocities: Vec<Vec3>,
h0_velocities: Vec<Vec3>,
h1_velocities: Vec<Vec3>,
bounds: (Vec3, Vec3),
}
impl WaterInitTemplate {
pub fn load(path: &Path) -> io::Result<Self> {
let bytes = fs::read(path)?;
load_from_bytes_bincode(&bytes)
}
pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
load_from_bytes_bincode(bytes)
}
pub fn create_and_save(
water: &[WaterMol],
bounds: (Vec3, Vec3),
path: &Path,
) -> io::Result<()> {
let n = water.len();
let mut o_posits = Vec::with_capacity(n);
let mut h0_posits = Vec::with_capacity(n);
let mut h1_posits = Vec::with_capacity(n);
let mut o_velocities = Vec::with_capacity(n);
let mut h0_velocities = Vec::with_capacity(n);
let mut h1_velocities = Vec::with_capacity(n);
let water = {
let ctr = (bounds.1 + bounds.0) / 2.;
let mut w = water.to_vec();
w.sort_by(|a, b| {
let da = (a.o.posit - ctr).magnitude_squared();
let db = (b.o.posit - ctr).magnitude_squared();
da.total_cmp(&db)
});
w
};
for mol in water {
o_posits.push(mol.o.posit);
h0_posits.push(mol.h0.posit);
h1_posits.push(mol.h1.posit);
o_velocities.push(mol.o.vel);
h0_velocities.push(mol.h0.vel);
h1_velocities.push(mol.h1.vel);
}
let result = Self {
o_posits,
h0_posits,
h1_posits,
o_velocities,
h0_velocities,
h1_velocities,
bounds,
};
save(path, &result)
}
pub fn to_gromacs(&self) -> gromacs::solvate::WaterInitTemplate {
gromacs::solvate::WaterInitTemplate {
o_posits: self.o_posits.clone(),
h0_posits: self.h0_posits.clone(),
h1_posits: self.h1_posits.clone(),
o_velocities: self.o_velocities.clone(),
h0_velocities: self.h0_velocities.clone(),
h1_velocities: self.h1_velocities.clone(),
bounds: self.bounds,
}
}
}
fn n_water_mols(cell: &SimBox, solute_atoms: &[AtomDynamics]) -> usize {
let cell_volume = cell.volume();
let mol_volume = sa_surface::vol_take_up_by_atoms(solute_atoms);
let free_vol = cell_volume - mol_volume;
let dims = format!(
"{}:.2 x {:.2} x {:.2}",
(cell.bounds_high.x - cell.bounds_low.x).abs(),
(cell.bounds_high.y - cell.bounds_low.y).abs(),
(cell.bounds_high.z - cell.bounds_low.z).abs()
);
println!(
"Solvent-free vol: {:.2} Cell vol: {:.2} (ų / 1,000). Dims: {dims} Å",
free_vol / 1_000.,
cell_volume / 1_000.
);
(WATER_MOLS_PER_VOL * free_vol).round() as usize
}
pub fn make_water_mols(
cell: &SimBox,
atoms: &[AtomDynamics],
specify_num_water: Option<usize>,
template_override: Option<&WaterInitTemplate>,
skip_pbc_filter: bool,
) -> Vec<WaterMol> {
println!("Initializing solvent molecules...");
let start = Instant::now();
let default_template;
let template: &WaterInitTemplate = match template_override {
Some(t) => t,
None => {
default_template = load_from_bytes_bincode(WATER_TEMPLATE_60A).unwrap();
&default_template
}
};
let n_mols = specify_num_water.unwrap_or_else(|| n_water_mols(cell, atoms));
let mut result = Vec::with_capacity(n_mols);
if n_mols == 0 {
println!("Complete in {} ms.", start.elapsed().as_millis());
return result;
}
let atom_posits: Vec<_> = atoms.iter().map(|a| a.posit).collect();
let template_size = template.bounds.1 - template.bounds.0;
let template_ctr = (template.bounds.0 + template.bounds.1) / 2.;
let cell_ctr = (cell.bounds_low + cell.bounds_high) / 2.;
let base_offset = cell_ctr - template_ctr;
let cell_size = cell.bounds_high - cell.bounds_low;
let half_x = (cell_size.x / (2.0 * template_size.x)).ceil() as i32 + 1;
let half_y = (cell_size.y / (2.0 * template_size.y)).ceil() as i32 + 1;
let half_z = (cell_size.z / (2.0 * template_size.z)).ceil() as i32 + 1;
let mut loops_used = 0;
'tiles: for ix in -half_x..=half_x {
for iy in -half_y..=half_y {
for iz in -half_z..=half_z {
let tile_offset = base_offset
+ Vec3::new(
ix as f32 * template_size.x,
iy as f32 * template_size.y,
iz as f32 * template_size.z,
);
'mol: for i in 0..template.o_posits.len() {
let o_posit = template.o_posits[i] + tile_offset;
let h0_posit = template.h0_posits[i] + tile_offset;
let h1_posit = template.h1_posits[i] + tile_offset;
loops_used += 1;
if !cell.contains(o_posit) {
continue;
}
for &atom_p in &atom_posits {
if (atom_p - o_posit).magnitude_squared() < MIN_NONWATER_DIST_SQ {
continue 'mol;
}
}
for w in &result {
let diff = w.o.posit - o_posit;
let direct_sq = diff.magnitude_squared();
if direct_sq < MIN_WATER_O_O_DIST_SQ {
continue 'mol;
}
let min_image_sq = cell.min_image(diff).magnitude_squared();
if min_image_sq < MIN_WATER_O_O_DIST_SQ {
continue 'mol;
}
if !skip_pbc_filter {
if min_image_sq < PBC_MIN_WATER_O_O_DIST_SQ && min_image_sq < direct_sq
{
continue 'mol;
}
}
}
let mut mol = WaterMol::new(
Vec3::new_zero(),
Vec3::new_zero(),
Quaternion::new_identity(),
);
mol.o.posit = o_posit;
mol.h0.posit = h0_posit;
mol.h1.posit = h1_posit;
mol.o.vel = template.o_velocities[i];
mol.h0.vel = template.h0_velocities[i];
mol.h1.vel = template.h1_velocities[i];
result.push(mol);
if result.len() == n_mols {
break 'tiles;
}
}
}
}
}
let elapsed = start.elapsed().as_millis();
println!(
"Added {} / {n_mols} solvent mols in {elapsed} ms. Used {loops_used} loops",
result.len()
);
result
}
pub(crate) fn pack_custom_solvent(
bounds_low: Vec3,
bounds_high: Vec3,
existing_posits: &[Vec3F64], mols_solvent: &[(MolDynamics, usize)],
) -> Vec<MolDynamics> {
const MIN_ATOM_DIST_SQ: f64 = 1.4 * 1.4; const WALL_MARGIN: f64 = 0.6; const MAX_ROT_ATTEMPTS: usize = 200;
let mut rng = rand::rng();
let lo = Vec3F64::new(
bounds_low.x as f64,
bounds_low.y as f64,
bounds_low.z as f64,
);
let hi = Vec3F64::new(
bounds_high.x as f64,
bounds_high.y as f64,
bounds_high.z as f64,
);
let box_size = hi - lo;
let box_ctr = (lo + hi) * 0.5;
let mut placed_posits: Vec<Vec3F64> = existing_posits.to_vec();
let mut result: Vec<MolDynamics> = Vec::new();
for (mol, count) in mols_solvent {
let count = *count;
if count == 0 {
continue;
}
let template_world: Vec<Vec3F64> = if let Some(ap) = &mol.atom_posits {
ap.clone()
} else {
mol.atoms.iter().map(|a| a.posit).collect()
};
let n_atoms = template_world.len();
if n_atoms == 0 {
continue;
}
let centroid = template_world
.iter()
.fold(Vec3F64::new(0., 0., 0.), |s, &p| s + p)
* (1.0 / n_atoms as f64);
let local: Vec<Vec3F64> = template_world.iter().map(|&p| p - centroid).collect();
let bounding_r: f64 = local.iter().map(|p| p.magnitude()).fold(0.0_f64, f64::max);
let search_sq = (bounding_r * 2.0 + 2.0).powi(2);
let safe_margin = bounding_r + WALL_MARGIN;
let inner_lo = lo + Vec3F64::new(safe_margin, safe_margin, safe_margin);
let inner_hi = hi - Vec3F64::new(safe_margin, safe_margin, safe_margin);
if inner_lo.x >= inner_hi.x || inner_lo.y >= inner_hi.y || inner_lo.z >= inner_hi.z {
eprintln!(
"pack_custom_solvent: box too small for molecule \
(bounding_r={:.1} Å, need >{:.1} Å per side); skipping {} copies.",
bounding_r,
2.0 * safe_margin,
count
);
continue;
}
let inner_size = inner_hi - inner_lo;
let naive_n = (count as f64).cbrt().ceil() as usize;
let scale = (box_size.x / inner_size.x)
.max(box_size.y / inner_size.y)
.max(box_size.z / inner_size.z);
let n = ((naive_n as f64 * scale).ceil() as usize).max(3);
let (sx, sy, sz) = (
box_size.x / n as f64,
box_size.y / n as f64,
box_size.z / n as f64,
);
let (hx, hy, hz) = (
box_size.x * 0.5 - WALL_MARGIN,
box_size.y * 0.5 - WALL_MARGIN,
box_size.z * 0.5 - WALL_MARGIN,
);
let mut grid: Vec<Vec3F64> = (0..n)
.flat_map(|ix| {
(0..n).flat_map(move |iy| {
(0..n).map(move |iz| {
Vec3F64::new(
lo.x + (ix as f64 + 0.5) * sx,
lo.y + (iy as f64 + 0.5) * sy,
lo.z + (iz as f64 + 0.5) * sz,
)
})
})
})
.filter(|c| {
c.x >= inner_lo.x
&& c.x <= inner_hi.x
&& c.y >= inner_lo.y
&& c.y <= inner_hi.y
&& c.z >= inner_lo.z
&& c.z <= inner_hi.z
})
.collect();
for copy_i in 0..count {
let best_cell_idx = if placed_posits.is_empty() {
0
} else {
grid.iter()
.enumerate()
.map(|(i, &cell_ctr)| {
let min_dsq = placed_posits
.iter()
.map(|&p| (cell_ctr - p).magnitude_squared())
.fold(f64::MAX, f64::min);
(i, min_dsq)
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
};
let world_ctr = grid.remove(best_cell_idx);
let mut best_min_sq = f64::NEG_INFINITY;
let mut best_posits: Vec<Vec3F64> = Vec::new();
for _ in 0..MAX_ROT_ATTEMPTS {
let (w, x, y, z): (f64, f64, f64, f64) =
(rng.random(), rng.random(), rng.random(), rng.random());
let rot = QuaternionF64::new(w, x, y, z).to_normalized();
let new_posits: Vec<Vec3F64> = local
.iter()
.map(|&l| rot.rotate_vec(l) + world_ctr)
.collect();
if !new_posits.iter().all(|p| {
let dp = *p - box_ctr;
dp.x.abs() <= hx && dp.y.abs() <= hy && dp.z.abs() <= hz
}) {
continue;
}
let mut min_sq = f64::MAX;
'check: for &np in &new_posits {
for &pp in &placed_posits {
if (pp - world_ctr).magnitude_squared() > search_sq {
continue;
}
let dsq = (np - pp).magnitude_squared();
if dsq < min_sq {
min_sq = dsq;
if min_sq < MIN_ATOM_DIST_SQ {
break 'check;
}
}
}
}
if min_sq > best_min_sq {
best_min_sq = min_sq;
best_posits = new_posits;
}
if best_min_sq >= MIN_ATOM_DIST_SQ {
break; }
}
if best_posits.is_empty() {
best_posits = local.iter().map(|&l| l + world_ctr).collect();
}
if best_min_sq < MIN_ATOM_DIST_SQ {
eprintln!(
"pack_custom_solvent: copy {copy_i}: best min atom dist {:.2} Å — \
placing with soft overlap (energy minimiser will resolve).",
best_min_sq.max(0.0).sqrt()
);
}
placed_posits.extend_from_slice(&best_posits);
let mut mol_copy = mol.clone();
mol_copy.atom_posits = Some(best_posits);
result.push(mol_copy);
if grid.is_empty() && copy_i + 1 < count {
eprintln!(
"pack_custom_solvent: grid cells exhausted after {} / {} copies; \
box may be too small for this many solvent molecules.",
copy_i + 1,
count
);
break;
}
}
}
result
}
#[allow(unused)]
pub fn make_water_mols_grid(
cell: &SimBox,
temperature_tgt: f32,
zero_com_drift: bool,
) -> Vec<WaterMol> {
println!("Initializing a solvent grid, as part of template preparation...");
let mut rng = rand::rng();
let distro = Uniform::<f32>::new(0.0, 1.0).unwrap();
let n_mols = n_water_mols(cell, &[]);
let mut result: Vec<WaterMol> = Vec::with_capacity(n_mols);
let lx = cell.bounds_high.x - cell.bounds_low.x;
let ly = cell.bounds_high.y - cell.bounds_low.y;
let lz = cell.bounds_high.z - cell.bounds_low.z;
let base = (n_mols as f32).cbrt().round().max(1.0) as usize;
let n_x = base;
let n_y = base;
let n_z = n_mols.div_ceil(n_x * n_y);
let spacing_x = lx / n_x as f32;
let spacing_y = ly / n_y as f32;
let spacing_z = lz / n_z as f32;
let fault_ratio = 3;
let mut num_added = 0;
let mut loops_used = 0;
'outer: for i in 0..n_mols * fault_ratio {
let a = i % n_x;
let b = (i / n_x) % n_y;
let c = (i / (n_x * n_y)) % n_z;
let posit = Vec3::new(
cell.bounds_low.x + (a as f32 + 0.5) * spacing_x,
cell.bounds_low.y + (b as f32 + 0.5) * spacing_y,
cell.bounds_low.z + (c as f32 + 0.5) * spacing_z,
);
for w in &result {
let dist_sq = (w.o.posit - posit).magnitude_squared();
if dist_sq < MIN_WATER_O_O_DIST_SQ {
loops_used += 1;
continue 'outer;
}
}
result.push(WaterMol::new(
posit,
Vec3::new_zero(),
random_quaternion(&mut rng, distro),
));
num_added += 1;
if num_added == n_mols {
break;
}
loops_used += 1;
}
init_velocities(&mut result, temperature_tgt, zero_com_drift, &mut rng);
println!(
"Added {} / {n_mols} solvent mols. Used {loops_used} loops",
result.len()
);
result
}
fn init_velocities(
mols: &mut [WaterMol],
t_target: f32,
zero_com_drift: bool,
rng: &mut ThreadRng,
) {
let kT = KB_A2_PS2_PER_K_PER_AMU * t_target;
for m in mols.iter_mut() {
let (r_com, m_tot) = {
let mut r = Vec3::new_zero();
let mut m_tot = 0.0;
for a in [&m.o, &m.h0, &m.h1] {
r += a.posit * a.mass;
m_tot += a.mass;
}
(r / m_tot, m_tot)
};
let r_0 = m.o.posit - r_com;
let r_h0 = m.h0.posit - r_com;
let r_h1 = m.h1.posit - r_com;
let sigma_v = (kT / m_tot).sqrt();
let n = Normal::new(0.0, sigma_v).unwrap();
let v_com = Vec3::new(n.sample(rng), n.sample(rng), n.sample(rng));
let inertia = |r: Vec3, mass: f32| {
let r2 = r.dot(r);
[
[
mass * (r2 - r.x * r.x),
-mass * r.x * r.y,
-mass * r.x * r.z,
],
[
-mass * r.y * r.x,
mass * (r2 - r.y * r.y),
-mass * r.y * r.z,
],
[
-mass * r.z * r.x,
-mass * r.z * r.y,
mass * (r2 - r.z * r.z),
],
]
};
let mut I_arr = inertia(r_0, m.o.mass);
let add_I = |I: &mut [[f32; 3]; 3], J: [[f32; 3]; 3]| {
for i in 0..3 {
for j in 0..3 {
I[i][j] += J[i][j];
}
}
};
add_I(&mut I_arr, inertia(r_h0, m.h0.mass));
add_I(&mut I_arr, inertia(r_h1, m.h1.mass));
let I = Mat3F32::from_arr(I_arr);
let (eigvecs, eigvals) = I.eigen_vecs_vals();
let L_principal = Vec3::new(
Normal::new(0.0, (kT * eigvals.x.max(0.0)).sqrt())
.unwrap()
.sample(rng),
Normal::new(0.0, (kT * eigvals.y.max(0.0)).sqrt())
.unwrap()
.sample(rng),
Normal::new(0.0, (kT * eigvals.z.max(0.0)).sqrt())
.unwrap()
.sample(rng),
);
let L_world = eigvecs * L_principal; let omega = I.solve_system(L_world);
m.o.vel = v_com + omega.cross(r_0);
m.h0.vel = v_com + omega.cross(r_h0);
m.h1.vel = v_com + omega.cross(r_h1);
}
if zero_com_drift {
remove_com_velocity(mols);
}
let (ke_raw, dof) = _kinetic_energy_and_dof(mols, zero_com_drift);
let temperature_meas = (2.0 * ke_raw) / (dof as f32 * GAS_CONST_R as f32);
let lambda = (t_target / temperature_meas).sqrt();
for a in atoms_mut(mols) {
if a.mass > 0.0 {
a.vel *= lambda;
}
}
}
fn _kinetic_energy_and_dof(mols: &[WaterMol], zero_com_drift: bool) -> (f32, usize) {
let mut ke = 0.;
for w in mols {
ke += (w.o.mass * w.o.vel.magnitude_squared()) as f64;
ke += (w.h0.mass * w.h0.vel.magnitude_squared()) as f64;
ke += (w.h1.mass * w.h1.vel.magnitude_squared()) as f64;
}
let mut dof = mols.len() * 3;
if zero_com_drift {
dof = dof.saturating_sub(3);
}
(ke as f32 * 0.5 * NATIVE_TO_KCAL, dof)
}
fn atoms_mut(mols: &mut [WaterMol]) -> impl Iterator<Item = &mut AtomDynamics> {
mols.iter_mut()
.flat_map(|m| [&mut m.o, &mut m.h0, &mut m.h1].into_iter())
}
#[allow(unused)]
fn remove_com_velocity(mols: &mut [WaterMol]) {
let mut p = Vec3::new_zero();
let mut m_tot = 0.0;
for a in atoms_mut(mols) {
p += a.vel * a.mass;
m_tot += a.mass;
}
let v_com = p / m_tot;
for a in atoms_mut(mols) {
a.vel -= v_com;
}
}
#[allow(unused)]
fn random_quaternion(rng: &mut ThreadRng, distro: Uniform<f32>) -> Quaternion {
let (u1, u2, u3) = (rng.sample(distro), rng.sample(distro), rng.sample(distro));
let sqrt1_minus_u1 = (1.0 - u1).sqrt();
let sqrt_u1 = u1.sqrt();
let (theta1, theta2) = (TAU * u2, TAU * u3);
Quaternion::new(
sqrt1_minus_u1 * theta1.sin(),
sqrt1_minus_u1 * theta1.cos(),
sqrt_u1 * theta2.sin(),
sqrt_u1 * theta2.cos(),
)
.to_normalized()
}
impl MdState {
pub fn md_on_water_only(&mut self, dev: &ComputationDevice) {
println!("Initializing solvent H bond networks...");
let start = Instant::now();
self.solvent_only_sim_at_init = true;
let mut static_state = Vec::with_capacity(self.atoms.len());
for a in &mut self.atoms {
static_state.push(a.static_);
a.static_ = true;
}
for _ in 0..NUM_EQUILIBRATION_STEPS {
self.step(dev, DT_EQUILIBRATION, None);
}
for (i, a) in self.atoms.iter_mut().enumerate() {
a.static_ = static_state[i];
}
self.solvent_only_sim_at_init = false;
self.step_count = 0;
let elapsed = start.elapsed().as_millis();
println!("Water H bond networks complete in {elapsed} ms");
}
}