#![allow(
unused_imports,
unused_variables,
dead_code,
clippy::unnecessary_cast,
clippy::needless_range_loop,
clippy::manual_repeat_n,
clippy::manual_str_repeat,
clippy::manual_is_multiple_of,
clippy::redundant_field_names,
clippy::useless_vec,
clippy::single_range_in_vec_init
)]
use serde::Deserialize;
use std::fs;
#[derive(Deserialize)]
struct RefAtom {
element: u8,
x: f32,
y: f32,
z: f32,
formal_charge: i8,
hybridization: String,
}
#[derive(Deserialize)]
struct RefBond {
start: usize,
end: usize,
order: String,
}
#[derive(Deserialize)]
struct RefTorsion {
atoms: Vec<usize>,
v: Vec<f64>,
signs: Vec<i32>,
}
#[derive(Deserialize)]
struct RefMolecule {
smiles: String,
atoms: Vec<RefAtom>,
bonds: Vec<RefBond>,
torsions: Vec<RefTorsion>,
}
fn build_mol_from_ref(ref_mol: &RefMolecule) -> sci_form::graph::Molecule {
let mut mol = sci_form::graph::Molecule::new(&ref_mol.smiles);
let mut node_indices = Vec::with_capacity(ref_mol.atoms.len());
for atom in &ref_mol.atoms {
let hybridization = match atom.hybridization.as_str() {
"SP" => sci_form::graph::Hybridization::SP,
"SP2" => sci_form::graph::Hybridization::SP2,
"SP3" => sci_form::graph::Hybridization::SP3,
"SP3D" => sci_form::graph::Hybridization::SP3D,
"SP3D2" => sci_form::graph::Hybridization::SP3D2,
_ => sci_form::graph::Hybridization::Unknown,
};
let new_atom = sci_form::graph::Atom {
element: atom.element,
position: nalgebra::Vector3::zeros(),
charge: 0.0,
formal_charge: atom.formal_charge,
hybridization,
chiral_tag: sci_form::graph::ChiralType::Unspecified,
explicit_h: if atom.element == 1 || atom.element == 0 {
1
} else {
0
},
};
node_indices.push(mol.add_atom(new_atom));
}
for bond in &ref_mol.bonds {
let order = match bond.order.as_str() {
"DOUBLE" => sci_form::graph::BondOrder::Double,
"TRIPLE" => sci_form::graph::BondOrder::Triple,
"AROMATIC" => sci_form::graph::BondOrder::Aromatic,
_ => sci_form::graph::BondOrder::Single,
};
mol.add_bond(
node_indices[bond.start],
node_indices[bond.end],
sci_form::graph::Bond {
order,
stereo: sci_form::graph::BondStereo::None,
},
);
}
mol
}
fn build_csd_torsions(
ref_torsions: &[RefTorsion],
) -> Vec<sci_form::forcefield::etkdg_3d::M6TorsionContrib> {
ref_torsions
.iter()
.filter_map(|t| {
if t.atoms.len() < 4 || t.v.len() < 6 || t.signs.len() < 6 {
return None;
}
let mut signs = [0.0f64; 6];
let mut v = [0.0f64; 6];
for k in 0..6 {
signs[k] = t.signs[k] as f64;
v[k] = t.v[k] as f64;
}
Some(sci_form::forcefield::etkdg_3d::M6TorsionContrib {
i: t.atoms[0],
j: t.atoms[1],
k: t.atoms[2],
l: t.atoms[3],
signs,
v,
})
})
.collect()
}
#[test]
fn test_diagnose_failures() {
let ref_data =
sci_form::fixture_io::read_text_fixture("tests/fixtures/gdb20_reference_1k.json")
.expect("Run scripts/generate_gdb20_reference.py first");
let ref_mols: Vec<RefMolecule> = serde_json::from_str(&ref_data).unwrap();
let limit = std::env::var("GDB20_LIMIT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(ref_mols.len());
let ref_mols = &ref_mols[..limit.min(ref_mols.len())];
println!("\n=== FAILURE DIAGNOSIS ===");
let mut cases_analyzed = 0;
for ref_mol in ref_mols {
let mol = build_mol_from_ref(ref_mol);
let csd_torsions = build_csd_torsions(&ref_mol.torsions);
let result =
sci_form::conformer::generate_3d_conformer_with_torsions(&mol, 42, &csd_torsions);
let coords = match result {
Ok(c) => c,
Err(_) => continue,
};
let n = ref_mol.atoms.len();
let mut sq_sum = 0.0f64;
let mut npairs = 0u64;
for a in 0..n {
for b in (a + 1)..n {
let dr = ((ref_mol.atoms[a].x - ref_mol.atoms[b].x).powi(2)
+ (ref_mol.atoms[a].y - ref_mol.atoms[b].y).powi(2)
+ (ref_mol.atoms[a].z - ref_mol.atoms[b].z).powi(2))
.sqrt() as f64;
let du = ((coords[(a, 0)] - coords[(b, 0)]).powi(2)
+ (coords[(a, 1)] - coords[(b, 1)]).powi(2)
+ (coords[(a, 2)] - coords[(b, 2)]).powi(2))
.sqrt() as f64;
sq_sum += (dr - du).powi(2);
npairs += 1;
}
}
let rmsd = if npairs > 0 {
(sq_sum / npairs as f64).sqrt()
} else {
0.0
};
if rmsd < 0.5 {
continue;
}
if cases_analyzed >= 5 {
break;
}
cases_analyzed += 1;
println!("\n--- {} (RMSD={:.3}) ---", ref_mol.smiles, rmsd);
println!(
" Heavy atoms: {}",
ref_mol.atoms.iter().filter(|a| a.element != 1).count()
);
println!(" CSD torsions: {}", ref_mol.torsions.len());
let mut pair_diffs: Vec<(usize, usize, f64, f64, f64)> = Vec::new();
for a in 0..n {
for b in (a + 1)..n {
let dr = ((ref_mol.atoms[a].x - ref_mol.atoms[b].x).powi(2)
+ (ref_mol.atoms[a].y - ref_mol.atoms[b].y).powi(2)
+ (ref_mol.atoms[a].z - ref_mol.atoms[b].z).powi(2))
.sqrt() as f64;
let du = ((coords[(a, 0)] - coords[(b, 0)]).powi(2)
+ (coords[(a, 1)] - coords[(b, 1)]).powi(2)
+ (coords[(a, 2)] - coords[(b, 2)]).powi(2))
.sqrt() as f64;
pair_diffs.push((a, b, dr, du, (dr - du).abs()));
}
}
pair_diffs.sort_by(|a, b| b.4.partial_cmp(&a.4).unwrap());
println!(" Top 10 worst atom pairs:");
for &(a, b, dr, du, diff) in pair_diffs.iter().take(10) {
let ea = ref_mol.atoms[a].element;
let eb = ref_mol.atoms[b].element;
let ha = &ref_mol.atoms[a].hybridization;
let hb = &ref_mol.atoms[b].hybridization;
println!(
" ({:2},{:2}) e{}({})-e{}({}): ref={:.3} ours={:.3} Δ={:.3}",
a, b, ea, ha, eb, hb, dr, du, diff
);
}
let heavy_only: Vec<_> = pair_diffs
.iter()
.filter(|&&(a, b, _, _, _)| {
ref_mol.atoms[a].element != 1 && ref_mol.atoms[b].element != 1
})
.collect();
let heavy_sq_sum: f64 = heavy_only
.iter()
.map(|&&(_, _, dr, du, _)| (dr - du).powi(2))
.sum();
let heavy_rmsd = (heavy_sq_sum / heavy_only.len() as f64).sqrt();
println!(" Heavy-only RMSD: {:.3}", heavy_rmsd);
let mut atom_error = vec![0.0f64; n];
let mut atom_count = vec![0u32; n];
for &(a, b, _, _, diff) in &pair_diffs {
atom_error[a] += diff;
atom_count[a] += 1;
atom_error[b] += diff;
atom_count[b] += 1;
}
let mut avg_err: Vec<(usize, f64)> = (0..n)
.map(|i| {
(
i,
if atom_count[i] > 0 {
atom_error[i] / atom_count[i] as f64
} else {
0.0
},
)
})
.collect();
avg_err.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
println!(" Atoms with highest avg error:");
for &(idx, err) in avg_err.iter().take(8) {
let e = ref_mol.atoms[idx].element;
let h = &ref_mol.atoms[idx].hybridization;
println!(" atom {:2} e{}({}): avg_err={:.3}", idx, e, h, err);
}
}
}