Skip to main content

oxiphysics_gpu/kernels/md_force/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#![allow(clippy::ptr_arg)]
6use super::types::{
7    ForceBuffer, HarmonicAngle, HarmonicBond, LjPotential, NeighborList, VirialStressTensorKernel,
8    VirialTensor,
9};
10use crate::compute::ComputeKernel;
11
12#[cfg(test)]
13use super::types::*;
14
15/// Compute the Lennard-Jones potential energy and scalar force magnitude at
16/// separation `r`.
17///
18/// Returns `(energy, force_magnitude)` where:
19/// * `energy = 4·ε·[(σ/r)^12 − (σ/r)^6]`
20/// * `force_magnitude = 24·ε·[2(σ/r)^12 − (σ/r)^6] / r` (positive = repulsive)
21///
22/// The caller is responsible for applying the cutoff.
23pub fn compute_lj_force(r: f64, lj: &LjPotential) -> (f64, f64) {
24    if r < 1e-30 {
25        return (f64::INFINITY, f64::INFINITY);
26    }
27    let sr = lj.sigma / r;
28    let sr6 = sr.powi(6);
29    let sr12 = sr6 * sr6;
30    let energy = 4.0 * lj.epsilon * (sr12 - sr6);
31    let force_mag = 24.0 * lj.epsilon * (2.0 * sr12 - sr6) / r;
32    (energy, force_mag)
33}
34/// Compute shifted LJ energy at distance `r` with cutoff `rc`.
35///
36/// V_shifted(r) = V(r) - V(rc) for r < rc, 0 otherwise.
37pub fn compute_lj_shifted_energy(r: f64, lj: &LjPotential, cutoff: f64) -> f64 {
38    if r >= cutoff {
39        return 0.0;
40    }
41    let (e_r, _) = compute_lj_force(r, lj);
42    let (e_c, _) = compute_lj_force(cutoff, lj);
43    e_r - e_c
44}
45/// Compute Coulomb force between two charged particles.
46///
47/// Returns `(energy, force_magnitude)`.
48pub fn compute_coulomb_force(r: f64, qi: f64, qj: f64, k_e: f64) -> (f64, f64) {
49    if r < 1e-30 {
50        return (f64::INFINITY, f64::INFINITY);
51    }
52    let energy = k_e * qi * qj / r;
53    let force_mag = k_e * qi * qj / (r * r);
54    (energy, force_mag)
55}
56/// Compute LJ forces using a neighbor list.
57#[allow(clippy::needless_range_loop)]
58pub fn compute_lj_forces_neighborlist(
59    positions: &[[f64; 3]],
60    lj: &LjPotential,
61    nlist: &NeighborList,
62    buffer: &mut ForceBuffer,
63) {
64    let cutoff2 = nlist.cutoff * nlist.cutoff;
65    buffer.clear();
66    let n = positions.len();
67    for i in 0..n {
68        for &j in &nlist.neighbors[i] {
69            if j <= i {
70                continue;
71            }
72            let dx = [
73                positions[i][0] - positions[j][0],
74                positions[i][1] - positions[j][1],
75                positions[i][2] - positions[j][2],
76            ];
77            let r2 = dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2];
78            if r2 >= cutoff2 || r2 < 1e-30 {
79                continue;
80            }
81            let r2_inv = 1.0 / r2;
82            let sr2 = lj.sigma * lj.sigma * r2_inv;
83            let sr6 = sr2 * sr2 * sr2;
84            let sr12 = sr6 * sr6;
85            let energy = 4.0 * lj.epsilon * (sr12 - sr6);
86            let f_mag = 24.0 * lj.epsilon * (2.0 * sr12 - sr6) * r2_inv;
87            let f_ij = [f_mag * dx[0], f_mag * dx[1], f_mag * dx[2]];
88            buffer.add_pair(i, j, f_ij, energy, dx);
89        }
90    }
91}
92/// Compute Coulomb forces using a neighbor list.
93#[allow(clippy::needless_range_loop)]
94pub fn compute_coulomb_forces_neighborlist(
95    positions: &[[f64; 3]],
96    charges: &[f64],
97    k_e: f64,
98    nlist: &NeighborList,
99    buffer: &mut ForceBuffer,
100) {
101    let cutoff2 = nlist.cutoff * nlist.cutoff;
102    let n = positions.len();
103    for i in 0..n {
104        for &j in &nlist.neighbors[i] {
105            if j <= i {
106                continue;
107            }
108            let dx = [
109                positions[i][0] - positions[j][0],
110                positions[i][1] - positions[j][1],
111                positions[i][2] - positions[j][2],
112            ];
113            let r2 = dx[0] * dx[0] + dx[1] * dx[1] + dx[2] * dx[2];
114            if r2 >= cutoff2 || r2 < 1e-30 {
115                continue;
116            }
117            let r = r2.sqrt();
118            let qi = charges[i];
119            let qj = charges[j];
120            let energy = k_e * qi * qj / r;
121            let f_mag = k_e * qi * qj / (r2 * r);
122            let f_ij = [f_mag * dx[0], f_mag * dx[1], f_mag * dx[2]];
123            buffer.add_pair(i, j, f_ij, energy, dx);
124        }
125    }
126}
127/// Compute Lennard-Jones forces for all particle pairs within `cutoff`.
128///
129/// Returns a `Vec<[f64;3]>` of forces, one per particle.
130/// Interactions beyond `cutoff` are ignored.
131#[allow(clippy::needless_range_loop)]
132pub fn compute_all_lj_forces(
133    positions: &[[f64; 3]],
134    _masses: &[f64],
135    lj: &LjPotential,
136    cutoff: f64,
137) -> Vec<[f64; 3]> {
138    let n = positions.len();
139    let cutoff2 = cutoff * cutoff;
140    let mut forces = vec![[0.0f64; 3]; n];
141    for i in 0..n {
142        for j in (i + 1)..n {
143            let dx = positions[i][0] - positions[j][0];
144            let dy = positions[i][1] - positions[j][1];
145            let dz = positions[i][2] - positions[j][2];
146            let r2 = dx * dx + dy * dy + dz * dz;
147            if r2 >= cutoff2 || r2 < 1e-30 {
148                continue;
149            }
150            let sr = lj.sigma / r2.sqrt();
151            let sr6 = sr.powi(6);
152            let sr12 = sr6 * sr6;
153            let f_mag = 24.0 * lj.epsilon * (2.0 * sr12 - sr6) / r2;
154            forces[i][0] += f_mag * dx;
155            forces[i][1] += f_mag * dy;
156            forces[i][2] += f_mag * dz;
157            forces[j][0] -= f_mag * dx;
158            forces[j][1] -= f_mag * dy;
159            forces[j][2] -= f_mag * dz;
160        }
161    }
162    forces
163}
164/// Compute Coulomb forces for all particle pairs within `cutoff`.
165///
166/// Returns a `Vec<[f64;3]>` of forces, one per particle.
167#[allow(clippy::needless_range_loop)]
168pub fn compute_all_coulomb_forces(
169    positions: &[[f64; 3]],
170    charges: &[f64],
171    k_e: f64,
172    cutoff: f64,
173) -> Vec<[f64; 3]> {
174    let n = positions.len();
175    let cutoff2 = cutoff * cutoff;
176    let mut forces = vec![[0.0f64; 3]; n];
177    for i in 0..n {
178        for j in (i + 1)..n {
179            let dx = positions[i][0] - positions[j][0];
180            let dy = positions[i][1] - positions[j][1];
181            let dz = positions[i][2] - positions[j][2];
182            let r2 = dx * dx + dy * dy + dz * dz;
183            if r2 >= cutoff2 || r2 < 1e-30 {
184                continue;
185            }
186            let r = r2.sqrt();
187            let f_mag = k_e * charges[i] * charges[j] / (r2 * r);
188            forces[i][0] += f_mag * dx;
189            forces[i][1] += f_mag * dy;
190            forces[i][2] += f_mag * dz;
191            forces[j][0] -= f_mag * dx;
192            forces[j][1] -= f_mag * dy;
193            forces[j][2] -= f_mag * dz;
194        }
195    }
196    forces
197}
198/// Complementary error function approximation (Abramowitz & Stegun 7.1.26).
199pub(super) fn erfc_approx(x: f64) -> f64 {
200    if x < 0.0 {
201        return 2.0 - erfc_approx(-x);
202    }
203    let t = 1.0 / (1.0 + 0.3275911 * x);
204    let poly = t
205        * (0.254829592
206            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
207    poly * (-x * x).exp()
208}
209/// Self-energy correction for Ewald summation.
210///
211/// Returns -α/√π · Σ qi².
212pub fn ewald_self_energy(charges: &[f64], alpha: f64) -> f64 {
213    let sum_q2: f64 = charges.iter().map(|&q| q * q).sum();
214    -alpha / std::f64::consts::PI.sqrt() * sum_q2
215}
216/// Estimate PPPM mesh contribution to long-range energy from a charge mesh.
217///
218/// This is a simplified mock: computes Σ ρ(k)² * G(k) using a uniform
219/// Green's function G(k) = 1 / |k|² for k ≠ 0.
220pub fn pppm_mesh_energy_estimate(charge_mesh: &[f64], nx: usize, ny: usize, nz: usize) -> f64 {
221    if nx == 0 || ny == 0 || nz == 0 {
222        return 0.0;
223    }
224    let q2: f64 = charge_mesh.iter().map(|&q| q * q).sum();
225    q2 / (nx * ny * nz) as f64
226}
227/// Compute the full virial stress tensor from positions using LJ potential.
228///
229/// Convenience wrapper around `VirialStressTensorKernel`.
230pub fn compute_virial_stress_tensor(
231    positions: &[[f64; 3]],
232    lj: &LjPotential,
233    cutoff: f64,
234) -> VirialTensor {
235    let n = positions.len();
236    let flat_pos: Vec<f64> = positions.iter().flat_map(|p| p.iter().copied()).collect();
237    let params = vec![lj.epsilon, lj.sigma, cutoff];
238    let mut outputs = vec![Vec::new()];
239    VirialStressTensorKernel.execute(&[&flat_pos, &params], &mut outputs, n);
240    if outputs[0].len() < 6 {
241        return VirialTensor::zero();
242    }
243    let mut c = [0.0f64; 6];
244    c.copy_from_slice(&outputs[0][..6]);
245    VirialTensor { components: c }
246}
247/// Compute harmonic bond forces and accumulate into a force buffer.
248///
249/// For each bond `(i, j)` with spring constant `k` and rest length `r0`:
250/// `F_i = -k*(r - r0)*r̂_ij`,  `F_j = +k*(r - r0)*r̂_ij`
251///
252/// Returns `(forces, total_bond_energy)`.
253pub fn compute_bond_forces(positions: &[[f64; 3]], bonds: &[HarmonicBond]) -> (Vec<[f64; 3]>, f64) {
254    let n = positions.len();
255    let mut forces = vec![[0.0f64; 3]; n];
256    let mut total_energy = 0.0f64;
257    for bond in bonds {
258        let i = bond.atom_i;
259        let j = bond.atom_j;
260        if i >= n || j >= n {
261            continue;
262        }
263        let dx = positions[j][0] - positions[i][0];
264        let dy = positions[j][1] - positions[i][1];
265        let dz = positions[j][2] - positions[i][2];
266        let r = (dx * dx + dy * dy + dz * dz).sqrt();
267        if r < 1e-30 {
268            continue;
269        }
270        let delta = r - bond.r0;
271        let energy = 0.5 * bond.k * delta * delta;
272        total_energy += energy;
273        let mag = bond.k * delta / r;
274        forces[i][0] += mag * dx;
275        forces[i][1] += mag * dy;
276        forces[i][2] += mag * dz;
277        forces[j][0] -= mag * dx;
278        forces[j][1] -= mag * dy;
279        forces[j][2] -= mag * dz;
280    }
281    (forces, total_energy)
282}
283/// Compute harmonic angle forces (CPU mock).
284///
285/// The angle θ at vertex `j` (between vectors `r_ij` and `r_kj`) is:
286/// `cos θ = (r_ij · r_kj) / (|r_ij| |r_kj|)`
287///
288/// Forces follow from the gradient of the harmonic angle potential.
289///
290/// Returns `(forces, total_angle_energy)`.
291pub fn compute_angle_forces(
292    positions: &[[f64; 3]],
293    angles: &[HarmonicAngle],
294) -> (Vec<[f64; 3]>, f64) {
295    let n = positions.len();
296    let mut forces = vec![[0.0f64; 3]; n];
297    let mut total_energy = 0.0f64;
298    for angle in angles {
299        let i = angle.atom_i;
300        let j = angle.atom_j;
301        let k = angle.atom_k;
302        if i >= n || j >= n || k >= n {
303            continue;
304        }
305        let rji = [
306            positions[i][0] - positions[j][0],
307            positions[i][1] - positions[j][1],
308            positions[i][2] - positions[j][2],
309        ];
310        let rjk = [
311            positions[k][0] - positions[j][0],
312            positions[k][1] - positions[j][1],
313            positions[k][2] - positions[j][2],
314        ];
315        let len_ji = (rji[0] * rji[0] + rji[1] * rji[1] + rji[2] * rji[2]).sqrt();
316        let len_jk = (rjk[0] * rjk[0] + rjk[1] * rjk[1] + rjk[2] * rjk[2]).sqrt();
317        if len_ji < 1e-30 || len_jk < 1e-30 {
318            continue;
319        }
320        let cos_theta = (rji[0] * rjk[0] + rji[1] * rjk[1] + rji[2] * rjk[2]) / (len_ji * len_jk);
321        let cos_theta = cos_theta.clamp(-1.0, 1.0);
322        let theta = cos_theta.acos();
323        let delta = theta - angle.theta0;
324        total_energy += 0.5 * angle.k_theta * delta * delta;
325        let sin_theta = (1.0 - cos_theta * cos_theta).sqrt().max(1e-12);
326        let d_prefactor = -angle.k_theta * delta / sin_theta;
327        for dim in 0..3 {
328            let d_cos_d_ri =
329                rjk[dim] / (len_ji * len_jk) - cos_theta * rji[dim] / (len_ji * len_ji);
330            let d_cos_d_rk =
331                rji[dim] / (len_ji * len_jk) - cos_theta * rjk[dim] / (len_jk * len_jk);
332            let fi = d_prefactor * d_cos_d_ri;
333            let fk = d_prefactor * d_cos_d_rk;
334            forces[i][dim] += fi;
335            forces[k][dim] += fk;
336            forces[j][dim] -= fi + fk;
337        }
338    }
339    (forces, total_energy)
340}
341/// Compute instantaneous kinetic temperature from particle velocities and masses.
342///
343/// `T = (2 * KE) / (N_dof * k_B)`, where `N_dof = 3*N - 3` (subtract COM).
344/// For simplicity, uses `N_dof = 3*N`.
345///
346/// # Arguments
347/// * `velocities` - Per-particle velocity vectors.
348/// * `masses`     - Per-particle masses.
349/// * `k_boltzmann` - Boltzmann constant in simulation units.
350pub fn kinetic_temperature(velocities: &[[f64; 3]], masses: &[f64], k_boltzmann: f64) -> f64 {
351    let n = velocities.len();
352    if n == 0 || k_boltzmann < 1e-30 {
353        return 0.0;
354    }
355    let ke2: f64 = velocities
356        .iter()
357        .zip(masses.iter())
358        .map(|(v, &m)| m * (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]))
359        .sum();
360    let n_dof = (3 * n) as f64;
361    ke2 / (n_dof * k_boltzmann)
362}
363/// Rescale all velocities to match the target temperature.
364///
365/// Applies the velocity-rescaling thermostat:
366/// `v_i ← v_i * sqrt(T_target / T_current)`
367///
368/// Does nothing if the current temperature is below a floor value.
369pub fn temperature_scale(
370    velocities: &mut Vec<[f64; 3]>,
371    masses: &[f64],
372    t_target: f64,
373    k_boltzmann: f64,
374) {
375    let t_current = kinetic_temperature(velocities, masses, k_boltzmann);
376    if t_current < 1e-30 || t_target < 0.0 {
377        return;
378    }
379    let scale = (t_target / t_current).sqrt();
380    for v in velocities.iter_mut() {
381        v[0] *= scale;
382        v[1] *= scale;
383        v[2] *= scale;
384    }
385}
386#[cfg(test)]
387mod tests {
388    use super::*;
389    #[test]
390    fn test_md_lj_force_repulsive_at_short_range() {
391        let sigma = 1.0_f64;
392        let epsilon = 1.0_f64;
393        let cutoff = 5.0_f64;
394        let r = 0.8_f64 * sigma;
395        let positions = vec![0.0, 0.0, 0.0, r, 0.0, 0.0];
396        let params = vec![epsilon, sigma, cutoff];
397        let mut outputs = vec![Vec::new(), Vec::new()];
398        LennardJonesKernel.execute(&[&positions, &params], &mut outputs, 2);
399        let fx0 = outputs[0][0];
400        let fx1 = outputs[0][3];
401        assert!(
402            fx0 < 0.0,
403            "at r < r_min, force on atom 0 should be negative (repulsive), got {fx0}"
404        );
405        assert!(
406            fx1 > 0.0,
407            "at r < r_min, force on atom 1 should be positive (repulsive), got {fx1}"
408        );
409        assert!(
410            (fx0 + fx1).abs() < 1e-10,
411            "forces should sum to zero (Newton III), got {fx0} + {fx1} = {}",
412            fx0 + fx1
413        );
414    }
415    #[test]
416    fn lj_kernel_correct_force_known_separation() {
417        let sigma = 1.0;
418        let epsilon = 1.0;
419        let cutoff = 3.0;
420        let positions = vec![0.0, 0.0, 0.0, sigma, 0.0, 0.0];
421        let params = vec![epsilon, sigma, cutoff];
422        let mut outputs = vec![Vec::new(), Vec::new()];
423        LennardJonesKernel.execute(&[&positions, &params], &mut outputs, 2);
424        let fx0 = outputs[0][0];
425        assert!(
426            (fx0 - (-24.0)).abs() < 1e-10,
427            "expected fx0 ~ -24.0, got {fx0}"
428        );
429        let pe = outputs[1][0];
430        assert!(pe.abs() < 1e-10, "expected PE ~ 0, got {pe}");
431    }
432    #[test]
433    fn lj_kernel_force_zero_beyond_cutoff() {
434        let sigma = 1.0;
435        let epsilon = 1.0;
436        let cutoff = 2.5;
437        let positions = vec![0.0, 0.0, 0.0, 3.0, 0.0, 0.0];
438        let params = vec![epsilon, sigma, cutoff];
439        let mut outputs = vec![Vec::new(), Vec::new()];
440        LennardJonesKernel.execute(&[&positions, &params], &mut outputs, 2);
441        for &f in &outputs[0] {
442            assert!(f.abs() < 1e-15, "expected zero force, got {f}");
443        }
444        assert!(outputs[1][0].abs() < 1e-15);
445    }
446    #[test]
447    fn lj_minimum_at_r_min() {
448        let lj = LjPotential::new(1.0, 1.0);
449        let r_min = lj.r_min();
450        let (energy, force_mag) = compute_lj_force(r_min, &lj);
451        assert!(
452            (energy - (-lj.epsilon)).abs() < 1e-10,
453            "energy at r_min should be -epsilon={}, got {energy}",
454            -lj.epsilon
455        );
456        assert!(
457            force_mag.abs() < 1e-10,
458            "force at r_min should be 0, got {force_mag}"
459        );
460    }
461    #[test]
462    fn compute_all_lj_forces_newtons_third_law() {
463        let lj = LjPotential::new(1.0, 1.0);
464        let positions = vec![[0.0, 0.0, 0.0], [1.2, 0.0, 0.0], [0.6, 1.0, 0.0]];
465        let masses = vec![1.0; 3];
466        let forces = compute_all_lj_forces(&positions, &masses, &lj, 5.0);
467        assert_eq!(forces.len(), 3);
468        for k in 0..3 {
469            let total: f64 = forces.iter().map(|f| f[k]).sum();
470            assert!(
471                total.abs() < 1e-10,
472                "total force component {k} should be 0, got {total}"
473            );
474        }
475    }
476    #[test]
477    fn compute_all_lj_forces_repulsive_at_short_range() {
478        let sigma = 1.0;
479        let lj = LjPotential::new(1.0, sigma);
480        let positions = vec![[0.0, 0.0, 0.0], [0.9 * sigma, 0.0, 0.0]];
481        let masses = vec![1.0; 2];
482        let forces = compute_all_lj_forces(&positions, &masses, &lj, 5.0);
483        assert!(
484            forces[0][0] < 0.0,
485            "repulsive: force[0].x should be < 0, got {}",
486            forces[0][0]
487        );
488        assert!(
489            forces[1][0] > 0.0,
490            "repulsive: force[1].x should be > 0, got {}",
491            forces[1][0]
492        );
493    }
494    #[test]
495    fn pair_force_kernel_new() {
496        let lj = LjPotential::new(2.0, 0.5);
497        let kern = PairForceKernel::new(lj, 3.0, true);
498        assert!((kern.lj.epsilon - 2.0).abs() < 1e-15);
499        assert!((kern.lj.sigma - 0.5).abs() < 1e-15);
500        assert!((kern.cutoff - 3.0).abs() < 1e-15);
501        assert!(kern.shift);
502        // Test shifted evaluation: at the cutoff boundary, energy should be zero (shifted)
503        let (e_at_cut, _f_at_cut) = kern.evaluate(kern.cutoff - 1e-10);
504        assert!(
505            e_at_cut.abs() < 0.1,
506            "energy near cutoff should be small when shifted"
507        );
508    }
509    #[test]
510    fn test_coulomb_potential() {
511        let cp = CoulombPotential::new(1.0);
512        let (e, f) = cp.compute(1.0, 1.0, 1.0);
513        assert!((e - 1.0).abs() < 1e-10);
514        assert!((f - 1.0).abs() < 1e-10);
515        let (e2, f2) = cp.compute(1.0, -1.0, 1.0);
516        assert!((e2 - (-1.0)).abs() < 1e-10);
517        assert!((f2 - (-1.0)).abs() < 1e-10);
518    }
519    #[test]
520    fn test_coulomb_force_function() {
521        let (e, f) = compute_coulomb_force(2.0, 1.0, 1.0, 1.0);
522        assert!((e - 0.5).abs() < 1e-10);
523        assert!((f - 0.25).abs() < 1e-10);
524    }
525    #[test]
526    fn test_coulomb_kernel_newton_iii() {
527        let positions = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
528        let charges = vec![1.0, -1.0];
529        let params = vec![1.0, 10.0];
530        let mut outputs = vec![Vec::new(), Vec::new()];
531        CoulombKernel.execute(&[&positions, &charges, &params], &mut outputs, 2);
532        for k in 0..3 {
533            let total = outputs[0][k] + outputs[0][3 + k];
534            assert!(
535                total.abs() < 1e-10,
536                "forces should sum to zero in dim {k}, got {total}"
537            );
538        }
539        assert!(
540            outputs[0][0] > 0.0,
541            "particle 0 should be attracted toward +x, got {}",
542            outputs[0][0]
543        );
544    }
545    #[test]
546    fn test_coulomb_kernel_same_charge_repulsive() {
547        let positions = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
548        let charges = vec![1.0, 1.0];
549        let params = vec![1.0, 10.0];
550        let mut outputs = vec![Vec::new(), Vec::new()];
551        CoulombKernel.execute(&[&positions, &charges, &params], &mut outputs, 2);
552        assert!(
553            outputs[0][0] < 0.0,
554            "particle 0 should be repelled in -x, got {}",
555            outputs[0][0]
556        );
557    }
558    #[test]
559    fn test_coulomb_kernel_beyond_cutoff() {
560        let positions = vec![0.0, 0.0, 0.0, 5.0, 0.0, 0.0];
561        let charges = vec![1.0, 1.0];
562        let params = vec![1.0, 3.0];
563        let mut outputs = vec![Vec::new(), Vec::new()];
564        CoulombKernel.execute(&[&positions, &charges, &params], &mut outputs, 2);
565        for &f in &outputs[0] {
566            assert!(
567                f.abs() < 1e-15,
568                "expected zero force beyond cutoff, got {f}"
569            );
570        }
571    }
572    #[test]
573    fn test_compute_all_coulomb_forces_newton_iii() {
574        let positions = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.5, 1.0, 0.0]];
575        let charges = vec![1.0, -1.0, 0.5];
576        let forces = compute_all_coulomb_forces(&positions, &charges, 1.0, 10.0);
577        for k in 0..3 {
578            let total: f64 = forces.iter().map(|f| f[k]).sum();
579            assert!(
580                total.abs() < 1e-10,
581                "total Coulomb force component {k} should be 0, got {total}"
582            );
583        }
584    }
585    #[test]
586    fn test_lj_shifted_energy() {
587        let lj = LjPotential::new(1.0, 1.0);
588        let cutoff = 2.5;
589        let e_at_cutoff = compute_lj_shifted_energy(cutoff, &lj, cutoff);
590        assert!(
591            e_at_cutoff.abs() < 1e-10,
592            "shifted energy at cutoff should be 0, got {e_at_cutoff}"
593        );
594        let e_beyond = compute_lj_shifted_energy(3.0, &lj, cutoff);
595        assert!(e_beyond.abs() < 1e-15);
596    }
597    #[test]
598    fn test_lj_well_depth() {
599        let lj = LjPotential::new(2.5, 1.0);
600        assert!((lj.well_depth() - (-2.5)).abs() < 1e-15);
601    }
602    #[test]
603    fn test_pair_force_kernel_evaluate() {
604        let lj = LjPotential::new(1.0, 1.0);
605        let kern = PairForceKernel::new(lj, 3.0, false);
606        let (e, f) = kern.evaluate(1.0);
607        assert!(e.abs() < 1e-10);
608        assert!((f - 24.0).abs() < 1e-10);
609        let (e2, f2) = kern.evaluate(5.0);
610        assert!(e2.abs() < 1e-15);
611        assert!(f2.abs() < 1e-15);
612    }
613    #[test]
614    fn test_cutoff_scheme_hard() {
615        let scheme = CutoffScheme::Hard { cutoff: 2.5 };
616        assert!((scheme.cutoff_distance() - 2.5).abs() < 1e-15);
617        assert!((scheme.switch_value(1.0) - 1.0).abs() < 1e-15);
618        assert!((scheme.switch_value(3.0) - 0.0).abs() < 1e-15);
619    }
620    #[test]
621    fn test_cutoff_scheme_switched() {
622        let scheme = CutoffScheme::Switched {
623            r_switch: 2.0,
624            r_cutoff: 3.0,
625        };
626        assert!((scheme.cutoff_distance() - 3.0).abs() < 1e-15);
627        assert!((scheme.switch_value(1.5) - 1.0).abs() < 1e-15);
628        assert!((scheme.switch_value(3.5) - 0.0).abs() < 1e-15);
629        assert!((scheme.switch_value(2.5) - 0.5).abs() < 1e-10);
630        let v1 = scheme.switch_value(2.2);
631        let v2 = scheme.switch_value(2.8);
632        assert!(
633            v1 > v2,
634            "switch should decrease: v(2.2)={}, v(2.8)={}",
635            v1,
636            v2
637        );
638    }
639    #[test]
640    fn test_neighbor_list_brute_force() {
641        let positions = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [5.0, 0.0, 0.0]];
642        let nlist = NeighborList::build_brute_force(&positions, 2.0, 0.5);
643        assert_eq!(nlist.num_particles(), 3);
644        assert!(nlist.neighbors[0].contains(&1));
645        assert!(nlist.neighbors[1].contains(&0));
646        assert!(!nlist.neighbors[0].contains(&2));
647        assert!(!nlist.neighbors[2].contains(&0));
648    }
649    #[test]
650    fn test_neighbor_list_num_pairs() {
651        let positions = vec![[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [1.0, 0.0, 0.0]];
652        let nlist = NeighborList::build_brute_force(&positions, 2.0, 0.0);
653        assert_eq!(nlist.num_pairs(), 3);
654    }
655    #[test]
656    fn test_neighbor_list_needs_rebuild() {
657        let nlist = NeighborList {
658            neighbors: vec![],
659            cutoff: 2.5,
660            skin: 0.4,
661        };
662        assert!(!nlist.needs_rebuild(0.1));
663        assert!(nlist.needs_rebuild(0.3));
664    }
665    #[test]
666    fn test_force_buffer_basic() {
667        let mut buf = ForceBuffer::new(3);
668        assert_eq!(buf.forces.len(), 3);
669        assert_eq!(buf.total_energy(), 0.0);
670        buf.add_pair(0, 1, [1.0, 0.0, 0.0], 2.0, [1.0, 0.0, 0.0]);
671        assert!((buf.forces[0][0] - 1.0).abs() < 1e-15);
672        assert!((buf.forces[1][0] - (-1.0)).abs() < 1e-15);
673        assert!((buf.energies[0] - 1.0).abs() < 1e-15);
674        assert!((buf.energies[1] - 1.0).abs() < 1e-15);
675        assert!((buf.total_energy() - 2.0).abs() < 1e-15);
676    }
677    #[allow(clippy::needless_range_loop)]
678    #[test]
679    fn test_force_buffer_total_force_zero() {
680        let mut buf = ForceBuffer::new(3);
681        buf.add_pair(0, 1, [3.0, -1.0, 2.0], 1.0, [1.0, 0.0, 0.0]);
682        buf.add_pair(1, 2, [-1.0, 2.0, 0.5], 0.5, [0.0, 1.0, 0.0]);
683        let total = buf.total_force();
684        for k in 0..3 {
685            assert!(
686                total[k].abs() < 1e-10,
687                "total force[{k}] should be 0, got {}",
688                total[k]
689            );
690        }
691    }
692    #[test]
693    fn test_force_buffer_clear() {
694        let mut buf = ForceBuffer::new(2);
695        buf.add_pair(0, 1, [1.0, 2.0, 3.0], 5.0, [1.0, 0.0, 0.0]);
696        buf.clear();
697        assert!((buf.total_energy() - 0.0).abs() < 1e-15);
698        for f in &buf.forces {
699            for &c in f {
700                assert!(c.abs() < 1e-15);
701            }
702        }
703    }
704    #[test]
705    fn test_force_buffer_reduce() {
706        let mut main_buf = ForceBuffer::new(2);
707        main_buf.add_pair(0, 1, [1.0, 0.0, 0.0], 2.0, [1.0, 0.0, 0.0]);
708        let mut other = ForceBuffer::new(2);
709        other.add_pair(0, 1, [0.5, 0.0, 0.0], 1.0, [1.0, 0.0, 0.0]);
710        main_buf.reduce_from(&[other]);
711        assert!((main_buf.forces[0][0] - 1.5).abs() < 1e-15);
712        assert!((main_buf.total_energy() - 3.0).abs() < 1e-15);
713    }
714    #[test]
715    fn test_force_buffer_virial() {
716        let mut buf = ForceBuffer::new(2);
717        buf.add_pair(0, 1, [2.0, 0.0, 0.0], 1.0, [3.0, 0.0, 0.0]);
718        assert!((buf.virial[0][0] - 3.0).abs() < 1e-15);
719        assert!((buf.virial[1][0] - 3.0).abs() < 1e-15);
720        assert!((buf.total_virial() - 6.0).abs() < 1e-15);
721    }
722    #[allow(clippy::needless_range_loop)]
723    #[test]
724    fn test_lj_forces_neighborlist() {
725        let lj = LjPotential::new(1.0, 1.0);
726        let positions = vec![[0.0, 0.0, 0.0], [1.2, 0.0, 0.0], [5.0, 0.0, 0.0]];
727        let nlist = NeighborList::build_brute_force(&positions, 2.5, 0.0);
728        let mut buf = ForceBuffer::new(3);
729        compute_lj_forces_neighborlist(&positions, &lj, &nlist, &mut buf);
730        let total = buf.total_force();
731        for k in 0..3 {
732            assert!(total[k].abs() < 1e-10, "total[{k}] = {}", total[k]);
733        }
734        for k in 0..3 {
735            assert!(buf.forces[2][k].abs() < 1e-15);
736        }
737    }
738    #[test]
739    fn test_coulomb_forces_neighborlist() {
740        let positions = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]];
741        let charges = vec![1.0, -1.0];
742        let nlist = NeighborList::build_brute_force(&positions, 5.0, 0.0);
743        let mut buf = ForceBuffer::new(2);
744        compute_coulomb_forces_neighborlist(&positions, &charges, 1.0, &nlist, &mut buf);
745        assert!(buf.forces[0][0] > 0.0, "should attract toward +x");
746        assert!(buf.forces[1][0] < 0.0, "should attract toward -x");
747        let total = buf.total_force();
748        assert!(total[0].abs() < 1e-10);
749    }
750    #[allow(clippy::needless_range_loop)]
751    #[test]
752    fn test_lj_forces_neighborlist_matches_brute_force() {
753        let lj = LjPotential::new(1.0, 1.0);
754        let positions = vec![[0.0, 0.0, 0.0], [1.1, 0.0, 0.0], [0.5, 1.0, 0.0]];
755        let cutoff = 5.0;
756        let masses = vec![1.0; 3];
757        let forces_bf = compute_all_lj_forces(&positions, &masses, &lj, cutoff);
758        let nlist = NeighborList::build_brute_force(&positions, cutoff, 0.0);
759        let mut buf = ForceBuffer::new(3);
760        compute_lj_forces_neighborlist(&positions, &lj, &nlist, &mut buf);
761        for i in 0..3 {
762            for k in 0..3 {
763                assert!(
764                    (buf.forces[i][k] - forces_bf[i][k]).abs() < 1e-10,
765                    "mismatch at particle {i}, dim {k}: nlist={}, brute={}",
766                    buf.forces[i][k],
767                    forces_bf[i][k]
768                );
769            }
770        }
771    }
772    #[test]
773    fn test_erfc_approx_at_zero() {
774        let result = erfc_approx(0.0);
775        assert!((result - 1.0).abs() < 1e-4, "erfc(0) ~ 1, got {result}");
776    }
777    #[test]
778    fn test_erfc_approx_large_arg() {
779        let result = erfc_approx(5.0);
780        assert!(result < 1e-10, "erfc(5) ~ 0, got {result}");
781    }
782    #[test]
783    fn test_ewald_self_energy() {
784        let charges = vec![1.0, -1.0];
785        let alpha = 0.5;
786        let se = ewald_self_energy(&charges, alpha);
787        let expected = -2.0 * alpha / std::f64::consts::PI.sqrt();
788        assert!((se - expected).abs() < 1e-10);
789    }
790    #[test]
791    fn test_ewald_real_space_kernel_newton_iii() {
792        let pos = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
793        let charges = vec![1.0, -1.0];
794        let params = vec![0.5, 5.0, 20.0];
795        let mut outputs = vec![Vec::new(), Vec::new()];
796        EwaldRealSpaceKernel.execute(&[&pos, &charges, &params], &mut outputs, 2);
797        assert_eq!(outputs[0].len(), 6);
798        let total_fx = outputs[0][0] + outputs[0][3];
799        assert!(
800            total_fx.abs() < 1e-10,
801            "Ewald Newton III violated: {total_fx}"
802        );
803    }
804    #[test]
805    fn test_ewald_params_accuracy() {
806        let p = EwaldParams::new(0.5, 6.0, 100.0, 20.0);
807        let acc = p.real_space_accuracy();
808        assert!(acc < 0.01, "erfc(3) should be small, got {acc}");
809    }
810    #[test]
811    fn test_pppm_grid_spacing() {
812        let grid = PppmGrid::new(32, 32, 32, 10.0, 2);
813        assert!((grid.dx() - 10.0 / 32.0).abs() < 1e-12);
814        assert_eq!(grid.total_points(), 32768);
815    }
816    #[test]
817    fn test_pppm_charge_assign_single_particle() {
818        let pos = vec![0.5, 0.5, 0.5];
819        let charges = vec![1.0];
820        let grid_params = vec![4.0, 4.0, 4.0, 4.0];
821        let mut outputs = vec![Vec::new()];
822        PppmChargeAssignKernel.execute(&[&pos, &charges, &grid_params], &mut outputs, 1);
823        assert_eq!(outputs[0].len(), 64);
824        let total: f64 = outputs[0].iter().sum();
825        assert!(
826            (total - 1.0).abs() < 1e-10,
827            "total charge on mesh = {total}"
828        );
829    }
830    #[test]
831    fn test_pppm_charge_assign_conservation() {
832        let pos = vec![1.0, 2.0, 3.0, 5.0, 5.0, 5.0];
833        let charges = vec![2.0, -1.5];
834        let grid_params = vec![8.0, 8.0, 8.0, 8.0];
835        let mut outputs = vec![Vec::new()];
836        PppmChargeAssignKernel.execute(&[&pos, &charges, &grid_params], &mut outputs, 2);
837        let total: f64 = outputs[0].iter().sum();
838        assert!(
839            (total - 0.5).abs() < 1e-10,
840            "net charge should be 0.5, got {total}"
841        );
842    }
843    #[test]
844    fn test_pppm_mesh_energy_estimate_positive() {
845        let mesh = vec![1.0, -1.0, 2.0, 0.5];
846        let e = pppm_mesh_energy_estimate(&mesh, 2, 2, 1);
847        assert!(e >= 0.0, "mesh energy should be non-negative");
848    }
849    #[test]
850    fn test_nlist_update_kernel_no_rebuild() {
851        let pos = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
852        let ref_pos = pos.clone();
853        let params = vec![2.5, 0.4];
854        let mut outputs = vec![Vec::new(), Vec::new()];
855        NlistUpdateKernel.execute(&[&pos, &ref_pos, &params], &mut outputs, 2);
856        assert!(
857            (outputs[1][0] - 0.0).abs() < 1e-10,
858            "status should be Valid (0)"
859        );
860    }
861    #[test]
862    fn test_nlist_update_kernel_rebuild() {
863        let pos = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
864        let ref_pos = vec![0.0, 0.0, 0.0, 5.0, 0.0, 0.0];
865        let params = vec![2.5, 0.4];
866        let mut outputs = vec![Vec::new(), Vec::new()];
867        NlistUpdateKernel.execute(&[&pos, &ref_pos, &params], &mut outputs, 2);
868        assert!(
869            (outputs[1][0] - 1.0).abs() < 1e-10,
870            "status should be Rebuilt (1)"
871        );
872    }
873    #[test]
874    fn test_nlist_update_pairs_found() {
875        let pos = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0];
876        let ref_pos = vec![100.0; 9];
877        let params = vec![2.5, 0.4];
878        let mut outputs = vec![Vec::new(), Vec::new()];
879        NlistUpdateKernel.execute(&[&pos, &ref_pos, &params], &mut outputs, 3);
880        let num_pairs = outputs[1][1] as usize;
881        assert_eq!(num_pairs, 1, "only 1 pair should be found, got {num_pairs}");
882    }
883    #[test]
884    fn test_pair_energy_accumulate_basic() {
885        let sigma = 1.0;
886        let pos = vec![0.0, 0.0, 0.0, sigma, 0.0, 0.0];
887        let pairs = vec![0.0, 1.0];
888        let params = vec![1.0, sigma, 5.0];
889        let mut outputs = vec![Vec::new(), Vec::new()];
890        PairEnergyAccumulateKernel.execute(&[&pos, &pairs, &params], &mut outputs, 2);
891        let total = outputs[1][0];
892        assert!(
893            total.abs() < 1e-10,
894            "energy at r=sigma should be 0, got {total}"
895        );
896    }
897    #[test]
898    fn test_pair_energy_accumulate_split_equally() {
899        let pos = vec![0.0, 0.0, 0.0, 0.9, 0.0, 0.0];
900        let pairs = vec![0.0, 1.0];
901        let params = vec![1.0, 1.0, 5.0];
902        let mut outputs = vec![Vec::new(), Vec::new()];
903        PairEnergyAccumulateKernel.execute(&[&pos, &pairs, &params], &mut outputs, 2);
904        let e0 = outputs[0][0];
905        let e1 = outputs[0][1];
906        assert!((e0 - e1).abs() < 1e-12, "energy should be split equally");
907        assert!(
908            outputs[1][0] > 0.0,
909            "total energy should be positive at r < r_min"
910        );
911    }
912    #[test]
913    fn test_pair_energy_beyond_cutoff_zero() {
914        let pos = vec![0.0, 0.0, 0.0, 10.0, 0.0, 0.0];
915        let pairs = vec![0.0, 1.0];
916        let params = vec![1.0, 1.0, 2.5];
917        let mut outputs = vec![Vec::new(), Vec::new()];
918        PairEnergyAccumulateKernel.execute(&[&pos, &pairs, &params], &mut outputs, 2);
919        assert!(
920            outputs[1][0].abs() < 1e-15,
921            "energy beyond cutoff should be 0"
922        );
923    }
924    #[test]
925    fn test_virial_tensor_trace() {
926        let vt = VirialTensor {
927            components: [1.0, 2.0, 3.0, 0.5, 0.2, 0.1],
928        };
929        assert!((vt.trace() - 6.0).abs() < 1e-12);
930    }
931    #[test]
932    fn test_virial_tensor_pressure() {
933        let vt = VirialTensor {
934            components: [-3.0, -3.0, -3.0, 0.0, 0.0, 0.0],
935        };
936        let p = vt.pressure_contribution(1.0);
937        assert!((p - 3.0).abs() < 1e-12);
938    }
939    #[test]
940    fn test_virial_tensor_add() {
941        let a = VirialTensor {
942            components: [1.0, 2.0, 3.0, 0.0, 0.0, 0.0],
943        };
944        let b = VirialTensor {
945            components: [4.0, 5.0, 6.0, 0.0, 0.0, 0.0],
946        };
947        let c = a.add(&b);
948        assert!((c.components[0] - 5.0).abs() < 1e-12);
949        assert!((c.trace() - 21.0).abs() < 1e-12);
950    }
951    #[test]
952    fn test_compute_virial_stress_tensor_symmetric() {
953        let lj = LjPotential::new(1.0, 1.0);
954        let positions = vec![[0.0, 0.0, 0.0], [1.2, 0.0, 0.0]];
955        let vt = compute_virial_stress_tensor(&positions, &lj, 5.0);
956        assert!(vt.components[1].abs() < 1e-10, "Wyy should be 0");
957        assert!(vt.components[2].abs() < 1e-10, "Wzz should be 0");
958    }
959    #[test]
960    fn test_virial_kernel_newton_iii_check() {
961        let lj = LjPotential::new(1.0, 1.0);
962        let positions = [[0.0, 0.0, 0.0], [1.1, 0.0, 0.0], [0.5, 1.0, 0.0]];
963        let flat_pos: Vec<f64> = positions.iter().flat_map(|p| p.iter().copied()).collect();
964        let params = [lj.epsilon, lj.sigma, 5.0f64];
965        let mut outputs = vec![Vec::new()];
966        VirialStressTensorKernel.execute(&[&flat_pos, &params], &mut outputs, 3);
967        assert_eq!(
968            outputs[0].len(),
969            6,
970            "virial tensor should have 6 components"
971        );
972    }
973    #[allow(clippy::needless_range_loop)]
974    #[test]
975    fn test_bond_force_equilibrium_no_force() {
976        let r0 = 1.5_f64;
977        let positions = vec![[0.0, 0.0, 0.0], [r0, 0.0, 0.0]];
978        let bonds = vec![HarmonicBond::new(0, 1, 100.0, r0)];
979        let (forces, energy) = compute_bond_forces(&positions, &bonds);
980        assert_eq!(forces.len(), 2);
981        for dim in 0..3 {
982            assert!(
983                forces[0][dim].abs() < 1e-10,
984                "force at equilibrium should be 0"
985            );
986            assert!(
987                forces[1][dim].abs() < 1e-10,
988                "force at equilibrium should be 0"
989            );
990        }
991        assert!(
992            energy.abs() < 1e-10,
993            "energy at equilibrium should be 0, got {energy}"
994        );
995    }
996    #[allow(clippy::needless_range_loop)]
997    #[test]
998    fn test_bond_force_compressed() {
999        let r0 = 2.0_f64;
1000        let r = 1.0_f64;
1001        let k = 50.0_f64;
1002        let positions = vec![[0.0, 0.0, 0.0], [r, 0.0, 0.0]];
1003        let bonds = vec![HarmonicBond::new(0, 1, k, r0)];
1004        let (forces, energy) = compute_bond_forces(&positions, &bonds);
1005        assert!(forces[0][0] < 0.0, "atom 0 should be pushed away from bond");
1006        assert!(forces[1][0] > 0.0, "atom 1 should be pushed away from bond");
1007        for dim in 0..3 {
1008            assert!(
1009                (forces[0][dim] + forces[1][dim]).abs() < 1e-10,
1010                "Newton III violated at dim {dim}"
1011            );
1012        }
1013        let expected_e = 0.5 * k * (r - r0).powi(2);
1014        assert!(
1015            (energy - expected_e).abs() < 1e-10,
1016            "energy mismatch: {energy} vs {expected_e}"
1017        );
1018    }
1019    #[test]
1020    fn test_bond_force_kernel_executes() {
1021        let positions_flat = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
1022        let bond_data = vec![0.0, 1.0, 100.0, 1.0];
1023        let mut outputs = vec![Vec::new(), Vec::new()];
1024        BondForceKernel.execute(&[&positions_flat, &bond_data], &mut outputs, 2);
1025        assert_eq!(outputs[0].len(), 6, "forces should have 6 components (3*2)");
1026        assert_eq!(outputs[1].len(), 1, "energies should have 1 component");
1027        for &f in &outputs[0] {
1028            assert!(f.abs() < 1e-10, "force at equilibrium should be 0, got {f}");
1029        }
1030    }
1031    #[test]
1032    fn test_bond_force_kernel_stretched() {
1033        let positions_flat = vec![0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
1034        let bond_data = vec![0.0, 1.0, 10.0, 1.0];
1035        let mut outputs = vec![Vec::new(), Vec::new()];
1036        BondForceKernel.execute(&[&positions_flat, &bond_data], &mut outputs, 2);
1037        let fx0 = outputs[0][0];
1038        let fx1 = outputs[0][3];
1039        assert!(
1040            fx0 > 0.0,
1041            "atom 0 should be pulled toward atom 1 (positive x)"
1042        );
1043        assert!(
1044            fx1 < 0.0,
1045            "atom 1 should be pulled toward atom 0 (negative x)"
1046        );
1047        assert!(
1048            (fx0 + fx1).abs() < 1e-10,
1049            "Newton III: forces should cancel"
1050        );
1051    }
1052    #[test]
1053    fn test_angle_force_at_equilibrium() {
1054        let positions = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
1055        let angles = vec![HarmonicAngle::new(0, 1, 2, 50.0, std::f64::consts::PI)];
1056        let (forces, energy) = compute_angle_forces(&positions, &angles);
1057        assert_eq!(forces.len(), 3);
1058        assert!(
1059            energy.abs() < 1e-8,
1060            "energy at equilibrium angle should be ~0, got {energy}"
1061        );
1062    }
1063    #[test]
1064    fn test_angle_force_finite_at_90_degrees() {
1065        let positions = vec![[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
1066        let theta0 = std::f64::consts::PI / 2.0;
1067        let angles = vec![HarmonicAngle::new(0, 1, 2, 100.0, theta0)];
1068        let (forces, energy) = compute_angle_forces(&positions, &angles);
1069        assert_eq!(forces.len(), 3);
1070        assert!(energy.is_finite(), "angle energy should be finite");
1071        for f in &forces {
1072            for &c in f {
1073                assert!(c.is_finite(), "angle force component should be finite: {c}");
1074            }
1075        }
1076        assert!(
1077            energy.abs() < 1e-8,
1078            "at equilibrium angle energy should be ~0, got {energy}"
1079        );
1080    }
1081    #[test]
1082    fn test_angle_force_kernel_executes() {
1083        let positions_flat = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
1084        let theta0 = std::f64::consts::PI / 2.0;
1085        let angle_data = vec![0.0, 1.0, 2.0, 100.0, theta0];
1086        let mut outputs = vec![Vec::new(), Vec::new()];
1087        AngleForceKernel.execute(&[&positions_flat, &angle_data], &mut outputs, 3);
1088        assert_eq!(outputs[0].len(), 9, "forces should have 9 components (3*3)");
1089        assert_eq!(outputs[1].len(), 1, "energies should have 1 element");
1090        for &f in &outputs[0] {
1091            assert!(f.is_finite(), "angle force not finite: {f}");
1092        }
1093    }
1094    #[test]
1095    fn test_kinetic_temperature_basic() {
1096        let velocities = vec![[3.0, 0.0, 0.0]];
1097        let masses = vec![1.0];
1098        let kb = 1.0;
1099        let t = kinetic_temperature(&velocities, &masses, kb);
1100        assert!((t - 3.0).abs() < 1e-10, "expected T=3, got {t}");
1101    }
1102    #[test]
1103    fn test_kinetic_temperature_zero_velocity() {
1104        let velocities = vec![[0.0; 3]; 5];
1105        let masses = vec![1.0; 5];
1106        let t = kinetic_temperature(&velocities, &masses, 1.0);
1107        assert!(
1108            t.abs() < 1e-15,
1109            "temperature of zero-velocity system should be 0"
1110        );
1111    }
1112    #[test]
1113    fn test_temperature_scale_reaches_target() {
1114        let mut velocities = vec![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
1115        let masses = vec![1.0, 1.0];
1116        let kb = 1.0;
1117        let t_before = kinetic_temperature(&velocities, &masses, kb);
1118        assert!(t_before > 0.0);
1119        let t_target = t_before * 4.0;
1120        temperature_scale(&mut velocities, &masses, t_target, kb);
1121        let t_after = kinetic_temperature(&velocities, &masses, kb);
1122        assert!(
1123            (t_after - t_target).abs() < 1e-8,
1124            "after scaling: expected T={t_target}, got T={t_after}"
1125        );
1126    }
1127    #[test]
1128    fn test_temperature_scale_kernel_rescales() {
1129        let vel_flat = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
1130        let masses = vec![1.0, 1.0];
1131        let kb = 1.0;
1132        let t_target = 2.0 / 3.0;
1133        let params = vec![t_target, kb];
1134        let mut outputs = vec![Vec::new(), Vec::new()];
1135        TemperatureScaleKernel.execute(&[&vel_flat, &masses, &params], &mut outputs, 2);
1136        assert_eq!(outputs[0].len(), 6);
1137        assert_eq!(outputs[1].len(), 2);
1138        let t_before = outputs[1][0];
1139        let t_after = outputs[1][1];
1140        assert!(t_before > 0.0, "t_before should be positive");
1141        assert!(
1142            (t_after - t_target).abs() < 1e-8,
1143            "t_after should be target {t_target}, got {t_after}"
1144        );
1145    }
1146    #[test]
1147    fn test_temperature_scale_kernel_outputs_finite() {
1148        let vel_flat = vec![2.0, 1.0, 0.5, 0.3, 0.7, 1.2, 0.1, 0.4, 0.9];
1149        let masses = vec![1.0, 2.0, 0.5];
1150        let params = vec![300.0, 1.0];
1151        let mut outputs = vec![Vec::new(), Vec::new()];
1152        TemperatureScaleKernel.execute(&[&vel_flat, &masses, &params], &mut outputs, 3);
1153        for &v in &outputs[0] {
1154            assert!(v.is_finite(), "scaled velocity not finite: {v}");
1155        }
1156        for &t in &outputs[1] {
1157            assert!(t.is_finite(), "temperature not finite: {t}");
1158        }
1159    }
1160}