use super::string::{grow_string, GsmConfig, GsmPath};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GsmResult {
pub ts_coords: Vec<f64>,
pub ts_energy: f64,
pub activation_energy: f64,
pub reverse_barrier: f64,
pub path_energies: Vec<f64>,
pub path_coords: Vec<Vec<f64>>,
pub ts_node_index: usize,
pub n_nodes: usize,
pub energy_evaluations: usize,
}
fn find_ts_candidate(path: &GsmPath) -> (usize, f64) {
let mut max_idx = 1;
let mut max_e = f64::NEG_INFINITY;
for i in 1..(path.nodes.len() - 1) {
if path.energies[i] > max_e {
max_e = path.energies[i];
max_idx = i;
}
}
(max_idx, max_e)
}
pub fn refine_saddle(
ts_coords: &[f64],
energy_fn: &dyn Fn(&[f64]) -> f64,
neighbors: (&[f64], &[f64]), max_steps: usize,
step_size: f64,
fd_step: f64,
) -> (Vec<f64>, f64) {
let n = ts_coords.len();
let mut coords = ts_coords.to_vec();
for _ in 0..max_steps {
let mut tangent: Vec<f64> = neighbors
.1
.iter()
.zip(neighbors.0)
.map(|(a, b)| a - b)
.collect();
let norm: f64 = tangent.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-15 {
for t in tangent.iter_mut() {
*t /= norm;
}
}
let mut grad = vec![0.0; n];
for i in 0..n {
let mut cp = coords.clone();
let mut cm = coords.clone();
cp[i] += fd_step;
cm[i] -= fd_step;
grad[i] = (energy_fn(&cp) - energy_fn(&cm)) / (2.0 * fd_step);
}
let dot: f64 = grad.iter().zip(&tangent).map(|(g, t)| g * t).sum();
let perp: Vec<f64> = grad
.iter()
.zip(&tangent)
.map(|(g, t)| g - dot * t)
.collect();
let perp_norm: f64 = perp.iter().map(|x| x * x).sum::<f64>().sqrt();
if perp_norm < 0.01 {
break;
}
for i in 0..n {
coords[i] -= step_size * perp[i];
}
}
let energy = energy_fn(&coords);
(coords, energy)
}
pub fn find_transition_state(
reactant: &[f64],
product: &[f64],
energy_fn: &dyn Fn(&[f64]) -> f64,
config: &GsmConfig,
) -> GsmResult {
let path = grow_string(reactant, product, energy_fn, config);
let (ts_idx, _ts_energy) = find_ts_candidate(&path);
let (prev, next) = if path.nodes.len() >= 3 {
let prev_idx = if ts_idx > 0 { ts_idx - 1 } else { 0 };
let next_idx = if ts_idx < path.nodes.len() - 1 {
ts_idx + 1
} else {
path.nodes.len() - 1
};
(&path.nodes[prev_idx][..], &path.nodes[next_idx][..])
} else {
(reactant, product)
};
let (ts_coords, ts_energy) = refine_saddle(
&path.nodes[ts_idx],
energy_fn,
(prev, next),
20,
config.step_size * 0.1,
config.fd_step,
);
let e_r = energy_fn(reactant);
let e_p = energy_fn(product);
let n_coords = reactant.len();
let energy_evals = path.iterations * path.nodes.len() * 2 * n_coords + 20 * 2 * n_coords;
GsmResult {
ts_coords,
ts_energy,
activation_energy: ts_energy - e_r,
reverse_barrier: ts_energy - e_p,
path_energies: path.energies.clone(),
path_coords: path.nodes.clone(),
ts_node_index: ts_idx,
n_nodes: path.nodes.len(),
energy_evaluations: energy_evals,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn double_well(coords: &[f64]) -> f64 {
let x = coords[0];
(x * x - 1.0).powi(2)
}
#[test]
fn test_find_ts_double_well() {
let reactant = vec![-1.0, 0.0, 0.0];
let product = vec![1.0, 0.0, 0.0];
let config = GsmConfig {
max_nodes: 9,
max_iter: 30,
step_size: 0.01,
..Default::default()
};
let result = find_transition_state(&reactant, &product, &double_well, &config);
assert!(
result.ts_energy >= 0.0,
"TS energy should be positive: {}",
result.ts_energy
);
assert!(
result.activation_energy > 0.0,
"Barrier should be positive: {}",
result.activation_energy
);
}
#[test]
fn test_gsm_path_endpoints() {
let reactant = vec![-1.0, 0.0, 0.0];
let product = vec![1.0, 0.0, 0.0];
let config = GsmConfig {
max_nodes: 7,
max_iter: 10,
..Default::default()
};
let result = find_transition_state(&reactant, &product, &double_well, &config);
let d_r = (result.path_coords[0][0] - reactant[0]).abs();
let d_p = (result.path_coords.last().unwrap()[0] - product[0]).abs();
assert!(d_r < 0.5, "First node should be near reactant: d={}", d_r);
assert!(d_p < 0.5, "Last node should be near product: d={}", d_p);
}
#[test]
fn test_gsm_result_consistency() {
let reactant = vec![0.0, 0.0, 0.0];
let product = vec![2.0, 0.0, 0.0];
let energy = |c: &[f64]| -> f64 {
let x = c[0];
(x - 1.0).powi(2) };
let config = GsmConfig {
max_nodes: 5,
max_iter: 10,
..Default::default()
};
let result = find_transition_state(&reactant, &product, &energy, &config);
assert_eq!(result.path_energies.len(), result.n_nodes);
assert_eq!(result.path_coords.len(), result.n_nodes);
assert!(result.ts_node_index < result.n_nodes);
}
}