use crate::forcefield::bounds_ff::ChiralSet;
use crate::graph::Molecule;
use nalgebra::{DMatrix, Vector3};
const MIN_TETRAHEDRAL_CHIRAL_VOL: f64 = 0.50;
const TETRAHEDRAL_CENTERINVOLUME_TOL: f64 = 0.30;
pub const MAX_MINIMIZED_E_PER_ATOM: f32 = 0.05;
pub struct TetrahedralCenter {
pub center: usize,
pub neighbors: [usize; 4],
pub in_small_ring: bool,
}
pub fn identify_tetrahedral_centers(mol: &Molecule) -> Vec<TetrahedralCenter> {
let n = mol.graph.node_count();
let rings = find_sssr(mol);
let mut ring_count = vec![0usize; n];
let mut in_3_ring = vec![false; n];
let mut small_ring_count = vec![0usize; n]; for ring in &rings {
for &atom_idx in ring {
ring_count[atom_idx] += 1;
if ring.len() == 3 {
in_3_ring[atom_idx] = true;
}
if ring.len() < 5 {
small_ring_count[atom_idx] += 1;
}
}
}
let mut centers = Vec::new();
for i in 0..n {
let ni = petgraph::graph::NodeIndex::new(i);
let atom = &mol.graph[ni];
let elem = atom.element;
if elem != 6 && elem != 7 {
continue;
}
let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
if nbs.len() != 4 {
continue;
}
if ring_count[i] < 2 || in_3_ring[i] {
continue;
}
centers.push(TetrahedralCenter {
center: i,
neighbors: [
nbs[0].index(),
nbs[1].index(),
nbs[2].index(),
nbs[3].index(),
],
in_small_ring: small_ring_count[i] > 1,
});
}
centers
}
pub fn find_sssr_pub(mol: &Molecule) -> Vec<Vec<usize>> {
find_sssr(mol)
}
fn find_sssr(mol: &Molecule) -> Vec<Vec<usize>> {
use std::collections::VecDeque;
let n = mol.graph.node_count();
if n == 0 {
return vec![];
}
let num_edges = mol.graph.edge_count();
let mut visited = vec![false; n];
let mut num_components = 0;
for start in 0..n {
if visited[start] {
continue;
}
num_components += 1;
let mut queue = VecDeque::new();
queue.push_back(start);
visited[start] = true;
while let Some(curr) = queue.pop_front() {
for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(curr)) {
if !visited[nb.index()] {
visited[nb.index()] = true;
queue.push_back(nb.index());
}
}
}
}
let cycle_rank = (num_edges + num_components).saturating_sub(n);
if cycle_rank == 0 {
return vec![];
}
let mut candidates: Vec<Vec<usize>> = Vec::new();
for root in 0..n {
let mut dist = vec![usize::MAX; n];
let mut parent = vec![usize::MAX; n];
dist[root] = 0;
let mut queue = VecDeque::new();
queue.push_back(root);
while let Some(curr) = queue.pop_front() {
for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(curr)) {
let nb_idx = nb.index();
if dist[nb_idx] == usize::MAX {
dist[nb_idx] = dist[curr] + 1;
parent[nb_idx] = curr;
queue.push_back(nb_idx);
}
}
}
for u in 0..n {
for nb in mol.graph.neighbors(petgraph::graph::NodeIndex::new(u)) {
let v = nb.index();
if u >= v {
continue;
} let ring_len = dist[u] + dist[v] + 1;
if ring_len > 8 {
continue;
} if dist[u] == usize::MAX || dist[v] == usize::MAX {
continue;
}
let path_u = trace_path(&parent, root, u);
let path_v = trace_path(&parent, root, v);
let mut ring = path_u.clone();
let mut path_v_rev: Vec<usize> = path_v.into_iter().rev().collect();
if !path_v_rev.is_empty() && !ring.is_empty() && path_v_rev.last() == ring.first() {
path_v_rev.pop(); }
ring.extend(path_v_rev);
let mut seen = std::collections::HashSet::new();
let is_simple = ring.iter().all(|&x| seen.insert(x));
if is_simple && ring.len() >= 3 {
let normalized = normalize_ring(&ring);
candidates.push(normalized);
}
}
}
}
candidates.sort();
candidates.dedup();
candidates.sort_by_key(|r| r.len());
let edge_sets: Vec<std::collections::HashSet<(usize, usize)>> =
candidates.iter().map(|r| ring_edges(r).collect()).collect();
let mut relevant = Vec::new();
for (i, ring) in candidates.iter().enumerate() {
let mut is_xor_of_smaller = false;
for j in 0..i {
if candidates[j].len() >= ring.len() {
continue;
}
for k in (j + 1)..i {
if candidates[k].len() >= ring.len() {
continue;
}
let sym_diff: std::collections::HashSet<(usize, usize)> = edge_sets[j]
.symmetric_difference(&edge_sets[k])
.copied()
.collect();
if sym_diff == edge_sets[i] {
is_xor_of_smaller = true;
break;
}
}
if is_xor_of_smaller {
break;
}
}
if !is_xor_of_smaller {
relevant.push(ring.clone());
}
}
relevant
}
fn trace_path(parent: &[usize], root: usize, target: usize) -> Vec<usize> {
let mut path = Vec::new();
let mut curr = target;
while curr != root && curr != usize::MAX {
path.push(curr);
curr = parent[curr];
}
if curr == root {
path.push(root);
}
path.reverse();
path
}
fn normalize_ring(ring: &[usize]) -> Vec<usize> {
if ring.is_empty() {
return vec![];
}
let min_pos = ring.iter().enumerate().min_by_key(|&(_, &v)| v).unwrap().0;
let n = ring.len();
let forward: Vec<usize> = (0..n).map(|i| ring[(min_pos + i) % n]).collect();
let backward: Vec<usize> = (0..n).map(|i| ring[(min_pos + n - i) % n]).collect();
forward.min(backward)
}
fn ring_edges(ring: &[usize]) -> impl Iterator<Item = (usize, usize)> + '_ {
let n = ring.len();
(0..n).map(move |i| {
let a = ring[i];
let b = ring[(i + 1) % n];
(a.min(b), a.max(b))
})
}
fn volume_test(
center: usize,
neighbors: &[usize; 4],
coords: &DMatrix<f64>,
relaxed: bool,
) -> bool {
let dim = coords.ncols().min(3);
let p0 = Vector3::new(
coords[(center, 0)],
coords[(center, 1)],
if dim >= 3 { coords[(center, 2)] } else { 0.0 },
);
let mut vecs = [Vector3::<f64>::zeros(); 4];
for (k, &nb) in neighbors.iter().enumerate() {
let pk = Vector3::new(
coords[(nb, 0)],
coords[(nb, 1)],
if dim >= 3 { coords[(nb, 2)] } else { 0.0 },
);
let v = p0 - pk; let norm = v.norm();
vecs[k] = if norm > 1e-8 { v / norm } else { v };
}
let vol_scale: f64 = if relaxed { 0.25 } else { 1.0 };
let threshold = vol_scale * MIN_TETRAHEDRAL_CHIRAL_VOL;
let combos: [(usize, usize, usize); 4] = [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)];
for (a, b, c) in combos {
let cross = vecs[a].cross(&vecs[b]);
let vol = cross.dot(&vecs[c]).abs();
if vol < threshold {
return false;
}
}
true
}
fn same_side(
v1: &Vector3<f64>,
v2: &Vector3<f64>,
v3: &Vector3<f64>,
v4: &Vector3<f64>,
p0: &Vector3<f64>,
tol: f64,
) -> bool {
let normal = (v2 - v1).cross(&(v3 - v1));
let d1 = normal.dot(&(v4 - v1));
let d2 = normal.dot(&(p0 - v1));
if d1.abs() < tol || d2.abs() < tol {
return false;
}
(d1 < 0.0) == (d2 < 0.0)
}
fn center_in_volume(
center: usize,
neighbors: &[usize; 4],
coords: &DMatrix<f64>,
tol: f64,
) -> bool {
let dim = coords.ncols().min(3);
let get_p3d = |idx: usize| -> Vector3<f64> {
Vector3::new(
coords[(idx, 0)],
coords[(idx, 1)],
if dim >= 3 { coords[(idx, 2)] } else { 0.0 },
)
};
let p0 = get_p3d(center);
let p = [
get_p3d(neighbors[0]),
get_p3d(neighbors[1]),
get_p3d(neighbors[2]),
get_p3d(neighbors[3]),
];
same_side(&p[0], &p[1], &p[2], &p[3], &p0, tol)
&& same_side(&p[1], &p[2], &p[3], &p[0], &p0, tol)
&& same_side(&p[2], &p[3], &p[0], &p[1], &p0, tol)
&& same_side(&p[3], &p[0], &p[1], &p[2], &p0, tol)
}
pub fn check_tetrahedral_centers(coords: &DMatrix<f64>, centers: &[TetrahedralCenter]) -> bool {
for tc in centers {
if !volume_test(tc.center, &tc.neighbors, coords, tc.in_small_ring) {
return false;
}
if !center_in_volume(
tc.center,
&tc.neighbors,
coords,
TETRAHEDRAL_CENTERINVOLUME_TOL,
) {
return false;
}
}
true
}
pub fn check_chiral_centers(coords: &DMatrix<f64>, chiral_sets: &[ChiralSet]) -> bool {
for cs in chiral_sets {
let vol = crate::distgeom::calc_chiral_volume_f64(
cs.neighbors[0],
cs.neighbors[1],
cs.neighbors[2],
cs.neighbors[3],
coords,
);
let lb = cs.lower_vol as f64;
let ub = cs.upper_vol as f64;
if lb > 0.0 && vol < lb && (vol / lb < 0.8 || have_opposite_sign(vol, lb)) {
return false;
}
if ub < 0.0 && vol > ub && (vol / ub < 0.8 || have_opposite_sign(vol, ub)) {
return false;
}
}
true
}
fn have_opposite_sign(a: f64, b: f64) -> bool {
(a < 0.0) != (b < 0.0)
}
pub fn check_planarity(mol: &Molecule, coords: &DMatrix<f32>, oop_k: f32, tolerance: f32) -> bool {
let n = mol.graph.node_count();
let mut n_impropers = 0usize;
let mut improper_energy = 0.0f32;
for i in 0..n {
let ni = petgraph::graph::NodeIndex::new(i);
if mol.graph[ni].hybridization != crate::graph::Hybridization::SP2 {
continue;
}
let nbs: Vec<_> = mol.graph.neighbors(ni).collect();
if nbs.len() != 3 {
continue;
}
n_impropers += 1;
let pc = Vector3::new(coords[(i, 0)], coords[(i, 1)], coords[(i, 2)]);
let p1 = Vector3::new(
coords[(nbs[0].index(), 0)],
coords[(nbs[0].index(), 1)],
coords[(nbs[0].index(), 2)],
);
let p2 = Vector3::new(
coords[(nbs[1].index(), 0)],
coords[(nbs[1].index(), 1)],
coords[(nbs[1].index(), 2)],
);
let p3 = Vector3::new(
coords[(nbs[2].index(), 0)],
coords[(nbs[2].index(), 1)],
coords[(nbs[2].index(), 2)],
);
let v1 = p1 - pc;
let v2 = p2 - pc;
let v3 = p3 - pc;
let vol = v1.dot(&v2.cross(&v3));
improper_energy += oop_k * vol * vol;
}
if n_impropers == 0 {
return true;
}
improper_energy <= n_impropers as f32 * tolerance
}
pub fn check_double_bond_geometry(mol: &Molecule, coords: &DMatrix<f64>) -> bool {
use petgraph::visit::EdgeRef;
for edge in mol.graph.edge_references() {
if mol.graph[edge.id()].order != crate::graph::BondOrder::Double {
continue;
}
let u = edge.source();
let v = edge.target();
let u_deg = mol.graph.neighbors(u).count();
if u_deg >= 2 {
for nb in mol.graph.neighbors(u) {
if nb == v {
continue;
}
if u_deg == 2 {
if let Some(eid) = mol.graph.find_edge(u, nb) {
if mol.graph[eid].order != crate::graph::BondOrder::Single {
continue;
}
}
}
if !check_linearity(nb.index(), u.index(), v.index(), coords) {
return false;
}
}
}
let v_deg = mol.graph.neighbors(v).count();
if v_deg >= 2 {
for nb in mol.graph.neighbors(v) {
if nb == u {
continue;
}
if v_deg == 2 {
if let Some(eid) = mol.graph.find_edge(v, nb) {
if mol.graph[eid].order != crate::graph::BondOrder::Single {
continue;
}
}
}
if !check_linearity(nb.index(), v.index(), u.index(), coords) {
return false;
}
}
}
}
true
}
fn check_linearity(a0: usize, a1: usize, a2: usize, coords: &DMatrix<f64>) -> bool {
let p0 = Vector3::new(coords[(a0, 0)], coords[(a0, 1)], coords[(a0, 2)]);
let p1 = Vector3::new(coords[(a1, 0)], coords[(a1, 1)], coords[(a1, 2)]);
let p2 = Vector3::new(coords[(a2, 0)], coords[(a2, 1)], coords[(a2, 2)]);
let mut v1 = p1 - p0;
let n1 = v1.norm();
if n1 < 1e-8 {
return true;
}
v1 /= n1;
let mut v2 = p1 - p2;
let n2 = v2.norm();
if n2 < 1e-8 {
return true;
}
v2 /= n2;
v1.dot(&v2) + 1.0 >= 1e-3
}
pub fn perturb_if_planar(coords: &mut DMatrix<f64>, rng: &mut crate::distgeom::MinstdRand) -> bool {
let n = coords.nrows();
if n < 4 || coords.ncols() < 3 {
return false;
}
let mut z_min = f64::INFINITY;
let mut z_max = f64::NEG_INFINITY;
for i in 0..n {
let z = coords[(i, 2)];
if z < z_min {
z_min = z;
}
if z > z_max {
z_max = z;
}
}
let z_spread = z_max - z_min;
let mut xy_max_spread = 0.0f64;
for d in 0..2 {
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for i in 0..n {
let v = coords[(i, d)];
if v < lo {
lo = v;
}
if v > hi {
hi = v;
}
}
xy_max_spread = xy_max_spread.max(hi - lo);
}
if xy_max_spread < 1e-8 {
return false;
}
if z_spread < 0.01 * xy_max_spread {
for i in 0..n {
coords[(i, 2)] += 0.3 * (rng.next_double() - 0.5);
}
return true;
}
false
}