use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JCoupling {
pub h1_index: usize,
pub h2_index: usize,
pub j_hz: f64,
pub n_bonds: usize,
pub coupling_type: String,
}
const KARPLUS_A: f64 = 7.76;
const KARPLUS_B: f64 = -1.10;
const KARPLUS_C: f64 = 1.40;
#[derive(Debug, Clone, Copy)]
pub struct KarplusParams {
pub a: f64,
pub b: f64,
pub c: f64,
}
impl KarplusParams {
pub fn evaluate(&self, phi_rad: f64) -> f64 {
let cos_phi = phi_rad.cos();
self.a * cos_phi * cos_phi + self.b * cos_phi + self.c
}
}
fn get_karplus_params(x_elem: u8, y_elem: u8) -> KarplusParams {
match (x_elem, y_elem) {
(6, 6) => KarplusParams {
a: 7.76,
b: -1.10,
c: 1.40,
},
(6, 7) | (7, 6) => KarplusParams {
a: 6.40,
b: -1.40,
c: 1.90,
},
(6, 8) | (8, 6) => KarplusParams {
a: 5.80,
b: -1.20,
c: 1.50,
},
(6, 16) | (16, 6) => KarplusParams {
a: 6.00,
b: -1.00,
c: 1.30,
},
_ => KarplusParams {
a: KARPLUS_A,
b: KARPLUS_B,
c: KARPLUS_C,
},
}
}
fn dihedral_angle(p1: &[f64; 3], p2: &[f64; 3], p3: &[f64; 3], p4: &[f64; 3]) -> f64 {
let b1 = [p2[0] - p1[0], p2[1] - p1[1], p2[2] - p1[2]];
let b2 = [p3[0] - p2[0], p3[1] - p2[1], p3[2] - p2[2]];
let b3 = [p4[0] - p3[0], p4[1] - p3[1], p4[2] - p3[2]];
let n1 = cross(&b1, &b2);
let n2 = cross(&b2, &b3);
let m1 = cross(&n1, &normalize(&b2));
let x = dot(&n1, &n2);
let y = dot(&m1, &n2);
(-y).atan2(x)
}
fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
[
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
]
}
fn dot(a: &[f64; 3], b: &[f64; 3]) -> f64 {
a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
}
fn normalize(v: &[f64; 3]) -> [f64; 3] {
let len = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
if len < 1e-10 {
return [0.0, 0.0, 0.0];
}
[v[0] / len, v[1] / len, v[2] / len]
}
pub fn predict_j_couplings(mol: &crate::graph::Molecule, positions: &[[f64; 3]]) -> Vec<JCoupling> {
let n = mol.graph.node_count();
let has_3d = positions.len() == n;
let mut couplings = Vec::new();
let h_atoms: Vec<usize> = (0..n)
.filter(|&i| mol.graph[petgraph::graph::NodeIndex::new(i)].element == 1)
.collect();
for i in 0..h_atoms.len() {
for j in (i + 1)..h_atoms.len() {
let h1 = h_atoms[i];
let h2 = h_atoms[j];
let h1_idx = petgraph::graph::NodeIndex::new(h1);
let h2_idx = petgraph::graph::NodeIndex::new(h2);
let parent1: Vec<petgraph::graph::NodeIndex> = mol
.graph
.neighbors(h1_idx)
.filter(|n| mol.graph[*n].element != 1)
.collect();
let parent2: Vec<petgraph::graph::NodeIndex> = mol
.graph
.neighbors(h2_idx)
.filter(|n| mol.graph[*n].element != 1)
.collect();
if parent1.is_empty() || parent2.is_empty() {
continue;
}
let p1 = parent1[0];
let p2 = parent2[0];
if p1 == p2 {
let j_hz: f64 = -12.0; couplings.push(JCoupling {
h1_index: h1,
h2_index: h2,
j_hz: j_hz.abs(), n_bonds: 2,
coupling_type: "geminal_2J".to_string(),
});
} else if mol.graph.find_edge(p1, p2).is_some() {
let x_elem = mol.graph[p1].element;
let y_elem = mol.graph[p2].element;
let params = get_karplus_params(x_elem, y_elem);
let j_hz = if has_3d {
let phi = dihedral_angle(
&positions[h1],
&positions[p1.index()],
&positions[p2.index()],
&positions[h2],
);
params.evaluate(phi)
} else {
7.0
};
couplings.push(JCoupling {
h1_index: h1,
h2_index: h2,
j_hz,
n_bonds: 3,
coupling_type: format!(
"vicinal_3J_H-{}-{}-H",
element_symbol(x_elem),
element_symbol(y_elem)
),
});
} else {
let p1_neighbors: Vec<petgraph::graph::NodeIndex> = mol
.graph
.neighbors(p1)
.filter(|&nb| nb != h1_idx && mol.graph[nb].element != 1)
.collect();
for &mid in &p1_neighbors {
if mol.graph.find_edge(mid, p2).is_some() {
let mid_elem = mol.graph[mid].element;
let is_sp2_mid =
mol.graph[mid].hybridization == crate::graph::Hybridization::SP2;
let is_sp2_p1 =
mol.graph[p1].hybridization == crate::graph::Hybridization::SP2;
let is_sp2_p2 =
mol.graph[p2].hybridization == crate::graph::Hybridization::SP2;
let j_hz = if is_sp2_mid && (is_sp2_p1 || is_sp2_p2) {
2.0 } else {
1.0 };
couplings.push(JCoupling {
h1_index: h1,
h2_index: h2,
j_hz,
n_bonds: 4,
coupling_type: format!(
"long_range_4J_H-{}-{}-{}-H",
element_symbol(mol.graph[p1].element),
element_symbol(mid_elem),
element_symbol(mol.graph[p2].element)
),
});
break; }
}
if couplings
.last()
.is_none_or(|c| c.n_bonds != 4 || c.h1_index != h1 || c.h2_index != h2)
{
'five_bond: for &mid1 in &p1_neighbors {
let mid1_neighbors: Vec<petgraph::graph::NodeIndex> = mol
.graph
.neighbors(mid1)
.filter(|&nb| nb != p1 && nb != h1_idx && mol.graph[nb].element != 1)
.collect();
for &mid2 in &mid1_neighbors {
if mol.graph.find_edge(mid2, p2).is_some() {
let is_aromatic_path = [p1, mid1, mid2, p2].iter().all(|&n| {
mol.graph[n].hybridization == crate::graph::Hybridization::SP2
});
let j_hz = if is_aromatic_path {
0.7 } else {
0.3 };
couplings.push(JCoupling {
h1_index: h1,
h2_index: h2,
j_hz,
n_bonds: 5,
coupling_type: format!(
"long_range_5J_H-{}-{}-{}-{}-H",
element_symbol(mol.graph[p1].element),
element_symbol(mol.graph[mid1].element),
element_symbol(mol.graph[mid2].element),
element_symbol(mol.graph[p2].element)
),
});
break 'five_bond;
}
}
}
}
}
}
}
couplings
}
fn element_symbol(z: u8) -> &'static str {
match z {
6 => "C",
7 => "N",
8 => "O",
16 => "S",
_ => "X",
}
}
pub fn ensemble_averaged_j_couplings(
mol: &crate::graph::Molecule,
conformer_positions: &[Vec<[f64; 3]>],
energies_kcal: &[f64],
temperature_k: f64,
) -> Vec<JCoupling> {
if conformer_positions.is_empty() {
return Vec::new();
}
if conformer_positions.len() != energies_kcal.len() {
return predict_j_couplings(mol, &conformer_positions[0]);
}
const KB_KCAL: f64 = 0.001987204;
let beta = 1.0 / (KB_KCAL * temperature_k);
let e_min = energies_kcal.iter().cloned().fold(f64::INFINITY, f64::min);
let weights: Vec<f64> = energies_kcal
.iter()
.map(|&e| (-(e - e_min) * beta).exp())
.collect();
let weight_sum: f64 = weights.iter().sum();
if weight_sum < 1e-30 {
return predict_j_couplings(mol, &conformer_positions[0]);
}
let all_couplings: Vec<Vec<JCoupling>> = conformer_positions
.iter()
.map(|pos| predict_j_couplings(mol, pos))
.collect();
if all_couplings.is_empty() {
return Vec::new();
}
let n_couplings = all_couplings[0].len();
let mut averaged = all_couplings[0].clone();
for k in 0..n_couplings {
let mut weighted_j = 0.0;
for (conf_idx, couplings) in all_couplings.iter().enumerate() {
if k < couplings.len() {
weighted_j += couplings[k].j_hz * weights[conf_idx];
}
}
averaged[k].j_hz = weighted_j / weight_sum;
}
averaged
}
#[cfg(feature = "parallel")]
pub fn ensemble_averaged_j_couplings_parallel(
mol: &crate::graph::Molecule,
conformer_positions: &[Vec<[f64; 3]>],
energies_kcal: &[f64],
temperature_k: f64,
) -> Vec<JCoupling> {
use rayon::prelude::*;
if conformer_positions.is_empty() {
return Vec::new();
}
if conformer_positions.len() != energies_kcal.len() {
return predict_j_couplings(mol, &conformer_positions[0]);
}
const KB_KCAL: f64 = 0.001987204;
let beta = 1.0 / (KB_KCAL * temperature_k);
let e_min = energies_kcal.iter().cloned().fold(f64::INFINITY, f64::min);
let weights: Vec<f64> = energies_kcal
.iter()
.map(|&e| (-(e - e_min) * beta).exp())
.collect();
let weight_sum: f64 = weights.iter().sum();
if weight_sum < 1e-30 {
return predict_j_couplings(mol, &conformer_positions[0]);
}
let all_couplings: Vec<Vec<JCoupling>> = conformer_positions
.par_iter()
.map(|pos| predict_j_couplings(mol, pos))
.collect();
if all_couplings.is_empty() {
return Vec::new();
}
let n_couplings = all_couplings[0].len();
let mut averaged = all_couplings[0].clone();
for k in 0..n_couplings {
let mut weighted_j = 0.0;
for (conf_idx, couplings) in all_couplings.iter().enumerate() {
if k < couplings.len() {
weighted_j += couplings[k].j_hz * weights[conf_idx];
}
}
averaged[k].j_hz = weighted_j / weight_sum;
}
averaged
}
#[cfg(test)]
mod tests {
use super::*;
fn karplus_3j(phi_rad: f64) -> f64 {
let cos_phi = phi_rad.cos();
KARPLUS_A * cos_phi * cos_phi + KARPLUS_B * cos_phi + KARPLUS_C
}
#[test]
fn test_karplus_equation_values() {
let j_0 = karplus_3j(0.0);
assert!(
j_0 > 8.0 && j_0 < 10.0,
"³J(0°) = {} Hz, expected ~9 Hz",
j_0
);
let j_90 = karplus_3j(std::f64::consts::FRAC_PI_2);
assert!(
j_90 > 0.0 && j_90 < 3.0,
"³J(90°) = {} Hz, expected ~1.4 Hz",
j_90
);
let j_180 = karplus_3j(std::f64::consts::PI);
assert!(
j_180 > 6.0 && j_180 < 12.0,
"³J(180°) = {} Hz, expected ~10 Hz",
j_180
);
}
#[test]
fn test_dihedral_angle_basic() {
let p1 = [1.0, 0.0, 0.0];
let p2 = [0.0, 0.0, 0.0];
let p3 = [0.0, 1.0, 0.0];
let p4 = [-1.0, 1.0, 0.0];
let angle = dihedral_angle(&p1, &p2, &p3, &p4);
assert!(angle.abs() < 0.1 || (angle.abs() - std::f64::consts::PI).abs() < 0.1);
}
#[test]
fn test_ethane_j_couplings() {
let mol = crate::graph::Molecule::from_smiles("CC").unwrap();
let couplings = predict_j_couplings(&mol, &[]);
assert!(
!couplings.is_empty(),
"Ethane should have J-coupling predictions"
);
let vicinal: Vec<&JCoupling> = couplings.iter().filter(|c| c.n_bonds == 3).collect();
assert!(
!vicinal.is_empty(),
"Ethane should have ³J vicinal couplings"
);
for c in &vicinal {
assert!(
c.coupling_type.contains("vicinal_3J"),
"Coupling type should be vicinal_3J, got {}",
c.coupling_type
);
}
}
#[test]
fn test_karplus_pathway_specific() {
let cc_params = get_karplus_params(6, 6);
let cn_params = get_karplus_params(6, 7);
let j_cc = cc_params.evaluate(0.0);
let j_cn = cn_params.evaluate(0.0);
assert!(
(j_cc - j_cn).abs() > 0.1,
"H-C-C-H and H-C-N-H should have different J at φ=0: {} vs {}",
j_cc,
j_cn
);
}
#[test]
fn test_ensemble_averaging() {
let mol = crate::graph::Molecule::from_smiles("CC").unwrap();
let n = mol.graph.node_count();
let positions = vec![[0.0, 0.0, 0.0]; n];
let result = ensemble_averaged_j_couplings(&mol, &[positions], &[0.0], 298.15);
assert!(!result.is_empty());
}
#[test]
fn test_methane_j_couplings() {
let mol = crate::graph::Molecule::from_smiles("C").unwrap();
let couplings = predict_j_couplings(&mol, &[]);
for c in &couplings {
assert_eq!(c.n_bonds, 2, "Methane H-H should be 2-bond (geminal)");
}
}
}