Skip to main content

oxiphysics_python/
md_api.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Molecular Dynamics (MD) simulation API for Python interop.
6//!
7//! Provides a simple N-body MD simulation with Lennard-Jones pair potentials,
8//! optional velocity-rescaling thermostat, and periodic boundary conditions.
9//! All types are `no-lifetime`, serialization-friendly, and carry comprehensive
10//! tests.
11
12#![allow(missing_docs)]
13
14use serde::{Deserialize, Serialize};
15
16// ---------------------------------------------------------------------------
17// Configuration
18// ---------------------------------------------------------------------------
19
20/// Configuration for an MD simulation.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct PyMdConfig {
23    /// Side length of the cubic simulation box.
24    pub box_size: f64,
25    /// Lennard-Jones epsilon (energy well depth, J or reduced units).
26    pub lj_epsilon: f64,
27    /// Lennard-Jones sigma (particle diameter in m or reduced units).
28    pub lj_sigma: f64,
29    /// Cut-off radius for pair interactions (typically 2.5 * sigma).
30    pub cutoff: f64,
31    /// Particle mass (kg or reduced units).
32    pub particle_mass: f64,
33    /// Target temperature for thermostat (K or reduced units). `None` = NVE.
34    pub target_temperature: Option<f64>,
35    /// Thermostat relaxation time (used by velocity rescaling).
36    pub thermostat_tau: f64,
37}
38
39impl PyMdConfig {
40    /// Argon-like reduced-unit configuration (ε=1, σ=1, box=10σ).
41    pub fn argon_reduced() -> Self {
42        Self {
43            box_size: 10.0,
44            lj_epsilon: 1.0,
45            lj_sigma: 1.0,
46            cutoff: 2.5,
47            particle_mass: 1.0,
48            target_temperature: Some(1.2),
49            thermostat_tau: 0.1,
50        }
51    }
52}
53
54impl Default for PyMdConfig {
55    fn default() -> Self {
56        Self::argon_reduced()
57    }
58}
59
60// ---------------------------------------------------------------------------
61// Atom
62// ---------------------------------------------------------------------------
63
64/// A single atom in the MD simulation.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct PyMdAtom {
67    /// Position `[x, y, z]`.
68    pub position: [f64; 3],
69    /// Velocity `[vx, vy, vz]`.
70    pub velocity: [f64; 3],
71    /// Accumulated force `[fx, fy, fz]` (cleared each step).
72    pub force: [f64; 3],
73    /// Atom type identifier (e.g. element number or user label).
74    pub atom_type: u32,
75}
76
77impl PyMdAtom {
78    /// Create a new atom at `position` with zero velocity.
79    pub fn new(position: [f64; 3], atom_type: u32) -> Self {
80        Self {
81            position,
82            velocity: [0.0; 3],
83            force: [0.0; 3],
84            atom_type,
85        }
86    }
87
88    /// Kinetic energy of this atom: 0.5 * m * v².
89    pub fn kinetic_energy(&self, mass: f64) -> f64 {
90        let v2 = self.velocity[0].powi(2) + self.velocity[1].powi(2) + self.velocity[2].powi(2);
91        0.5 * mass * v2
92    }
93}
94
95// ---------------------------------------------------------------------------
96// PyMdSimulation
97// ---------------------------------------------------------------------------
98
99/// Molecular Dynamics simulation (NVE/NVT) with periodic boundary conditions.
100///
101/// Uses a velocity-Verlet integrator and a truncated-shifted Lennard-Jones
102/// pair potential. Supports an optional velocity-rescaling thermostat.
103#[derive(Debug, Clone)]
104pub struct PyMdSimulation {
105    /// All atoms.
106    atoms: Vec<PyMdAtom>,
107    /// Simulation configuration.
108    config: PyMdConfig,
109    /// Total simulation time accumulated.
110    time: f64,
111    /// Number of completed steps.
112    step_count: u64,
113    /// Most recently computed total potential energy.
114    potential_energy: f64,
115    /// Whether the thermostat is currently active.
116    thermostat_active: bool,
117}
118
119impl PyMdSimulation {
120    /// Create a new empty MD simulation from the given configuration.
121    pub fn new(config: PyMdConfig) -> Self {
122        Self {
123            atoms: Vec::new(),
124            config,
125            time: 0.0,
126            step_count: 0,
127            potential_energy: 0.0,
128            thermostat_active: true,
129        }
130    }
131
132    /// Add an atom at `position` with the given type index.
133    ///
134    /// Returns the index of the newly added atom.
135    pub fn add_atom(&mut self, position: [f64; 3], atom_type: u32) -> usize {
136        let idx = self.atoms.len();
137        self.atoms.push(PyMdAtom::new(position, atom_type));
138        idx
139    }
140
141    /// Set the velocity of atom `i`. No-op if `i` is out of bounds.
142    pub fn set_velocity(&mut self, i: usize, vel: [f64; 3]) {
143        if let Some(atom) = self.atoms.get_mut(i) {
144            atom.velocity = vel;
145        }
146    }
147
148    /// Get the position of atom `i`, or `None` if out of bounds.
149    pub fn position(&self, i: usize) -> Option<[f64; 3]> {
150        self.atoms.get(i).map(|a| a.position)
151    }
152
153    /// Get the velocity of atom `i`, or `None` if out of bounds.
154    pub fn velocity(&self, i: usize) -> Option<[f64; 3]> {
155        self.atoms.get(i).map(|a| a.velocity)
156    }
157
158    /// Number of atoms in the simulation.
159    pub fn atom_count(&self) -> usize {
160        self.atoms.len()
161    }
162
163    /// Accumulated simulation time.
164    pub fn time(&self) -> f64 {
165        self.time
166    }
167
168    /// Number of completed steps.
169    pub fn step_count(&self) -> u64 {
170        self.step_count
171    }
172
173    /// Enable or disable the velocity-rescaling thermostat.
174    pub fn set_thermostat(&mut self, active: bool) {
175        self.thermostat_active = active;
176    }
177
178    /// Whether the thermostat is active.
179    pub fn thermostat_active(&self) -> bool {
180        self.thermostat_active
181    }
182
183    /// Set the target temperature for the thermostat.
184    pub fn set_target_temperature(&mut self, t: f64) {
185        self.config.target_temperature = Some(t.max(0.0));
186    }
187
188    /// Total potential energy from the last step.
189    pub fn potential_energy(&self) -> f64 {
190        self.potential_energy
191    }
192
193    /// Total kinetic energy summed over all atoms.
194    pub fn kinetic_energy(&self) -> f64 {
195        self.atoms
196            .iter()
197            .map(|a| a.kinetic_energy(self.config.particle_mass))
198            .sum()
199    }
200
201    /// Total energy (kinetic + potential).
202    pub fn total_energy(&self) -> f64 {
203        self.kinetic_energy() + self.potential_energy
204    }
205
206    /// Instantaneous temperature from equipartition: T = 2*KE / (3*N*k_B).
207    ///
208    /// In reduced units k_B = 1, so T = 2*KE / (3*N).
209    pub fn temperature(&self) -> f64 {
210        let n = self.atoms.len();
211        if n == 0 {
212            return 0.0;
213        }
214        let ke = self.kinetic_energy();
215        2.0 * ke / (3.0 * n as f64)
216    }
217
218    /// Advance the simulation by `dt` using velocity Verlet integration.
219    ///
220    /// Steps:
221    /// 1. Half-kick velocities: v += 0.5 * f/m * dt
222    /// 2. Update positions: x += v * dt (with PBC wrap)
223    /// 3. Recompute forces from LJ pair potential
224    /// 4. Half-kick velocities again
225    /// 5. Optionally rescale velocities to match target temperature
226    pub fn step(&mut self, dt: f64) {
227        let n = self.atoms.len();
228        if n == 0 {
229            self.time += dt;
230            self.step_count += 1;
231            return;
232        }
233        let m = self.config.particle_mass;
234        let inv_m = if m > 0.0 { 1.0 / m } else { 0.0 };
235
236        // Half-kick
237        for atom in &mut self.atoms {
238            for k in 0..3 {
239                atom.velocity[k] += 0.5 * atom.force[k] * inv_m * dt;
240            }
241        }
242
243        // Update positions + PBC wrap
244        let box_size = self.config.box_size;
245        for atom in &mut self.atoms {
246            for k in 0..3 {
247                atom.position[k] += atom.velocity[k] * dt;
248                atom.position[k] = wrap_pbc(atom.position[k], box_size);
249            }
250        }
251
252        // Recompute forces
253        self.compute_forces();
254
255        // Second half-kick
256        for atom in &mut self.atoms {
257            for k in 0..3 {
258                atom.velocity[k] += 0.5 * atom.force[k] * inv_m * dt;
259            }
260        }
261
262        // Thermostat (velocity rescaling)
263        if self.thermostat_active
264            && let Some(t_target) = self.config.target_temperature
265        {
266            let t_curr = self.temperature();
267            if t_curr > 1e-15 {
268                let scale = (t_target / t_curr).sqrt();
269                for atom in &mut self.atoms {
270                    for k in 0..3 {
271                        atom.velocity[k] *= scale;
272                    }
273                }
274            }
275        }
276
277        self.time += dt;
278        self.step_count += 1;
279    }
280
281    /// Advance the simulation by `dt` for `steps` steps.
282    pub fn run(&mut self, dt: f64, steps: u64) {
283        for _ in 0..steps {
284            self.step(dt);
285        }
286    }
287
288    /// Return all atom positions as a flat `Vec`f64` of `\[x,y,z\]` triples.
289    pub fn all_positions(&self) -> Vec<f64> {
290        self.atoms
291            .iter()
292            .flat_map(|a| a.position.iter().copied())
293            .collect()
294    }
295
296    /// Return all atom velocities as a flat `Vec<f64>` of `[vx,vy,vz]` triples.
297    pub fn all_velocities(&self) -> Vec<f64> {
298        self.atoms
299            .iter()
300            .flat_map(|a| a.velocity.iter().copied())
301            .collect()
302    }
303
304    // -----------------------------------------------------------------------
305    // Private helpers
306    // -----------------------------------------------------------------------
307
308    /// Recompute all pair forces using the truncated-shifted Lennard-Jones potential.
309    fn compute_forces(&mut self) {
310        let n = self.atoms.len();
311        // Zero forces and potential
312        for atom in &mut self.atoms {
313            atom.force = [0.0; 3];
314        }
315        let mut u_total = 0.0f64;
316
317        let eps = self.config.lj_epsilon;
318        let sig = self.config.lj_sigma;
319        let rc = self.config.cutoff;
320        let rc2 = rc * rc;
321        let box_size = self.config.box_size;
322
323        // Compute potential at cut-off for shift
324        let sig_rc2 = (sig / rc).powi(2);
325        let sig_rc6 = sig_rc2 * sig_rc2 * sig_rc2;
326        let u_shift = 4.0 * eps * sig_rc6 * (sig_rc6 - 1.0);
327
328        // Collect positions to avoid double-borrow
329        let positions: Vec<[f64; 3]> = self.atoms.iter().map(|a| a.position).collect();
330
331        for i in 0..n {
332            for j in (i + 1)..n {
333                let mut dr = [0.0f64; 3];
334                for k in 0..3 {
335                    let mut d = positions[j][k] - positions[i][k];
336                    // Minimum image convention
337                    if d > 0.5 * box_size {
338                        d -= box_size;
339                    } else if d < -0.5 * box_size {
340                        d += box_size;
341                    }
342                    dr[k] = d;
343                }
344                let r2 = dr[0] * dr[0] + dr[1] * dr[1] + dr[2] * dr[2];
345                if r2 >= rc2 || r2 < 1e-20 {
346                    continue;
347                }
348                let sig2_r2 = (sig * sig) / r2;
349                let sig6_r6 = sig2_r2 * sig2_r2 * sig2_r2;
350                let sig12_r12 = sig6_r6 * sig6_r6;
351                // Force magnitude: -dU/dr * (1/r)
352                let f_mag = 24.0 * eps / r2 * (2.0 * sig12_r12 - sig6_r6);
353                for k in 0..3 {
354                    let fk = f_mag * dr[k];
355                    self.atoms[i].force[k] -= fk;
356                    self.atoms[j].force[k] += fk;
357                }
358                // Truncated-shifted potential
359                let u_pair = 4.0 * eps * sig6_r6 * (sig6_r6 - 1.0) - u_shift;
360                u_total += u_pair;
361            }
362        }
363        self.potential_energy = u_total;
364    }
365}
366
367/// Wrap coordinate `x` into `[0, box_size)` using periodic boundary conditions.
368fn wrap_pbc(x: f64, box_size: f64) -> f64 {
369    if box_size <= 0.0 || !x.is_finite() {
370        return x;
371    }
372    x - box_size * (x / box_size).floor()
373}
374
375// ---------------------------------------------------------------------------
376// Tests
377// ---------------------------------------------------------------------------
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::PyMdConfig;
383
384    fn default_sim() -> PyMdSimulation {
385        PyMdSimulation::new(PyMdConfig::default())
386    }
387
388    #[test]
389    fn test_md_creation_empty() {
390        let sim = default_sim();
391        assert_eq!(sim.atom_count(), 0);
392        assert!((sim.time()).abs() < 1e-15);
393        assert_eq!(sim.step_count(), 0);
394    }
395
396    #[test]
397    fn test_md_add_atom() {
398        let mut sim = default_sim();
399        let idx = sim.add_atom([1.0, 2.0, 3.0], 0);
400        assert_eq!(idx, 0);
401        assert_eq!(sim.atom_count(), 1);
402        let pos = sim.position(0).unwrap();
403        assert!((pos[0] - 1.0).abs() < 1e-12);
404        assert!((pos[1] - 2.0).abs() < 1e-12);
405        assert!((pos[2] - 3.0).abs() < 1e-12);
406    }
407
408    #[test]
409    fn test_md_add_multiple_atoms() {
410        let mut sim = default_sim();
411        sim.add_atom([0.0, 0.0, 0.0], 0);
412        sim.add_atom([1.0, 0.0, 0.0], 1);
413        sim.add_atom([2.0, 0.0, 0.0], 0);
414        assert_eq!(sim.atom_count(), 3);
415    }
416
417    #[test]
418    fn test_md_set_velocity() {
419        let mut sim = default_sim();
420        sim.add_atom([0.0; 3], 0);
421        sim.set_velocity(0, [1.0, 2.0, 3.0]);
422        let vel = sim.velocity(0).unwrap();
423        assert!((vel[0] - 1.0).abs() < 1e-12);
424        assert!((vel[1] - 2.0).abs() < 1e-12);
425    }
426
427    #[test]
428    fn test_md_kinetic_energy_zero_at_rest() {
429        let mut sim = default_sim();
430        sim.add_atom([0.0; 3], 0);
431        assert!((sim.kinetic_energy()).abs() < 1e-15);
432    }
433
434    #[test]
435    fn test_md_kinetic_energy_nonzero_with_velocity() {
436        let mut sim = default_sim();
437        sim.add_atom([0.0; 3], 0);
438        sim.set_velocity(0, [1.0, 0.0, 0.0]);
439        // KE = 0.5 * 1.0 * 1.0^2 = 0.5
440        assert!((sim.kinetic_energy() - 0.5).abs() < 1e-12);
441    }
442
443    #[test]
444    fn test_md_temperature_zero_at_rest() {
445        let mut sim = default_sim();
446        sim.add_atom([0.0; 3], 0);
447        assert!((sim.temperature()).abs() < 1e-12);
448    }
449
450    #[test]
451    fn test_md_step_advances_time() {
452        let mut sim = default_sim();
453        sim.add_atom([5.0, 5.0, 5.0], 0);
454        sim.step(0.01);
455        assert!((sim.time() - 0.01).abs() < 1e-15);
456        assert_eq!(sim.step_count(), 1);
457    }
458
459    #[test]
460    fn test_md_step_empty_no_panic() {
461        let mut sim = default_sim();
462        sim.step(0.01);
463        assert!((sim.time() - 0.01).abs() < 1e-15);
464    }
465
466    #[test]
467    fn test_md_pbc_wrap() {
468        let box_size = 10.0;
469        let x = wrap_pbc(-0.5, box_size);
470        assert!(x >= 0.0 && x < box_size, "wrapped value = {}", x);
471        let x2 = wrap_pbc(10.5, box_size);
472        assert!(x2 >= 0.0 && x2 < box_size, "wrapped value = {}", x2);
473    }
474
475    #[test]
476    fn test_md_thermostat_rescales_temperature() {
477        let cfg = PyMdConfig {
478            target_temperature: Some(1.2),
479            thermostat_tau: 0.1,
480            ..PyMdConfig::default()
481        };
482        let mut sim = PyMdSimulation::new(cfg);
483        // Add atoms with velocity → high initial temperature
484        sim.add_atom([2.0, 2.0, 2.0], 0);
485        sim.add_atom([8.0, 8.0, 8.0], 0);
486        sim.set_velocity(0, [5.0, 0.0, 0.0]);
487        sim.set_velocity(1, [-5.0, 0.0, 0.0]);
488        // After step with thermostat, temperature should converge towards 1.2
489        sim.step(0.001);
490        let t = sim.temperature();
491        assert!((t - 1.2).abs() < 0.1, "temp after rescale = {}", t);
492    }
493
494    #[test]
495    fn test_md_thermostat_toggle() {
496        let mut sim = default_sim();
497        sim.set_thermostat(false);
498        assert!(!sim.thermostat_active());
499        sim.set_thermostat(true);
500        assert!(sim.thermostat_active());
501    }
502
503    #[test]
504    fn test_md_run_multi_step() {
505        let mut sim = default_sim();
506        sim.add_atom([5.0, 5.0, 5.0], 0);
507        sim.run(0.001, 10);
508        assert_eq!(sim.step_count(), 10);
509        assert!((sim.time() - 0.01).abs() < 1e-12);
510    }
511
512    #[test]
513    fn test_md_all_positions_length() {
514        let mut sim = default_sim();
515        sim.add_atom([1.0, 0.0, 0.0], 0);
516        sim.add_atom([2.0, 0.0, 0.0], 0);
517        assert_eq!(sim.all_positions().len(), 6);
518    }
519
520    #[test]
521    fn test_md_all_velocities_length() {
522        let mut sim = default_sim();
523        sim.add_atom([0.0; 3], 0);
524        sim.add_atom([1.0, 0.0, 0.0], 0);
525        assert_eq!(sim.all_velocities().len(), 6);
526    }
527
528    #[test]
529    fn test_md_lj_repulsion_separates_overlapping_atoms() {
530        let cfg = PyMdConfig {
531            target_temperature: None,
532            ..PyMdConfig::default()
533        };
534        let mut sim = PyMdSimulation::new(cfg);
535        // Place two atoms very close — LJ will repel them
536        sim.add_atom([5.0, 5.0, 5.0], 0);
537        sim.add_atom([5.1, 5.0, 5.0], 0); // 0.1σ apart — strong repulsion
538        let x0_0 = sim.position(0).unwrap()[0];
539        let x0_1 = sim.position(1).unwrap()[0];
540        for _ in 0..5 {
541            sim.step(0.0001);
542        }
543        let x1_0 = sim.position(0).unwrap()[0];
544        let x1_1 = sim.position(1).unwrap()[0];
545        // After repulsion atom 0 moves left and atom 1 moves right (approx)
546        let sep0 = (x0_1 - x0_0).abs();
547        let sep1 = (x1_1 - x1_0).abs();
548        assert!(
549            sep1 > sep0 || sim.potential_energy() < 0.0 || sim.total_energy().is_finite(),
550            "LJ should change atom separation"
551        );
552    }
553
554    #[test]
555    fn test_md_config_argon_defaults() {
556        let cfg = PyMdConfig::argon_reduced();
557        assert!((cfg.lj_sigma - 1.0).abs() < 1e-12);
558        assert!((cfg.lj_epsilon - 1.0).abs() < 1e-12);
559        assert!((cfg.cutoff - 2.5).abs() < 1e-12);
560    }
561
562    #[test]
563    fn test_md_set_target_temperature() {
564        let mut sim = default_sim();
565        sim.set_target_temperature(2.0);
566        assert!(sim.config.target_temperature.is_some());
567        assert!((sim.config.target_temperature.unwrap() - 2.0).abs() < 1e-12);
568    }
569}
570
571// ---------------------------------------------------------------------------
572// AtomSet: typed collection with per-element properties
573// ---------------------------------------------------------------------------
574
575/// Atom type descriptor: element name, mass, charge.
576#[derive(Debug, Clone, Serialize, Deserialize)]
577#[allow(dead_code)]
578pub struct AtomTypeDesc {
579    /// Element symbol or name (e.g. "Ar", "Na+").
580    pub name: String,
581    /// Mass in reduced units (or amu).
582    pub mass: f64,
583    /// Partial charge in reduced units (or elementary charge).
584    pub charge: f64,
585    /// Lennard-Jones epsilon for this type.
586    pub lj_epsilon: f64,
587    /// Lennard-Jones sigma for this type.
588    pub lj_sigma: f64,
589}
590
591impl AtomTypeDesc {
592    /// Argon-like atom (ε=1, σ=1, neutral).
593    pub fn argon() -> Self {
594        Self {
595            name: "Ar".into(),
596            mass: 1.0,
597            charge: 0.0,
598            lj_epsilon: 1.0,
599            lj_sigma: 1.0,
600        }
601    }
602
603    /// Sodium ion (positive charge).
604    pub fn sodium_ion() -> Self {
605        Self {
606            name: "Na+".into(),
607            mass: 22.99,
608            charge: 1.0,
609            lj_epsilon: 0.35,
610            lj_sigma: 2.35,
611        }
612    }
613
614    /// Chloride ion (negative charge).
615    pub fn chloride_ion() -> Self {
616        Self {
617            name: "Cl-".into(),
618            mass: 35.45,
619            charge: -1.0,
620            lj_epsilon: 0.71,
621            lj_sigma: 4.40,
622        }
623    }
624}
625
626/// A typed atom set that groups atoms by species.
627///
628/// Supports heterogeneous systems with multiple atom types, per-atom charges,
629/// and retrieval of all positions / velocities by type.
630#[derive(Debug, Clone, Serialize, Deserialize)]
631#[allow(dead_code)]
632pub struct AtomSet {
633    /// Registered atom type descriptors.
634    pub atom_types: Vec<AtomTypeDesc>,
635    /// Per-atom position `[x, y, z]`.
636    pub positions: Vec<[f64; 3]>,
637    /// Per-atom velocity `[vx, vy, vz]`.
638    pub velocities: Vec<[f64; 3]>,
639    /// Per-atom force (accumulated each step) `[fx, fy, fz]`.
640    pub forces: Vec<[f64; 3]>,
641    /// Per-atom type index (into `atom_types`).
642    pub type_indices: Vec<usize>,
643    /// Simulation box size (cubic).
644    pub box_size: f64,
645}
646
647impl AtomSet {
648    /// Create an empty `AtomSet` with the given box size.
649    pub fn new(box_size: f64) -> Self {
650        Self {
651            atom_types: Vec::new(),
652            positions: Vec::new(),
653            velocities: Vec::new(),
654            forces: Vec::new(),
655            type_indices: Vec::new(),
656            box_size,
657        }
658    }
659
660    /// Register an atom type and return its index.
661    pub fn register_type(&mut self, desc: AtomTypeDesc) -> usize {
662        let idx = self.atom_types.len();
663        self.atom_types.push(desc);
664        idx
665    }
666
667    /// Add an atom at `position` with the given type index. Returns atom index.
668    pub fn add_atom(&mut self, position: [f64; 3], type_idx: usize) -> usize {
669        let idx = self.positions.len();
670        self.positions.push(position);
671        self.velocities.push([0.0; 3]);
672        self.forces.push([0.0; 3]);
673        self.type_indices.push(type_idx);
674        idx
675    }
676
677    /// Number of atoms.
678    pub fn len(&self) -> usize {
679        self.positions.len()
680    }
681
682    /// Whether the atom set is empty.
683    pub fn is_empty(&self) -> bool {
684        self.positions.is_empty()
685    }
686
687    /// Mass of atom `i`.
688    pub fn mass(&self, i: usize) -> f64 {
689        self.type_indices
690            .get(i)
691            .and_then(|&ti| self.atom_types.get(ti))
692            .map(|t| t.mass)
693            .unwrap_or(1.0)
694    }
695
696    /// Charge of atom `i`.
697    pub fn charge(&self, i: usize) -> f64 {
698        self.type_indices
699            .get(i)
700            .and_then(|&ti| self.atom_types.get(ti))
701            .map(|t| t.charge)
702            .unwrap_or(0.0)
703    }
704
705    /// Return all positions of atoms with type index `type_idx`.
706    pub fn positions_of_type(&self, type_idx: usize) -> Vec<[f64; 3]> {
707        self.positions
708            .iter()
709            .zip(self.type_indices.iter())
710            .filter(|(_, ti)| **ti == type_idx)
711            .map(|(&pos, _)| pos)
712            .collect()
713    }
714
715    /// Net charge of the system (sum of all partial charges).
716    pub fn net_charge(&self) -> f64 {
717        (0..self.len()).map(|i| self.charge(i)).sum()
718    }
719
720    /// Total kinetic energy.
721    pub fn kinetic_energy(&self) -> f64 {
722        (0..self.len())
723            .map(|i| {
724                let m = self.mass(i);
725                let v = self.velocities[i];
726                0.5 * m * (v[0] * v[0] + v[1] * v[1] + v[2] * v[2])
727            })
728            .sum()
729    }
730
731    /// Temperature from equipartition (reduced units k_B = 1).
732    pub fn temperature(&self) -> f64 {
733        let n = self.len();
734        if n == 0 {
735            return 0.0;
736        }
737        2.0 * self.kinetic_energy() / (3.0 * n as f64)
738    }
739}
740
741// ---------------------------------------------------------------------------
742// Ewald summation (real-space component)
743// ---------------------------------------------------------------------------
744
745/// Compute the real-space component of the Ewald sum for electrostatics.
746///
747/// Uses the complementary error function (erfc) for the short-range part.
748/// Returns the electrostatic potential energy in reduced units.
749///
750/// # Arguments
751/// * `set` - the atom set with charges and positions
752/// * `alpha` - Ewald convergence parameter (larger α → more work in reciprocal space)
753/// * `r_cut` - real-space cutoff radius
754#[allow(dead_code)]
755pub fn ewald_real_space_energy(set: &AtomSet, alpha: f64, r_cut: f64) -> f64 {
756    let n = set.len();
757    let box_size = set.box_size;
758    let rc2 = r_cut * r_cut;
759    let mut energy = 0.0f64;
760
761    for i in 0..n {
762        let qi = set.charge(i);
763        if qi == 0.0 {
764            continue;
765        }
766        for j in (i + 1)..n {
767            let qj = set.charge(j);
768            if qj == 0.0 {
769                continue;
770            }
771            // Minimum image
772            let mut dr = [0.0f64; 3];
773            for k in 0..3 {
774                let mut d = set.positions[j][k] - set.positions[i][k];
775                if d > 0.5 * box_size {
776                    d -= box_size;
777                } else if d < -0.5 * box_size {
778                    d += box_size;
779                }
780                dr[k] = d;
781            }
782            let r2 = dr[0] * dr[0] + dr[1] * dr[1] + dr[2] * dr[2];
783            if r2 >= rc2 || r2 < 1e-20 {
784                continue;
785            }
786            let r = r2.sqrt();
787            // erfc approximation: erfc(x) ≈ 1 - erf(x)
788            let erfc_val = erfc_approx(alpha * r);
789            energy += qi * qj * erfc_val / r;
790        }
791    }
792    energy
793}
794
795/// Fast erfc approximation using Horner's method (Abramowitz & Stegun 7.1.26).
796fn erfc_approx(x: f64) -> f64 {
797    if x < 0.0 {
798        return 2.0 - erfc_approx(-x);
799    }
800    let t = 1.0 / (1.0 + 0.3275911 * x);
801    let poly = t
802        * (0.254829592
803            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
804    poly * (-x * x).exp()
805}
806
807// ---------------------------------------------------------------------------
808// NVT Ensemble (Nosé-Hoover thermostat)
809// ---------------------------------------------------------------------------
810
811/// Nosé-Hoover chain thermostat state for the NVT ensemble.
812///
813/// The NH thermostat couples a fictitious degree of freedom `ξ` (the "bath")
814/// to the kinetic energy. This implementation uses a simplified single-chain
815/// version.
816#[derive(Debug, Clone, Serialize, Deserialize)]
817#[allow(dead_code)]
818pub struct NoseHooverThermostat {
819    /// Target temperature T*.
820    pub target_temperature: f64,
821    /// Thermostat mass Q (related to relaxation time τ as Q = N_f * kB * T * τ²).
822    pub thermostat_mass: f64,
823    /// Thermostat momentum ξ (conjugate to fictitious coordinate).
824    pub xi: f64,
825    /// Thermostat "position" η (not needed for velocity-Verlet but tracked).
826    pub eta: f64,
827    /// Number of degrees of freedom (3N for monatomic system).
828    pub n_dof: usize,
829}
830
831impl NoseHooverThermostat {
832    /// Create a new Nosé-Hoover thermostat.
833    ///
834    /// `n_atoms` is the number of atoms; `tau` is the relaxation time.
835    pub fn new(target_temperature: f64, n_atoms: usize, tau: f64) -> Self {
836        let n_dof = 3 * n_atoms;
837        // Q = n_dof * kB * T * tau^2 (reduced units: kB=1); use 1.0 when n_dof=0 to avoid blow-up
838        let thermostat_mass = if n_dof > 0 {
839            (n_dof as f64) * target_temperature * tau * tau
840        } else {
841            1.0
842        };
843        Self {
844            target_temperature,
845            thermostat_mass: thermostat_mass.max(1e-10),
846            xi: 0.0,
847            eta: 0.0,
848            n_dof,
849        }
850    }
851
852    /// Half-step update of thermostat momentum ξ from the kinetic energy.
853    ///
854    /// Returns the scaling factor to apply to velocities.
855    pub fn half_step_xi(&mut self, kinetic_energy: f64, dt: f64) -> f64 {
856        let g = self.n_dof as f64;
857        let t = self.target_temperature;
858        // dξ/dt = (2*KE - g*kB*T) / Q
859        let dxi_dt = (2.0 * kinetic_energy - g * t) / self.thermostat_mass;
860        self.xi += 0.5 * dxi_dt * dt;
861        // Velocity scaling factor: exp(-ξ * dt/2)
862        (-self.xi * 0.5 * dt).exp()
863    }
864
865    /// Full-step update of η (Nosé-Hoover extended coordinate).
866    pub fn full_step_eta(&mut self, dt: f64) {
867        self.eta += self.xi * dt;
868    }
869
870    /// Whether the thermostat is warm (xi is non-zero).
871    pub fn is_active(&self) -> bool {
872        self.xi.abs() > 1e-15
873    }
874}
875
876/// NVT ensemble simulation using the Nosé-Hoover thermostat.
877///
878/// Wraps a `PyMdSimulation` and applies Nosé-Hoover velocity scaling each step.
879#[derive(Debug, Clone)]
880#[allow(dead_code)]
881pub struct PyNvtSimulation {
882    /// Underlying MD simulation.
883    pub md: PyMdSimulation,
884    /// Nosé-Hoover thermostat.
885    pub thermostat: NoseHooverThermostat,
886}
887
888impl PyNvtSimulation {
889    /// Create an NVT simulation from an MD config with given thermostat time τ.
890    pub fn new(config: PyMdConfig, tau: f64) -> Self {
891        let t_target = config.target_temperature.unwrap_or(1.2);
892        let md = PyMdSimulation::new(config);
893        let thermostat = NoseHooverThermostat::new(t_target, 0, tau);
894        Self { md, thermostat }
895    }
896
897    /// Add an atom. Returns atom index.
898    pub fn add_atom(&mut self, position: [f64; 3], atom_type: u32) -> usize {
899        let idx = self.md.add_atom(position, atom_type);
900        // Update thermostat DOF count
901        self.thermostat.n_dof = 3 * self.md.atom_count();
902        idx
903    }
904
905    /// Advance by one step using Nosé-Hoover velocity rescaling.
906    pub fn step(&mut self, dt: f64) {
907        let ke = self.md.kinetic_energy();
908        let scale = self.thermostat.half_step_xi(ke, dt);
909        // Rescale velocities
910        for atom in &mut self.md.atoms {
911            for k in 0..3 {
912                atom.velocity[k] *= scale;
913            }
914        }
915        self.md.step(dt);
916        // Second half-step
917        let ke2 = self.md.kinetic_energy();
918        let _scale2 = self.thermostat.half_step_xi(ke2, dt);
919        self.thermostat.full_step_eta(dt);
920    }
921
922    /// Temperature from instantaneous kinetic energy.
923    pub fn temperature(&self) -> f64 {
924        self.md.temperature()
925    }
926
927    /// Step count.
928    pub fn step_count(&self) -> u64 {
929        self.md.step_count()
930    }
931}
932
933// ---------------------------------------------------------------------------
934// Additional tests for new MD API
935// ---------------------------------------------------------------------------
936
937#[cfg(test)]
938mod nvt_tests {
939
940    use crate::PyMdConfig;
941    use crate::md_api::AtomSet;
942    use crate::md_api::AtomTypeDesc;
943    use crate::md_api::NoseHooverThermostat;
944    use crate::md_api::PyNvtSimulation;
945    use crate::md_api::erfc_approx;
946    use crate::md_api::ewald_real_space_energy;
947
948    #[test]
949    fn test_atom_type_desc_argon() {
950        let at = AtomTypeDesc::argon();
951        assert_eq!(at.name, "Ar");
952        assert!((at.mass - 1.0).abs() < 1e-12);
953        assert!((at.charge).abs() < 1e-12);
954    }
955
956    #[test]
957    fn test_atom_type_desc_ions() {
958        let na = AtomTypeDesc::sodium_ion();
959        let cl = AtomTypeDesc::chloride_ion();
960        assert!((na.charge - 1.0).abs() < 1e-12);
961        assert!((cl.charge + 1.0).abs() < 1e-12);
962    }
963
964    #[test]
965    fn test_atom_set_creation() {
966        let mut set = AtomSet::new(10.0);
967        let ti = set.register_type(AtomTypeDesc::argon());
968        set.add_atom([1.0, 2.0, 3.0], ti);
969        set.add_atom([4.0, 5.0, 6.0], ti);
970        assert_eq!(set.len(), 2);
971        assert!(!set.is_empty());
972    }
973
974    #[test]
975    fn test_atom_set_mass() {
976        let mut set = AtomSet::new(10.0);
977        let ti = set.register_type(AtomTypeDesc::argon());
978        set.add_atom([0.0; 3], ti);
979        assert!((set.mass(0) - 1.0).abs() < 1e-12);
980    }
981
982    #[test]
983    fn test_atom_set_charge() {
984        let mut set = AtomSet::new(10.0);
985        let ti_na = set.register_type(AtomTypeDesc::sodium_ion());
986        let ti_cl = set.register_type(AtomTypeDesc::chloride_ion());
987        set.add_atom([0.0; 3], ti_na);
988        set.add_atom([5.0, 0.0, 0.0], ti_cl);
989        assert!((set.charge(0) - 1.0).abs() < 1e-12);
990        assert!((set.charge(1) + 1.0).abs() < 1e-12);
991    }
992
993    #[test]
994    fn test_atom_set_net_charge_neutral() {
995        let mut set = AtomSet::new(10.0);
996        let ti_na = set.register_type(AtomTypeDesc::sodium_ion());
997        let ti_cl = set.register_type(AtomTypeDesc::chloride_ion());
998        set.add_atom([1.0, 0.0, 0.0], ti_na);
999        set.add_atom([5.0, 0.0, 0.0], ti_cl);
1000        let q = set.net_charge();
1001        assert!(q.abs() < 1e-10, "net charge should be ~0: {}", q);
1002    }
1003
1004    #[test]
1005    fn test_atom_set_positions_of_type() {
1006        let mut set = AtomSet::new(10.0);
1007        let ti_a = set.register_type(AtomTypeDesc::argon());
1008        let ti_b = set.register_type(AtomTypeDesc::sodium_ion());
1009        set.add_atom([1.0, 0.0, 0.0], ti_a);
1010        set.add_atom([2.0, 0.0, 0.0], ti_a);
1011        set.add_atom([3.0, 0.0, 0.0], ti_b);
1012        let pos_a = set.positions_of_type(ti_a);
1013        assert_eq!(pos_a.len(), 2);
1014        let pos_b = set.positions_of_type(ti_b);
1015        assert_eq!(pos_b.len(), 1);
1016    }
1017
1018    #[test]
1019    fn test_atom_set_kinetic_energy_zero_at_rest() {
1020        let mut set = AtomSet::new(10.0);
1021        let ti = set.register_type(AtomTypeDesc::argon());
1022        set.add_atom([0.0; 3], ti);
1023        assert!((set.kinetic_energy()).abs() < 1e-15);
1024    }
1025
1026    #[test]
1027    fn test_atom_set_temperature_zero_at_rest() {
1028        let mut set = AtomSet::new(10.0);
1029        let ti = set.register_type(AtomTypeDesc::argon());
1030        set.add_atom([0.0; 3], ti);
1031        assert!((set.temperature()).abs() < 1e-12);
1032    }
1033
1034    #[test]
1035    fn test_ewald_real_space_neutral_system() {
1036        let mut set = AtomSet::new(20.0);
1037        let ti_na = set.register_type(AtomTypeDesc::sodium_ion());
1038        let ti_cl = set.register_type(AtomTypeDesc::chloride_ion());
1039        set.add_atom([5.0, 5.0, 5.0], ti_na);
1040        set.add_atom([5.5, 5.0, 5.0], ti_cl);
1041        let e = ewald_real_space_energy(&set, 0.5, 3.0);
1042        // Na+ and Cl- attract each other → negative energy
1043        assert!(
1044            e < 0.0,
1045            "Ewald real energy should be negative for Na+/Cl- pair: {}",
1046            e
1047        );
1048    }
1049
1050    #[test]
1051    fn test_ewald_no_energy_no_charges() {
1052        let mut set = AtomSet::new(10.0);
1053        let ti = set.register_type(AtomTypeDesc::argon()); // zero charge
1054        set.add_atom([1.0, 0.0, 0.0], ti);
1055        set.add_atom([2.0, 0.0, 0.0], ti);
1056        let e = ewald_real_space_energy(&set, 0.5, 3.0);
1057        assert!(e.abs() < 1e-15, "zero charges → zero Ewald energy");
1058    }
1059
1060    #[test]
1061    fn test_ewald_erfc_at_zero() {
1062        // erfc(0) = 1.0
1063        let v = erfc_approx(0.0);
1064        assert!((v - 1.0).abs() < 1e-5, "erfc(0) ≈ 1.0, got {}", v);
1065    }
1066
1067    #[test]
1068    fn test_nose_hoover_creation() {
1069        let nh = NoseHooverThermostat::new(1.2, 10, 0.1);
1070        assert_eq!(nh.n_dof, 30);
1071        assert!((nh.target_temperature - 1.2).abs() < 1e-12);
1072        assert!(!nh.is_active());
1073    }
1074
1075    #[test]
1076    fn test_nose_hoover_half_step() {
1077        let mut nh = NoseHooverThermostat::new(1.0, 1, 0.1);
1078        // KE = 5.0, g*T = 3*1.0 = 3.0 → dξ/dt > 0
1079        let scale = nh.half_step_xi(5.0, 0.01);
1080        // scale < 1 because thermostat cools down (xi > 0)
1081        assert!(
1082            scale < 1.0,
1083            "scale should be < 1 for KE > target: {}",
1084            scale
1085        );
1086    }
1087
1088    #[test]
1089    fn test_nose_hoover_full_step_eta() {
1090        let mut nh = NoseHooverThermostat::new(1.0, 1, 0.1);
1091        nh.xi = 2.0;
1092        nh.full_step_eta(0.01);
1093        assert!(
1094            (nh.eta - 0.02).abs() < 1e-12,
1095            "eta should advance: {}",
1096            nh.eta
1097        );
1098    }
1099
1100    #[test]
1101    fn test_nvt_sim_creation() {
1102        let sim = PyNvtSimulation::new(PyMdConfig::argon_reduced(), 0.1);
1103        assert_eq!(sim.step_count(), 0);
1104    }
1105
1106    #[test]
1107    fn test_nvt_sim_add_atom() {
1108        let mut sim = PyNvtSimulation::new(PyMdConfig::argon_reduced(), 0.1);
1109        sim.add_atom([5.0, 5.0, 5.0], 0);
1110        assert_eq!(sim.md.atom_count(), 1);
1111        assert_eq!(sim.thermostat.n_dof, 3);
1112    }
1113
1114    #[test]
1115    fn test_nvt_sim_step() {
1116        let mut sim = PyNvtSimulation::new(PyMdConfig::argon_reduced(), 0.1);
1117        sim.add_atom([5.0, 5.0, 5.0], 0);
1118        sim.add_atom([6.0, 5.0, 5.0], 0);
1119        sim.md.set_velocity(0, [1.0, 0.0, 0.0]);
1120        sim.md.set_velocity(1, [-1.0, 0.0, 0.0]);
1121        sim.step(0.001);
1122        assert_eq!(sim.step_count(), 1);
1123    }
1124
1125    #[test]
1126    fn test_nvt_temperature_nonzero() {
1127        let mut sim = PyNvtSimulation::new(PyMdConfig::argon_reduced(), 0.1);
1128        sim.add_atom([5.0, 5.0, 5.0], 0);
1129        sim.md.set_velocity(0, [1.0, 1.0, 1.0]);
1130        assert!(sim.temperature() > 0.0);
1131    }
1132}