1use std::collections::HashSet;
8
9use chematic_core::{AtomIdx, BondOrder, Molecule};
10use chematic_ff::{
11 assign_dreiding_types, assign_mmff94_types, dreiding_angle, dreiding_bond_len, dreiding_vdw,
12 mmff94_angle_params, mmff94_bond_params, mmff94_charges_3d, mmff94_vdw_params,
13};
14
15use crate::coords::{Coords3D, Point3};
16
17#[derive(Debug, Clone, Copy)]
23#[derive(Default)]
24pub enum ForceField {
25 UFF,
27 #[default]
29 DREIDING,
30 MMFF94,
32}
33
34
35pub struct MinimizeConfig {
37 pub max_steps: usize,
39 pub step_size: f64,
41 pub convergence: f64,
43 pub force_field: ForceField,
45}
46
47impl Default for MinimizeConfig {
48 fn default() -> Self {
49 Self {
50 max_steps: 200,
51 step_size: 0.05,
52 convergence: 1e-4,
53 force_field: ForceField::DREIDING,
54 }
55 }
56}
57
58pub fn minimize(mol: &Molecule, coords: Coords3D) -> Coords3D {
60 minimize_with_config(mol, coords, &MinimizeConfig::default())
61}
62
63pub fn minimize_uff(mol: &Molecule, coords: Coords3D) -> Coords3D {
67 minimize(mol, coords)
68}
69
70pub fn minimize_dreiding(mol: &Molecule, coords: Coords3D) -> Coords3D {
82 minimize_dreiding_with_config(mol, coords, &MinimizeConfig::default())
83}
84
85pub fn minimize_mmff94(mol: &Molecule, coords: Coords3D) -> Coords3D {
97 let config = MinimizeConfig {
98 force_field: ForceField::MMFF94,
99 ..MinimizeConfig::default()
100 };
101 minimize_with_config(mol, coords, &config)
102}
103
104fn minimize_mmff94_with_config(
106 mol: &Molecule,
107 coords: Coords3D,
108 config: &MinimizeConfig,
109) -> Coords3D {
110 if mol.atom_count() <= 1 {
111 return coords;
112 }
113
114 let mmff94_types = match assign_mmff94_types(mol) {
116 Ok(types) => types,
117 Err(_) => return coords, };
119
120 let mut c = coords;
121 let delta = 1e-4;
122
123 fn partial_mmff94(
124 mol: &Molecule,
125 c: &mut Coords3D,
126 idx: AtomIdx,
127 delta: f64,
128 axis: impl Fn(&mut Point3, f64),
129 mmff94_types: &[chematic_ff::MMFF94Type],
130 ) -> f64 {
131 let orig = c.get(idx);
132 let mut p = orig;
133 axis(&mut p, delta);
134 c.set(idx, p);
135 let ep = total_energy_mmff94(mol, c, mmff94_types);
136 let mut p = orig;
137 axis(&mut p, -delta);
138 c.set(idx, p);
139 let em = total_energy_mmff94(mol, c, mmff94_types);
140 c.set(idx, orig);
141 (ep - em) / (2.0 * delta)
142 }
143
144 for _ in 0..config.max_steps {
145 let mut grad = vec![Point3::zero(); mol.atom_count()];
146 let mut max_grad = 0.0f64;
147
148 for i in 0..mol.atom_count() {
149 let idx = AtomIdx(i as u32);
150 grad[i].x = partial_mmff94(mol, &mut c, idx, delta, |p, d| p.x += d, &mmff94_types);
151 grad[i].y = partial_mmff94(mol, &mut c, idx, delta, |p, d| p.y += d, &mmff94_types);
152 grad[i].z = partial_mmff94(mol, &mut c, idx, delta, |p, d| p.z += d, &mmff94_types);
153
154 let gmax = grad[i].x.abs().max(grad[i].y.abs()).max(grad[i].z.abs());
155 if gmax > max_grad {
156 max_grad = gmax;
157 }
158 }
159
160 if max_grad < config.convergence {
161 break;
162 }
163
164 let scale = config.step_size / max_grad.max(1e-8);
165 for i in 0..mol.atom_count() {
166 let idx = AtomIdx(i as u32);
167 let p = c.get(idx);
168 c.set(
169 idx,
170 Point3::new(
171 p.x - scale * grad[i].x,
172 p.y - scale * grad[i].y,
173 p.z - scale * grad[i].z,
174 ),
175 );
176 }
177 }
178
179 c
180}
181
182pub fn minimize_dreiding_with_config(
184 mol: &Molecule,
185 coords: Coords3D,
186 config: &MinimizeConfig,
187) -> Coords3D {
188 if mol.atom_count() <= 1 {
189 return coords;
190 }
191
192 let dreiding_types = assign_dreiding_types(mol);
194
195 let mut c = coords;
196 let delta = 1e-4;
197
198 fn partial_dreiding(
199 mol: &Molecule,
200 c: &mut Coords3D,
201 idx: AtomIdx,
202 delta: f64,
203 axis: impl Fn(&mut Point3, f64),
204 dreiding_types: &[chematic_ff::DREIDINGType],
205 ) -> f64 {
206 let orig = c.get(idx);
207 let mut p = orig;
208 axis(&mut p, delta);
209 c.set(idx, p);
210 let ep = total_energy_dreiding(mol, c, dreiding_types);
211 let mut p = orig;
212 axis(&mut p, -delta);
213 c.set(idx, p);
214 let em = total_energy_dreiding(mol, c, dreiding_types);
215 c.set(idx, orig);
216 (ep - em) / (2.0 * delta)
217 }
218
219 for _ in 0..config.max_steps {
220 let mut grad = vec![Point3::zero(); mol.atom_count()];
221 let mut max_grad = 0.0f64;
222
223 for i in 0..mol.atom_count() {
224 let idx = AtomIdx(i as u32);
225 grad[i].x = partial_dreiding(mol, &mut c, idx, delta, |p, d| p.x += d, &dreiding_types);
226 grad[i].y = partial_dreiding(mol, &mut c, idx, delta, |p, d| p.y += d, &dreiding_types);
227 grad[i].z = partial_dreiding(mol, &mut c, idx, delta, |p, d| p.z += d, &dreiding_types);
228
229 let gmax = grad[i].x.abs().max(grad[i].y.abs()).max(grad[i].z.abs());
230 if gmax > max_grad {
231 max_grad = gmax;
232 }
233 }
234
235 if max_grad < config.convergence {
236 break;
237 }
238
239 let scale = config.step_size / max_grad.max(1e-8);
240 for i in 0..mol.atom_count() {
241 let idx = AtomIdx(i as u32);
242 let p = c.get(idx);
243 c.set(
244 idx,
245 Point3::new(
246 p.x - scale * grad[i].x,
247 p.y - scale * grad[i].y,
248 p.z - scale * grad[i].z,
249 ),
250 );
251 }
252 }
253
254 c
255}
256
257fn total_energy_dreiding(
258 mol: &Molecule,
259 coords: &Coords3D,
260 dreiding_types: &[chematic_ff::DREIDINGType],
261) -> f64 {
262 bond_energy_dreiding(mol, coords, dreiding_types)
263 + angle_energy_dreiding(mol, coords, dreiding_types)
264 + vdw_energy_dreiding(mol, coords, dreiding_types)
265}
266
267fn bond_energy_dreiding(
268 mol: &Molecule,
269 coords: &Coords3D,
270 dreiding_types: &[chematic_ff::DREIDINGType],
271) -> f64 {
272 let mut energy = 0.0;
273 let k = 700.0; for (_, bond) in mol.bonds() {
275 let a1 = bond.atom1;
276 let a2 = bond.atom2;
277 let r = coords.get(a1).distance(&coords.get(a2));
278 let t1 = dreiding_types[a1.0 as usize];
279 let t2 = dreiding_types[a2.0 as usize];
280 let r0 = dreiding_bond_len(t1, t2, bond.order);
281 let dr = r - r0;
282 energy += 0.5 * k * dr * dr;
283 }
284 energy
285}
286
287fn angle_energy_dreiding(
288 mol: &Molecule,
289 coords: &Coords3D,
290 dreiding_types: &[chematic_ff::DREIDINGType],
291) -> f64 {
292 let mut energy = 0.0;
293 let k = 100.0; for b_idx in 0..mol.atom_count() {
296 let b = AtomIdx(b_idx as u32);
297 let neighbors: Vec<AtomIdx> = mol.neighbors(b).map(|(nb, _)| nb).collect();
298
299 if neighbors.len() < 2 {
300 continue;
301 }
302
303 let theta0 = dreiding_angle(dreiding_types[b_idx]);
304
305 for (i, &a) in neighbors.iter().enumerate() {
306 for &c in &neighbors[i + 1..] {
307 let pb = coords.get(b);
308
309 let pa = coords.get(a);
310 let pc = coords.get(c);
311
312 let va = pa.sub(&pb);
313 let vc = pc.sub(&pb);
314
315 let na = va.norm();
316 let nc = vc.norm();
317
318 if na < 1e-10 || nc < 1e-10 {
319 continue;
320 }
321
322 let cos_theta = (va.dot(&vc) / (na * nc)).clamp(-1.0, 1.0);
323 let theta = cos_theta.acos();
324 let dtheta = theta - theta0;
325 energy += 0.5 * k * dtheta * dtheta;
326 }
327 }
328 }
329
330 energy
331}
332
333fn vdw_energy_dreiding(
334 mol: &Molecule,
335 coords: &Coords3D,
336 dreiding_types: &[chematic_ff::DREIDINGType],
337) -> f64 {
338 let n = mol.atom_count();
339 let cutoff = 8.0_f64;
340
341 let mut excluded: HashSet<(usize, usize)> = HashSet::new();
342
343 for (_, bond) in mol.bonds() {
344 let i = bond.atom1.0 as usize;
345 let j = bond.atom2.0 as usize;
346 excluded.insert((i.min(j), i.max(j)));
347 }
348
349 for b_idx in 0..n {
350 let b = AtomIdx(b_idx as u32);
351 let neighbors: Vec<usize> = mol.neighbors(b).map(|(nb, _)| nb.0 as usize).collect();
352 for ii in 0..neighbors.len() {
353 for jj in (ii + 1)..neighbors.len() {
354 let i = neighbors[ii];
355 let j = neighbors[jj];
356 excluded.insert((i.min(j), i.max(j)));
357 }
358 }
359 }
360
361 let mut energy = 0.0;
362 for i in 0..n {
363 for j in (i + 1)..n {
364 if excluded.contains(&(i, j)) {
365 continue;
366 }
367 let r = coords
368 .get(AtomIdx(i as u32))
369 .distance(&coords.get(AtomIdx(j as u32)));
370
371 if r < 0.01 || r >= cutoff {
372 continue;
373 }
374
375 let t_i = dreiding_types[i];
376 let t_j = dreiding_types[j];
377 let (r0_i, well_i) = dreiding_vdw(t_i);
378 let (r0_j, well_j) = dreiding_vdw(t_j);
379
380 let r0 = (r0_i + r0_j) / 2.0;
382 let well = (well_i * well_j).sqrt();
383
384 let ratio = r0 / r;
385 let ratio6 = ratio * ratio * ratio * ratio * ratio * ratio;
386 let ratio12 = ratio6 * ratio6;
387 energy += well * (ratio12 - 2.0 * ratio6);
388 }
389 }
390
391 energy
392}
393
394pub fn minimize_with_config(mol: &Molecule, coords: Coords3D, config: &MinimizeConfig) -> Coords3D {
396 if mol.atom_count() <= 1 {
397 return coords;
398 }
399
400 match config.force_field {
402 ForceField::MMFF94 => minimize_mmff94_with_config(mol, coords, config),
403 _ => {
404 minimize_generic_with_config(mol, coords, config)
406 }
407 }
408}
409
410fn minimize_generic_with_config(mol: &Molecule, coords: Coords3D, config: &MinimizeConfig) -> Coords3D {
411 if mol.atom_count() <= 1 {
412 return coords;
413 }
414
415 let mut c = coords;
416 let delta = 1e-4;
417
418 fn partial(
419 mol: &Molecule,
420 c: &mut Coords3D,
421 idx: AtomIdx,
422 delta: f64,
423 axis: impl Fn(&mut Point3, f64),
424 ) -> f64 {
425 let orig = c.get(idx);
426 let mut p = orig;
427 axis(&mut p, delta);
428 c.set(idx, p);
429 let ep = total_energy(mol, c);
430 let mut p = orig;
431 axis(&mut p, -delta);
432 c.set(idx, p);
433 let em = total_energy(mol, c);
434 c.set(idx, orig);
435 (ep - em) / (2.0 * delta)
436 }
437
438 for _ in 0..config.max_steps {
439 let mut grad = vec![Point3::zero(); mol.atom_count()];
440 let mut max_grad = 0.0f64;
441
442 for i in 0..mol.atom_count() {
443 let idx = AtomIdx(i as u32);
444 grad[i].x = partial(mol, &mut c, idx, delta, |p, d| p.x += d);
445 grad[i].y = partial(mol, &mut c, idx, delta, |p, d| p.y += d);
446 grad[i].z = partial(mol, &mut c, idx, delta, |p, d| p.z += d);
447
448 let gmax = grad[i].x.abs().max(grad[i].y.abs()).max(grad[i].z.abs());
449 if gmax > max_grad {
450 max_grad = gmax;
451 }
452 }
453
454 if max_grad < config.convergence {
455 break;
456 }
457
458 let scale = config.step_size / max_grad.max(1e-8);
459 for i in 0..mol.atom_count() {
460 let idx = AtomIdx(i as u32);
461 let p = c.get(idx);
462 c.set(
463 idx,
464 Point3::new(
465 p.x - scale * grad[i].x,
466 p.y - scale * grad[i].y,
467 p.z - scale * grad[i].z,
468 ),
469 );
470 }
471 }
472
473 c
474}
475
476fn total_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
481 bond_energy(mol, coords) + angle_energy(mol, coords) + vdw_energy(mol, coords)
482}
483
484fn ideal_bond_len(sym1: &str, sym2: &str, order: BondOrder) -> f64 {
491 let (a, b) = if sym1 <= sym2 {
492 (sym1, sym2)
493 } else {
494 (sym2, sym1)
495 };
496 match (a, b, order) {
497 ("C", "C", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.540,
499 ("C", "C", BondOrder::Double) => 1.340,
500 ("C", "C", BondOrder::Triple) => 1.204,
501 ("C", "C", BondOrder::Aromatic) => 1.395,
502 ("C", "H", _) => 1.090,
504 ("C", "N", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.469,
506 ("C", "N", BondOrder::Double) => 1.279,
507 ("C", "N", BondOrder::Triple) => 1.158,
508 ("C", "N", BondOrder::Aromatic) => 1.340,
509 ("C", "O", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.427,
511 ("C", "O", BondOrder::Double) => 1.217,
512 ("C", "O", BondOrder::Aromatic) => 1.355,
513 ("C", "S", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.819,
515 ("C", "S", BondOrder::Double) => 1.610,
516 ("C", "S", BondOrder::Aromatic) => 1.750,
517 ("C", "F", _) => 1.350,
519 ("C", "Cl", _) => 1.770,
521 ("Br", "C", _) => 1.940,
523 ("C", "I", _) => 2.140,
525 ("C", "P", _) => 1.840,
527 ("C", "Si", _) => 1.870,
529 ("H", "H", _) => 0.741,
531 ("H", "N", _) => 1.010,
533 ("H", "O", _) => 0.960,
535 ("H", "S", _) => 1.340,
537 ("H", "P", _) => 1.420,
539 ("N", "N", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.450,
541 ("N", "N", BondOrder::Double) => 1.250,
542 ("N", "N", BondOrder::Triple) => 1.100,
543 ("N", "N", BondOrder::Aromatic) => 1.350,
544 ("N", "O", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.400,
546 ("N", "O", BondOrder::Double) => 1.210,
547 ("N", "O", BondOrder::Aromatic) => 1.340,
548 ("O", "O", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 1.480,
550 ("O", "O", BondOrder::Double) => 1.210,
551 ("S", "S", BondOrder::Single | BondOrder::Up | BondOrder::Down) => 2.050,
553 ("S", "S", BondOrder::Double) => 1.890,
554 ("P", "P", _) => 2.280,
556 _ => match order {
558 BondOrder::Single | BondOrder::Up | BondOrder::Down => 1.54,
559 BondOrder::Double => 1.34,
560 BondOrder::Triple => 1.20,
561 BondOrder::Quadruple => 1.20,
562 BondOrder::Aromatic => 1.40,
563 BondOrder::Zero
564 | BondOrder::Dative
565 | BondOrder::QueryAny
566 | BondOrder::QuerySingleOrDouble
567 | BondOrder::QuerySingleOrAromatic => 1.54,
568 BondOrder::QueryDoubleOrAromatic => 1.40,
569 },
570 }
571}
572
573#[derive(Clone, Copy, PartialEq, Debug)]
575enum Hybridization {
576 SP, SP2, SP3, }
580
581fn atom_hybridization(mol: &Molecule, idx: AtomIdx) -> Hybridization {
582 if mol.atom(idx).aromatic {
583 return Hybridization::SP2;
584 }
585 let mut has_triple = false;
586 let mut has_double_or_aromatic = false;
587 for (_, bond_idx) in mol.neighbors(idx) {
588 match mol.bond(bond_idx).order {
589 BondOrder::Triple => has_triple = true,
590 BondOrder::Double | BondOrder::Aromatic => has_double_or_aromatic = true,
591 _ => {}
592 }
593 }
594 if has_triple {
595 Hybridization::SP
596 } else if has_double_or_aromatic {
597 Hybridization::SP2
598 } else {
599 Hybridization::SP3
600 }
601}
602
603fn ideal_angle_rad(sym: &str, hyb: Hybridization) -> f64 {
605 match hyb {
606 Hybridization::SP => 180.0_f64.to_radians(),
607 Hybridization::SP2 => 120.0_f64.to_radians(),
608 Hybridization::SP3 => match sym {
609 "O" | "Se" => 104.5_f64.to_radians(),
610 "N" => 107.0_f64.to_radians(),
611 "S" => 99.0_f64.to_radians(),
612 "P" => 93.0_f64.to_radians(),
613 _ => 109.47_f64.to_radians(),
614 },
615 }
616}
617
618fn uff_vdw_radius(sym: &str) -> f64 {
620 match sym {
621 "H" => 1.20,
622 "C" => 1.70,
623 "N" => 1.55,
624 "O" => 1.52,
625 "F" => 1.47,
626 "Si" => 2.10,
627 "P" => 1.80,
628 "S" => 1.80,
629 "Cl" => 1.75,
630 "Br" => 1.85,
631 "I" => 1.98,
632 "Se" => 1.90,
633 "Te" => 2.06,
634 _ => 1.70,
635 }
636}
637
638fn bond_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
643 let mut energy = 0.0;
644 for (_, bond) in mol.bonds() {
645 let a1 = bond.atom1;
646 let a2 = bond.atom2;
647 let r = coords.get(a1).distance(&coords.get(a2));
648 let sym1 = mol.atom(a1).element.symbol();
649 let sym2 = mol.atom(a2).element.symbol();
650 let r0 = ideal_bond_len(sym1, sym2, bond.order);
651 let dr = r - r0;
652 energy += 0.5 * 700.0 * dr * dr;
653 }
654 energy
655}
656
657fn angle_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
662 let mut energy = 0.0;
663
664 for b_idx in 0..mol.atom_count() {
665 let b = AtomIdx(b_idx as u32);
666 let neighbors: Vec<AtomIdx> = mol.neighbors(b).map(|(nb, _)| nb).collect();
667
668 if neighbors.len() < 2 {
669 continue;
670 }
671
672 let sym_b = mol.atom(b).element.symbol();
673 let hyb = atom_hybridization(mol, b);
674 let theta0 = ideal_angle_rad(sym_b, hyb);
675 let pb = coords.get(b);
676
677 for i in 0..neighbors.len() {
678 for j in (i + 1)..neighbors.len() {
679 let a = neighbors[i];
680 let c = neighbors[j];
681
682 let pa = coords.get(a);
683 let pc = coords.get(c);
684
685 let va = pa.sub(&pb);
686 let vc = pc.sub(&pb);
687
688 let na = va.norm();
689 let nc = vc.norm();
690
691 if na < 1e-10 || nc < 1e-10 {
692 continue;
693 }
694
695 let cos_theta = (va.dot(&vc) / (na * nc)).clamp(-1.0, 1.0);
696 let theta = cos_theta.acos();
697 let dtheta = theta - theta0;
698 energy += 0.5 * 100.0 * dtheta * dtheta;
699 }
700 }
701 }
702
703 energy
704}
705
706fn vdw_energy(mol: &Molecule, coords: &Coords3D) -> f64 {
711 let n = mol.atom_count();
712 let cutoff = 8.0_f64;
713
714 let mut excluded: HashSet<(usize, usize)> = HashSet::new();
715
716 for (_, bond) in mol.bonds() {
717 let i = bond.atom1.0 as usize;
718 let j = bond.atom2.0 as usize;
719 excluded.insert((i.min(j), i.max(j)));
720 }
721
722 for b_idx in 0..n {
723 let b = AtomIdx(b_idx as u32);
724 let neighbors: Vec<usize> = mol.neighbors(b).map(|(nb, _)| nb.0 as usize).collect();
725 for ii in 0..neighbors.len() {
726 for jj in (ii + 1)..neighbors.len() {
727 let i = neighbors[ii];
728 let j = neighbors[jj];
729 excluded.insert((i.min(j), i.max(j)));
730 }
731 }
732 }
733
734 let mut energy = 0.0;
735 for i in 0..n {
736 for j in (i + 1)..n {
737 if excluded.contains(&(i, j)) {
738 continue;
739 }
740 let r = coords
741 .get(AtomIdx(i as u32))
742 .distance(&coords.get(AtomIdx(j as u32)));
743
744 if r < 0.01 || r >= cutoff {
745 continue;
746 }
747
748 let sym_i = mol.atom(AtomIdx(i as u32)).element.symbol();
749 let sym_j = mol.atom(AtomIdx(j as u32)).element.symbol();
750 let r0 = uff_vdw_radius(sym_i) + uff_vdw_radius(sym_j);
751
752 let ratio = r0 / r;
753 let ratio6 = ratio * ratio * ratio * ratio * ratio * ratio;
754 let ratio12 = ratio6 * ratio6;
755 energy += 0.05 * ratio12;
756 }
757 }
758
759 energy
760}
761
762fn total_energy_mmff94(
767 mol: &Molecule,
768 coords: &Coords3D,
769 mmff94_types: &[chematic_ff::MMFF94Type],
770) -> f64 {
771 let bond_e = bond_energy_mmff94(mol, coords, mmff94_types);
772 let angle_e = angle_energy_mmff94(mol, coords, mmff94_types);
773 let vdw_e = vdw_energy_mmff94(mol, coords, mmff94_types);
774
775 let elec_e = electrostatic_energy_mmff94(mol, coords, mmff94_types).unwrap_or(0.0);
777
778 bond_e + angle_e + vdw_e + elec_e
779}
780
781fn bond_energy_mmff94(
782 mol: &Molecule,
783 coords: &Coords3D,
784 mmff94_types: &[chematic_ff::MMFF94Type],
785) -> f64 {
786 let mut energy = 0.0;
787
788 for (_, bond) in mol.bonds() {
789 let a1 = bond.atom1;
790 let a2 = bond.atom2;
791 let r = coords.get(a1).distance(&coords.get(a2));
792 let t1 = mmff94_types[a1.0 as usize];
793 let t2 = mmff94_types[a2.0 as usize];
794
795 if let Some(params) = mmff94_bond_params(t1, t2, bond.order) {
796 let dr = r - params.r0;
797 energy += 0.5 * params.kb * dr * dr;
798 }
799 }
800
801 energy
802}
803
804fn angle_energy_mmff94(
805 mol: &Molecule,
806 coords: &Coords3D,
807 mmff94_types: &[chematic_ff::MMFF94Type],
808) -> f64 {
809 let mut energy = 0.0;
810
811 for b_idx in 0..mol.atom_count() {
812 let b = AtomIdx(b_idx as u32);
813 let neighbors: Vec<AtomIdx> = mol.neighbors(b).map(|(nb, _)| nb).collect();
814
815 if neighbors.len() < 2 {
816 continue;
817 }
818
819 for (i, &a) in neighbors.iter().enumerate() {
820 for &c in &neighbors[i + 1..] {
821 let t1 = mmff94_types[a.0 as usize];
822 let t2 = mmff94_types[b_idx];
823 let t3 = mmff94_types[c.0 as usize];
824
825 if let Some(params) = mmff94_angle_params(t1, t2, t3) {
826 let pb = coords.get(b);
827 let pa = coords.get(a);
828 let pc = coords.get(c);
829
830 let va = pa.sub(&pb);
831 let vc = pc.sub(&pb);
832
833 let na = va.norm();
834 let nc = vc.norm();
835
836 if na < 1e-10 || nc < 1e-10 {
837 continue;
838 }
839
840 let cos_theta = (va.dot(&vc) / (na * nc)).clamp(-1.0, 1.0);
841 let theta = cos_theta.acos();
842 let dtheta = theta - params.theta0;
843 energy += 0.5 * params.ka * dtheta * dtheta;
844 }
845 }
846 }
847 }
848
849 energy
850}
851
852fn vdw_energy_mmff94(
853 mol: &Molecule,
854 coords: &Coords3D,
855 mmff94_types: &[chematic_ff::MMFF94Type],
856) -> f64 {
857 let n = mol.atom_count();
858 let cutoff = 8.0_f64;
859 let mut excluded: HashSet<(usize, usize)> = HashSet::new();
860
861 for (_, bond) in mol.bonds() {
862 let i = bond.atom1.0 as usize;
863 let j = bond.atom2.0 as usize;
864 excluded.insert((i.min(j), i.max(j)));
865 }
866
867 for b_idx in 0..n {
869 let b = AtomIdx(b_idx as u32);
870 let neighbors: Vec<usize> = mol.neighbors(b).map(|(nb, _)| nb.0 as usize).collect();
871 for &neighbor in &neighbors {
872 excluded.insert((b_idx.min(neighbor), b_idx.max(neighbor)));
873 }
874 }
875
876 let mut energy = 0.0;
877
878 for i in 0..n {
879 for j in (i + 1)..n {
880 if excluded.contains(&(i, j)) {
881 continue;
882 }
883
884 let ri = coords.get(AtomIdx(i as u32));
885 let rj = coords.get(AtomIdx(j as u32));
886 let d = ri.distance(&rj);
887
888 if d > cutoff {
889 continue;
890 }
891
892 let params_i = mmff94_vdw_params(mmff94_types[i]);
893 let params_j = mmff94_vdw_params(mmff94_types[j]);
894
895 let r_ij = (params_i.r_star * params_j.r_star).sqrt();
897 let eps_ij = (params_i.epsilon * params_j.epsilon).sqrt();
898
899 if d > 0.0 {
901 let r6 = (r_ij / d).powi(6);
902 energy += eps_ij * (r6 * r6 - 2.0 * r6);
903 }
904 }
905 }
906
907 energy
908}
909
910fn electrostatic_energy_mmff94(
913 mol: &Molecule,
914 coords: &Coords3D,
915 _mmff94_types: &[chematic_ff::MMFF94Type],
916) -> Result<f64, String> {
917 let coord_tuples: Vec<(f64, f64, f64)> = (0..mol.atom_count())
919 .map(|i| {
920 let p = coords.get(AtomIdx(i as u32));
921 (p.x, p.y, p.z)
922 })
923 .collect();
924
925 let charges = mmff94_charges_3d(mol, &coord_tuples)
927 .map_err(|e| format!("charge calculation failed: {}", e))?;
928
929 let n = mol.atom_count();
930 let mut energy = 0.0;
931
932 let mut excluded: HashSet<(usize, usize)> = HashSet::new();
934
935 for (_, bond) in mol.bonds() {
937 let i = bond.atom1.0 as usize;
938 let j = bond.atom2.0 as usize;
939 excluded.insert((i.min(j), i.max(j)));
940 }
941
942 for b_idx in 0..n {
944 let b = AtomIdx(b_idx as u32);
945 let neighbors: Vec<usize> = mol.neighbors(b).map(|(nb, _)| nb.0 as usize).collect();
946 for &neighbor in &neighbors {
947 excluded.insert((b_idx.min(neighbor), b_idx.max(neighbor)));
948 }
949 }
950
951 let dielectric = 4.0; let coulomb_const = 332.0; for i in 0..n {
955 for j in (i + 1)..n {
956 if excluded.contains(&(i, j)) {
958 continue;
959 }
960
961 let ri = coords.get(AtomIdx(i as u32));
962 let rj = coords.get(AtomIdx(j as u32));
963 let d = ri.distance(&rj);
964
965 if d > 0.01 {
966 let coulomb = coulomb_const * charges[i] * charges[j] / (d * dielectric);
968 energy += coulomb;
969 }
970 }
971 }
972
973 Ok(energy)
974}
975
976#[cfg(test)]
981mod tests {
982 use super::*;
983 use crate::dg::generate_coords;
984 use chematic_smiles::parse;
985
986 fn all_pairs_min_dist(coords: &Coords3D, n: usize) -> f64 {
987 let mut min_d = f64::MAX;
988 for i in 0..n {
989 for j in (i + 1)..n {
990 let d = coords
991 .get(AtomIdx(i as u32))
992 .distance(&coords.get(AtomIdx(j as u32)));
993 min_d = min_d.min(d);
994 }
995 }
996 min_d
997 }
998
999 #[test]
1000 fn test_single_atom_unchanged() {
1001 let mol = parse("O").unwrap();
1002 let coords = generate_coords(&mol);
1003 let orig = coords.get(AtomIdx(0));
1004 let result = minimize(&mol, coords);
1005 let after = result.get(AtomIdx(0));
1006 assert!((orig.x - after.x).abs() < 1e-10);
1007 }
1008
1009 #[test]
1010 fn test_zero_steps_unchanged() {
1011 let mol = parse("CC").unwrap();
1012 let coords = generate_coords(&mol);
1013 let config = MinimizeConfig {
1014 max_steps: 0,
1015 ..MinimizeConfig::default()
1016 };
1017 let before0 = coords.get(AtomIdx(0));
1018 let result = minimize_with_config(&mol, coords, &config);
1019 let after0 = result.get(AtomIdx(0));
1020 assert!((before0.x - after0.x).abs() < 1e-10);
1021 }
1022
1023 #[test]
1024 fn test_ethane_bond_after_minimize() {
1025 let mol = parse("CC").unwrap();
1026 let coords = generate_coords(&mol);
1027 let result = minimize(&mol, coords);
1028 let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
1029 assert!(
1030 d > 1.2 && d < 1.8,
1031 "C-C distance={d:.3}, expected 1.2-1.8 Å"
1032 );
1033 }
1034
1035 #[test]
1036 fn test_ethane_converges_to_uff_length() {
1037 let mol = parse("CC").unwrap();
1038 let coords = generate_coords(&mol);
1039 let result = minimize(&mol, coords);
1040 let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
1041 assert!(
1043 (d - 1.540).abs() < 0.05,
1044 "C-C distance={d:.4}, expected ~1.540"
1045 );
1046 }
1047
1048 #[test]
1049 fn test_propane_no_clash() {
1050 let mol = parse("CCC").unwrap();
1051 let coords = generate_coords(&mol);
1052 let result = minimize(&mol, coords);
1053 let min_d = all_pairs_min_dist(&result, mol.atom_count());
1054 assert!(min_d > 0.8, "atom clash: min distance={min_d:.3}");
1055 }
1056
1057 #[test]
1058 fn test_benzene_no_clash() {
1059 let mol = parse("c1ccccc1").unwrap();
1060 let coords = generate_coords(&mol);
1061 let result = minimize(&mol, coords);
1062 let min_d = all_pairs_min_dist(&result, mol.atom_count());
1063 assert!(
1064 min_d > 0.8,
1065 "atom clash in benzene: min distance={min_d:.3}"
1066 );
1067 }
1068
1069 #[test]
1070 fn test_disconnected_no_clash() {
1071 let mol = parse("CC.CC").unwrap();
1072 let coords = generate_coords(&mol);
1073 let result = minimize(&mol, coords);
1074 let min_d = all_pairs_min_dist(&result, mol.atom_count());
1075 assert!(
1076 min_d > 0.8,
1077 "atom clash in disconnected: min distance={min_d:.3}"
1078 );
1079 }
1080
1081 #[test]
1082 fn test_default_config_no_panic() {
1083 let mol = parse("CC(=O)O").unwrap();
1084 let coords = generate_coords(&mol);
1085 let result = minimize(&mol, coords);
1086 assert_eq!(result.atom_count(), mol.atom_count());
1087 }
1088
1089 #[test]
1090 fn test_acetic_acid_no_clash() {
1091 let mol = parse("CC(=O)O").unwrap();
1092 let coords = generate_coords(&mol);
1093 let result = minimize(&mol, coords);
1094 let min_d = all_pairs_min_dist(&result, mol.atom_count());
1095 assert!(min_d > 0.8, "clash in acetic acid: {min_d:.3}");
1096 }
1097
1098 #[test]
1099 fn test_minimize_idempotent() {
1100 let mol = parse("CCC").unwrap();
1101 let coords = generate_coords(&mol);
1102 let result1 = minimize(&mol, coords);
1103 let e1 = total_energy(&mol, &result1);
1104 let result2 = minimize(&mol, result1);
1105 let e2 = total_energy(&mol, &result2);
1106 assert!(e2 <= e1 + 1.0, "energy increased: e1={e1:.4}, e2={e2:.4}");
1107 }
1108
1109 #[test]
1110 fn test_naphthalene_no_overlap() {
1111 let mol = parse("c1ccc2ccccc2c1").unwrap();
1112 let coords = generate_coords(&mol);
1113 let result = minimize(&mol, coords);
1114 let min_d = all_pairs_min_dist(&result, mol.atom_count());
1115 assert!(min_d > 0.8, "overlap in naphthalene: {min_d:.3}");
1116 }
1117
1118 #[test]
1119 fn test_co_bond_double_shorter_than_single() {
1120 let mol = parse("CC(=O)O").unwrap();
1122 let coords = generate_coords(&mol);
1123 let result = minimize(&mol, coords);
1124 assert_eq!(result.atom_count(), 4);
1127 let min_d = all_pairs_min_dist(&result, 4);
1128 assert!(min_d > 0.5, "clash in CO test: {min_d:.3}");
1129 }
1130
1131 #[test]
1132 fn test_heteroatom_c_n_bond() {
1133 let mol = parse("CN").unwrap(); let coords = generate_coords(&mol);
1135 let result = minimize(&mol, coords);
1136 let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
1137 assert!(
1139 (d - 1.469).abs() < 0.1,
1140 "C-N distance={d:.4}, expected ~1.469"
1141 );
1142 }
1143
1144 #[test]
1145 fn test_acetylene_sp_hybridization() {
1146 let mol = parse("C#C").unwrap(); let coords = generate_coords(&mol);
1148 let result = minimize(&mol, coords);
1149 let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
1150 assert!(
1152 (d - 1.204).abs() < 0.05,
1153 "C≡C distance={d:.4}, expected ~1.204"
1154 );
1155 }
1156
1157 #[test]
1158 fn test_ideal_bond_len_cc_single() {
1159 assert!((ideal_bond_len("C", "C", BondOrder::Single) - 1.540).abs() < 1e-6);
1160 assert!((ideal_bond_len("C", "C", BondOrder::Double) - 1.340).abs() < 1e-6);
1161 assert!((ideal_bond_len("C", "C", BondOrder::Triple) - 1.204).abs() < 1e-6);
1162 assert!((ideal_bond_len("C", "C", BondOrder::Aromatic) - 1.395).abs() < 1e-6);
1163 }
1164
1165 #[test]
1166 fn test_ideal_bond_len_symmetry() {
1167 let bo = BondOrder::Single;
1169 assert_eq!(ideal_bond_len("C", "N", bo), ideal_bond_len("N", "C", bo));
1170 assert_eq!(ideal_bond_len("C", "O", bo), ideal_bond_len("O", "C", bo));
1171 assert_eq!(ideal_bond_len("Br", "C", bo), ideal_bond_len("C", "Br", bo));
1172 }
1173
1174 #[test]
1175 fn test_atom_hybridization_sp2_aromatic() {
1176 let mol = parse("c1ccccc1").unwrap();
1177 for i in 0..6 {
1178 assert_eq!(
1179 atom_hybridization(&mol, AtomIdx(i)),
1180 Hybridization::SP2,
1181 "benzene atom {i} should be SP2"
1182 );
1183 }
1184 }
1185
1186 #[test]
1187 fn test_atom_hybridization_sp_triple() {
1188 let mol = parse("C#C").unwrap();
1189 assert_eq!(atom_hybridization(&mol, AtomIdx(0)), Hybridization::SP);
1190 assert_eq!(atom_hybridization(&mol, AtomIdx(1)), Hybridization::SP);
1191 }
1192
1193 #[test]
1194 fn test_atom_hybridization_sp3_alkane() {
1195 let mol = parse("CCC").unwrap();
1196 for i in 0..3 {
1197 assert_eq!(
1198 atom_hybridization(&mol, AtomIdx(i)),
1199 Hybridization::SP3,
1200 "propane atom {i} should be SP3"
1201 );
1202 }
1203 }
1204
1205 #[test]
1206 fn test_minimize_dreiding_ethane_no_clash() {
1207 let mol = parse("CC").unwrap();
1208 let coords = generate_coords(&mol);
1209 let min_coords = minimize_dreiding(&mol, coords);
1210 let n = mol.atom_count();
1211 for i in 0..n {
1212 for j in (i + 1)..n {
1213 let d = min_coords
1214 .get(AtomIdx(i as u32))
1215 .distance(&min_coords.get(AtomIdx(j as u32)));
1216 assert!(
1217 d > 0.5,
1218 "atoms {i} and {j} clashed after DREIDING minimization (d={d:.3})"
1219 );
1220 }
1221 }
1222 }
1223
1224 #[test]
1225 fn test_minimize_dreiding_benzene_no_clash() {
1226 let mol = parse("c1ccccc1").unwrap();
1227 let coords = generate_coords(&mol);
1228 let min_coords = minimize_dreiding(&mol, coords);
1229 let n = mol.atom_count();
1230 for i in 0..n {
1231 for j in (i + 1)..n {
1232 let d = min_coords
1233 .get(AtomIdx(i as u32))
1234 .distance(&min_coords.get(AtomIdx(j as u32)));
1235 assert!(
1236 d > 0.5,
1237 "atoms {i} and {j} clashed after DREIDING minimization (d={d:.3})"
1238 );
1239 }
1240 }
1241 }
1242
1243 #[test]
1244 fn test_minimize_mmff94_ethane() {
1245 let mol = parse("CC").unwrap();
1246 let c = generate_coords(&mol);
1247 let result = minimize_mmff94(&mol, c);
1248 assert_eq!(result.atom_count(), 2);
1249 let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
1250 assert!(d > 1.4 && d < 1.7, "C-C should be ~1.54 Å, got {:.3}", d);
1251 }
1252
1253 #[test]
1254 fn test_minimize_mmff94_benzene() {
1255 let mol = parse("c1ccccc1").unwrap();
1256 let c = generate_coords(&mol);
1257 let result = minimize_mmff94(&mol, c);
1258 assert_eq!(result.atom_count(), 6);
1259 let min_d = all_pairs_min_dist(&result, 6);
1260 assert!(min_d > 1.2, "benzene clash: {min_d:.3}");
1261 }
1262
1263 #[test]
1264 fn test_minimize_mmff94_aspirin() {
1265 let mol = parse("CC(=O)Oc1ccccc1C(=O)O").unwrap();
1266 let c = generate_coords(&mol);
1267 let result = minimize_mmff94(&mol, c);
1268 assert_eq!(result.atom_count(), mol.atom_count());
1270 for i in 0..mol.atom_count() {
1271 let p = result.get(chematic_core::AtomIdx(i as u32));
1272 assert!(p.x.is_finite() && p.y.is_finite() && p.z.is_finite(),
1273 "aspirin atom {i} has invalid coords");
1274 }
1275 }
1276
1277 #[test]
1280 fn test_electrostatic_energy_methanol() {
1281 let mol = parse("CO").unwrap();
1283 let c = generate_coords(&mol);
1284 let mmff94_types = assign_mmff94_types(&mol).unwrap();
1285
1286 let elec_e = electrostatic_energy_mmff94(&mol, &c, &mmff94_types);
1288 assert!(elec_e.is_ok());
1289 assert!(elec_e.unwrap().is_finite());
1290 }
1291
1292 #[test]
1293 fn test_electrostatic_energy_carboxylic_acid() {
1294 let mol = parse("CC(=O)O").unwrap();
1296 let c = generate_coords(&mol);
1297 let mmff94_types = assign_mmff94_types(&mol).unwrap();
1298
1299 let elec_e = electrostatic_energy_mmff94(&mol, &c, &mmff94_types);
1300 assert!(elec_e.is_ok());
1301 let energy = elec_e.unwrap();
1302 assert!(energy.is_finite());
1303 }
1305
1306 #[test]
1307 fn test_mmff94_with_electrostatic_ethane() {
1308 let mol = parse("CC").unwrap();
1310 let c = generate_coords(&mol);
1311 let result = minimize_mmff94(&mol, c);
1312
1313 assert_eq!(result.atom_count(), 2);
1315 let d = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
1316 assert!(d > 1.4 && d < 1.7, "ethane C-C should be ~1.54 Å with electrostatic, got {:.3}", d);
1317 }
1318
1319 #[test]
1320 fn test_mmff94_minimization_includes_charge_effects() {
1321 let mol = parse("CCO").unwrap();
1323 let c = generate_coords(&mol);
1324
1325 let result = minimize_mmff94(&mol, c);
1327
1328 assert_eq!(result.atom_count(), 3);
1330 for i in 0..3 {
1331 let p = result.get(AtomIdx(i as u32));
1332 assert!(p.x.is_finite() && p.y.is_finite() && p.z.is_finite(),
1333 "atom {i} has invalid coordinate after minimization");
1334 }
1335
1336 let c_c = result.get(AtomIdx(0)).distance(&result.get(AtomIdx(1)));
1338 let c_o = result.get(AtomIdx(1)).distance(&result.get(AtomIdx(2)));
1339 assert!(c_c > 1.0, "C-C bond too short: {c_c:.3}");
1340 assert!(c_o > 1.0, "C-O bond too short: {c_o:.3}");
1341 }
1342
1343 #[test]
1344 fn test_mmff94_charges_3d_integration() {
1345 let mol = parse("c1ccccc1O").unwrap(); let c = generate_coords(&mol);
1348
1349 let result = minimize_mmff94(&mol, c);
1351 assert_eq!(result.atom_count(), mol.atom_count());
1352
1353 for i in 0..mol.atom_count() {
1355 let p = result.get(AtomIdx(i as u32));
1356 assert!(p.x.is_finite() && p.y.is_finite() && p.z.is_finite());
1357 }
1358 }
1359
1360 #[test]
1361 fn test_total_energy_mmff94_includes_electrostatic() {
1362 let mol = parse("CCN").unwrap(); let c = generate_coords(&mol);
1365 let mmff94_types = assign_mmff94_types(&mol).unwrap();
1366
1367 let total_e = total_energy_mmff94(&mol, &c, &mmff94_types);
1368 let bond_e = bond_energy_mmff94(&mol, &c, &mmff94_types);
1369 let angle_e = angle_energy_mmff94(&mol, &c, &mmff94_types);
1370 let vdw_e = vdw_energy_mmff94(&mol, &c, &mmff94_types);
1371
1372 let electrostatic_e = electrostatic_energy_mmff94(&mol, &c, &mmff94_types).unwrap_or(0.0);
1374 let expected = bond_e + angle_e + vdw_e + electrostatic_e;
1375
1376 assert!((total_e - expected).abs() < 1e-6,
1377 "total energy mismatch: got {}, expected {}", total_e, expected);
1378 }
1379}