Skip to main content

chematic_3d/
minimize.rs

1//! Simplified force-field geometry minimization for molecular structures.
2//!
3//! Uses gradient descent with finite differences over three energy terms:
4//! bond stretching, angle bending, and VDW repulsion. Bond lengths and angles
5//! use element-specific UFF-derived parameters rather than bond-order-only values.
6
7use std::collections::HashSet;
8
9use chematic_core::{AtomIdx, BondIdx, BondOrder, Molecule};
10
11use crate::coords::{Coords3D, Point3};
12
13// ---------------------------------------------------------------------------
14// Public API
15// ---------------------------------------------------------------------------
16
17/// Configuration for the minimization algorithm.
18pub struct MinimizeConfig {
19    /// Maximum number of gradient-descent steps.
20    pub max_steps: usize,
21    /// Base step size for coordinate updates (scaled by max gradient).
22    pub step_size: f64,
23    /// Convergence threshold: stop when max gradient component < this value.
24    pub convergence: f64,
25}
26
27impl Default for MinimizeConfig {
28    fn default() -> Self {
29        Self {
30            max_steps: 200,
31            step_size: 0.05,
32            convergence: 1e-4,
33        }
34    }
35}
36
37/// Minimize molecular geometry using default configuration.
38pub fn minimize(mol: &Molecule, coords: Coords3D) -> Coords3D {
39    minimize_with_config(mol, coords, &MinimizeConfig::default())
40}
41
42/// Minimize molecular geometry using the provided configuration.
43pub fn minimize_with_config(mol: &Molecule, coords: Coords3D, config: &MinimizeConfig) -> Coords3D {
44    if mol.atom_count() <= 1 {
45        return coords;
46    }
47
48    let mut c = coords;
49    let delta = 1e-4;
50
51    fn partial(
52        mol: &Molecule,
53        c: &mut Coords3D,
54        idx: AtomIdx,
55        delta: f64,
56        axis: impl Fn(&mut Point3, f64),
57    ) -> f64 {
58        let orig = c.get(idx);
59        let mut p = orig;
60        axis(&mut p, delta);
61        c.set(idx, p);
62        let ep = total_energy(mol, c);
63        let mut p = orig;
64        axis(&mut p, -delta);
65        c.set(idx, p);
66        let em = total_energy(mol, c);
67        c.set(idx, orig);
68        (ep - em) / (2.0 * delta)
69    }
70
71    for _ in 0..config.max_steps {
72        let mut grad = vec![Point3::zero(); mol.atom_count()];
73        let mut max_grad = 0.0f64;
74
75        for i in 0..mol.atom_count() {
76            let idx = AtomIdx(i as u32);
77            grad[i].x = partial(mol, &mut c, idx, delta, |p, d| p.x += d);
78            grad[i].y = partial(mol, &mut c, idx, delta, |p, d| p.y += d);
79            grad[i].z = partial(mol, &mut c, idx, delta, |p, d| p.z += d);
80
81            let gmax = grad[i].x.abs().max(grad[i].y.abs()).max(grad[i].z.abs());
82            if gmax > max_grad {
83                max_grad = gmax;
84            }
85        }
86
87        if max_grad < config.convergence {
88            break;
89        }
90
91        let scale = config.step_size / max_grad.max(1e-8);
92        for i in 0..mol.atom_count() {
93            let idx = AtomIdx(i as u32);
94            let p = c.get(idx);
95            c.set(
96                idx,
97                Point3::new(
98                    p.x - scale * grad[i].x,
99                    p.y - scale * grad[i].y,
100                    p.z - scale * grad[i].z,
101                ),
102            );
103        }
104    }
105
106    c
107}
108
109// ---------------------------------------------------------------------------
110// Total energy
111// ---------------------------------------------------------------------------
112
113fn total_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
114    bond_energy(mol, coords) + angle_energy(mol, coords) + vdw_energy(mol, coords)
115}
116
117// ---------------------------------------------------------------------------
118// UFF-derived element parameters
119// ---------------------------------------------------------------------------
120
121/// Ideal bond length (Å) by atom element pair and bond order.
122/// Canonical pair: (a, b) where a <= b lexicographically.
123fn ideal_bond_len(sym1: &str, sym2: &str, order: BondOrder) -> f64 {
124    let (a, b) = if sym1 <= sym2 { (sym1, sym2) } else { (sym2, sym1) };
125    match (a, b, order) {
126        // C–C
127        ("C", "C", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.540,
128        ("C", "C", BondOrder::Double) => 1.340,
129        ("C", "C", BondOrder::Triple) => 1.204,
130        ("C", "C", BondOrder::Aromatic) => 1.395,
131        // C–H
132        ("C", "H", _) => 1.090,
133        // C–N
134        ("C", "N", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.469,
135        ("C", "N", BondOrder::Double) => 1.279,
136        ("C", "N", BondOrder::Triple) => 1.158,
137        ("C", "N", BondOrder::Aromatic) => 1.340,
138        // C–O
139        ("C", "O", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.427,
140        ("C", "O", BondOrder::Double) => 1.217,
141        ("C", "O", BondOrder::Aromatic) => 1.355,
142        // C–S
143        ("C", "S", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.819,
144        ("C", "S", BondOrder::Double) => 1.610,
145        ("C", "S", BondOrder::Aromatic) => 1.750,
146        // C–F
147        ("C", "F", _) => 1.350,
148        // C–Cl ("C" < "Cl" since "C" == "C" and "" < "l")
149        ("C", "Cl", _) => 1.770,
150        // C–Br ("Br" < "C")
151        ("Br", "C", _) => 1.940,
152        // C–I
153        ("C", "I", _) => 2.140,
154        // C–P
155        ("C", "P", _) => 1.840,
156        // C–Si
157        ("C", "Si", _) => 1.870,
158        // H–H
159        ("H", "H", _) => 0.741,
160        // H–N
161        ("H", "N", _) => 1.010,
162        // H–O
163        ("H", "O", _) => 0.960,
164        // H–S
165        ("H", "S", _) => 1.340,
166        // H–P
167        ("H", "P", _) => 1.420,
168        // N–N
169        ("N", "N", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.450,
170        ("N", "N", BondOrder::Double) => 1.250,
171        ("N", "N", BondOrder::Triple) => 1.100,
172        ("N", "N", BondOrder::Aromatic) => 1.350,
173        // N–O
174        ("N", "O", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.400,
175        ("N", "O", BondOrder::Double) => 1.210,
176        ("N", "O", BondOrder::Aromatic) => 1.340,
177        // O–O
178        ("O", "O", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.480,
179        ("O", "O", BondOrder::Double) => 1.210,
180        // S–S
181        ("S", "S", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 2.050,
182        ("S", "S", BondOrder::Double) => 1.890,
183        // P–P
184        ("P", "P", _) => 2.280,
185        // fallback: order-based only
186        _ => match order {
187            BondOrder::Single | BondOrder::Up | BondOrder::Down => 1.54,
188            BondOrder::Double => 1.34,
189            BondOrder::Triple => 1.20,
190            BondOrder::Quadruple => 1.20,
191            BondOrder::Aromatic => 1.40,
192        },
193    }
194}
195
196/// Atom hybridization inferred from bond orders and aromaticity.
197#[derive(Clone, Copy, PartialEq, Debug)]
198enum Hybridization {
199    SP,   // linear (triple bond present)
200    SP2,  // trigonal planar (double bond or aromatic)
201    SP3,  // tetrahedral
202}
203
204fn atom_hybridization(mol: &Molecule, idx: AtomIdx) -> Hybridization {
205    if mol.atom(idx).aromatic {
206        return Hybridization::SP2;
207    }
208    let mut has_triple = false;
209    let mut has_double_or_aromatic = false;
210    for (_, bond_idx) in mol.neighbors(idx) {
211        match mol.bond(bond_idx).order {
212            BondOrder::Triple => has_triple = true,
213            BondOrder::Double | BondOrder::Aromatic => has_double_or_aromatic = true,
214            _ => {}
215        }
216    }
217    if has_triple {
218        Hybridization::SP
219    } else if has_double_or_aromatic {
220        Hybridization::SP2
221    } else {
222        Hybridization::SP3
223    }
224}
225
226/// Ideal bond angle (radians) for a center atom given its hybridization.
227fn ideal_angle_rad(sym: &str, hyb: Hybridization) -> f64 {
228    match hyb {
229        Hybridization::SP => 180.0_f64.to_radians(),
230        Hybridization::SP2 => 120.0_f64.to_radians(),
231        Hybridization::SP3 => match sym {
232            "O" | "Se" => 104.5_f64.to_radians(),
233            "N" => 107.0_f64.to_radians(),
234            "S" => 99.0_f64.to_radians(),
235            "P" => 93.0_f64.to_radians(),
236            _ => 109.47_f64.to_radians(),
237        },
238    }
239}
240
241/// VDW radius (Å) derived from UFF/Bondi values.
242fn uff_vdw_radius(sym: &str) -> f64 {
243    match sym {
244        "H" => 1.20,
245        "C" => 1.70,
246        "N" => 1.55,
247        "O" => 1.52,
248        "F" => 1.47,
249        "Si" => 2.10,
250        "P" => 1.80,
251        "S" => 1.80,
252        "Cl" => 1.75,
253        "Br" => 1.85,
254        "I" => 1.98,
255        "Se" => 1.90,
256        "Te" => 2.06,
257        _ => 1.70,
258    }
259}
260
261// ---------------------------------------------------------------------------
262// Bond stretching energy
263// ---------------------------------------------------------------------------
264
265fn bond_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
266    let mut energy = 0.0;
267    for (_, bond) in mol.bonds() {
268        let a1 = bond.atom1;
269        let a2 = bond.atom2;
270        let r = coords.get(a1).distance(&coords.get(a2));
271        let sym1 = mol.atom(a1).element.symbol();
272        let sym2 = mol.atom(a2).element.symbol();
273        let r0 = ideal_bond_len(sym1, sym2, bond.order);
274        let dr = r - r0;
275        energy += 0.5 * 700.0 * dr * dr;
276    }
277    energy
278}
279
280// ---------------------------------------------------------------------------
281// Angle bending energy
282// ---------------------------------------------------------------------------
283
284fn angle_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
285    let mut energy = 0.0;
286
287    for b_idx in 0..mol.atom_count() {
288        let b = AtomIdx(b_idx as u32);
289        let neighbors: Vec<AtomIdx> = mol.neighbors(b).map(|(nb, _)| nb).collect();
290
291        if neighbors.len() < 2 {
292            continue;
293        }
294
295        let sym_b = mol.atom(b).element.symbol();
296        let hyb = atom_hybridization(mol, b);
297        let theta0 = ideal_angle_rad(sym_b, hyb);
298        let pb = coords.get(b);
299
300        for i in 0..neighbors.len() {
301            for j in (i + 1)..neighbors.len() {
302                let a = neighbors[i];
303                let c = neighbors[j];
304
305                let pa = coords.get(a);
306                let pc = coords.get(c);
307
308                let va = pa.sub(&pb);
309                let vc = pc.sub(&pb);
310
311                let na = va.norm();
312                let nc = vc.norm();
313
314                if na < 1e-10 || nc < 1e-10 {
315                    continue;
316                }
317
318                let cos_theta = (va.dot(&vc) / (na * nc)).clamp(-1.0, 1.0);
319                let theta = cos_theta.acos();
320                let dtheta = theta - theta0;
321                energy += 0.5 * 100.0 * dtheta * dtheta;
322            }
323        }
324    }
325
326    energy
327}
328
329// ---------------------------------------------------------------------------
330// VDW repulsion energy
331// ---------------------------------------------------------------------------
332
333fn vdw_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
334    let n = mol.atom_count();
335    let cutoff = 8.0_f64;
336
337    let mut excluded: HashSet<(usize, usize)> = HashSet::new();
338
339    for (_, bond) in mol.bonds() {
340        let i = bond.atom1.0 as usize;
341        let j = bond.atom2.0 as usize;
342        excluded.insert((i.min(j), i.max(j)));
343    }
344
345    for b_idx in 0..n {
346        let b = AtomIdx(b_idx as u32);
347        let neighbors: Vec<usize> = mol.neighbors(b).map(|(nb, _)| nb.0 as usize).collect();
348        for ii in 0..neighbors.len() {
349            for jj in (ii + 1)..neighbors.len() {
350                let i = neighbors[ii];
351                let j = neighbors[jj];
352                excluded.insert((i.min(j), i.max(j)));
353            }
354        }
355    }
356
357    let mut energy = 0.0;
358    for i in 0..n {
359        for j in (i + 1)..n {
360            if excluded.contains(&(i, j)) {
361                continue;
362            }
363            let r = coords
364                .get(AtomIdx(i as u32))
365                .distance(&coords.get(AtomIdx(j as u32)));
366
367            if r < 0.01 || r >= cutoff {
368                continue;
369            }
370
371            let sym_i = mol.atom(AtomIdx(i as u32)).element.symbol();
372            let sym_j = mol.atom(AtomIdx(j as u32)).element.symbol();
373            let r0 = uff_vdw_radius(sym_i) + uff_vdw_radius(sym_j);
374
375            let ratio = r0 / r;
376            let ratio6 = ratio * ratio * ratio * ratio * ratio * ratio;
377            let ratio12 = ratio6 * ratio6;
378            energy += 0.05 * ratio12;
379        }
380    }
381
382    energy
383}
384
385// ---------------------------------------------------------------------------
386// Tests
387// ---------------------------------------------------------------------------
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::dg::generate_coords;
393    use chematic_smiles::parse;
394
395    fn all_pairs_min_dist(coords: &Coords3D, n: usize) -> f64 {
396        let mut min_d = f64::MAX;
397        for i in 0..n {
398            for j in (i + 1)..n {
399                let d = coords
400                    .get(AtomIdx(i as u32))
401                    .distance(&coords.get(AtomIdx(j as u32)));
402                min_d = min_d.min(d);
403            }
404        }
405        min_d
406    }
407
408    #[test]
409    fn test_single_atom_unchanged() {
410        let mol = parse("O").unwrap();
411        let coords = generate_coords(&mol);
412        let orig = coords.get(AtomIdx(0));
413        let result = minimize(&mol, coords);
414        let after = result.get(AtomIdx(0));
415        assert!((orig.x - after.x).abs() < 1e-10);
416    }
417
418    #[test]
419    fn test_zero_steps_unchanged() {
420        let mol = parse("CC").unwrap();
421        let coords = generate_coords(&mol);
422        let config = MinimizeConfig {
423            max_steps: 0,
424            ..MinimizeConfig::default()
425        };
426        let before0 = coords.get(AtomIdx(0));
427        let result = minimize_with_config(&mol, coords, &config);
428        let after0 = result.get(AtomIdx(0));
429        assert!((before0.x - after0.x).abs() < 1e-10);
430    }
431
432    #[test]
433    fn test_ethane_bond_after_minimize() {
434        let mol = parse("CC").unwrap();
435        let coords = generate_coords(&mol);
436        let result = minimize(&mol, coords);
437        let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
438        assert!(d > 1.2 && d < 1.8, "C-C distance={d:.3}, expected 1.2-1.8 Å");
439    }
440
441    #[test]
442    fn test_ethane_converges_to_uff_length() {
443        let mol = parse("CC").unwrap();
444        let coords = generate_coords(&mol);
445        let result = minimize(&mol, coords);
446        let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
447        // UFF C-C single bond is 1.540 Å; minimizer should get within 0.05 Å.
448        assert!(
449            (d - 1.540).abs() < 0.05,
450            "C-C distance={d:.4}, expected ~1.540"
451        );
452    }
453
454    #[test]
455    fn test_propane_no_clash() {
456        let mol = parse("CCC").unwrap();
457        let coords = generate_coords(&mol);
458        let result = minimize(&mol, coords);
459        let min_d = all_pairs_min_dist(&result, mol.atom_count());
460        assert!(min_d > 0.8, "atom clash: min distance={min_d:.3}");
461    }
462
463    #[test]
464    fn test_benzene_no_clash() {
465        let mol = parse("c1ccccc1").unwrap();
466        let coords = generate_coords(&mol);
467        let result = minimize(&mol, coords);
468        let min_d = all_pairs_min_dist(&result, mol.atom_count());
469        assert!(min_d > 0.8, "atom clash in benzene: min distance={min_d:.3}");
470    }
471
472    #[test]
473    fn test_disconnected_no_clash() {
474        let mol = parse("CC.CC").unwrap();
475        let coords = generate_coords(&mol);
476        let result = minimize(&mol, coords);
477        let min_d = all_pairs_min_dist(&result, mol.atom_count());
478        assert!(min_d > 0.8, "atom clash in disconnected: min distance={min_d:.3}");
479    }
480
481    #[test]
482    fn test_default_config_no_panic() {
483        let mol = parse("CC(=O)O").unwrap();
484        let coords = generate_coords(&mol);
485        let result = minimize(&mol, coords);
486        assert_eq!(result.atom_count(), mol.atom_count());
487    }
488
489    #[test]
490    fn test_acetic_acid_no_clash() {
491        let mol = parse("CC(=O)O").unwrap();
492        let coords = generate_coords(&mol);
493        let result = minimize(&mol, coords);
494        let min_d = all_pairs_min_dist(&result, mol.atom_count());
495        assert!(min_d > 0.8, "clash in acetic acid: {min_d:.3}");
496    }
497
498    #[test]
499    fn test_minimize_idempotent() {
500        let mol = parse("CCC").unwrap();
501        let coords = generate_coords(&mol);
502        let result1 = minimize(&mol, coords);
503        let e1 = total_energy(&mol, &result1);
504        let result2 = minimize(&mol, result1);
505        let e2 = total_energy(&mol, &result2);
506        assert!(e2 <= e1 + 1.0, "energy increased: e1={e1:.4}, e2={e2:.4}");
507    }
508
509    #[test]
510    fn test_naphthalene_no_overlap() {
511        let mol = parse("c1ccc2ccccc2c1").unwrap();
512        let coords = generate_coords(&mol);
513        let result = minimize(&mol, coords);
514        let min_d = all_pairs_min_dist(&result, mol.atom_count());
515        assert!(min_d > 0.8, "overlap in naphthalene: {min_d:.3}");
516    }
517
518    #[test]
519    fn test_co_bond_double_shorter_than_single() {
520        // Acetic acid: C=O should be shorter than C-O
521        let mol = parse("CC(=O)O").unwrap();
522        let coords = generate_coords(&mol);
523        let result = minimize(&mol, coords);
524        // Atom 1 is the carbonyl C, its bonds include C=O (double) and C-O (single).
525        // Just check overall: minimized coords have no clash and atom count preserved.
526        assert_eq!(result.atom_count(), 4);
527        let min_d = all_pairs_min_dist(&result, 4);
528        assert!(min_d > 0.5, "clash in CO test: {min_d:.3}");
529    }
530
531    #[test]
532    fn test_heteroatom_c_n_bond() {
533        let mol = parse("CN").unwrap(); // methylamine
534        let coords = generate_coords(&mol);
535        let result = minimize(&mol, coords);
536        let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
537        // C-N single bond UFF: 1.469 Å; expect within 0.1 Å.
538        assert!(
539            (d - 1.469).abs() < 0.1,
540            "C-N distance={d:.4}, expected ~1.469"
541        );
542    }
543
544    #[test]
545    fn test_acetylene_sp_hybridization() {
546        let mol = parse("C#C").unwrap(); // acetylene: C≡C
547        let coords = generate_coords(&mol);
548        let result = minimize(&mol, coords);
549        let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
550        // C≡C triple bond UFF: 1.204 Å; expect within 0.05 Å.
551        assert!(
552            (d - 1.204).abs() < 0.05,
553            "C≡C distance={d:.4}, expected ~1.204"
554        );
555    }
556
557    #[test]
558    fn test_ideal_bond_len_cc_single() {
559        assert!((ideal_bond_len("C", "C", BondOrder::Single) - 1.540).abs() < 1e-6);
560        assert!((ideal_bond_len("C", "C", BondOrder::Double) - 1.340).abs() < 1e-6);
561        assert!((ideal_bond_len("C", "C", BondOrder::Triple) - 1.204).abs() < 1e-6);
562        assert!((ideal_bond_len("C", "C", BondOrder::Aromatic) - 1.395).abs() < 1e-6);
563    }
564
565    #[test]
566    fn test_ideal_bond_len_symmetry() {
567        // Should be the same regardless of argument order.
568        let bo = BondOrder::Single;
569        assert_eq!(ideal_bond_len("C", "N", bo), ideal_bond_len("N", "C", bo));
570        assert_eq!(ideal_bond_len("C", "O", bo), ideal_bond_len("O", "C", bo));
571        assert_eq!(ideal_bond_len("Br", "C", bo), ideal_bond_len("C", "Br", bo));
572    }
573
574    #[test]
575    fn test_atom_hybridization_sp2_aromatic() {
576        let mol = parse("c1ccccc1").unwrap();
577        for i in 0..6 {
578            assert_eq!(
579                atom_hybridization(&mol, AtomIdx(i)),
580                Hybridization::SP2,
581                "benzene atom {i} should be SP2"
582            );
583        }
584    }
585
586    #[test]
587    fn test_atom_hybridization_sp_triple() {
588        let mol = parse("C#C").unwrap();
589        assert_eq!(atom_hybridization(&mol, AtomIdx(0)), Hybridization::SP);
590        assert_eq!(atom_hybridization(&mol, AtomIdx(1)), Hybridization::SP);
591    }
592
593    #[test]
594    fn test_atom_hybridization_sp3_alkane() {
595        let mol = parse("CCC").unwrap();
596        for i in 0..3 {
597            assert_eq!(
598                atom_hybridization(&mol, AtomIdx(i)),
599                Hybridization::SP3,
600                "propane atom {i} should be SP3"
601            );
602        }
603    }
604}