use nalgebra::{DMatrix, DVector, SymmetricEigen};
use serde::{Deserialize, Serialize};
use super::params::{analyze_eht_support, EhtSupport};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EhtResult {
pub energies: Vec<f64>,
pub coefficients: Vec<Vec<f64>>,
pub n_electrons: usize,
pub homo_index: usize,
pub lumo_index: usize,
pub homo_energy: f64,
pub lumo_energy: f64,
pub gap: f64,
pub support: EhtSupport,
}
pub fn solve_generalized_eigenproblem(
h: &DMatrix<f64>,
s: &DMatrix<f64>,
) -> (DVector<f64>, DMatrix<f64>) {
let n = h.nrows();
let s_eigen = SymmetricEigen::new(s.clone());
let s_vals = &s_eigen.eigenvalues;
let s_vecs = &s_eigen.eigenvectors;
let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
for i in 0..n {
let val = s_vals[i];
if val > 1e-10 {
s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
}
}
let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
let h_eigen = SymmetricEigen::new(h_prime);
let energies = h_eigen.eigenvalues.clone();
let c_prime = h_eigen.eigenvectors.clone();
let c = &s_inv_sqrt * c_prime;
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
let mut sorted_energies = DVector::zeros(n);
let mut sorted_c = DMatrix::zeros(n, n);
for (new_idx, &old_idx) in indices.iter().enumerate() {
sorted_energies[new_idx] = energies[old_idx];
for row in 0..n {
sorted_c[(row, new_idx)] = c[(row, old_idx)];
}
}
(sorted_energies, sorted_c)
}
fn count_valence_electrons(elements: &[u8]) -> usize {
elements
.iter()
.map(|&z| match z {
1 => 1, 5 => 3, 6 => 4, 7 => 5, 8 => 6, 9 => 7, 14 => 4, 15 => 5, 16 => 6, 17 => 7, 35 => 7, 53 => 7, 21 | 39 => 3, 22 | 40 | 72 => 4, 23 | 41 | 73 => 5, 24 | 42 | 74 => 6, 25 | 43 | 75 => 7, 26 | 44 | 76 => 8, 27 | 45 | 77 => 9, 28 | 46 | 78 => 10, 29 | 47 | 79 => 11, 30 | 48 | 80 => 12, _ => 0,
})
.sum()
}
pub fn solve_eht(
elements: &[u8],
positions: &[[f64; 3]],
k: Option<f64>,
) -> Result<EhtResult, String> {
use super::basis::build_basis;
use super::hamiltonian::build_hamiltonian;
use super::overlap::build_overlap_matrix;
if elements.len() != positions.len() {
return Err("Element and position arrays must have equal length".to_string());
}
let support = analyze_eht_support(elements);
if !support.unsupported_elements.is_empty() {
return Err(support.warnings.join(" "));
}
let basis = build_basis(elements, positions);
if basis.is_empty() {
return Err("No valence orbitals found for given elements".to_string());
}
let s = build_overlap_matrix(&basis);
let h = build_hamiltonian(&basis, &s, k);
let (energies, c) = solve_generalized_eigenproblem(&h, &s);
let n_electrons = count_valence_electrons(elements);
let n_orbitals = basis.len();
let n_occupied = n_electrons.div_ceil(2); let homo_idx = if n_occupied > 0 && n_occupied <= n_orbitals {
n_occupied - 1
} else if n_orbitals > 0 {
0
} else {
return Err("No orbitals in EHT basis".to_string());
};
let lumo_idx = if n_occupied < n_orbitals {
n_occupied
} else {
homo_idx
};
let homo_energy = energies[homo_idx];
let lumo_energy = energies[lumo_idx];
let coefficients: Vec<Vec<f64>> = (0..n_orbitals)
.map(|row| (0..n_orbitals).map(|col| c[(row, col)]).collect())
.collect();
Ok(EhtResult {
energies: energies.iter().copied().collect(),
coefficients,
n_electrons,
homo_index: homo_idx,
lumo_index: lumo_idx,
homo_energy,
lumo_energy,
gap: lumo_energy - homo_energy,
support,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_h2_two_eigenvalues() {
let elements = [1u8, 1];
let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
let result = solve_eht(&elements, &positions, None).unwrap();
assert_eq!(result.energies.len(), 2);
assert!(result.energies[0] < result.energies[1]);
assert_eq!(result.homo_index, 0);
assert_eq!(result.lumo_index, 1);
assert!(result.gap > 0.0, "H2 HOMO-LUMO gap should be positive");
}
#[test]
fn test_h2_energies_sorted() {
let elements = [1u8, 1];
let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
let result = solve_eht(&elements, &positions, None).unwrap();
for i in 1..result.energies.len() {
assert!(
result.energies[i] >= result.energies[i - 1],
"Energies not sorted: E[{}]={} < E[{}]={}",
i,
result.energies[i],
i - 1,
result.energies[i - 1]
);
}
}
#[test]
fn test_h2_coefficients_shape() {
let elements = [1u8, 1];
let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
let result = solve_eht(&elements, &positions, None).unwrap();
assert_eq!(result.coefficients.len(), 2);
assert_eq!(result.coefficients[0].len(), 2);
}
#[test]
fn test_h2o_six_orbitals() {
let elements = [8u8, 1, 1];
let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
let result = solve_eht(&elements, &positions, None).unwrap();
assert_eq!(result.energies.len(), 6);
assert_eq!(result.n_electrons, 8);
assert_eq!(result.homo_index, 3);
assert_eq!(result.lumo_index, 4);
}
#[test]
fn test_h2o_gap_positive() {
let elements = [8u8, 1, 1];
let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
let result = solve_eht(&elements, &positions, None).unwrap();
assert!(
result.gap > 0.0,
"H2O HOMO-LUMO gap = {} should be > 0",
result.gap
);
}
#[test]
fn test_lowdin_preserves_orthogonality() {
use super::super::basis::build_basis;
use super::super::hamiltonian::build_hamiltonian;
use super::super::overlap::build_overlap_matrix;
let elements = [8u8, 1, 1];
let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
let basis = build_basis(&elements, &positions);
let s = build_overlap_matrix(&basis);
let h = build_hamiltonian(&basis, &s, None);
let (_, c) = solve_generalized_eigenproblem(&h, &s);
let ct_s_c = c.transpose() * &s * &c;
let n = ct_s_c.nrows();
for i in 0..n {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(ct_s_c[(i, j)] - expected).abs() < 1e-8,
"C^T S C[{},{}] = {}, expected {}",
i,
j,
ct_s_c[(i, j)],
expected,
);
}
}
}
#[test]
fn test_error_mismatched_arrays() {
let elements = [1u8, 1];
let positions = [[0.0, 0.0, 0.0]]; assert!(solve_eht(&elements, &positions, None).is_err());
}
#[test]
fn test_valence_electron_count() {
assert_eq!(count_valence_electrons(&[1, 1]), 2); assert_eq!(count_valence_electrons(&[8, 1, 1]), 8); assert_eq!(count_valence_electrons(&[6, 1, 1, 1, 1]), 8); assert_eq!(count_valence_electrons(&[7, 1, 1, 1]), 8); }
#[test]
fn test_valence_electron_count_transition_metals() {
assert_eq!(count_valence_electrons(&[21]), 3); assert_eq!(count_valence_electrons(&[22]), 4); assert_eq!(count_valence_electrons(&[26]), 8); assert_eq!(count_valence_electrons(&[28]), 10); assert_eq!(count_valence_electrons(&[29]), 11); assert_eq!(count_valence_electrons(&[30]), 12); assert_eq!(count_valence_electrons(&[39]), 3); assert_eq!(count_valence_electrons(&[46]), 10); assert_eq!(count_valence_electrons(&[47]), 11); assert_eq!(count_valence_electrons(&[48]), 12); assert_eq!(count_valence_electrons(&[72]), 4); assert_eq!(count_valence_electrons(&[73]), 5); assert_eq!(count_valence_electrons(&[74]), 6); assert_eq!(count_valence_electrons(&[76]), 8); assert_eq!(count_valence_electrons(&[77]), 9); assert_eq!(count_valence_electrons(&[78]), 10); assert_eq!(count_valence_electrons(&[79]), 11); assert_eq!(count_valence_electrons(&[80]), 12); }
#[test]
fn test_cisplatin_has_even_electron_count() {
let elements = [78u8, 17, 17, 7, 7, 1, 1, 1, 1, 1, 1];
assert_eq!(count_valence_electrons(&elements), 40);
}
#[test]
fn test_transition_metal_support_metadata() {
let elements = [26u8];
let positions = [[0.0, 0.0, 0.0]];
let result = solve_eht(&elements, &positions, None).unwrap();
assert!(result.support.has_transition_metals);
assert_eq!(result.support.provisional_elements, vec![26]);
assert!(!result.support.warnings.is_empty());
}
#[test]
fn test_unsupported_element_reports_capability_error() {
let elements = [118u8];
let positions = [[0.0, 0.0, 0.0]];
let error = solve_eht(&elements, &positions, None).unwrap_err();
assert!(error.contains("No EHT parameters are available"));
}
}