Skip to main content

chematic_ff/
mmff94_minimizer.rs

1//! MMFF94 geometry minimizer using complete Halgren 1996 parameters.
2//!
3//! Provides full MMFF94 energy evaluation (bond + angle + torsion + vdW + electrostatic)
4//! and geometry optimization with two algorithms:
5//! - **Steepest descent** (`minimize_mmff94_full`) — robust, simple
6//! - **L-BFGS** (`minimize_mmff94_lbfgs`) — faster convergence, quasi-Newton
7//!
8//! ## Energy terms
9//! - **Bond**: cubic-corrected harmonic (Halgren MMFF.II eq. 1)
10//! - **Angle**: cubic-corrected harmonic (Halgren MMFF.III eq. 2)
11//! - **Torsion**: three-term Fourier (Halgren MMFF.IV)
12//! - **vdW**: buffered 14-7 potential with Slater-Kirkwood combining rule (Halgren MMFF.I eq. 2)
13//! - **Electrostatic**: Coulomb with δ buffer (Halgren MMFF.V eq. 14)
14
15use std::collections::VecDeque;
16
17use chematic_core::{AtomIdx, Molecule};
18
19use crate::mmff94_energy::{
20    mmff94_angle_energy, mmff94_bond_energy, mmff94_oop, mmff94_stbn, mmff94_torsion_energy,
21    mmff94_vdw_combined,
22};
23use crate::mmff94_numeric::{assign_mmff94_numeric_types, mmff94_charges_numeric, NumericTypeError};
24
25// ─── Public types ────────────────────────────────────────────────────────────
26
27/// Result of a geometry minimization run.
28#[derive(Debug, Clone)]
29pub struct MinimizeResult {
30    /// Final MMFF94 energy (kcal/mol).
31    pub energy: f64,
32    /// RMSD of atom positions vs initial geometry (Å).
33    pub rmsd: f64,
34    /// Whether the minimization converged before `max_iter`.
35    pub converged: bool,
36    /// Number of gradient steps performed.
37    pub iterations: usize,
38}
39
40/// Error from MMFF94 minimizer setup.
41#[derive(Debug)]
42pub enum MinimizerError {
43    TypeAssignment(NumericTypeError),
44}
45
46impl From<NumericTypeError> for MinimizerError {
47    fn from(e: NumericTypeError) -> Self {
48        MinimizerError::TypeAssignment(e)
49    }
50}
51
52impl std::fmt::Display for MinimizerError {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            MinimizerError::TypeAssignment(e) => write!(f, "MMFF94 type assignment failed: {}", e),
56        }
57    }
58}
59
60// ─── Public API ──────────────────────────────────────────────────────────────
61
62/// Per-term MMFF94 energy breakdown (kcal/mol). Includes all 7 Halgren 1996 energy terms.
63#[derive(Debug, Clone, Copy)]
64pub struct EnergyBreakdown {
65    pub bond: f64,
66    pub angle: f64,
67    /// Stretch-bend coupling (STRE-BEN, Halgren MMFF.V)
68    pub stretch_bend: f64,
69    pub torsion: f64,
70    /// Out-of-plane bending for sp2 atoms (Halgren MMFF.VI)
71    pub oop: f64,
72    pub vdw: f64,
73    pub electrostatic: f64,
74    pub total: f64,
75}
76
77/// Compute total MMFF94 energy for a given geometry (kcal/mol).
78///
79/// Includes bond, angle, torsion, vdW, and electrostatic terms.
80/// Does not modify coordinates.
81pub fn mmff94_total_energy(
82    mol: &Molecule,
83    coords: &[[f64; 3]],
84) -> Result<f64, MinimizerError> {
85    let types = assign_mmff94_numeric_types(mol)?;
86    let charges = mmff94_charges_numeric(mol).unwrap_or_else(|_| vec![0.0; mol.atom_count()]);
87    Ok(total_energy(mol, coords, &types, &charges))
88}
89
90/// Scan a torsion dihedral angle i-j-k-l from 0° to 360° in `steps` increments,
91/// returning (angle_deg, energy_kcal) pairs. Coordinates are not modified.
92///
93/// At each step the dihedral is set by rotating atoms past `k` about the j-k bond.
94pub fn mmff94_torsion_scan(
95    mol: &Molecule,
96    coords: &[[f64; 3]],
97    atom_i: usize,
98    atom_j: usize,
99    atom_k: usize,
100    atom_l: usize,
101    steps: usize,
102) -> Result<Vec<(f64, f64)>, MinimizerError> {
103    let types = assign_mmff94_numeric_types(mol)?;
104    let charges = mmff94_charges_numeric(mol).unwrap_or_else(|_| vec![0.0; mol.atom_count()]);
105    let n = mol.atom_count();
106    let steps = steps.max(2);
107
108    let mut results = Vec::with_capacity(steps);
109
110    // Collect atoms on the `l` side of the j-k bond (BFS from k, not crossing j)
111    let moving_atoms: Vec<usize> = {
112        let mut visited = vec![false; n];
113        visited[atom_j] = true;
114        let mut queue = std::collections::VecDeque::new();
115        queue.push_back(atom_k);
116        visited[atom_k] = true;
117        let mut group = Vec::new();
118        while let Some(cur) = queue.pop_front() {
119            group.push(cur);
120            for (nb, _) in mol.neighbors(AtomIdx(cur as u32)) {
121                let nbi = nb.0 as usize;
122                if !visited[nbi] {
123                    visited[nbi] = true;
124                    queue.push_back(nbi);
125                }
126            }
127        }
128        group
129    };
130
131    let mut work = coords.to_vec();
132
133    // Rotate the moving group in `steps` increments of 360°/steps
134    let step_rad = 2.0 * std::f64::consts::PI / steps as f64;
135
136    for step in 0..steps {
137        let angle_deg = step as f64 * 360.0 / steps as f64;
138
139        if step > 0 {
140            // Rotate moving_atoms by step_rad about the j→k axis
141            let j = work[atom_j];
142            let k = work[atom_k];
143            let axis = {
144                let d = [k[0] - j[0], k[1] - j[1], k[2] - j[2]];
145                let len = (d[0] * d[0] + d[1] * d[1] + d[2] * d[2]).sqrt();
146                if len < 1e-12 { [1.0, 0.0, 0.0] } else { [d[0]/len, d[1]/len, d[2]/len] }
147            };
148            let (sin_a, cos_a) = step_rad.sin_cos();
149            for &ai in &moving_atoms {
150                // Rodrigues' rotation about axis through j
151                let p = [work[ai][0] - j[0], work[ai][1] - j[1], work[ai][2] - j[2]];
152                let cross = [
153                    axis[1]*p[2] - axis[2]*p[1],
154                    axis[2]*p[0] - axis[0]*p[2],
155                    axis[0]*p[1] - axis[1]*p[0],
156                ];
157                let dot = axis[0]*p[0] + axis[1]*p[1] + axis[2]*p[2];
158                work[ai] = [
159                    j[0] + cos_a*p[0] + sin_a*cross[0] + (1.0-cos_a)*dot*axis[0],
160                    j[1] + cos_a*p[1] + sin_a*cross[1] + (1.0-cos_a)*dot*axis[1],
161                    j[2] + cos_a*p[2] + sin_a*cross[2] + (1.0-cos_a)*dot*axis[2],
162                ];
163                let _ = (atom_i, atom_l); // suppress unused warnings
164            }
165        }
166
167        let energy = total_energy(mol, &work, &types, &charges);
168        results.push((angle_deg, energy));
169    }
170
171    Ok(results)
172}
173
174/// Compute per-term MMFF94 energy breakdown for a given geometry.
175pub fn mmff94_energy_breakdown(
176    mol: &Molecule,
177    coords: &[[f64; 3]],
178) -> Result<EnergyBreakdown, MinimizerError> {
179    let types = assign_mmff94_numeric_types(mol)?;
180    let charges = mmff94_charges_numeric(mol).unwrap_or_else(|_| vec![0.0; mol.atom_count()]);
181    let b = bond_energy(mol, coords, &types);
182    let a = angle_energy(mol, coords, &types);
183    let sb = stretch_bend_energy(mol, coords, &types);
184    let t = torsion_energy(mol, coords, &types);
185    let o = oop_energy(mol, coords, &types);
186    let v = vdw_energy(mol, coords, &types);
187    let e = elec_energy(mol, coords, &charges);
188    Ok(EnergyBreakdown {
189        bond: b,
190        angle: a,
191        stretch_bend: sb,
192        torsion: t,
193        oop: o,
194        vdw: v,
195        electrostatic: e,
196        total: b + a + sb + t + o + v + e,
197    })
198}
199
200/// Minimize molecular geometry using the full MMFF94 force field.
201///
202/// Uses steepest descent with finite-difference gradients and the complete
203/// Halgren 1996 parameter tables (bond, angle, torsion, vdW, electrostatic).
204///
205/// # Arguments
206/// * `mol` — molecule graph (topology only, no coordinates)
207/// * `coords` — initial 3D coordinates `[[x, y, z]]` in Å; updated in place
208/// * `max_iter` — maximum gradient steps (200 typically sufficient)
209pub fn minimize_mmff94_full(
210    mol: &Molecule,
211    coords: &mut Vec<[f64; 3]>,
212    max_iter: usize,
213) -> Result<MinimizeResult, MinimizerError> {
214    if mol.atom_count() <= 1 {
215        return Ok(MinimizeResult {
216            energy: 0.0,
217            rmsd: 0.0,
218            converged: true,
219            iterations: 0,
220        });
221    }
222
223    let types = assign_mmff94_numeric_types(mol)?;
224    let charges = mmff94_charges_numeric(mol).unwrap_or_else(|_| vec![0.0; mol.atom_count()]);
225
226    let n = mol.atom_count();
227    let initial = coords.clone();
228    let convergence = 1e-4_f64;
229    let step_size = 0.05_f64;
230    let delta = 1e-4_f64;
231
232    let mut iters = 0usize;
233    let mut converged = false;
234
235    for _ in 0..max_iter {
236        iters += 1;
237        let grad = compute_gradient(mol, coords, &types, &charges, delta);
238        let max_g = grad.iter().flat_map(|v| v.iter()).map(|x| x.abs()).fold(0.0_f64, f64::max);
239
240        if max_g < convergence {
241            converged = true;
242            break;
243        }
244
245        let scale = step_size / max_g.max(1e-8);
246        for i in 0..n {
247            for axis in 0..3 {
248                coords[i][axis] -= scale * grad[i][axis];
249            }
250        }
251    }
252
253    let energy = total_energy(mol, coords, &types, &charges);
254
255    let rmsd = {
256        let sum: f64 = coords
257            .iter()
258            .zip(initial.iter())
259            .map(|(c, i0)| {
260                let dx = c[0] - i0[0];
261                let dy = c[1] - i0[1];
262                let dz = c[2] - i0[2];
263                dx * dx + dy * dy + dz * dz
264            })
265            .sum();
266        (sum / n as f64).sqrt()
267    };
268
269    Ok(MinimizeResult {
270        energy,
271        rmsd,
272        converged,
273        iterations: iters,
274    })
275}
276
277/// Minimize molecular geometry using L-BFGS (limited-memory quasi-Newton).
278///
279/// Typically converges in 2–5× fewer iterations than steepest descent for
280/// well-behaved energy surfaces. Falls back to a steepest-descent step when
281/// the curvature condition `y·s > 0` is not satisfied.
282///
283/// Uses finite-difference gradients (δ=1e-4 Å) and backtracking Armijo line search.
284pub fn minimize_mmff94_lbfgs(
285    mol: &Molecule,
286    coords: &mut Vec<[f64; 3]>,
287    max_iter: usize,
288) -> Result<MinimizeResult, MinimizerError> {
289    const M: usize = 5;            // L-BFGS history size
290    const DELTA: f64 = 1e-4;       // finite-difference step (Å)
291    const CONVERGENCE: f64 = 1e-4; // max |gradient| threshold
292    const C_ARMIJO: f64 = 1e-4;   // Armijo sufficient-decrease constant
293    const TAU: f64 = 0.5;          // Armijo backtracking factor
294
295    if mol.atom_count() <= 1 {
296        return Ok(MinimizeResult { energy: 0.0, rmsd: 0.0, converged: true, iterations: 0 });
297    }
298
299    let types = assign_mmff94_numeric_types(mol)?;
300    let charges = mmff94_charges_numeric(mol).unwrap_or_else(|_| vec![0.0; mol.atom_count()]);
301
302    let n = mol.atom_count();
303    let initial = coords.clone();
304
305    // Circular history buffer: (s_k = Δx, y_k = Δg, ρ_k = 1/(y·s))
306    let mut history: VecDeque<(Vec<[f64; 3]>, Vec<[f64; 3]>, f64)> = VecDeque::new();
307
308    let mut g = compute_gradient(mol, coords, &types, &charges, DELTA);
309    let mut f0 = total_energy(mol, coords, &types, &charges);
310
311    let mut iters = 0usize;
312    let mut converged = false;
313
314    for _ in 0..max_iter {
315        iters += 1;
316
317        // Convergence check
318        let max_g = g.iter().flat_map(|v| v.iter()).map(|x| x.abs()).fold(0.0_f64, f64::max);
319        if max_g < CONVERGENCE {
320            converged = true;
321            break;
322        }
323
324        // Two-loop L-BFGS recursion → search direction p
325        let p = lbfgs_direction(&g, &history);
326
327        // Armijo backtracking line search along p
328        let gp: f64 = g.iter().zip(p.iter()).map(|(gi, pi)| dot3(*gi, *pi)).sum();
329        let mut alpha = 1.0_f64;
330        let new_coords = loop {
331            let trial: Vec<[f64; 3]> = coords
332                .iter()
333                .zip(p.iter())
334                .map(|(c, pi)| [c[0] + alpha * pi[0], c[1] + alpha * pi[1], c[2] + alpha * pi[2]])
335                .collect();
336            let f_trial = total_energy(mol, &trial, &types, &charges);
337            if f_trial <= f0 + C_ARMIJO * alpha * gp {
338                break trial;
339            }
340            alpha *= TAU;
341            if alpha < 1e-12 {
342                // Line search failed — take a tiny steepest descent step
343                let scale = 0.01 / max_g.max(1e-8);
344                break coords
345                    .iter()
346                    .zip(g.iter())
347                    .map(|(c, gi)| [c[0] - scale * gi[0], c[1] - scale * gi[1], c[2] - scale * gi[2]])
348                    .collect();
349            }
350        };
351
352        // Compute new gradient
353        let g_new = compute_gradient(mol, &new_coords, &types, &charges, DELTA);
354        let f_new = total_energy(mol, &new_coords, &types, &charges);
355
356        // Compute s = x_new - x, y = g_new - g
357        let s: Vec<[f64; 3]> = new_coords
358            .iter()
359            .zip(coords.iter())
360            .map(|(xn, xo)| [xn[0] - xo[0], xn[1] - xo[1], xn[2] - xo[2]])
361            .collect();
362        let y: Vec<[f64; 3]> = g_new
363            .iter()
364            .zip(g.iter())
365            .map(|(gn, go)| [gn[0] - go[0], gn[1] - go[1], gn[2] - go[2]])
366            .collect();
367        let ys: f64 = y.iter().zip(s.iter()).map(|(yi, si)| dot3(*yi, *si)).sum();
368
369        // Only store if curvature condition holds
370        if ys > 1e-10 {
371            if history.len() >= M {
372                history.pop_front();
373            }
374            history.push_back((s, y, 1.0 / ys));
375        }
376
377        *coords = new_coords;
378        g = g_new;
379        f0 = f_new;
380    }
381
382    let rmsd = {
383        let sum: f64 = coords
384            .iter()
385            .zip(initial.iter())
386            .map(|(c, i0)| {
387                let dx = c[0] - i0[0]; let dy = c[1] - i0[1]; let dz = c[2] - i0[2];
388                dx * dx + dy * dy + dz * dz
389            })
390            .sum();
391        (sum / n as f64).sqrt()
392    };
393
394    Ok(MinimizeResult { energy: f0, rmsd, converged, iterations: iters })
395}
396
397/// L-BFGS two-loop recursion: compute search direction p = -H_k × g.
398fn lbfgs_direction(
399    g: &[[f64; 3]],
400    history: &VecDeque<(Vec<[f64; 3]>, Vec<[f64; 3]>, f64)>,
401) -> Vec<[f64; 3]> {
402    let n = g.len();
403    let m = history.len();
404
405    if m == 0 {
406        // No history: steepest descent direction
407        return g.iter().map(|gi| [-gi[0], -gi[1], -gi[2]]).collect();
408    }
409
410    let mut q: Vec<[f64; 3]> = g.to_vec();
411    let mut alphas = vec![0.0_f64; m];
412
413    // First loop (backward)
414    for i in (0..m).rev() {
415        let (s, y, rho) = &history[i];
416        let sq: f64 = s.iter().zip(q.iter()).map(|(si, qi)| dot3(*si, *qi)).sum();
417        alphas[i] = rho * sq;
418        let a = alphas[i];
419        for (qi, yi) in q.iter_mut().zip(y.iter()) {
420            qi[0] -= a * yi[0]; qi[1] -= a * yi[1]; qi[2] -= a * yi[2];
421        }
422    }
423
424    // Scale by γ = (s_{m-1}·y_{m-1}) / (y_{m-1}·y_{m-1})
425    let (s_last, y_last, _) = &history[m - 1];
426    let sy: f64 = s_last.iter().zip(y_last.iter()).map(|(si, yi)| dot3(*si, *yi)).sum();
427    let yy: f64 = y_last.iter().map(|yi| dot3(*yi, *yi)).sum();
428    let gamma = if yy > 1e-20 { sy / yy } else { 1.0 };
429    for qi in q.iter_mut() {
430        qi[0] *= gamma; qi[1] *= gamma; qi[2] *= gamma;
431    }
432
433    // Second loop (forward)
434    for i in 0..m {
435        let (s, y, rho) = &history[i];
436        let yr: f64 = y.iter().zip(q.iter()).map(|(yi, ri)| dot3(*yi, *ri)).sum();
437        let beta = rho * yr;
438        let diff = alphas[i] - beta;
439        for (qi, si) in q.iter_mut().zip(s.iter()) {
440            qi[0] += diff * si[0]; qi[1] += diff * si[1]; qi[2] += diff * si[2];
441        }
442    }
443
444    // p = -H_k g = -q
445    q.iter().map(|qi| [-qi[0], -qi[1], -qi[2]]).collect()
446}
447
448/// Compute finite-difference gradient: ∂E/∂x_i via central differences.
449fn compute_gradient(
450    mol: &Molecule,
451    coords: &[[f64; 3]],
452    types: &[u8],
453    charges: &[f64],
454    delta: f64,
455) -> Vec<[f64; 3]> {
456    let n = coords.len();
457    let mut grad = vec![[0.0_f64; 3]; n];
458    let mut work = coords.to_vec();
459    for i in 0..n {
460        for axis in 0..3 {
461            work[i][axis] += delta;
462            let ep = total_energy(mol, &work, types, charges);
463            work[i][axis] -= 2.0 * delta;
464            let em = total_energy(mol, &work, types, charges);
465            work[i][axis] += delta;
466            grad[i][axis] = (ep - em) / (2.0 * delta);
467        }
468    }
469    grad
470}
471
472// ─── Energy components ───────────────────────────────────────────────────────
473
474fn total_energy(
475    mol: &Molecule,
476    coords: &[[f64; 3]],
477    types: &[u8],
478    charges: &[f64],
479) -> f64 {
480    bond_energy(mol, coords, types)
481        + angle_energy(mol, coords, types)
482        + stretch_bend_energy(mol, coords, types)
483        + torsion_energy(mol, coords, types)
484        + oop_energy(mol, coords, types)
485        + vdw_energy(mol, coords, types)
486        + elec_energy(mol, coords, charges)
487}
488
489/// Stretch-bend coupling (Halgren MMFF.V eq. 4)
490/// E_sb = 2.51210 × (kba_ijk × Δr_ij + kba_kji × Δr_kj) × Δθ   [kcal/mol, Δθ in degrees]
491fn stretch_bend_energy(mol: &Molecule, coords: &[[f64; 3]], types: &[u8]) -> f64 {
492    const CONV: f64 = 2.51210; // md/Å → kcal/(mol·Å·deg)
493    const RAD_TO_DEG: f64 = 180.0 / std::f64::consts::PI;
494    const KB_CONV: f64 = 143.9325;
495    const CS: f64 = 2.0;
496    let mut energy = 0.0;
497    for j_idx in 0..mol.atom_count() {
498        let j = AtomIdx(j_idx as u32);
499        let neighbors: Vec<usize> = mol.neighbors(j).map(|(nb, _)| nb.0 as usize).collect();
500        if neighbors.len() < 2 {
501            continue;
502        }
503        let at = angle_type_for(types[j_idx]);
504        for (ii, &i) in neighbors.iter().enumerate() {
505            for &k in &neighbors[ii + 1..] {
506                if let Some((kba_ijk, kba_kji)) = mmff94_stbn(at, types[i], types[j_idx], types[k]) {
507                    // Δr_ij
508                    let r_ij = dist(coords[i], coords[j_idx]);
509                    let bt_ij = bond_type_for(types[i], types[j_idx]);
510                    let dr_ij = if let Some(p) = mmff94_bond_energy(bt_ij, types[i], types[j_idx]) {
511                        r_ij - p.r0
512                    } else { 0.0 };
513                    // Δr_kj
514                    let r_kj = dist(coords[k], coords[j_idx]);
515                    let bt_kj = bond_type_for(types[k], types[j_idx]);
516                    let dr_kj = if let Some(p) = mmff94_bond_energy(bt_kj, types[k], types[j_idx]) {
517                        r_kj - p.r0
518                    } else { 0.0 };
519                    // Δθ in degrees
520                    let cos_t = cos_angle(coords[i], coords[j_idx], coords[k]);
521                    if let Some(ap) = mmff94_angle_energy(at, types[i], types[j_idx], types[k]) {
522                        let dtheta = cos_t.acos() * RAD_TO_DEG - ap.theta0;
523                        energy += CONV * (kba_ijk * dr_ij + kba_kji * dr_kj) * dtheta;
524                    }
525                    let _ = (KB_CONV, CS); // suppress warnings
526                }
527            }
528        }
529    }
530    energy
531}
532
533/// Out-of-plane bending for trigonal sp2 centers (Halgren MMFF.VI eq. 6)
534/// E_oop = (0.043844 × koop / 2) × χ²  (χ in degrees: Wilson angle of out-of-plane distortion)
535fn oop_energy(mol: &Molecule, coords: &[[f64; 3]], types: &[u8]) -> f64 {
536    const CONV: f64 = 0.043844;
537    const RAD_TO_DEG: f64 = 180.0 / std::f64::consts::PI;
538    // sp2 atom types that can have OOP bending
539    const SP2_TYPES: &[u8] = &[
540        2, 3, 9, 10, 30, 37, 38, 39, 40, 41, 43, 45, 49, 54, 56, 57,
541        58, 59, 63, 64, 65, 66, 67, 76, 78, 79, 80, 81, 82,
542    ];
543    let mut energy = 0.0;
544    for j_idx in 0..mol.atom_count() {
545        if SP2_TYPES.binary_search(&types[j_idx]).is_err() {
546            continue;
547        }
548        let j = AtomIdx(j_idx as u32);
549        let neighbors: Vec<usize> = mol.neighbors(j).map(|(nb, _)| nb.0 as usize).collect();
550        if neighbors.len() != 3 {
551            continue; // OOP only for exactly 3 substituents (trigonal)
552        }
553        let [i, k, l] = [neighbors[0], neighbors[1], neighbors[2]];
554        if let Some(koop) = mmff94_oop(types[j_idx], types[i], types[k], types[l]) {
555            // Wilson out-of-plane angle: angle between j→l vector and plane (i,j,k)
556            let pj = coords[j_idx];
557            let pi = coords[i];
558            let pk = coords[k];
559            let pl = coords[l];
560            let rji = [pi[0]-pj[0], pi[1]-pj[1], pi[2]-pj[2]];
561            let rjk = [pk[0]-pj[0], pk[1]-pj[1], pk[2]-pj[2]];
562            let rjl = [pl[0]-pj[0], pl[1]-pj[1], pl[2]-pj[2]];
563            let n = cross(rji, rjk); // normal to ijk plane
564            let n_len = (n[0]*n[0]+n[1]*n[1]+n[2]*n[2]).sqrt();
565            let l_len = (rjl[0]*rjl[0]+rjl[1]*rjl[1]+rjl[2]*rjl[2]).sqrt();
566            if n_len < 1e-12 || l_len < 1e-12 {
567                continue;
568            }
569            let sin_chi = dot3(n, rjl) / (n_len * l_len);
570            let chi_deg = sin_chi.clamp(-1.0, 1.0).asin() * RAD_TO_DEG;
571            energy += (CONV * koop / 2.0) * chi_deg * chi_deg;
572        }
573    }
574    energy
575}
576
577/// Bond stretching: cubic-corrected harmonic (Halgren MMFF.II eq. 1)
578/// E = (143.9325 × kb / 2) × ΔR² × (1 − cs×ΔR + (7/12)×cs²×ΔR²)
579fn bond_energy(mol: &Molecule, coords: &[[f64; 3]], types: &[u8]) -> f64 {
580    const KB_CONV: f64 = 143.9325;
581    const CS: f64 = 2.0;
582    let mut energy = 0.0;
583    for (_, bond) in mol.bonds() {
584        let i = bond.atom1.0 as usize;
585        let j = bond.atom2.0 as usize;
586        let bt = bond_type_for(types[i], types[j]);
587        if let Some(p) = mmff94_bond_energy(bt, types[i], types[j]) {
588            let r = dist(coords[i], coords[j]);
589            let dr = r - p.r0;
590            let cubic = 1.0 - CS * dr + (7.0 / 12.0) * CS * CS * dr * dr;
591            energy += (KB_CONV * p.kb / 2.0) * dr * dr * cubic;
592        }
593    }
594    energy
595}
596
597/// Angle bending: cubic-corrected harmonic (Halgren MMFF.III eq. 2)
598/// E = (0.043844 × ka / 2) × Δθ² × (1 − 0.007×Δθ)   [Δθ in degrees]
599fn angle_energy(mol: &Molecule, coords: &[[f64; 3]], types: &[u8]) -> f64 {
600    const KA_CONV: f64 = 0.043844;
601    const RAD_TO_DEG: f64 = 180.0 / std::f64::consts::PI;
602    let mut energy = 0.0;
603    for j_idx in 0..mol.atom_count() {
604        let j = AtomIdx(j_idx as u32);
605        let neighbors: Vec<usize> = mol.neighbors(j).map(|(nb, _)| nb.0 as usize).collect();
606        if neighbors.len() < 2 {
607            continue;
608        }
609        let at = angle_type_for(types[j_idx]);
610        for (ii, &i) in neighbors.iter().enumerate() {
611            for &k in &neighbors[ii + 1..] {
612                if let Some(p) = mmff94_angle_energy(at, types[i], types[j_idx], types[k]) {
613                    let cos_t = cos_angle(coords[i], coords[j_idx], coords[k]);
614                    let theta_deg = cos_t.acos() * RAD_TO_DEG;
615                    let dt = theta_deg - p.theta0;
616                    let cubic = 1.0 - 0.007 * dt;
617                    energy += (KA_CONV * p.ka / 2.0) * dt * dt * cubic;
618                }
619            }
620        }
621    }
622    energy
623}
624
625/// Torsion: three-term Fourier (Halgren MMFF.IV)
626/// E = (v1/2)(1+cosφ) + (v2/2)(1-cos2φ) + (v3/2)(1+cos3φ)
627fn torsion_energy(mol: &Molecule, coords: &[[f64; 3]], types: &[u8]) -> f64 {
628    let mut energy = 0.0;
629    for (_, bond) in mol.bonds() {
630        let j = bond.atom1.0 as usize;
631        let k = bond.atom2.0 as usize;
632        let nbrs_j: Vec<usize> = mol.neighbors(bond.atom1).map(|(nb, _)| nb.0 as usize).collect();
633        let nbrs_k: Vec<usize> = mol.neighbors(bond.atom2).map(|(nb, _)| nb.0 as usize).collect();
634        let tt = torsion_type_for(types[j], types[k]);
635        for &i in &nbrs_j {
636            if i == k {
637                continue;
638            }
639            for &l in &nbrs_k {
640                if l == j {
641                    continue;
642                }
643                if let Some(p) = mmff94_torsion_energy(tt, types[i], types[j], types[k], types[l]) {
644                    let phi = dihedral(coords[i], coords[j], coords[k], coords[l]);
645                    energy += 0.5 * p.v1 * (1.0 + phi.cos())
646                        + 0.5 * p.v2 * (1.0 - (2.0 * phi).cos())
647                        + 0.5 * p.v3 * (1.0 + (3.0 * phi).cos());
648                }
649            }
650        }
651    }
652    energy
653}
654
655/// Van der Waals: buffered 14-7 (Halgren MMFF.I eq. 2)
656/// t = (1.07 × r*) / (r + 0.07 × r*)
657/// E = ε × t⁷ × (t⁷ − 2)
658fn vdw_energy(mol: &Molecule, coords: &[[f64; 3]], types: &[u8]) -> f64 {
659    let n = mol.atom_count();
660    let mut excl = std::collections::HashSet::new();
661    for (_, bond) in mol.bonds() {
662        let i = bond.atom1.0 as usize;
663        let j = bond.atom2.0 as usize;
664        excl.insert((i.min(j), i.max(j)));
665        for (nb_i, _) in mol.neighbors(bond.atom1) {
666            let ni = nb_i.0 as usize;
667            excl.insert((ni.min(j), ni.max(j)));
668        }
669        for (nb_j, _) in mol.neighbors(bond.atom2) {
670            let nj = nb_j.0 as usize;
671            excl.insert((i.min(nj), i.max(nj)));
672        }
673    }
674    let cutoff = 10.0_f64;
675    let mut energy = 0.0;
676    for i in 0..n {
677        for j in (i + 1)..n {
678            if excl.contains(&(i, j)) {
679                continue;
680            }
681            let r = dist(coords[i], coords[j]);
682            if r > cutoff {
683                continue;
684            }
685            if let Some((r_star, eps)) = mmff94_vdw_combined(types[i], types[j]) {
686                if r_star > 0.0 && eps > 0.0 && r > 0.01 {
687                    let t = (1.07 * r_star) / (r + 0.07 * r_star);
688                    let t7 = t.powi(7);
689                    energy += eps * t7 * (t7 - 2.0);
690                }
691            }
692        }
693    }
694    energy
695}
696
697/// Electrostatic: Coulomb with δ=0.05 Å buffer (Halgren MMFF.V eq. 14)
698/// E = 332.0716 × q_i × q_j / (D × (r + δ))   [D=1.0]
699fn elec_energy(mol: &Molecule, coords: &[[f64; 3]], charges: &[f64]) -> f64 {
700    const COULOMB: f64 = 332.0716;
701    const DELTA: f64 = 0.05;
702    let n = mol.atom_count();
703    let mut excl = std::collections::HashSet::new();
704    for (_, bond) in mol.bonds() {
705        let i = bond.atom1.0 as usize;
706        let j = bond.atom2.0 as usize;
707        excl.insert((i.min(j), i.max(j)));
708        for (nb_i, _) in mol.neighbors(bond.atom1) {
709            excl.insert(((nb_i.0 as usize).min(j), (nb_i.0 as usize).max(j)));
710        }
711        for (nb_j, _) in mol.neighbors(bond.atom2) {
712            excl.insert((i.min(nb_j.0 as usize), i.max(nb_j.0 as usize)));
713        }
714    }
715    // 1-4 pairs: scale by 0.75 (MMFF94 convention)
716    let mut one_four = std::collections::HashSet::new();
717    for (_, bond) in mol.bonds() {
718        let j = bond.atom1.0 as usize;
719        let k = bond.atom2.0 as usize;
720        for (nb_j, _) in mol.neighbors(bond.atom1) {
721            let i = nb_j.0 as usize;
722            if i == k {
723                continue;
724            }
725            for (nb_k, _) in mol.neighbors(bond.atom2) {
726                let l = nb_k.0 as usize;
727                if l == j {
728                    continue;
729                }
730                let key = (i.min(l), i.max(l));
731                if !excl.contains(&key) {
732                    one_four.insert(key);
733                }
734            }
735        }
736    }
737    let mut energy = 0.0;
738    for i in 0..n {
739        for j in (i + 1)..n {
740            if excl.contains(&(i, j)) {
741                continue;
742            }
743            let r = dist(coords[i], coords[j]);
744            let scale = if one_four.contains(&(i, j)) { 0.75 } else { 1.0 };
745            energy += scale * COULOMB * charges[i] * charges[j] / (r + DELTA);
746        }
747    }
748    energy
749}
750
751// ─── Geometry helpers ─────────────────────────────────────────────────────────
752
753#[inline]
754fn dist(a: [f64; 3], b: [f64; 3]) -> f64 {
755    let dx = a[0] - b[0];
756    let dy = a[1] - b[1];
757    let dz = a[2] - b[2];
758    (dx * dx + dy * dy + dz * dz).sqrt()
759}
760
761#[inline]
762fn cos_angle(a: [f64; 3], b: [f64; 3], c: [f64; 3]) -> f64 {
763    let ba = [a[0] - b[0], a[1] - b[1], a[2] - b[2]];
764    let bc = [c[0] - b[0], c[1] - b[1], c[2] - b[2]];
765    let dot_val = ba[0] * bc[0] + ba[1] * bc[1] + ba[2] * bc[2];
766    let na = (ba[0] * ba[0] + ba[1] * ba[1] + ba[2] * ba[2]).sqrt();
767    let nc = (bc[0] * bc[0] + bc[1] * bc[1] + bc[2] * bc[2]).sqrt();
768    if na < 1e-12 || nc < 1e-12 {
769        return 0.0;
770    }
771    (dot_val / (na * nc)).clamp(-1.0, 1.0)
772}
773
774/// Signed dihedral angle φ (radians) for the quartet i-j-k-l.
775#[inline]
776fn dihedral(i: [f64; 3], j: [f64; 3], k: [f64; 3], l: [f64; 3]) -> f64 {
777    let b1 = [j[0] - i[0], j[1] - i[1], j[2] - i[2]];
778    let b2 = [k[0] - j[0], k[1] - j[1], k[2] - j[2]];
779    let b3 = [l[0] - k[0], l[1] - k[1], l[2] - k[2]];
780    let n1 = cross(b1, b2);
781    let n2 = cross(b2, b3);
782    let m1 = cross(n1, b2);
783    let b2_len = (b2[0] * b2[0] + b2[1] * b2[1] + b2[2] * b2[2]).sqrt();
784    if b2_len < 1e-12 {
785        return 0.0;
786    }
787    let x = dot3(n1, n2);
788    let y = dot3(m1, n2) / b2_len;
789    y.atan2(x)
790}
791
792#[inline]
793fn cross(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
794    [
795        a[1] * b[2] - a[2] * b[1],
796        a[2] * b[0] - a[0] * b[2],
797        a[0] * b[1] - a[1] * b[0],
798    ]
799}
800
801#[inline]
802fn dot3(a: [f64; 3], b: [f64; 3]) -> f64 {
803    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
804}
805
806// ─── Type classification helpers ──────────────────────────────────────────────
807
808/// Determine MMFF94 bond type: 1 if either atom is sp2/aromatic, else 0.
809fn bond_type_for(ti: u8, tj: u8) -> u8 {
810    const SP2: &[u8] = &[
811        2, 3, 9, 10, 37, 38, 39, 40, 41, 56, 57, 58, 59, 63, 64, 65, 66, 67, 78, 79, 80, 81, 82,
812    ];
813    if SP2.binary_search(&ti).is_ok() || SP2.binary_search(&tj).is_ok() {
814        1
815    } else {
816        0
817    }
818}
819
820/// Determine MMFF94 angle type (simplified: 0 for most organic angles).
821fn angle_type_for(_tj: u8) -> u8 {
822    0
823}
824
825/// Determine MMFF94 torsion type from central bond atom types.
826fn torsion_type_for(tj: u8, tk: u8) -> u8 {
827    const SP2: &[u8] = &[
828        2, 3, 9, 10, 37, 38, 39, 40, 41, 56, 57, 58, 59, 63, 64, 65, 66, 67, 78, 79, 80, 81, 82,
829    ];
830    match (
831        SP2.binary_search(&tj).is_ok(),
832        SP2.binary_search(&tk).is_ok(),
833    ) {
834        (false, false) => 0, // sp3-sp3
835        (true, false) | (false, true) => 1, // sp3-sp2
836        (true, true) => 2, // sp2-sp2
837    }
838}
839
840// ─── Tests ───────────────────────────────────────────────────────────────────
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845    use chematic_core::molecule::MoleculeBuilder;
846    use chematic_core::{Atom, BondOrder, Element};
847
848    fn methane_mol() -> (Molecule, Vec<[f64; 3]>) {
849        let mut b = MoleculeBuilder::new();
850        let c = b.add_atom(Atom::new(Element::C));
851        let h1 = b.add_atom(Atom::new(Element::H));
852        let h2 = b.add_atom(Atom::new(Element::H));
853        let h3 = b.add_atom(Atom::new(Element::H));
854        let h4 = b.add_atom(Atom::new(Element::H));
855        b.add_bond(c, h1, BondOrder::Single).unwrap();
856        b.add_bond(c, h2, BondOrder::Single).unwrap();
857        b.add_bond(c, h3, BondOrder::Single).unwrap();
858        b.add_bond(c, h4, BondOrder::Single).unwrap();
859        let mol = b.build();
860        let coords = vec![
861            [0.0, 0.0, 0.0],
862            [0.630, 0.630, 0.630],
863            [-0.630, -0.630, 0.630],
864            [-0.630, 0.630, -0.630],
865            [0.630, -0.630, -0.630],
866        ];
867        (mol, coords)
868    }
869
870    fn butane_backbone() -> Molecule {
871        let mut b = MoleculeBuilder::new();
872        let c0 = b.add_atom(Atom::new(Element::C));
873        let c1 = b.add_atom(Atom::new(Element::C));
874        let c2 = b.add_atom(Atom::new(Element::C));
875        let c3 = b.add_atom(Atom::new(Element::C));
876        b.add_bond(c0, c1, BondOrder::Single).unwrap();
877        b.add_bond(c1, c2, BondOrder::Single).unwrap();
878        b.add_bond(c2, c3, BondOrder::Single).unwrap();
879        b.build()
880    }
881
882    #[test]
883    fn energy_is_finite_for_methane() {
884        let (mol, coords) = methane_mol();
885        let e = mmff94_total_energy(&mol, &coords).expect("energy");
886        assert!(e.is_finite(), "energy={}", e);
887    }
888
889    #[test]
890    fn torsion_differs_by_conformation() {
891        let mol = butane_backbone();
892        let types = assign_mmff94_numeric_types(&mol).expect("types");
893        // Gauche: ~60° central dihedral
894        let coords_gauche = vec![
895            [0.0, 0.0, 0.0_f64],
896            [1.508, 0.0, 0.0],
897            [2.016, 1.192, 0.688],
898            [3.524, 1.192, 0.688],
899        ];
900        // Anti: ~180° central dihedral
901        let coords_anti = vec![
902            [0.0, 0.0, 0.0_f64],
903            [1.508, 0.0, 0.0],
904            [3.016, 0.0, 0.0],
905            [4.524, 0.0, 0.0],
906        ];
907        let e_gauche = torsion_energy(&mol, &coords_gauche, &types);
908        let e_anti = torsion_energy(&mol, &coords_anti, &types);
909        assert!(e_gauche.is_finite());
910        assert!(e_anti.is_finite());
911        assert!(
912            (e_gauche - e_anti).abs() > 1e-6,
913            "torsion must differ: gauche={}, anti={}",
914            e_gauche,
915            e_anti
916        );
917    }
918
919    #[test]
920    fn vdw_more_repulsive_at_short_range() {
921        let mol = butane_backbone();
922        let types = assign_mmff94_numeric_types(&mol).expect("types");
923        // Atoms 0 and 3 are 1-4 (not excluded from vdW)
924        let coords_close = vec![
925            [0.0, 0.0, 0.0_f64],
926            [1.5, 0.0, 0.0],
927            [3.0, 0.0, 0.0],
928            [0.5, 0.0, 0.0], // atom 3 very close to atom 0
929        ];
930        let coords_far = vec![
931            [0.0, 0.0, 0.0_f64],
932            [1.5, 0.0, 0.0],
933            [3.0, 0.0, 0.0],
934            [8.0, 0.0, 0.0],
935        ];
936        let e_close = vdw_energy(&mol, &coords_close, &types);
937        let e_far = vdw_energy(&mol, &coords_far, &types);
938        assert!(e_close.is_finite());
939        assert!(e_far.is_finite());
940        assert!(e_close > e_far, "close={} should > far={}", e_close, e_far);
941    }
942
943    #[test]
944    fn dihedral_anti_is_pi() {
945        let i = [0.0_f64, 0.0, 0.0];
946        let j = [1.0, 0.0, 0.0];
947        let k = [2.0, 0.0, 1.0];
948        let l = [3.0, 0.0, 0.0];
949        let phi = dihedral(i, j, k, l);
950        assert!(phi.abs() > 2.5, "anti dihedral ≈ π: {}", phi);
951    }
952
953    #[test]
954    fn dihedral_syn_is_zero() {
955        let i = [0.0_f64, 1.0, 0.0];
956        let j = [0.0, 0.0, 0.0];
957        let k = [1.0, 0.0, 0.0];
958        let l = [1.0, 1.0, 0.0];
959        let phi = dihedral(i, j, k, l);
960        assert!(phi.abs() < 0.1, "syn dihedral ≈ 0: {}", phi);
961    }
962
963    #[test]
964    fn minimize_reduces_energy_for_methane() {
965        let (mol, _) = methane_mol();
966        let mut coords = vec![
967            [0.0, 0.0, 0.0_f64],
968            [1.5, 1.5, 1.5],
969            [-1.5, -1.5, 1.5],
970            [-1.5, 1.5, -1.5],
971            [1.5, -1.5, -1.5],
972        ];
973        let e_before = mmff94_total_energy(&mol, &coords).expect("energy before");
974        let result = minimize_mmff94_full(&mol, &mut coords, 300).expect("minimize");
975        assert!(
976            result.energy <= e_before,
977            "minimize should reduce energy: {} → {}",
978            e_before,
979            result.energy
980        );
981        assert!(result.energy.is_finite());
982        assert!(result.iterations > 0);
983    }
984
985    #[test]
986    fn lbfgs_reduces_energy_for_methane() {
987        let (mol, _) = methane_mol();
988        let mut coords = vec![
989            [0.0, 0.0, 0.0_f64],
990            [1.5, 1.5, 1.5],
991            [-1.5, -1.5, 1.5],
992            [-1.5, 1.5, -1.5],
993            [1.5, -1.5, -1.5],
994        ];
995        let e_before = mmff94_total_energy(&mol, &coords).expect("energy before");
996        let result = minimize_mmff94_lbfgs(&mol, &mut coords, 300).expect("lbfgs");
997        assert!(
998            result.energy <= e_before,
999            "L-BFGS should reduce energy: {} → {}",
1000            e_before,
1001            result.energy
1002        );
1003        assert!(result.energy.is_finite());
1004    }
1005
1006    #[test]
1007    fn lbfgs_converges_in_fewer_iters_than_sd() {
1008        let (mol, _) = methane_mol();
1009        // Moderately distorted — both should converge but L-BFGS faster
1010        let base_coords = vec![
1011            [0.0, 0.0, 0.0_f64],
1012            [1.2, 1.2, 1.2],
1013            [-1.2, -1.2, 1.2],
1014            [-1.2, 1.2, -1.2],
1015            [1.2, -1.2, -1.2],
1016        ];
1017        let mut coords_sd = base_coords.clone();
1018        let mut coords_lbfgs = base_coords;
1019        let sd = minimize_mmff94_full(&mol, &mut coords_sd, 500).expect("sd");
1020        let lb = minimize_mmff94_lbfgs(&mol, &mut coords_lbfgs, 500).expect("lbfgs");
1021        // Both should converge; L-BFGS should need ≤ SD iterations
1022        assert!(lb.iterations <= sd.iterations || lb.converged,
1023            "L-BFGS iters={} SD iters={}", lb.iterations, sd.iterations);
1024        assert!(lb.energy.is_finite());
1025    }
1026
1027    #[test]
1028    fn energy_breakdown_sums_to_total() {
1029        let (mol, coords) = methane_mol();
1030        let bd = mmff94_energy_breakdown(&mol, &coords).expect("breakdown");
1031        let sum = bd.bond + bd.angle + bd.stretch_bend + bd.torsion + bd.oop + bd.vdw + bd.electrostatic;
1032        assert!((sum - bd.total).abs() < 1e-10, "sum={} total={}", sum, bd.total);
1033        assert!(bd.total.is_finite());
1034    }
1035
1036    #[test]
1037    fn energy_breakdown_bond_term_positive_for_distorted() {
1038        let (mol, _) = methane_mol();
1039        // Very distorted C-H bonds → high bond energy
1040        let stretched = vec![
1041            [0.0, 0.0, 0.0_f64],
1042            [2.0, 2.0, 2.0],
1043            [-2.0, -2.0, 2.0],
1044            [-2.0, 2.0, -2.0],
1045            [2.0, -2.0, -2.0],
1046        ];
1047        let bd = mmff94_energy_breakdown(&mol, &stretched).expect("breakdown");
1048        assert!(bd.bond > 0.0, "stretched bond energy should be positive: {}", bd.bond);
1049    }
1050}