use nalgebra::DMatrix;
use petgraph::visit::EdgeRef;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::collections::{HashSet, VecDeque};
pub struct RotatableBond {
pub dihedral: [usize; 4],
pub mobile_atoms: Vec<usize>,
pub preferred_angles: Vec<f32>,
}
pub fn find_rotatable_bonds(mol: &crate::graph::Molecule) -> Vec<RotatableBond> {
let mut bonds = Vec::new();
let n = mol.graph.node_count();
for edge in mol.graph.edge_references() {
let u = edge.source();
let v = edge.target();
if edge.weight().order != crate::graph::BondOrder::Single {
continue;
}
let deg_u = mol.graph.neighbors(u).count();
let deg_v = mol.graph.neighbors(v).count();
if deg_u < 2 || deg_v < 2 {
continue;
}
if crate::graph::min_path_excluding2(mol, u, v, u, v, 7).is_some() {
continue;
}
let a = mol.graph.neighbors(u).find(|&x| x != v).unwrap();
let d = mol.graph.neighbors(v).find(|&x| x != u).unwrap();
let mut mobile = Vec::new();
let mut visited = HashSet::new();
visited.insert(u.index());
visited.insert(v.index());
let mut queue = VecDeque::new();
for nb in mol.graph.neighbors(v) {
if nb != u {
queue.push_back(nb.index());
visited.insert(nb.index());
}
}
while let Some(curr) = queue.pop_front() {
mobile.push(curr);
let ni = petgraph::graph::NodeIndex::new(curr);
for nb in mol.graph.neighbors(ni) {
if !visited.contains(&nb.index()) {
visited.insert(nb.index());
queue.push_back(nb.index());
}
}
}
let other_count = n - mobile.len() - 2; if mobile.len() > other_count {
let mut mobile_u = Vec::new();
let mut visited_u = HashSet::new();
visited_u.insert(u.index());
visited_u.insert(v.index());
let mut queue_u = VecDeque::new();
for nb in mol.graph.neighbors(u) {
if nb != v {
queue_u.push_back(nb.index());
visited_u.insert(nb.index());
}
}
while let Some(curr) = queue_u.pop_front() {
mobile_u.push(curr);
let ni = petgraph::graph::NodeIndex::new(curr);
for nb in mol.graph.neighbors(ni) {
if !visited_u.contains(&nb.index()) {
visited_u.insert(nb.index());
queue_u.push_back(nb.index());
}
}
}
let preferred = get_preferred_angles(mol, v, u);
bonds.push(RotatableBond {
dihedral: [d.index(), v.index(), u.index(), a.index()],
mobile_atoms: mobile_u,
preferred_angles: preferred,
});
} else {
let preferred = get_preferred_angles(mol, u, v);
bonds.push(RotatableBond {
dihedral: [a.index(), u.index(), v.index(), d.index()],
mobile_atoms: mobile,
preferred_angles: preferred,
});
}
}
bonds
}
fn get_preferred_angles(
mol: &crate::graph::Molecule,
u: petgraph::graph::NodeIndex,
v: petgraph::graph::NodeIndex,
) -> Vec<f32> {
use crate::graph::Hybridization::*;
use std::f32::consts::PI;
let hyb_u = mol.graph[u].hybridization;
let hyb_v = mol.graph[v].hybridization;
match (hyb_u, hyb_v) {
(SP3, SP3) => {
vec![PI / 3.0, PI, 5.0 * PI / 3.0]
}
(SP2, SP2) => {
vec![0.0, PI]
}
(SP2, SP3) | (SP3, SP2) => {
vec![
0.0,
PI / 3.0,
2.0 * PI / 3.0,
PI,
4.0 * PI / 3.0,
5.0 * PI / 3.0,
]
}
_ => {
(0..12).map(|i| i as f32 * PI / 6.0).collect()
}
}
}
pub fn compute_dihedral(coords: &DMatrix<f32>, i: usize, j: usize, k: usize, l: usize) -> f32 {
let b1 = nalgebra::Vector3::new(
coords[(j, 0)] - coords[(i, 0)],
coords[(j, 1)] - coords[(i, 1)],
coords[(j, 2)] - coords[(i, 2)],
);
let b2 = nalgebra::Vector3::new(
coords[(k, 0)] - coords[(j, 0)],
coords[(k, 1)] - coords[(j, 1)],
coords[(k, 2)] - coords[(j, 2)],
);
let b3 = nalgebra::Vector3::new(
coords[(l, 0)] - coords[(k, 0)],
coords[(l, 1)] - coords[(k, 1)],
coords[(l, 2)] - coords[(k, 2)],
);
let n1 = b1.cross(&b2).normalize();
let n2 = b2.cross(&b3).normalize();
let m1 = n1.cross(&b2.normalize());
let x = n1.dot(&n2);
let y = m1.dot(&n2);
y.atan2(x)
}
pub fn rotate_atoms(coords: &mut DMatrix<f32>, mobile: &[usize], j: usize, k: usize, angle: f32) {
if angle.abs() < 1e-8 {
return;
}
let axis = nalgebra::Vector3::new(
coords[(k, 0)] - coords[(j, 0)],
coords[(k, 1)] - coords[(j, 1)],
coords[(k, 2)] - coords[(j, 2)],
);
let axis_len = axis.norm();
if axis_len < 1e-8 {
return;
}
let axis = axis / axis_len;
let cos_a = angle.cos();
let sin_a = angle.sin();
let px = coords[(j, 0)];
let py = coords[(j, 1)];
let pz = coords[(j, 2)];
for &idx in mobile {
let vx = coords[(idx, 0)] - px;
let vy = coords[(idx, 1)] - py;
let vz = coords[(idx, 2)] - pz;
let dot = axis[0] * vx + axis[1] * vy + axis[2] * vz;
let cx = axis[1] * vz - axis[2] * vy;
let cy = axis[2] * vx - axis[0] * vz;
let cz = axis[0] * vy - axis[1] * vx;
coords[(idx, 0)] = px + vx * cos_a + cx * sin_a + axis[0] * dot * (1.0 - cos_a);
coords[(idx, 1)] = py + vy * cos_a + cy * sin_a + axis[1] * dot * (1.0 - cos_a);
coords[(idx, 2)] = pz + vz * cos_a + cz * sin_a + axis[2] * dot * (1.0 - cos_a);
}
}
pub fn snap_torsions_to_preferred(
coords: &mut DMatrix<f32>,
mol: &crate::graph::Molecule,
) -> usize {
let rotatable = find_rotatable_bonds(mol);
let num_rotatable = rotatable.len();
for rb in &rotatable {
let [a, b, c, d] = rb.dihedral;
let current = compute_dihedral(coords, a, b, c, d);
let mut best_delta_abs = f32::MAX;
let mut best_rotation = 0.0f32;
for &target in &rb.preferred_angles {
let mut delta = target - current;
delta = (delta + std::f32::consts::PI).rem_euclid(2.0 * std::f32::consts::PI)
- std::f32::consts::PI;
if delta.abs() < best_delta_abs {
best_delta_abs = delta.abs();
best_rotation = delta;
}
}
if best_rotation.abs() > 0.05 {
rotate_atoms(coords, &rb.mobile_atoms, b, c, best_rotation);
}
}
num_rotatable
}
pub fn optimize_torsions_greedy(
coords: &mut DMatrix<f32>,
mol: &crate::graph::Molecule,
bounds: &DMatrix<f64>,
passes: usize,
) -> usize {
let rotatable = find_rotatable_bonds(mol);
let num_rotatable = rotatable.len();
if rotatable.is_empty() {
return 0;
}
let params = super::energy::FFParams {
kb: 300.0,
k_theta: 200.0,
k_omega: 10.0,
k_oop: 20.0,
k_bounds: 100.0,
k_chiral: 0.0,
k_vdw: 0.0,
};
for _pass in 0..passes {
for rb in &rotatable {
let [a, b, c, d] = rb.dihedral;
let current_angle = compute_dihedral(coords, a, b, c, d);
let current_energy =
super::energy::calculate_total_energy(coords, mol, ¶ms, bounds);
let mut best_energy = current_energy;
let mut best_rotation = 0.0f32;
for &target_angle in &rb.preferred_angles {
let delta = target_angle - current_angle;
let delta = ((delta + std::f32::consts::PI) % (2.0 * std::f32::consts::PI))
- std::f32::consts::PI;
rotate_atoms(coords, &rb.mobile_atoms, b, c, delta);
let e = super::energy::calculate_total_energy(coords, mol, ¶ms, bounds);
if e < best_energy {
best_energy = e;
best_rotation = delta;
}
rotate_atoms(coords, &rb.mobile_atoms, b, c, -delta);
}
if best_rotation.abs() > 1e-6 {
rotate_atoms(coords, &rb.mobile_atoms, b, c, best_rotation);
}
}
}
num_rotatable
}
pub fn optimize_torsions_bounds(
coords: &mut DMatrix<f32>,
mol: &crate::graph::Molecule,
bounds: &DMatrix<f64>,
passes: usize,
) -> usize {
let rotatable = find_rotatable_bonds(mol);
let num_rotatable = rotatable.len();
if rotatable.is_empty() {
return 0;
}
for _pass in 0..passes {
for rb in &rotatable {
let [a, b, c, d] = rb.dihedral;
let current_angle = compute_dihedral(coords, a, b, c, d);
let current_energy = super::bounds_ff::bounds_violation_energy(coords, bounds);
let mut best_energy = current_energy;
let mut best_rotation = 0.0f32;
for &target_angle in &rb.preferred_angles {
let delta = target_angle - current_angle;
let delta = ((delta + std::f32::consts::PI) % (2.0 * std::f32::consts::PI))
- std::f32::consts::PI;
rotate_atoms(coords, &rb.mobile_atoms, b, c, delta);
let e = super::bounds_ff::bounds_violation_energy(coords, bounds);
if e < best_energy {
best_energy = e;
best_rotation = delta;
}
rotate_atoms(coords, &rb.mobile_atoms, b, c, -delta);
}
if best_rotation.abs() > 1e-6 {
rotate_atoms(coords, &rb.mobile_atoms, b, c, best_rotation);
}
}
}
num_rotatable
}
pub fn optimize_torsions_monte_carlo_bounds(
coords: &mut DMatrix<f32>,
mol: &crate::graph::Molecule,
bounds: &DMatrix<f64>,
seed: u64,
n_steps: usize,
temperature: f32,
) -> usize {
let rotatable = find_rotatable_bonds(mol);
let num_rotatable = rotatable.len();
if rotatable.is_empty() || n_steps == 0 {
return num_rotatable;
}
let mut rng = StdRng::seed_from_u64(seed);
let temp = temperature.max(1e-6);
let two_pi = 2.0 * std::f32::consts::PI;
let mut current_energy = super::bounds_ff::bounds_violation_energy(coords, bounds);
for _ in 0..n_steps {
let rb = &rotatable[rng.gen_range(0..rotatable.len())];
let [a, b, c, d] = rb.dihedral;
let current_angle = compute_dihedral(coords, a, b, c, d);
let target = rng.gen_range(-std::f32::consts::PI..std::f32::consts::PI);
let mut delta = target - current_angle;
delta = (delta + std::f32::consts::PI).rem_euclid(two_pi) - std::f32::consts::PI;
rotate_atoms(coords, &rb.mobile_atoms, b, c, delta);
let trial_energy = super::bounds_ff::bounds_violation_energy(coords, bounds);
let d_e = trial_energy - current_energy;
let accept = if d_e <= 0.0 {
true
} else {
let p_accept = (-d_e / temp).exp();
rng.gen::<f32>() < p_accept
};
if accept {
current_energy = trial_energy;
} else {
rotate_atoms(coords, &rb.mobile_atoms, b, c, -delta);
}
}
num_rotatable
}
pub fn systematic_rotor_search(
smiles: &str,
coords: &[f64],
max_rotors: usize,
) -> Result<Vec<(Vec<f64>, f64)>, String> {
use std::f32::consts::PI;
let mol = crate::graph::Molecule::from_smiles(smiles)?;
let n_atoms = mol.graph.node_count();
if coords.len() != n_atoms * 3 {
return Err("coords length mismatch".to_string());
}
let rotatable = find_rotatable_bonds(&mol);
let n_rot = rotatable.len().min(max_rotors);
if n_rot == 0 {
let e = crate::compute_uff_energy(smiles, coords).unwrap_or(0.0);
return Ok(vec![(coords.to_vec(), e)]);
}
let ff = super::builder::build_uff_force_field(&mol);
let angles: Vec<f32> = (0..12).map(|i| i as f32 * PI / 6.0).collect();
let total: usize = 12usize.pow(n_rot as u32);
let base_matrix = flat_to_matrix_internal(coords, n_atoms);
let eval_combo = |combo_idx: usize| -> Option<(Vec<f64>, f64)> {
let mut matrix = base_matrix.clone();
let mut idx = combo_idx;
for r in 0..n_rot {
let angle_idx = idx % 12;
idx /= 12;
let rb = &rotatable[r];
let [a, b, c, d] = rb.dihedral;
let current = compute_dihedral(&matrix, a, b, c, d);
let target = angles[angle_idx];
let mut delta = target - current;
delta = (delta + PI).rem_euclid(2.0 * PI) - PI;
rotate_atoms(&mut matrix, &rb.mobile_atoms, b, c, delta);
}
let flat = matrix_to_flat(&matrix, n_atoms);
let mut grad = vec![0.0f64; n_atoms * 3];
let energy = ff.compute_system_energy_and_gradients(&flat, &mut grad);
if energy.is_finite() {
Some((flat, energy))
} else {
None
}
};
#[cfg(feature = "parallel")]
let mut results: Vec<(Vec<f64>, f64)> = {
use rayon::prelude::*;
(0..total).into_par_iter().filter_map(eval_combo).collect()
};
#[cfg(not(feature = "parallel"))]
let mut results: Vec<(Vec<f64>, f64)> = (0..total).filter_map(eval_combo).collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(results)
}
pub fn simulated_annealing_torsion_search(
smiles: &str,
coords: &[f64],
n_steps: usize,
t_start: f64,
t_end: f64,
seed: u64,
) -> Result<Vec<(Vec<f64>, f64)>, String> {
let mol = crate::graph::Molecule::from_smiles(smiles)?;
let n_atoms = mol.graph.node_count();
if coords.len() != n_atoms * 3 {
return Err("coords length mismatch".to_string());
}
let rotatable = find_rotatable_bonds(&mol);
if rotatable.is_empty() {
let e = crate::compute_uff_energy(smiles, coords).unwrap_or(0.0);
return Ok(vec![(coords.to_vec(), e)]);
}
let ff = super::builder::build_uff_force_field(&mol);
let mut rng = StdRng::seed_from_u64(seed);
let mut matrix = flat_to_matrix_internal(coords, n_atoms);
let flat = matrix_to_flat(&matrix, n_atoms);
let mut grad = vec![0.0f64; n_atoms * 3];
let mut current_energy = ff.compute_system_energy_and_gradients(&flat, &mut grad);
let mut best_coords = flat;
let mut best_energy = current_energy;
let cooling_rate = if n_steps > 1 {
(t_end / t_start).powf(1.0 / (n_steps - 1) as f64)
} else {
1.0
};
let mut collected = Vec::new();
let collect_interval = (n_steps / 50).max(1);
let mut temp = t_start;
for step in 0..n_steps {
let rb_idx = rng.gen_range(0..rotatable.len());
let rb = &rotatable[rb_idx];
let [a, b, c, d] = rb.dihedral;
let current_angle = compute_dihedral(&matrix, a, b, c, d);
let perturbation: f32 = rng.gen_range(-std::f32::consts::PI..std::f32::consts::PI);
let mut delta = perturbation - current_angle;
delta = (delta + std::f32::consts::PI).rem_euclid(2.0 * std::f32::consts::PI)
- std::f32::consts::PI;
rotate_atoms(&mut matrix, &rb.mobile_atoms, b, c, delta);
let trial_flat = matrix_to_flat(&matrix, n_atoms);
let trial_energy = ff.compute_system_energy_and_gradients(&trial_flat, &mut grad);
let d_e = trial_energy - current_energy;
let accept = if d_e <= 0.0 {
true
} else {
let p = (-d_e / temp).exp();
rng.gen::<f64>() < p
};
if accept && trial_energy.is_finite() {
current_energy = trial_energy;
if current_energy < best_energy {
best_energy = current_energy;
best_coords = trial_flat;
}
} else {
rotate_atoms(&mut matrix, &rb.mobile_atoms, b, c, -delta);
}
if step % collect_interval == 0 {
let snap = matrix_to_flat(&matrix, n_atoms);
collected.push((snap, current_energy));
}
temp *= cooling_rate;
}
collected.push((best_coords, best_energy));
collected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(collected)
}
pub fn torsional_sampling_diverse(
smiles: &str,
coords: &[f64],
rmsd_cutoff: f64,
seed: u64,
) -> Result<Vec<(Vec<f64>, f64)>, String> {
let mol = crate::graph::Molecule::from_smiles(smiles)?;
let rotatable = find_rotatable_bonds(&mol);
let n_rot = rotatable.len();
let conformers = if n_rot <= 4 {
systematic_rotor_search(smiles, coords, 4)?
} else {
simulated_annealing_torsion_search(smiles, coords, 500, 5.0, 0.1, seed)?
};
if conformers.len() <= 1 {
return Ok(conformers);
}
let coords_vecs: Vec<Vec<f64>> = conformers.iter().map(|(c, _)| c.clone()).collect();
let cluster_result = crate::clustering::butina_cluster(&coords_vecs, rmsd_cutoff);
let diverse: Vec<(Vec<f64>, f64)> = cluster_result
.centroid_indices
.iter()
.map(|&ci| conformers[ci].clone())
.collect();
Ok(diverse)
}
fn flat_to_matrix_internal(coords: &[f64], n_atoms: usize) -> DMatrix<f32> {
let mut m = DMatrix::<f32>::zeros(n_atoms, 3);
for i in 0..n_atoms {
m[(i, 0)] = coords[3 * i] as f32;
m[(i, 1)] = coords[3 * i + 1] as f32;
m[(i, 2)] = coords[3 * i + 2] as f32;
}
m
}
fn matrix_to_flat(matrix: &DMatrix<f32>, n_atoms: usize) -> Vec<f64> {
let mut flat = vec![0.0f64; n_atoms * 3];
for i in 0..n_atoms {
flat[3 * i] = matrix[(i, 0)] as f64;
flat[3 * i + 1] = matrix[(i, 1)] as f64;
flat[3 * i + 2] = matrix[(i, 2)] as f64;
}
flat
}
#[cfg(test)]
mod tests {
use super::*;
fn flat_to_matrix(coords: &[f64], n_atoms: usize) -> DMatrix<f32> {
let mut m = DMatrix::<f32>::zeros(n_atoms, 3);
for i in 0..n_atoms {
m[(i, 0)] = coords[3 * i] as f32;
m[(i, 1)] = coords[3 * i + 1] as f32;
m[(i, 2)] = coords[3 * i + 2] as f32;
}
m
}
#[test]
fn test_monte_carlo_torsion_optimizer_runs_for_butane() {
let smiles = "CCCC";
let mol = crate::graph::Molecule::from_smiles(smiles).unwrap();
let conf = crate::embed(smiles, 42);
assert!(conf.error.is_none());
let bounds =
crate::distgeom::smooth_bounds_matrix(crate::distgeom::calculate_bounds_matrix(&mol));
let mut coords = flat_to_matrix(&conf.coords, mol.graph.node_count());
let rot = optimize_torsions_monte_carlo_bounds(&mut coords, &mol, &bounds, 123, 64, 0.3);
assert!(rot >= 1);
assert!(coords.iter().all(|v| v.is_finite()));
}
}