use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GsmConfig {
pub max_nodes: usize,
pub step_size: f64,
pub grad_tol: f64,
pub max_iter: usize,
pub fd_step: f64,
}
impl Default for GsmConfig {
fn default() -> Self {
Self {
max_nodes: 11,
step_size: 0.1,
grad_tol: 0.5,
max_iter: 100,
fd_step: 0.005,
}
}
}
#[derive(Debug, Clone)]
pub struct GsmPath {
pub nodes: Vec<Vec<f64>>,
pub energies: Vec<f64>,
pub perp_grad_norms: Vec<f64>,
pub joined: bool,
pub iterations: usize,
}
pub fn interpolate_node(reactant: &[f64], product: &[f64], t: f64) -> Vec<f64> {
reactant
.iter()
.zip(product.iter())
.map(|(r, p)| r + t * (p - r))
.collect()
}
fn numerical_gradient(coords: &[f64], energy_fn: &dyn Fn(&[f64]) -> f64, h: f64) -> Vec<f64> {
let n = coords.len();
let mut grad = vec![0.0; n];
for i in 0..n {
let mut cp = coords.to_vec();
let mut cm = coords.to_vec();
cp[i] += h;
cm[i] -= h;
grad[i] = (energy_fn(&cp) - energy_fn(&cm)) / (2.0 * h);
}
grad
}
fn compute_tangent(nodes: &[Vec<f64>], idx: usize) -> Vec<f64> {
if idx == 0 {
let mut t: Vec<f64> = nodes[1].iter().zip(&nodes[0]).map(|(a, b)| a - b).collect();
normalize(&mut t);
t
} else if idx == nodes.len() - 1 {
let mut t: Vec<f64> = nodes[idx]
.iter()
.zip(&nodes[idx - 1])
.map(|(a, b)| a - b)
.collect();
normalize(&mut t);
t
} else {
let mut t: Vec<f64> = nodes[idx + 1]
.iter()
.zip(&nodes[idx - 1])
.map(|(a, b)| a - b)
.collect();
normalize(&mut t);
t
}
}
fn perpendicular_gradient(gradient: &[f64], tangent: &[f64]) -> Vec<f64> {
let dot: f64 = gradient.iter().zip(tangent).map(|(g, t)| g * t).sum();
gradient
.iter()
.zip(tangent)
.map(|(g, t)| g - dot * t)
.collect()
}
fn normalize(v: &mut [f64]) {
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > 1e-15 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
fn vec_norm(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
fn distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b)
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
pub fn grow_string(
reactant: &[f64],
product: &[f64],
energy_fn: &dyn Fn(&[f64]) -> f64,
config: &GsmConfig,
) -> GsmPath {
let n_coords = reactant.len();
let mid = interpolate_node(reactant, product, 0.5);
let mut nodes = vec![reactant.to_vec(), mid, product.to_vec()];
let mut joined = false;
let mut iterations = 0;
for it in 0..config.max_iter {
iterations = it + 1;
if nodes.len() >= config.max_nodes {
break;
}
if nodes.len() < config.max_nodes {
let idx = 1; let new = interpolate_node(&nodes[0], &nodes[idx], 0.5);
nodes.insert(1, new);
}
if nodes.len() < config.max_nodes {
let idx = nodes.len() - 2; let new = interpolate_node(&nodes[idx], &nodes[nodes.len() - 1], 0.5);
nodes.insert(nodes.len() - 1, new);
}
for i in 1..(nodes.len() - 1) {
let tangent = compute_tangent(&nodes, i);
let grad = numerical_gradient(&nodes[i], energy_fn, config.fd_step);
let perp = perpendicular_gradient(&grad, &tangent);
let step = config.step_size;
for k in 0..n_coords {
nodes[i][k] -= step * perp[k];
}
}
let mut all_close = true;
for i in 0..(nodes.len() - 1) {
let d = distance(&nodes[i], &nodes[i + 1]);
if d > config.step_size * 2.0 {
all_close = false;
break;
}
}
if all_close && nodes.len() >= 5 {
joined = true;
break;
}
let mut max_perp = 0.0;
for i in 1..(nodes.len() - 1) {
let tangent = compute_tangent(&nodes, i);
let grad = numerical_gradient(&nodes[i], energy_fn, config.fd_step);
let perp = perpendicular_gradient(&grad, &tangent);
let norm = vec_norm(&perp);
if norm > max_perp {
max_perp = norm;
}
}
if max_perp < config.grad_tol {
joined = true;
break;
}
}
let energies: Vec<f64> = nodes.iter().map(|n| energy_fn(n)).collect();
let mut perp_norms = Vec::new();
for i in 0..nodes.len() {
if i == 0 || i == nodes.len() - 1 {
perp_norms.push(0.0);
} else {
let tangent = compute_tangent(&nodes, i);
let grad = numerical_gradient(&nodes[i], energy_fn, config.fd_step);
let perp = perpendicular_gradient(&grad, &tangent);
perp_norms.push(vec_norm(&perp));
}
}
GsmPath {
nodes,
energies,
perp_grad_norms: perp_norms,
joined,
iterations,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interpolate_midpoint() {
let r = vec![0.0, 0.0, 0.0];
let p = vec![2.0, 2.0, 2.0];
let mid = interpolate_node(&r, &p, 0.5);
assert_eq!(mid, vec![1.0, 1.0, 1.0]);
}
#[test]
fn test_interpolate_endpoints() {
let r = vec![1.0, 2.0, 3.0];
let p = vec![4.0, 5.0, 6.0];
let r0 = interpolate_node(&r, &p, 0.0);
let p1 = interpolate_node(&r, &p, 1.0);
assert_eq!(r0, r);
assert_eq!(p1, p);
}
#[test]
fn test_perpendicular_gradient() {
let g = vec![1.0, 1.0, 0.0];
let t = vec![1.0, 0.0, 0.0]; let perp = perpendicular_gradient(&g, &t);
assert!((perp[0]).abs() < 1e-10);
assert!((perp[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_grow_string_basic() {
let energy_fn = |coords: &[f64]| -> f64 {
let x = coords[0];
(x * x - 1.0).powi(2)
};
let reactant = vec![-1.0, 0.0, 0.0];
let product = vec![1.0, 0.0, 0.0];
let config = GsmConfig {
max_nodes: 7,
max_iter: 20,
step_size: 0.01,
..Default::default()
};
let path = grow_string(&reactant, &product, &energy_fn, &config);
assert!(path.nodes.len() >= 3);
assert_eq!(path.energies.len(), path.nodes.len());
}
}