1use std::collections::HashSet;
7use std::f64::consts::PI;
8
9use chematic_core::{AtomIdx, BondOrder, Molecule};
10
11use crate::coords::{Coords3D, Point3};
12
13pub struct MinimizeConfig {
19 pub max_steps: usize,
21 pub step_size: f64,
23 pub convergence: f64,
25}
26
27impl Default for MinimizeConfig {
28 fn default() -> Self {
29 Self {
30 max_steps: 200,
31 step_size: 0.05,
32 convergence: 1e-4,
33 }
34 }
35}
36
37pub fn minimize(mol: &Molecule, coords: Coords3D) -> Coords3D {
39 minimize_with_config(mol, coords, &MinimizeConfig::default())
40}
41
42pub fn minimize_with_config(mol: &Molecule, coords: Coords3D, config: &MinimizeConfig) -> Coords3D {
44 if mol.atom_count() <= 1 {
45 return coords;
46 }
47
48 let mut c = coords;
49 let delta = 1e-4;
50
51 for _ in 0..config.max_steps {
52 let mut grad = vec![Point3::zero(); mol.atom_count()];
53 let mut max_grad = 0.0f64;
54
55 for i in 0..mol.atom_count() {
56 let idx = AtomIdx(i as u32);
57 let orig = c.get(idx);
58
59 {
61 c.set(idx, Point3::new(orig.x + delta, orig.y, orig.z));
62 let ep = total_energy(mol, &c);
63 c.set(idx, Point3::new(orig.x - delta, orig.y, orig.z));
64 let em = total_energy(mol, &c);
65 c.set(idx, orig);
66 grad[i].x = (ep - em) / (2.0 * delta);
67 }
68
69 {
71 c.set(idx, Point3::new(orig.x, orig.y + delta, orig.z));
72 let ep = total_energy(mol, &c);
73 c.set(idx, Point3::new(orig.x, orig.y - delta, orig.z));
74 let em = total_energy(mol, &c);
75 c.set(idx, orig);
76 grad[i].y = (ep - em) / (2.0 * delta);
77 }
78
79 {
81 c.set(idx, Point3::new(orig.x, orig.y, orig.z + delta));
82 let ep = total_energy(mol, &c);
83 c.set(idx, Point3::new(orig.x, orig.y, orig.z - delta));
84 let em = total_energy(mol, &c);
85 c.set(idx, orig);
86 grad[i].z = (ep - em) / (2.0 * delta);
87 }
88
89 let gmax = grad[i].x.abs().max(grad[i].y.abs()).max(grad[i].z.abs());
90 if gmax > max_grad {
91 max_grad = gmax;
92 }
93 }
94
95 if max_grad < config.convergence {
96 break;
97 }
98
99 let scale = config.step_size / max_grad.max(1e-8);
101 for i in 0..mol.atom_count() {
102 let idx = AtomIdx(i as u32);
103 let p = c.get(idx);
104 c.set(
105 idx,
106 Point3::new(
107 p.x - scale * grad[i].x,
108 p.y - scale * grad[i].y,
109 p.z - scale * grad[i].z,
110 ),
111 );
112 }
113 }
114
115 c
116}
117
118fn total_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
123 bond_energy(mol, coords) + angle_energy(mol, coords) + vdw_energy(mol, coords)
124}
125
126fn ideal_bond_len_by_order(order: BondOrder) -> f64 {
132 match order {
133 BondOrder::Single | BondOrder::Up | BondOrder::Down => 1.54,
134 BondOrder::Double => 1.34,
135 BondOrder::Triple => 1.20,
136 BondOrder::Quadruple => 1.20,
137 BondOrder::Aromatic => 1.40,
138 }
139}
140
141fn bond_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
142 let mut energy = 0.0;
143 for (_, bond) in mol.bonds() {
144 let a1 = bond.atom1;
145 let a2 = bond.atom2;
146 let r = coords.get(a1).distance(&coords.get(a2));
147 let r0 = ideal_bond_len_by_order(bond.order);
148 let dr = r - r0;
149 energy += 0.5 * 700.0 * dr * dr;
150 }
151 energy
152}
153
154fn angle_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
159 let mut energy = 0.0;
160
161 for b_idx in 0..mol.atom_count() {
162 let b = AtomIdx(b_idx as u32);
163 let neighbors: Vec<AtomIdx> = mol.neighbors(b).map(|(nb, _)| nb).collect();
164 let deg = neighbors.len();
165
166 if deg < 2 {
167 continue;
168 }
169
170 let theta0 = match deg {
172 4 => 109.47_f64.to_radians(),
173 3 => 120.0_f64.to_radians(),
174 2 => PI, _ => 109.47_f64.to_radians(), };
177
178 let pb = coords.get(b);
179
180 for i in 0..neighbors.len() {
182 for j in (i + 1)..neighbors.len() {
183 let a = neighbors[i];
184 let c = neighbors[j];
185
186 let pa = coords.get(a);
187 let pc = coords.get(c);
188
189 let va = pa.sub(&pb);
190 let vc = pc.sub(&pb);
191
192 let na = va.norm();
193 let nc = vc.norm();
194
195 if na < 1e-10 || nc < 1e-10 {
196 continue;
197 }
198
199 let cos_theta = (va.dot(&vc) / (na * nc)).clamp(-1.0, 1.0);
200 let theta = cos_theta.acos();
201 let dtheta = theta - theta0;
202 energy += 0.5 * 100.0 * dtheta * dtheta;
203 }
204 }
205 }
206
207 energy
208}
209
210fn vdw_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
215 let n = mol.atom_count();
216 let cutoff = 5.0_f64;
217
218 let mut excluded: HashSet<(usize, usize)> = HashSet::new();
220
221 for (_, bond) in mol.bonds() {
223 let i = bond.atom1.0 as usize;
224 let j = bond.atom2.0 as usize;
225 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
226 excluded.insert((lo, hi));
227 }
228
229 for b_idx in 0..n {
231 let b = AtomIdx(b_idx as u32);
232 let neighbors: Vec<usize> = mol.neighbors(b).map(|(nb, _)| nb.0 as usize).collect();
233 for ii in 0..neighbors.len() {
234 for jj in (ii + 1)..neighbors.len() {
235 let i = neighbors[ii];
236 let j = neighbors[jj];
237 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
238 excluded.insert((lo, hi));
239 }
240 }
241 }
242
243 let mut energy = 0.0;
244 for i in 0..n {
245 for j in (i + 1)..n {
246 if excluded.contains(&(i, j)) {
247 continue;
248 }
249 let r = coords
250 .get(AtomIdx(i as u32))
251 .distance(&coords.get(AtomIdx(j as u32)));
252
253 if r < 0.01 {
254 continue;
255 }
256 if r >= cutoff {
257 continue;
258 }
259
260 let ratio = 2.0 / r;
261 let ratio6 = ratio * ratio * ratio * ratio * ratio * ratio;
262 let ratio12 = ratio6 * ratio6;
263 energy += 0.05 * ratio12;
264 }
265 }
266
267 energy
268}
269
270#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::dg::generate_coords;
278 use chematic_smiles::parse;
279
280 fn all_pairs_min_dist(coords: &Coords3D, n: usize) -> f64 {
281 let mut min_d = f64::MAX;
282 for i in 0..n {
283 for j in (i + 1)..n {
284 let d = coords
285 .get(AtomIdx(i as u32))
286 .distance(&coords.get(AtomIdx(j as u32)));
287 min_d = min_d.min(d);
288 }
289 }
290 min_d
291 }
292
293 #[test]
294 fn test_single_atom_unchanged() {
295 let mol = parse("O").unwrap();
296 let coords = generate_coords(&mol);
297 let orig = coords.get(AtomIdx(0));
298 let result = minimize(&mol, coords);
299 let after = result.get(AtomIdx(0));
300 assert!((orig.x - after.x).abs() < 1e-10);
301 }
302
303 #[test]
304 fn test_zero_steps_unchanged() {
305 let mol = parse("CC").unwrap();
306 let coords = generate_coords(&mol);
307 let config = MinimizeConfig {
308 max_steps: 0,
309 ..MinimizeConfig::default()
310 };
311 let before0 = coords.get(AtomIdx(0));
312 let result = minimize_with_config(&mol, coords, &config);
313 let after0 = result.get(AtomIdx(0));
314 assert!((before0.x - after0.x).abs() < 1e-10);
315 }
316
317 #[test]
318 fn test_ethane_bond_after_minimize() {
319 let mol = parse("CC").unwrap();
320 let coords = generate_coords(&mol);
321 let result = minimize(&mol, coords);
322 let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
323 assert!(d > 1.2 && d < 1.8, "C-C distance={d:.3}, expected 1.2-1.8 Å");
324 }
325
326 #[test]
327 fn test_propane_no_clash() {
328 let mol = parse("CCC").unwrap();
329 let coords = generate_coords(&mol);
330 let result = minimize(&mol, coords);
331 let min_d = all_pairs_min_dist(&result, mol.atom_count());
332 assert!(min_d > 0.8, "atom clash: min distance={min_d:.3}");
333 }
334
335 #[test]
336 fn test_benzene_no_clash() {
337 let mol = parse("c1ccccc1").unwrap();
338 let coords = generate_coords(&mol);
339 let result = minimize(&mol, coords);
340 let min_d = all_pairs_min_dist(&result, mol.atom_count());
341 assert!(min_d > 0.8, "atom clash in benzene: min distance={min_d:.3}");
342 }
343
344 #[test]
345 fn test_disconnected_no_clash() {
346 let mol = parse("CC.CC").unwrap();
347 let coords = generate_coords(&mol);
348 let result = minimize(&mol, coords);
349 let min_d = all_pairs_min_dist(&result, mol.atom_count());
350 assert!(min_d > 0.8, "atom clash in disconnected: min distance={min_d:.3}");
351 }
352
353 #[test]
354 fn test_default_config_no_panic() {
355 let mol = parse("CC(=O)O").unwrap(); let coords = generate_coords(&mol);
357 let result = minimize(&mol, coords);
358 assert_eq!(result.atom_count(), mol.atom_count());
359 }
360
361 #[test]
362 fn test_acetic_acid_no_clash() {
363 let mol = parse("CC(=O)O").unwrap();
364 let coords = generate_coords(&mol);
365 let result = minimize(&mol, coords);
366 let min_d = all_pairs_min_dist(&result, mol.atom_count());
367 assert!(min_d > 0.8, "clash in acetic acid: {min_d:.3}");
368 }
369
370 #[test]
371 fn test_minimize_idempotent() {
372 let mol = parse("CCC").unwrap();
373 let coords = generate_coords(&mol);
374 let result1 = minimize(&mol, coords);
375 let e1 = total_energy(&mol, &result1);
376 let result2 = minimize(&mol, result1);
377 let e2 = total_energy(&mol, &result2);
378 assert!(e2 <= e1 + 1.0, "energy increased: e1={e1:.4}, e2={e2:.4}");
380 }
381
382 #[test]
383 fn test_naphthalene_no_overlap() {
384 let mol = parse("c1ccc2ccccc2c1").unwrap();
385 let coords = generate_coords(&mol);
386 let result = minimize(&mol, coords);
387 let min_d = all_pairs_min_dist(&result, mol.atom_count());
388 assert!(min_d > 0.8, "overlap in naphthalene: {min_d:.3}");
389 }
390}