use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy)]
pub struct GasteigerParams {
pub a: f64,
pub b: f64,
pub c: f64,
}
impl GasteigerParams {
#[inline]
pub fn chi(&self, q: f64) -> f64 {
self.a + self.b * q + self.c * q * q
}
}
pub fn get_gasteiger_params(z: u8) -> Option<GasteigerParams> {
match z {
1 => Some(GasteigerParams {
a: 7.17,
b: 6.24,
c: -0.56,
}),
3 => Some(GasteigerParams {
a: 3.01,
b: 2.78,
c: 0.30,
}),
4 => Some(GasteigerParams {
a: 4.90,
b: 4.52,
c: 0.55,
}),
5 => Some(GasteigerParams {
a: 5.98,
b: 8.46,
c: 1.70,
}),
6 => Some(GasteigerParams {
a: 7.98,
b: 9.18,
c: 1.88,
}),
7 => Some(GasteigerParams {
a: 11.54,
b: 10.82,
c: 1.36,
}),
8 => Some(GasteigerParams {
a: 14.18,
b: 12.92,
c: 1.39,
}),
9 => Some(GasteigerParams {
a: 14.66,
b: 13.85,
c: 2.31,
}),
11 => Some(GasteigerParams {
a: 2.84,
b: 3.00,
c: 0.40,
}),
12 => Some(GasteigerParams {
a: 3.75,
b: 3.90,
c: 0.50,
}),
13 => Some(GasteigerParams {
a: 5.47,
b: 5.10,
c: 0.65,
}),
14 => Some(GasteigerParams {
a: 7.30,
b: 6.56,
c: 0.68,
}),
15 => Some(GasteigerParams {
a: 8.90,
b: 8.24,
c: 0.96,
}),
16 => Some(GasteigerParams {
a: 10.14,
b: 9.13,
c: 1.38,
}),
17 => Some(GasteigerParams {
a: 11.00,
b: 9.69,
c: 1.35,
}),
19 => Some(GasteigerParams {
a: 2.42,
b: 2.60,
c: 0.35,
}),
20 => Some(GasteigerParams {
a: 3.23,
b: 3.40,
c: 0.45,
}),
31 => Some(GasteigerParams {
a: 5.20,
b: 5.50,
c: 0.70,
}),
32 => Some(GasteigerParams {
a: 6.90,
b: 6.80,
c: 0.80,
}),
33 => Some(GasteigerParams {
a: 8.30,
b: 7.80,
c: 0.90,
}),
34 => Some(GasteigerParams {
a: 9.50,
b: 8.70,
c: 1.10,
}),
35 => Some(GasteigerParams {
a: 10.08,
b: 8.47,
c: 1.16,
}),
53 => Some(GasteigerParams {
a: 9.90,
b: 7.96,
c: 0.96,
}),
_ => None,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChargeResult {
pub charges: Vec<f64>,
pub iterations: usize,
pub total_charge: f64,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct GasteigerConfig {
pub max_iter: usize,
pub initial_damping: f64,
pub convergence_threshold: f64,
}
impl Default for GasteigerConfig {
fn default() -> Self {
GasteigerConfig {
max_iter: 6,
initial_damping: 0.5,
convergence_threshold: 1e-10,
}
}
}
pub fn gasteiger_marsili_charges(
elements: &[u8],
bonds: &[(usize, usize)],
formal_charges: &[i8],
max_iter: usize,
) -> Result<ChargeResult, String> {
let config = GasteigerConfig {
max_iter,
..Default::default()
};
gasteiger_marsili_charges_configured(elements, bonds, formal_charges, &config)
}
pub fn gasteiger_marsili_charges_configured(
elements: &[u8],
bonds: &[(usize, usize)],
formal_charges: &[i8],
config: &GasteigerConfig,
) -> Result<ChargeResult, String> {
let n = elements.len();
if formal_charges.len() != n {
return Err(format!(
"formal_charges length {} != elements length {}",
formal_charges.len(),
n
));
}
let params: Vec<GasteigerParams> = elements
.iter()
.map(|&z| {
get_gasteiger_params(z)
.ok_or_else(|| format!("No Gasteiger parameters for element Z={}", z))
})
.collect::<Result<Vec<_>, _>>()?;
let mut charges: Vec<f64> = formal_charges.iter().map(|&fc| fc as f64).collect();
let mut damping = config.initial_damping;
let mut actual_iters = 0;
let mut did_converge = false;
for _iter in 0..config.max_iter {
actual_iters += 1;
let chi: Vec<f64> = (0..n).map(|i| params[i].chi(charges[i])).collect();
let mut delta_q = vec![0.0f64; n];
for &(i, j) in bonds {
if i >= n || j >= n {
return Err(format!(
"Bond ({}, {}) references atom outside range 0..{}",
i, j, n
));
}
let chi_diff = chi[j] - chi[i];
let dq = if chi_diff > 0.0 {
let divisor = params[j].chi(1.0);
if divisor.abs() < 1e-6 {
0.0
} else {
damping * chi_diff / divisor.abs()
}
} else {
let divisor = params[i].chi(1.0);
if divisor.abs() < 1e-6 {
0.0
} else {
damping * chi_diff / divisor.abs()
}
};
delta_q[i] += dq;
delta_q[j] -= dq;
}
let mut max_delta = 0.0f64;
for i in 0..n {
charges[i] += delta_q[i];
max_delta = max_delta.max(delta_q[i].abs());
}
damping *= 0.5;
if max_delta < config.convergence_threshold {
did_converge = true;
break;
}
}
let total_charge: f64 = charges.iter().sum();
Ok(ChargeResult {
charges,
iterations: actual_iters,
total_charge,
converged: did_converge,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_h2_symmetric_charges() {
let elems = vec![1, 1];
let bonds = vec![(0, 1)];
let fc = vec![0, 0];
let result = gasteiger_marsili_charges(&elems, &bonds, &fc, 6).unwrap();
assert_eq!(result.charges.len(), 2);
assert!((result.charges[0] - result.charges[1]).abs() < 1e-6);
assert!(result.charges[0].abs() < 0.01);
}
#[test]
fn test_water_oxygen_negative() {
let elems = vec![8, 1, 1];
let bonds = vec![(0, 1), (0, 2)];
let fc = vec![0, 0, 0];
let result = gasteiger_marsili_charges(&elems, &bonds, &fc, 6).unwrap();
assert!(
result.charges[0] < -0.1,
"O charge should be negative: {}",
result.charges[0]
);
assert!(result.charges[1] > 0.0);
assert!(result.charges[2] > 0.0);
assert!(result.total_charge.abs() < 1e-6);
}
#[test]
fn test_methane_symmetric() {
let elems = vec![6, 1, 1, 1, 1];
let bonds = vec![(0, 1), (0, 2), (0, 3), (0, 4)];
let fc = vec![0, 0, 0, 0, 0];
let result = gasteiger_marsili_charges(&elems, &bonds, &fc, 6).unwrap();
let h_charges: Vec<f64> = result.charges[1..].to_vec();
for c in &h_charges {
assert!(
(c - h_charges[0]).abs() < 1e-10,
"H charges should be equal"
);
}
assert!(result.total_charge.abs() < 1e-6);
}
#[test]
fn test_co2_carbon_positive() {
let elems = vec![6, 8, 8];
let bonds = vec![(0, 1), (0, 2)];
let fc = vec![0, 0, 0];
let result = gasteiger_marsili_charges(&elems, &bonds, &fc, 6).unwrap();
assert!(
result.charges[0] > 0.1,
"C in CO₂ should be positive: {}",
result.charges[0]
);
assert!(result.charges[1] < -0.05);
assert!((result.charges[1] - result.charges[2]).abs() < 1e-10);
}
#[test]
fn test_hf_fluorine_negative() {
let elems = vec![1, 9];
let bonds = vec![(0, 1)];
let fc = vec![0, 0];
let result = gasteiger_marsili_charges(&elems, &bonds, &fc, 6).unwrap();
assert!(
result.charges[1] < -0.1,
"F should be very negative: {}",
result.charges[1]
);
assert!(result.charges[0] > 0.1);
assert!(result.total_charge.abs() < 1e-6);
}
#[test]
fn test_unsupported_element() {
let elems = vec![2]; let bonds = vec![];
let fc = vec![0];
let result = gasteiger_marsili_charges(&elems, &bonds, &fc, 6);
assert!(result.is_err());
}
#[test]
fn test_formal_charge_preserved() {
let elems = vec![7, 1, 1, 1, 1];
let bonds = vec![(0, 1), (0, 2), (0, 3), (0, 4)];
let fc = vec![1, 0, 0, 0, 0];
let result = gasteiger_marsili_charges(&elems, &bonds, &fc, 6).unwrap();
assert!(
(result.total_charge - 1.0).abs() < 0.1,
"Total charge should be ~+1: {}",
result.total_charge
);
}
#[test]
fn test_electronegativity_order() {
let h = get_gasteiger_params(1).unwrap().chi(0.0);
let c = get_gasteiger_params(6).unwrap().chi(0.0);
let n = get_gasteiger_params(7).unwrap().chi(0.0);
let o = get_gasteiger_params(8).unwrap().chi(0.0);
let f = get_gasteiger_params(9).unwrap().chi(0.0);
assert!(f > o);
assert!(o > n);
assert!(n > c);
assert!(c > h);
}
}