use serde::{Deserialize, Serialize};
use crate::dynamics::{atomic_mass_amu, MdBackend};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamicAtom {
pub element: u8,
pub mass: f64,
pub charge: f64,
pub position: [f64; 3],
pub velocity: [f64; 3],
pub force: [f64; 3],
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LiveMolecularSystem {
pub atoms: Vec<DynamicAtom>,
pub potential_energy: f64,
pub kinetic_energy: f64,
pub time_fs: f64,
pub step: usize,
pub temperature_k: f64,
pub backend: MdBackend,
pub positions_flat: Vec<f64>,
pub velocities_flat: Vec<f64>,
pub forces_flat: Vec<f64>,
pub smiles: String,
pub bonds: Vec<(usize, usize, String)>,
}
const AMU_ANGFS2_TO_KCAL_MOL: f64 = 2_390.057_361_533_49;
const R_GAS_KCAL_MOLK: f64 = 0.001_987_204_258_640_83;
impl LiveMolecularSystem {
pub fn from_conformer(
smiles: &str,
elements: &[u8],
coords: &[f64],
bonds: &[(usize, usize, String)],
backend: MdBackend,
) -> Self {
let n = elements.len();
let mut atoms = Vec::with_capacity(n);
for i in 0..n {
atoms.push(DynamicAtom {
element: elements[i],
mass: atomic_mass_amu(elements[i]),
charge: 0.0,
position: [coords[3 * i], coords[3 * i + 1], coords[3 * i + 2]],
velocity: [0.0; 3],
force: [0.0; 3],
});
}
Self {
atoms,
potential_energy: 0.0,
kinetic_energy: 0.0,
time_fs: 0.0,
step: 0,
temperature_k: 0.0,
backend,
positions_flat: coords.to_vec(),
velocities_flat: vec![0.0; n * 3],
forces_flat: vec![0.0; n * 3],
smiles: smiles.to_string(),
bonds: bonds.to_vec(),
}
}
pub fn n_atoms(&self) -> usize {
self.atoms.len()
}
pub fn sync_flat_from_atoms(&mut self) {
let n = self.atoms.len();
self.positions_flat.resize(n * 3, 0.0);
self.velocities_flat.resize(n * 3, 0.0);
self.forces_flat.resize(n * 3, 0.0);
for (i, atom) in self.atoms.iter().enumerate() {
self.positions_flat[3 * i] = atom.position[0];
self.positions_flat[3 * i + 1] = atom.position[1];
self.positions_flat[3 * i + 2] = atom.position[2];
self.velocities_flat[3 * i] = atom.velocity[0];
self.velocities_flat[3 * i + 1] = atom.velocity[1];
self.velocities_flat[3 * i + 2] = atom.velocity[2];
self.forces_flat[3 * i] = atom.force[0];
self.forces_flat[3 * i + 1] = atom.force[1];
self.forces_flat[3 * i + 2] = atom.force[2];
}
}
pub fn sync_atoms_from_flat(&mut self) {
for (i, atom) in self.atoms.iter_mut().enumerate() {
atom.position[0] = self.positions_flat[3 * i];
atom.position[1] = self.positions_flat[3 * i + 1];
atom.position[2] = self.positions_flat[3 * i + 2];
atom.velocity[0] = self.velocities_flat[3 * i];
atom.velocity[1] = self.velocities_flat[3 * i + 1];
atom.velocity[2] = self.velocities_flat[3 * i + 2];
atom.force[0] = self.forces_flat[3 * i];
atom.force[1] = self.forces_flat[3 * i + 1];
atom.force[2] = self.forces_flat[3 * i + 2];
}
}
pub fn initialize_velocities(&mut self, temperature_k: f64, seed: u64) {
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
let mut rng = StdRng::seed_from_u64(seed);
let n = self.atoms.len();
for atom in &mut self.atoms {
let sigma =
(R_GAS_KCAL_MOLK * temperature_k / (atom.mass * AMU_ANGFS2_TO_KCAL_MOL)).sqrt();
for d in 0..3 {
let u1 = (1.0 - rng.gen::<f64>()).max(1e-12);
let u2 = rng.gen::<f64>();
atom.velocity[d] =
sigma * (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
}
}
let mut com_v = [0.0f64; 3];
let mut total_mass = 0.0;
for atom in &self.atoms {
for d in 0..3 {
com_v[d] += atom.mass * atom.velocity[d];
}
total_mass += atom.mass;
}
if total_mass > 0.0 {
for d in 0..3 {
com_v[d] /= total_mass;
}
for atom in &mut self.atoms {
for d in 0..3 {
atom.velocity[d] -= com_v[d];
}
}
}
let (_ke, temp) = self.kinetic_energy_and_temperature();
if temp > 1e-10 {
let scale = (temperature_k / temp).sqrt();
for atom in &mut self.atoms {
for d in 0..3 {
atom.velocity[d] *= scale;
}
}
}
let (ke, temp) = self.kinetic_energy_and_temperature();
self.kinetic_energy = ke;
self.temperature_k = temp;
self.sync_flat_from_atoms();
let _ = n; }
pub fn kinetic_energy_and_temperature(&self) -> (f64, f64) {
let n = self.atoms.len();
let mut ke = 0.0;
for atom in &self.atoms {
let v2 = atom.velocity[0].powi(2) + atom.velocity[1].powi(2) + atom.velocity[2].powi(2);
ke += 0.5 * atom.mass * v2 * AMU_ANGFS2_TO_KCAL_MOL;
}
let dof = (3 * n).saturating_sub(6).max(1) as f64;
let t = 2.0 * ke / (dof * R_GAS_KCAL_MOLK);
(ke, t)
}
pub fn compute_forces(&mut self) -> Result<f64, String> {
let elements: Vec<u8> = self.atoms.iter().map(|a| a.element).collect();
let energy = crate::dynamics::compute_backend_energy_and_gradients(
self.backend,
&self.smiles,
&elements,
&self.positions_flat,
&mut self.forces_flat,
)?;
for f in &mut self.forces_flat {
*f = -*f;
}
for (i, atom) in self.atoms.iter_mut().enumerate() {
atom.force[0] = self.forces_flat[3 * i];
atom.force[1] = self.forces_flat[3 * i + 1];
atom.force[2] = self.forces_flat[3 * i + 2];
}
Ok(energy)
}
pub fn verlet_step(&mut self, dt_fs: f64) -> Result<(), String> {
let n = self.n_atoms();
for atom in &mut self.atoms {
let inv_m = 1.0 / (atom.mass * AMU_ANGFS2_TO_KCAL_MOL);
for d in 0..3 {
atom.velocity[d] += 0.5 * dt_fs * atom.force[d] * inv_m;
}
}
for atom in &mut self.atoms {
for d in 0..3 {
atom.position[d] += dt_fs * atom.velocity[d];
}
}
self.sync_flat_from_atoms();
let pe = self.compute_forces()?;
self.potential_energy = pe;
for atom in &mut self.atoms {
let inv_m = 1.0 / (atom.mass * AMU_ANGFS2_TO_KCAL_MOL);
for d in 0..3 {
atom.velocity[d] += 0.5 * dt_fs * atom.force[d] * inv_m;
}
}
let (ke, temp) = self.kinetic_energy_and_temperature();
self.kinetic_energy = ke;
self.temperature_k = temp;
self.time_fs += dt_fs;
self.step += 1;
self.sync_flat_from_atoms();
for i in 0..n {
if !self.positions_flat[3 * i].is_finite() {
return Err("simulation diverged: non-finite coordinates".to_string());
}
}
Ok(())
}
pub fn integrate(&mut self, dt_fs: f64, substeps: usize) -> Result<(), String> {
for _ in 0..substeps {
self.verlet_step(dt_fs)?;
}
Ok(())
}
pub fn berendsen_thermostat(&mut self, target_temp_k: f64, tau_fs: f64, dt_fs: f64) {
if self.temperature_k < 1e-10 {
return;
}
let lambda = (1.0 + (dt_fs / tau_fs) * (target_temp_k / self.temperature_k - 1.0))
.sqrt()
.clamp(0.5, 2.0);
for atom in &mut self.atoms {
for d in 0..3 {
atom.velocity[d] *= lambda;
}
}
let (ke, temp) = self.kinetic_energy_and_temperature();
self.kinetic_energy = ke;
self.temperature_k = temp;
self.sync_flat_from_atoms();
}
pub fn nose_hoover_thermostat(
&mut self,
target_temp_k: f64,
thermostat_mass: f64,
xi: &mut f64,
v_xi: &mut f64,
dt_fs: f64,
) {
let n = self.n_atoms();
let dof = (3 * n).saturating_sub(6).max(1) as f64;
let target_ke = 0.5 * dof * R_GAS_KCAL_MOLK * target_temp_k;
*v_xi += (self.kinetic_energy - target_ke) / thermostat_mass * dt_fs;
*xi += *v_xi * dt_fs;
let scale = (-(*v_xi) * dt_fs).exp();
for atom in &mut self.atoms {
for d in 0..3 {
atom.velocity[d] *= scale;
}
}
let (ke, temp) = self.kinetic_energy_and_temperature();
self.kinetic_energy = ke;
self.temperature_k = temp;
self.sync_flat_from_atoms();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn h2_system() -> LiveMolecularSystem {
let elements = [1u8, 1];
let coords = [0.0, 0.0, 0.0, 0.74, 0.0, 0.0];
let bonds = vec![(0, 1, "SINGLE".to_string())];
LiveMolecularSystem::from_conformer("", &elements, &coords, &bonds, MdBackend::Uff)
}
#[test]
fn from_conformer_creates_atoms() {
let sys = h2_system();
assert_eq!(sys.n_atoms(), 2);
assert_eq!(sys.atoms[0].element, 1);
assert_eq!(sys.atoms[1].element, 1);
assert!(sys.atoms[0].mass > 0.9 && sys.atoms[0].mass < 1.1);
}
#[test]
fn sync_flat_roundtrip() {
let mut sys = h2_system();
sys.atoms[0].position = [1.0, 2.0, 3.0];
sys.sync_flat_from_atoms();
assert!((sys.positions_flat[0] - 1.0).abs() < 1e-12);
assert!((sys.positions_flat[1] - 2.0).abs() < 1e-12);
assert!((sys.positions_flat[2] - 3.0).abs() < 1e-12);
sys.positions_flat[3] = 5.0;
sys.sync_atoms_from_flat();
assert!((sys.atoms[1].position[0] - 5.0).abs() < 1e-12);
}
#[test]
fn initialize_velocities_nonzero() {
let mut sys = h2_system();
sys.initialize_velocities(300.0, 42);
let v_mag: f64 = sys
.atoms
.iter()
.map(|a| a.velocity.iter().map(|v| v * v).sum::<f64>())
.sum();
assert!(v_mag > 0.0, "velocities should be nonzero after init");
}
#[test]
fn kinetic_energy_zero_for_stationary() {
let sys = h2_system();
let (ke, temp) = sys.kinetic_energy_and_temperature();
assert!(ke.abs() < 1e-15, "KE should be zero for stationary atoms");
assert!(temp.abs() < 1e-15);
}
#[test]
fn berendsen_thermostat_scales_velocities() {
let mut sys = h2_system();
sys.initialize_velocities(300.0, 42);
let (ke_before, _) = sys.kinetic_energy_and_temperature();
sys.berendsen_thermostat(600.0, 100.0, 1.0);
let (ke_after, _) = sys.kinetic_energy_and_temperature();
assert!(ke_after > ke_before);
}
#[test]
fn n_atoms_matches_flat_buffer_length() {
let sys = h2_system();
assert_eq!(sys.positions_flat.len(), 3 * sys.n_atoms());
assert_eq!(sys.velocities_flat.len(), 3 * sys.n_atoms());
assert_eq!(sys.forces_flat.len(), 3 * sys.n_atoms());
}
}