use crate::distgeom::{
calculate_bounds_matrix_opts, check_chiral_centers, check_double_bond_geometry,
check_tetrahedral_centers, compute_initial_coords_rdkit, identify_chiral_sets,
identify_tetrahedral_centers, pick_rdkit_distances, triangle_smooth_tol, MinstdRand,
MAX_MINIMIZED_E_PER_ATOM,
};
use crate::forcefield::bounds_ff::minimize_bfgs_rdkit;
use crate::forcefield::etkdg_3d::{build_etkdg_3d_ff_with_torsions, minimize_etkdg_3d_bfgs};
use crate::graph::Molecule;
use nalgebra::DMatrix;
const BASIN_THRESH: f32 = 5.0;
const FORCE_TOL: f32 = 1e-3;
const PLANARITY_TOLERANCE: f32 = 1.0;
const ERROR_TOL: f64 = 1e-5;
#[derive(Debug, Clone)]
pub struct DistanceRestraint {
pub atom_i: usize,
pub atom_j: usize,
pub target_distance: f64,
pub force_constant: f64,
}
fn apply_restraints_to_bounds(bounds: &mut DMatrix<f64>, restraints: &[DistanceRestraint]) {
for r in restraints {
let i = r.atom_i;
let j = r.atom_j;
let (lo, hi) = if i > j { (i, j) } else { (j, i) };
let window = (0.1 / r.force_constant.sqrt()).max(0.05);
let new_lb = (r.target_distance - window).max(0.0);
let new_ub = r.target_distance + window;
if new_lb > bounds[(lo, hi)] {
bounds[(lo, hi)] = new_lb;
}
if new_ub < bounds[(hi, lo)] || bounds[(hi, lo)] == 0.0 {
bounds[(hi, lo)] = new_ub;
}
}
}
fn build_trivial_conformer(bounds: &DMatrix<f64>) -> DMatrix<f32> {
let n = bounds.nrows();
let mut coords = DMatrix::from_element(n, 3, 0.0f32);
if n == 2 {
let lower = bounds[(1, 0)].max(0.0);
let upper = bounds[(0, 1)].max(lower);
let distance = if upper > 0.0 {
if lower > 0.0 {
0.5 * (lower + upper)
} else {
upper
}
} else if lower > 0.0 {
lower
} else {
1.0
} as f32;
let half_distance = 0.5 * distance;
coords[(0, 0)] = -half_distance;
coords[(1, 0)] = half_distance;
}
coords
}
pub fn generate_3d_conformer_restrained(
mol: &Molecule,
seed: u64,
restraints: &[DistanceRestraint],
) -> Result<DMatrix<f32>, String> {
let n = mol.graph.node_count();
if n == 0 {
return Err("Empty molecule".to_string());
}
let csd_torsions = crate::smarts::match_experimental_torsions(mol);
let bounds = {
let raw = calculate_bounds_matrix_opts(mol, true);
let mut b = raw;
apply_restraints_to_bounds(&mut b, restraints);
if triangle_smooth_tol(&mut b, 0.0) {
b
} else {
let raw2 = calculate_bounds_matrix_opts(mol, false);
let mut b2 = raw2.clone();
apply_restraints_to_bounds(&mut b2, restraints);
if triangle_smooth_tol(&mut b2, 0.0) {
b2
} else {
let mut b3 = raw2;
apply_restraints_to_bounds(&mut b3, restraints);
triangle_smooth_tol(&mut b3, 0.05);
b3
}
}
};
if n <= 2 {
return Ok(build_trivial_conformer(&bounds));
}
let chiral_sets = identify_chiral_sets(mol);
let tetrahedral_centers = identify_tetrahedral_centers(mol);
let use_4d = !chiral_sets.is_empty();
let embed_dim = if use_4d { 4 } else { 3 };
let max_iterations = 10 * n;
let mut rng = MinstdRand::new(seed as u32);
let mut consecutive_embed_fails = 0u32;
let embed_fail_threshold = if n > 100 {
(n as u32 / 8).max(10)
} else {
(n as u32 / 4).max(20)
};
let mut random_coord_attempts = 0u32;
let max_random_coord_attempts = if n > 100 { 80u32 } else { 150u32 };
let bfgs_restart_limit = if n > 100 { 20 } else { 50 };
let mut energy_check_failures = 0u32;
let energy_relax_threshold = (max_iterations as f64 * 0.3) as u32;
for _iter in 0..max_iterations {
let use_random_coords = consecutive_embed_fails >= embed_fail_threshold
&& random_coord_attempts < max_random_coord_attempts;
let (mut coords, basin_thresh) = if use_random_coords {
random_coord_attempts += 1;
let box_size = 2.0 * (n as f64).cbrt().max(2.5);
let mut c = DMatrix::from_element(n, embed_dim, 0.0f64);
for i in 0..n {
for d in 0..embed_dim {
c[(i, d)] = box_size * (rng.next_double() - 0.5);
}
}
(c, 1e8f64)
} else {
let dists = pick_rdkit_distances(&mut rng, &bounds);
match compute_initial_coords_rdkit(&mut rng, &dists, embed_dim) {
Some(c) => {
consecutive_embed_fails = 0;
(c, BASIN_THRESH as f64)
}
None => {
consecutive_embed_fails += 1;
continue;
}
}
};
{
let bt = basin_thresh as f32;
let initial_energy =
compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
if initial_energy > ERROR_TOL {
let mut need_more = 1;
let mut restarts = 0;
while need_more != 0 && restarts < bfgs_restart_limit {
need_more = minimize_bfgs_rdkit(
&mut coords,
&bounds,
&chiral_sets,
400,
FORCE_TOL as f64,
bt,
0.1,
1.0,
);
restarts += 1;
}
}
}
let bt = basin_thresh as f32;
let energy = compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
let effective_e_thresh = if energy_check_failures >= energy_relax_threshold {
MAX_MINIMIZED_E_PER_ATOM as f64 * 2.5
} else {
MAX_MINIMIZED_E_PER_ATOM as f64
};
if energy / n as f64 >= effective_e_thresh {
energy_check_failures += 1;
continue;
}
if !check_tetrahedral_centers(&coords, &tetrahedral_centers) {
continue;
}
if !chiral_sets.is_empty() && !check_chiral_centers(&coords, &chiral_sets) {
continue;
}
if use_4d {
let energy2 =
compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 1.0, 0.2);
if energy2 > ERROR_TOL {
let mut need_more = 1;
let mut restarts = 0;
while need_more != 0 && restarts < bfgs_restart_limit {
need_more = minimize_bfgs_rdkit(
&mut coords,
&bounds,
&chiral_sets,
200,
FORCE_TOL as f64,
bt,
1.0,
0.2,
);
restarts += 1;
}
}
}
let coords3d = coords.columns(0, 3).into_owned();
let ff = build_etkdg_3d_ff_with_torsions(mol, &coords3d, &bounds, &csd_torsions);
let e3d = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(
&{
let mut flat = vec![0.0f64; n * 3];
for a in 0..n {
flat[a * 3] = coords3d[(a, 0)];
flat[a * 3 + 1] = coords3d[(a, 1)];
flat[a * 3 + 2] = coords3d[(a, 2)];
}
flat
},
n,
mol,
&ff,
);
let refined = if e3d > ERROR_TOL {
minimize_etkdg_3d_bfgs(mol, &coords3d, &ff, 300, FORCE_TOL)
} else {
coords3d
};
{
let n_improper_atoms = ff.inversion_contribs.len() / 3;
let flat_f64: Vec<f64> = {
let nr = refined.nrows();
let mut flat = vec![0.0f64; nr * 3];
for a in 0..nr {
flat[a * 3] = refined[(a, 0)];
flat[a * 3 + 1] = refined[(a, 1)];
flat[a * 3 + 2] = refined[(a, 2)];
}
flat
};
let planarity_energy =
crate::forcefield::etkdg_3d::planarity_check_energy_f64(&flat_f64, n, &ff);
if planarity_energy > n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64 {
continue;
}
}
if !check_double_bond_geometry(mol, &refined) {
continue;
}
return Ok(refined.map(|v| v as f32));
}
Err(format!(
"Failed to generate restrained conformer after {} iterations",
max_iterations
))
}
pub fn generate_3d_conformer_from_smiles(smiles: &str, seed: u64) -> Result<DMatrix<f32>, String> {
let mol = Molecule::from_smiles(smiles)?;
generate_3d_conformer(&mol, seed)
}
pub fn generate_3d_conformer(mol: &Molecule, seed: u64) -> Result<DMatrix<f32>, String> {
let csd_torsions = crate::smarts::match_experimental_torsions(mol);
generate_3d_conformer_with_torsions(mol, seed, &csd_torsions)
}
pub fn generate_3d_conformer_best_of_k(
mol: &Molecule,
seed: u64,
csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
num_seeds: usize,
) -> Result<DMatrix<f32>, String> {
if num_seeds <= 1 {
return generate_3d_conformer_with_torsions(mol, seed, csd_torsions);
}
let mut best: Option<(DMatrix<f32>, f64)> = None;
let mut last_err = String::new();
let bounds = {
let raw = calculate_bounds_matrix_opts(mol, true);
let mut b = raw;
if triangle_smooth_tol(&mut b, 0.0) {
b
} else {
let raw2 = calculate_bounds_matrix_opts(mol, false);
let mut b2 = raw2.clone();
if triangle_smooth_tol(&mut b2, 0.0) {
b2
} else {
let mut b3 = raw2;
triangle_smooth_tol(&mut b3, 0.05);
b3
}
}
};
for k in 0..num_seeds {
let s = seed.wrapping_add(k as u64 * 1000);
match generate_3d_conformer_with_torsions(mol, s, csd_torsions) {
Ok(coords) => {
let n = mol.graph.node_count();
let coords_f64 = coords.map(|v| v as f64);
let ff = build_etkdg_3d_ff_with_torsions(mol, &coords_f64, &bounds, csd_torsions);
let mut flat = vec![0.0f64; n * 3];
for a in 0..n {
flat[a * 3] = coords[(a, 0)] as f64;
flat[a * 3 + 1] = coords[(a, 1)] as f64;
flat[a * 3 + 2] = coords[(a, 2)] as f64;
}
let energy = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(&flat, n, mol, &ff);
match &best {
Some((_, best_e)) if energy >= *best_e => {}
_ => {
best = Some((coords, energy));
}
}
}
Err(e) => {
last_err = e;
}
}
}
match best {
Some((coords, _)) => Ok(coords),
None => Err(last_err),
}
}
pub fn generate_3d_conformer_with_torsions(
mol: &Molecule,
seed: u64,
csd_torsions: &[crate::forcefield::etkdg_3d::M6TorsionContrib],
) -> Result<DMatrix<f32>, String> {
let n = mol.graph.node_count();
if n == 0 {
return Err("Empty molecule".to_string());
}
let bounds = {
let raw = calculate_bounds_matrix_opts(mol, true);
let mut b = raw;
if triangle_smooth_tol(&mut b, 0.0) {
b
} else {
#[cfg(test)]
eprintln!(" [FALLBACK] strict smoothing failed, retrying without set15");
let raw2 = calculate_bounds_matrix_opts(mol, false);
let mut b2 = raw2.clone();
if triangle_smooth_tol(&mut b2, 0.0) {
b2
} else {
#[cfg(test)]
eprintln!(" [FALLBACK] second smoothing also failed, using soft smooth");
let mut b3 = raw2;
triangle_smooth_tol(&mut b3, 0.05);
b3
}
}
};
if n <= 2 {
return Ok(build_trivial_conformer(&bounds));
}
let chiral_sets = identify_chiral_sets(mol);
let tetrahedral_centers = identify_tetrahedral_centers(mol);
let max_iterations = 10 * n;
let mut rng = MinstdRand::new(seed as u32);
let use_4d = !chiral_sets.is_empty();
let embed_dim = if use_4d { 4 } else { 3 };
let mut consecutive_embed_fails = 0u32;
let embed_fail_threshold = if n > 100 {
(n as u32 / 8).max(10)
} else {
(n as u32 / 4).max(20)
};
let mut random_coord_attempts = 0u32;
let max_random_coord_attempts = if n > 100 { 80u32 } else { 150u32 };
let bfgs_restart_limit = if n > 100 { 20 } else { 50 };
let mut energy_check_failures = 0u32;
let energy_relax_threshold = (max_iterations as f64 * 0.3) as u32;
for _iter in 0..max_iterations {
let _log_attempts = std::env::var("LOG_ATTEMPTS").is_ok();
let use_random_coords = consecutive_embed_fails >= embed_fail_threshold
&& random_coord_attempts < max_random_coord_attempts;
let (mut coords, basin_thresh) = if use_random_coords {
random_coord_attempts += 1;
let box_size = 2.0 * (n as f64).cbrt().max(2.5);
let mut c = DMatrix::from_element(n, embed_dim, 0.0f64);
for i in 0..n {
for d in 0..embed_dim {
c[(i, d)] = box_size * (rng.next_double() - 0.5);
}
}
(c, 1e8f64)
} else {
let dists = pick_rdkit_distances(&mut rng, &bounds);
let coords_opt = compute_initial_coords_rdkit(&mut rng, &dists, embed_dim);
match coords_opt {
Some(c) => {
consecutive_embed_fails = 0;
(c, BASIN_THRESH as f64)
}
None => {
consecutive_embed_fails += 1;
if random_coord_attempts >= max_random_coord_attempts {
break;
}
if _log_attempts && consecutive_embed_fails == embed_fail_threshold {
eprintln!(
" attempt {} → switching to random coords after {} failures",
_iter, embed_fail_threshold
);
} else if _log_attempts {
eprintln!(" attempt {} → embedding failed", _iter);
}
continue;
}
}
};
{
let bt = basin_thresh as f32;
let initial_energy =
compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
if initial_energy > ERROR_TOL {
let mut need_more = 1;
let mut restarts = 0;
while need_more != 0 && restarts < bfgs_restart_limit {
need_more = minimize_bfgs_rdkit(
&mut coords,
&bounds,
&chiral_sets,
400,
FORCE_TOL as f64,
bt,
0.1,
1.0,
);
restarts += 1;
}
}
}
let bt = basin_thresh as f32;
let energy = compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 0.1, 1.0);
let effective_e_thresh = if energy_check_failures >= energy_relax_threshold {
MAX_MINIMIZED_E_PER_ATOM as f64 * 2.5 } else {
MAX_MINIMIZED_E_PER_ATOM as f64
};
if energy / n as f64 >= effective_e_thresh {
energy_check_failures += 1;
if _log_attempts {
eprintln!(
" attempt {} → energy check failed: {:.6}/atom",
_iter,
energy / n as f64
);
}
continue;
}
if !check_tetrahedral_centers(&coords, &tetrahedral_centers) {
if _log_attempts {
eprintln!(" attempt {} → tetrahedral check failed", _iter);
}
continue;
}
if !chiral_sets.is_empty() && !check_chiral_centers(&coords, &chiral_sets) {
if _log_attempts {
eprintln!(" attempt {} → chiral check failed", _iter);
}
continue;
}
if use_4d {
let energy2 =
compute_total_bounds_energy_f64(&coords, &bounds, &chiral_sets, bt, 1.0, 0.2);
if energy2 > ERROR_TOL {
let mut need_more = 1;
let mut restarts = 0;
while need_more != 0 && restarts < bfgs_restart_limit {
need_more = minimize_bfgs_rdkit(
&mut coords,
&bounds,
&chiral_sets,
200,
FORCE_TOL as f64,
bt,
1.0,
0.2,
);
restarts += 1;
}
}
}
let coords3d = coords.columns(0, 3).into_owned();
let ff = build_etkdg_3d_ff_with_torsions(mol, &coords3d, &bounds, csd_torsions);
let e3d = crate::forcefield::etkdg_3d::etkdg_3d_energy_f64(
&{
let n = mol.graph.node_count();
let mut flat = vec![0.0f64; n * 3];
for a in 0..n {
flat[a * 3] = coords3d[(a, 0)];
flat[a * 3 + 1] = coords3d[(a, 1)];
flat[a * 3 + 2] = coords3d[(a, 2)];
}
flat
},
mol.graph.node_count(),
mol,
&ff,
);
let refined = if e3d > ERROR_TOL {
minimize_etkdg_3d_bfgs(mol, &coords3d, &ff, 300, FORCE_TOL)
} else {
coords3d
};
{
let n_improper_atoms = ff.inversion_contribs.len() / 3;
let flat_f64: Vec<f64> = {
let nr = refined.nrows();
let mut flat = vec![0.0f64; nr * 3];
for a in 0..nr {
flat[a * 3] = refined[(a, 0)];
flat[a * 3 + 1] = refined[(a, 1)];
flat[a * 3 + 2] = refined[(a, 2)];
}
flat
};
let planarity_energy =
crate::forcefield::etkdg_3d::planarity_check_energy_f64(&flat_f64, n, &ff);
if planarity_energy > n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64 {
if _log_attempts {
eprintln!(
" attempt {} → planarity check failed (energy={:.4} > threshold={:.4})",
_iter,
planarity_energy,
n_improper_atoms as f64 * PLANARITY_TOLERANCE as f64
);
}
continue;
}
}
if !check_double_bond_geometry(mol, &refined) {
if _log_attempts {
eprintln!(" attempt {} → double bond check failed", _iter);
}
continue;
}
if _log_attempts {
eprintln!(" attempt {} → SUCCESS", _iter);
}
let refined_f32 = refined.map(|v| v as f32);
return Ok(refined_f32);
}
Err(format!(
"Failed to generate valid conformer after {} iterations",
max_iterations
))
}
pub fn compute_total_bounds_energy_f64(
coords: &DMatrix<f64>,
bounds: &DMatrix<f64>,
chiral_sets: &[crate::forcefield::bounds_ff::ChiralSet],
basin_thresh: f32,
weight_4d: f32,
weight_chiral: f32,
) -> f64 {
let n = coords.nrows();
let dim_coords = coords.ncols();
let basin_thresh_f64 = basin_thresh as f64;
let weight_4d_f64 = weight_4d as f64;
let weight_chiral_f64 = weight_chiral as f64;
let mut energy = 0.0f64;
for i in 1..n {
for j in 0..i {
let ub = bounds[(j, i)];
let lb = bounds[(i, j)];
if ub - lb > basin_thresh_f64 {
continue;
}
let mut d2 = 0.0f64;
for d in 0..dim_coords {
let diff = coords[(i, d)] - coords[(j, d)];
d2 += diff * diff;
}
let ub2 = ub * ub;
let lb2 = lb * lb;
let val = if d2 > ub2 {
d2 / ub2 - 1.0
} else if d2 < lb2 {
2.0 * lb2 / (lb2 + d2) - 1.0
} else {
0.0
};
if val > 0.0 {
energy += val * val;
}
}
}
if !chiral_sets.is_empty() {
let mut flat = vec![0.0f64; n * dim_coords];
for i in 0..n {
for d in 0..dim_coords {
flat[i * dim_coords + d] = coords[(i, d)];
}
}
energy += weight_chiral_f64
* crate::forcefield::bounds_ff::chiral_violation_energy_f64(
&flat,
dim_coords,
chiral_sets,
);
}
if dim_coords == 4 {
for i in 0..n {
let x4 = coords[(i, 3)];
energy += weight_4d_f64 * x4 * x4;
}
}
energy
}
#[cfg(test)]
mod tests {
#[test]
fn embed_handles_hydrogen_halides() {
for smiles in ["F", "Cl"] {
let result = crate::embed(smiles, 42);
assert!(
result.error.is_none(),
"{smiles} embed failed: {:?}",
result.error
);
assert_eq!(result.num_atoms, 2, "{smiles} should expand to a diatomic");
assert_eq!(
result.coords.len(),
6,
"{smiles} should return 2 x 3 coordinates"
);
let dx = result.coords[3] - result.coords[0];
let dy = result.coords[4] - result.coords[1];
let dz = result.coords[5] - result.coords[2];
let distance = (dx * dx + dy * dy + dz * dz).sqrt();
assert!(distance > 0.5, "{smiles} distance should be positive");
}
}
}