Skip to main content

chematic_3d/
minimize.rs

1//! Simplified force-field geometry minimization for molecular structures.
2//!
3//! Uses gradient descent with finite differences over three energy terms:
4//! bond stretching, angle bending, and VDW repulsion.
5
6use std::collections::HashSet;
7use std::f64::consts::PI;
8
9use chematic_core::{AtomIdx, BondOrder, Molecule};
10
11use crate::coords::{Coords3D, Point3};
12
13// ---------------------------------------------------------------------------
14// Public API
15// ---------------------------------------------------------------------------
16
17/// Configuration for the minimization algorithm.
18pub struct MinimizeConfig {
19    /// Maximum number of gradient-descent steps.
20    pub max_steps: usize,
21    /// Base step size for coordinate updates (scaled by max gradient).
22    pub step_size: f64,
23    /// Convergence threshold: stop when max gradient component < this value.
24    pub convergence: f64,
25}
26
27impl Default for MinimizeConfig {
28    fn default() -> Self {
29        Self {
30            max_steps: 200,
31            step_size: 0.05,
32            convergence: 1e-4,
33        }
34    }
35}
36
37/// Minimize molecular geometry using default configuration.
38pub fn minimize(mol: &Molecule, coords: Coords3D) -> Coords3D {
39    minimize_with_config(mol, coords, &MinimizeConfig::default())
40}
41
42/// Minimize molecular geometry using the provided configuration.
43pub fn minimize_with_config(mol: &Molecule, coords: Coords3D, config: &MinimizeConfig) -> Coords3D {
44    if mol.atom_count() <= 1 {
45        return coords;
46    }
47
48    let mut c = coords;
49    let delta = 1e-4;
50
51    for _ in 0..config.max_steps {
52        let mut grad = vec![Point3::zero(); mol.atom_count()];
53        let mut max_grad = 0.0f64;
54
55        for i in 0..mol.atom_count() {
56            let idx = AtomIdx(i as u32);
57            let orig = c.get(idx);
58
59            // x gradient
60            {
61                c.set(idx, Point3::new(orig.x + delta, orig.y, orig.z));
62                let ep = total_energy(mol, &c);
63                c.set(idx, Point3::new(orig.x - delta, orig.y, orig.z));
64                let em = total_energy(mol, &c);
65                c.set(idx, orig);
66                grad[i].x = (ep - em) / (2.0 * delta);
67            }
68
69            // y gradient
70            {
71                c.set(idx, Point3::new(orig.x, orig.y + delta, orig.z));
72                let ep = total_energy(mol, &c);
73                c.set(idx, Point3::new(orig.x, orig.y - delta, orig.z));
74                let em = total_energy(mol, &c);
75                c.set(idx, orig);
76                grad[i].y = (ep - em) / (2.0 * delta);
77            }
78
79            // z gradient
80            {
81                c.set(idx, Point3::new(orig.x, orig.y, orig.z + delta));
82                let ep = total_energy(mol, &c);
83                c.set(idx, Point3::new(orig.x, orig.y, orig.z - delta));
84                let em = total_energy(mol, &c);
85                c.set(idx, orig);
86                grad[i].z = (ep - em) / (2.0 * delta);
87            }
88
89            let gmax = grad[i].x.abs().max(grad[i].y.abs()).max(grad[i].z.abs());
90            if gmax > max_grad {
91                max_grad = gmax;
92            }
93        }
94
95        if max_grad < config.convergence {
96            break;
97        }
98
99        // Update coordinates: scale step so the largest gradient component moves step_size
100        let scale = config.step_size / max_grad.max(1e-8);
101        for i in 0..mol.atom_count() {
102            let idx = AtomIdx(i as u32);
103            let p = c.get(idx);
104            c.set(
105                idx,
106                Point3::new(
107                    p.x - scale * grad[i].x,
108                    p.y - scale * grad[i].y,
109                    p.z - scale * grad[i].z,
110                ),
111            );
112        }
113    }
114
115    c
116}
117
118// ---------------------------------------------------------------------------
119// Total energy
120// ---------------------------------------------------------------------------
121
122fn total_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
123    bond_energy(mol, coords) + angle_energy(mol, coords) + vdw_energy(mol, coords)
124}
125
126// ---------------------------------------------------------------------------
127// Bond stretching energy
128// ---------------------------------------------------------------------------
129
130/// Ideal bond length (Å) based solely on bond order — used for the minimizer.
131fn ideal_bond_len_by_order(order: BondOrder) -> f64 {
132    match order {
133        BondOrder::Single | BondOrder::Up | BondOrder::Down => 1.54,
134        BondOrder::Double => 1.34,
135        BondOrder::Triple => 1.20,
136        BondOrder::Quadruple => 1.20,
137        BondOrder::Aromatic => 1.40,
138    }
139}
140
141fn bond_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
142    let mut energy = 0.0;
143    for (_, bond) in mol.bonds() {
144        let a1 = bond.atom1;
145        let a2 = bond.atom2;
146        let r = coords.get(a1).distance(&coords.get(a2));
147        let r0 = ideal_bond_len_by_order(bond.order);
148        let dr = r - r0;
149        energy += 0.5 * 700.0 * dr * dr;
150    }
151    energy
152}
153
154// ---------------------------------------------------------------------------
155// Angle bending energy
156// ---------------------------------------------------------------------------
157
158fn angle_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
159    let mut energy = 0.0;
160
161    for b_idx in 0..mol.atom_count() {
162        let b = AtomIdx(b_idx as u32);
163        let neighbors: Vec<AtomIdx> = mol.neighbors(b).map(|(nb, _)| nb).collect();
164        let deg = neighbors.len();
165
166        if deg < 2 {
167            continue;
168        }
169
170        // Ideal angle based on degree
171        let theta0 = match deg {
172            4 => 109.47_f64.to_radians(),
173            3 => 120.0_f64.to_radians(),
174            2 => PI, // 180°
175            _ => 109.47_f64.to_radians(), // > 4 neighbors: use tetrahedral
176        };
177
178        let pb = coords.get(b);
179
180        // Iterate unique pairs of neighbors (i < j) to avoid double counting
181        for i in 0..neighbors.len() {
182            for j in (i + 1)..neighbors.len() {
183                let a = neighbors[i];
184                let c = neighbors[j];
185
186                let pa = coords.get(a);
187                let pc = coords.get(c);
188
189                let va = pa.sub(&pb);
190                let vc = pc.sub(&pb);
191
192                let na = va.norm();
193                let nc = vc.norm();
194
195                if na < 1e-10 || nc < 1e-10 {
196                    continue;
197                }
198
199                let cos_theta = (va.dot(&vc) / (na * nc)).clamp(-1.0, 1.0);
200                let theta = cos_theta.acos();
201                let dtheta = theta - theta0;
202                energy += 0.5 * 100.0 * dtheta * dtheta;
203            }
204        }
205    }
206
207    energy
208}
209
210// ---------------------------------------------------------------------------
211// VDW repulsion energy
212// ---------------------------------------------------------------------------
213
214fn vdw_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
215    let n = mol.atom_count();
216    let cutoff = 5.0_f64;
217
218    // Build exclusion set: directly bonded pairs + 1-3 pairs (share a common neighbor)
219    let mut excluded: HashSet<(usize, usize)> = HashSet::new();
220
221    // 1-2 exclusions (directly bonded)
222    for (_, bond) in mol.bonds() {
223        let i = bond.atom1.0 as usize;
224        let j = bond.atom2.0 as usize;
225        let (lo, hi) = if i < j { (i, j) } else { (j, i) };
226        excluded.insert((lo, hi));
227    }
228
229    // 1-3 exclusions (atoms sharing a common bonded neighbor)
230    for b_idx in 0..n {
231        let b = AtomIdx(b_idx as u32);
232        let neighbors: Vec<usize> = mol.neighbors(b).map(|(nb, _)| nb.0 as usize).collect();
233        for ii in 0..neighbors.len() {
234            for jj in (ii + 1)..neighbors.len() {
235                let i = neighbors[ii];
236                let j = neighbors[jj];
237                let (lo, hi) = if i < j { (i, j) } else { (j, i) };
238                excluded.insert((lo, hi));
239            }
240        }
241    }
242
243    let mut energy = 0.0;
244    for i in 0..n {
245        for j in (i + 1)..n {
246            if excluded.contains(&(i, j)) {
247                continue;
248            }
249            let r = coords
250                .get(AtomIdx(i as u32))
251                .distance(&coords.get(AtomIdx(j as u32)));
252
253            if r < 0.01 {
254                continue;
255            }
256            if r >= cutoff {
257                continue;
258            }
259
260            let ratio = 2.0 / r;
261            let ratio6 = ratio * ratio * ratio * ratio * ratio * ratio;
262            let ratio12 = ratio6 * ratio6;
263            energy += 0.05 * ratio12;
264        }
265    }
266
267    energy
268}
269
270// ---------------------------------------------------------------------------
271// Tests
272// ---------------------------------------------------------------------------
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::dg::generate_coords;
278    use chematic_smiles::parse;
279
280    fn all_pairs_min_dist(coords: &Coords3D, n: usize) -> f64 {
281        let mut min_d = f64::MAX;
282        for i in 0..n {
283            for j in (i + 1)..n {
284                let d = coords
285                    .get(AtomIdx(i as u32))
286                    .distance(&coords.get(AtomIdx(j as u32)));
287                min_d = min_d.min(d);
288            }
289        }
290        min_d
291    }
292
293    #[test]
294    fn test_single_atom_unchanged() {
295        let mol = parse("O").unwrap();
296        let coords = generate_coords(&mol);
297        let orig = coords.get(AtomIdx(0));
298        let result = minimize(&mol, coords);
299        let after = result.get(AtomIdx(0));
300        assert!((orig.x - after.x).abs() < 1e-10);
301    }
302
303    #[test]
304    fn test_zero_steps_unchanged() {
305        let mol = parse("CC").unwrap();
306        let coords = generate_coords(&mol);
307        let config = MinimizeConfig {
308            max_steps: 0,
309            ..MinimizeConfig::default()
310        };
311        let before0 = coords.get(AtomIdx(0));
312        let result = minimize_with_config(&mol, coords, &config);
313        let after0 = result.get(AtomIdx(0));
314        assert!((before0.x - after0.x).abs() < 1e-10);
315    }
316
317    #[test]
318    fn test_ethane_bond_after_minimize() {
319        let mol = parse("CC").unwrap();
320        let coords = generate_coords(&mol);
321        let result = minimize(&mol, coords);
322        let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
323        assert!(d > 1.2 && d < 1.8, "C-C distance={d:.3}, expected 1.2-1.8 Å");
324    }
325
326    #[test]
327    fn test_propane_no_clash() {
328        let mol = parse("CCC").unwrap();
329        let coords = generate_coords(&mol);
330        let result = minimize(&mol, coords);
331        let min_d = all_pairs_min_dist(&result, mol.atom_count());
332        assert!(min_d > 0.8, "atom clash: min distance={min_d:.3}");
333    }
334
335    #[test]
336    fn test_benzene_no_clash() {
337        let mol = parse("c1ccccc1").unwrap();
338        let coords = generate_coords(&mol);
339        let result = minimize(&mol, coords);
340        let min_d = all_pairs_min_dist(&result, mol.atom_count());
341        assert!(min_d > 0.8, "atom clash in benzene: min distance={min_d:.3}");
342    }
343
344    #[test]
345    fn test_disconnected_no_clash() {
346        let mol = parse("CC.CC").unwrap();
347        let coords = generate_coords(&mol);
348        let result = minimize(&mol, coords);
349        let min_d = all_pairs_min_dist(&result, mol.atom_count());
350        assert!(min_d > 0.8, "atom clash in disconnected: min distance={min_d:.3}");
351    }
352
353    #[test]
354    fn test_default_config_no_panic() {
355        let mol = parse("CC(=O)O").unwrap(); // acetic acid
356        let coords = generate_coords(&mol);
357        let result = minimize(&mol, coords);
358        assert_eq!(result.atom_count(), mol.atom_count());
359    }
360
361    #[test]
362    fn test_acetic_acid_no_clash() {
363        let mol = parse("CC(=O)O").unwrap();
364        let coords = generate_coords(&mol);
365        let result = minimize(&mol, coords);
366        let min_d = all_pairs_min_dist(&result, mol.atom_count());
367        assert!(min_d > 0.8, "clash in acetic acid: {min_d:.3}");
368    }
369
370    #[test]
371    fn test_minimize_idempotent() {
372        let mol = parse("CCC").unwrap();
373        let coords = generate_coords(&mol);
374        let result1 = minimize(&mol, coords);
375        let e1 = total_energy(&mol, &result1);
376        let result2 = minimize(&mol, result1);
377        let e2 = total_energy(&mol, &result2);
378        // Second minimization shouldn't increase energy significantly
379        assert!(e2 <= e1 + 1.0, "energy increased: e1={e1:.4}, e2={e2:.4}");
380    }
381
382    #[test]
383    fn test_naphthalene_no_overlap() {
384        let mol = parse("c1ccc2ccccc2c1").unwrap();
385        let coords = generate_coords(&mol);
386        let result = minimize(&mol, coords);
387        let min_d = all_pairs_min_dist(&result, mol.atom_count());
388        assert!(min_d > 0.8, "overlap in naphthalene: {min_d:.3}");
389    }
390}