use crate::error::{IntegrateError, IntegrateResult};
use crate::specialized::quantum::gaussian_integrals::build_overlap_matrix;
use crate::specialized::quantum::gaussian_integrals::GaussianBasis;
use crate::specialized::quantum::tdhf::eri::get_eri;
use crate::specialized::quantum::tdhf::scf::{jacobi_diag, HartreeFockSCF, ScfResult};
use scirs2_core::ndarray::{Array1, Array2};
#[derive(Debug, Clone)]
pub struct CasidaConfig {
pub n_roots: usize,
}
impl Default for CasidaConfig {
fn default() -> Self {
Self { n_roots: 5 }
}
}
#[derive(Debug, Clone)]
pub struct CasidaResult {
pub excitation_energies: Vec<f64>,
pub oscillator_strengths: Vec<f64>,
pub transition_dipoles: Vec<[f64; 3]>,
}
pub struct CasidaSolver {
config: CasidaConfig,
}
impl CasidaSolver {
pub fn new(config: CasidaConfig) -> Self {
Self { config }
}
pub fn solve(
&self,
scf_result: &ScfResult,
basis: &[GaussianBasis],
eri: &[f64],
n_electrons: usize,
) -> IntegrateResult<CasidaResult> {
let n_basis = basis.len();
let n_occ = n_electrons / 2;
let n_virt = n_basis.saturating_sub(n_occ);
if n_virt == 0 {
return Err(IntegrateError::InvalidInput(
"No virtual orbitals available for Casida calculation".to_string(),
));
}
let n_singles = n_occ * n_virt;
if n_singles == 0 {
return Err(IntegrateError::InvalidInput(
"Zero single-excitation space".to_string(),
));
}
let overlap = build_overlap_matrix(basis);
let dipole_ao = build_dipole_ao(basis, &overlap);
let a_mat = Self::build_a_matrix(scf_result, eri, n_basis, n_occ)?;
let (eigs, vecs) = HartreeFockSCF::diagonalize_symmetric(&a_mat)?;
let n_physical_roots = n_singles;
let n_roots_out = self.config.n_roots;
let n_roots_calc = n_roots_out.min(n_physical_roots);
let mut excitation_energies = Vec::with_capacity(n_roots_out);
let mut oscillator_strengths = Vec::with_capacity(n_roots_out);
let mut transition_dipoles = Vec::with_capacity(n_roots_out);
let mut valid_roots: Vec<(f64, usize)> = eigs
.iter()
.enumerate()
.filter(|(_, &e)| e > 0.0)
.map(|(i, &e)| (e, i))
.collect();
valid_roots.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
for (omega, root_idx) in valid_roots.iter().take(n_roots_calc) {
excitation_energies.push(*omega);
let x_vec: Vec<f64> = (0..n_singles).map(|s| vecs[[s, *root_idx]]).collect();
let tdip =
compute_transition_dipole(&x_vec, n_occ, n_virt, n_basis, &dipole_ao, scf_result);
let mu_sq = tdip[0].powi(2) + tdip[1].powi(2) + tdip[2].powi(2);
let osc = (2.0 / 3.0) * omega * mu_sq;
oscillator_strengths.push(osc.max(0.0));
transition_dipoles.push(tdip);
}
while excitation_energies.len() < n_roots_out {
excitation_energies.push(0.0);
oscillator_strengths.push(0.0);
transition_dipoles.push([0.0; 3]);
}
Ok(CasidaResult {
excitation_energies,
oscillator_strengths,
transition_dipoles,
})
}
pub fn build_a_matrix(
scf: &ScfResult,
eri: &[f64],
n_basis: usize,
n_occ: usize,
) -> IntegrateResult<Array2<f64>> {
let n_virt = n_basis.saturating_sub(n_occ);
let n_singles = n_occ * n_virt;
let c = &scf.mo_coefficients;
let eps = &scf.orbital_energies;
let mut a = Array2::zeros((n_singles, n_singles));
let ia_idx = |i: usize, a: usize| i * n_virt + a;
for i in 0..n_occ {
for a_idx in 0..n_virt {
let a_mo = n_occ + a_idx; let row = ia_idx(i, a_idx);
a[[row, row]] = eps[a_mo] - eps[i];
for j in 0..n_occ {
for b_idx in 0..n_virt {
let b_mo = n_occ + b_idx;
let col = ia_idx(j, b_idx);
let iajb = mo_eri(eri, n_basis, c, i, a_mo, j, b_mo);
let ijab = mo_eri(eri, n_basis, c, i, j, a_mo, b_mo);
a[[row, col]] += iajb - ijab;
}
}
}
}
Ok(a)
}
}
fn mo_eri(
eri: &[f64],
n_basis: usize,
c: &Array2<f64>,
p: usize,
q: usize,
r: usize,
s: usize,
) -> f64 {
let mut val = 0.0;
for mu in 0..n_basis {
let cmu_p = c[[mu, p]];
if cmu_p.abs() < 1e-14 {
continue;
}
for nu in 0..n_basis {
let cnu_q = c[[nu, q]];
if cnu_q.abs() < 1e-14 {
continue;
}
let cmu_cnu = cmu_p * cnu_q;
for lam in 0..n_basis {
let clam_r = c[[lam, r]];
if clam_r.abs() < 1e-14 {
continue;
}
for sig in 0..n_basis {
let csig_s = c[[sig, s]];
if csig_s.abs() < 1e-14 {
continue;
}
val += cmu_cnu * get_eri(eri, n_basis, mu, nu, lam, sig) * clam_r * csig_s;
}
}
}
}
val
}
fn build_dipole_ao(basis: &[GaussianBasis], overlap: &Array2<f64>) -> [Array2<f64>; 3] {
let n = basis.len();
let mut rmat: [Array2<f64>; 3] = [
Array2::zeros((n, n)),
Array2::zeros((n, n)),
Array2::zeros((n, n)),
];
for mu in 0..n {
for nu in 0..n {
for alpha in 0..3 {
let centroid = 0.5 * (basis[mu].center[alpha] + basis[nu].center[alpha]);
rmat[alpha][[mu, nu]] = centroid * overlap[[mu, nu]];
}
}
}
rmat
}
fn compute_transition_dipole(
x_vec: &[f64],
n_occ: usize,
n_virt: usize,
n_basis: usize,
dipole_ao: &[Array2<f64>; 3],
scf: &ScfResult,
) -> [f64; 3] {
let c = &scf.mo_coefficients;
let mut tdip = [0.0_f64; 3];
for alpha in 0..3 {
let mut mu_alpha = 0.0;
for i in 0..n_occ {
for a_idx in 0..n_virt {
let a_mo = n_occ + a_idx;
let x_ia = x_vec[i * n_virt + a_idx];
if x_ia.abs() < 1e-15 {
continue;
}
let mut dp = 0.0;
for mu in 0..n_basis {
for nu in 0..n_basis {
dp += c[[mu, i]] * dipole_ao[alpha][[mu, nu]] * c[[nu, a_mo]];
}
}
mu_alpha += x_ia * dp;
}
}
tdip[alpha] = mu_alpha;
}
tdip
}
#[cfg(test)]
mod tests {
use super::*;
use crate::specialized::quantum::gaussian_integrals::normalized_s_gto;
use crate::specialized::quantum::tdhf::eri::build_eri_tensor;
use crate::specialized::quantum::tdhf::scf::{HartreeFockSCF, ScfConfig};
fn minimal_h2() -> (Vec<GaussianBasis>, Vec<(f64, [f64; 3])>) {
let basis = vec![
normalized_s_gto([0.0, 0.0, 0.0], 1.0),
normalized_s_gto([0.0, 0.0, 1.4], 1.0),
];
let charges = vec![
(1.0_f64, [0.0_f64, 0.0, 0.0]),
(1.0_f64, [0.0_f64, 0.0, 1.4]),
];
(basis, charges)
}
#[test]
fn test_casida_positive_excitation_energies() {
let (basis, charges) = minimal_h2();
let scf = HartreeFockSCF::new(ScfConfig::default())
.run(&basis, &charges, 2)
.unwrap();
let eri = build_eri_tensor(&basis);
let cfg = CasidaConfig { n_roots: 1 };
let result = CasidaSolver::new(cfg).solve(&scf, &basis, &eri, 2).unwrap();
for &e in &result.excitation_energies {
assert!(
e >= 0.0,
"Excitation energy should be non-negative, got {e}"
);
}
}
#[test]
fn test_casida_n_roots_respected() {
let basis = vec![
normalized_s_gto([0.0, 0.0, 0.0], 1.5),
normalized_s_gto([0.0, 0.0, 1.4], 0.8),
normalized_s_gto([0.0, 0.0, 2.8], 1.2),
];
let charges = vec![
(1.0_f64, [0.0_f64, 0.0, 0.0]),
(1.0_f64, [0.0_f64, 0.0, 1.4]),
];
let scf = HartreeFockSCF::new(ScfConfig::default())
.run(&basis, &charges, 2)
.unwrap();
let eri = build_eri_tensor(&basis);
let cfg = CasidaConfig { n_roots: 3 };
let result = CasidaSolver::new(cfg).solve(&scf, &basis, &eri, 2).unwrap();
assert_eq!(
result.excitation_energies.len(),
3,
"Expected 3 roots, got {}",
result.excitation_energies.len()
);
assert_eq!(result.oscillator_strengths.len(), 3);
assert_eq!(result.transition_dipoles.len(), 3);
}
#[test]
fn test_casida_oscillator_strength_nonneg() {
let (basis, charges) = minimal_h2();
let scf = HartreeFockSCF::new(ScfConfig::default())
.run(&basis, &charges, 2)
.unwrap();
let eri = build_eri_tensor(&basis);
let cfg = CasidaConfig { n_roots: 2 };
let result = CasidaSolver::new(cfg).solve(&scf, &basis, &eri, 2).unwrap();
for &f in &result.oscillator_strengths {
assert!(
f >= 0.0,
"Oscillator strength must be non-negative, got {f}"
);
}
}
#[test]
fn test_casida_a_matrix_diagonal_positive() {
let (basis, charges) = minimal_h2();
let scf = HartreeFockSCF::new(ScfConfig::default())
.run(&basis, &charges, 2)
.unwrap();
let eri = build_eri_tensor(&basis);
let n_basis = basis.len();
let n_occ = 1;
let a = CasidaSolver::build_a_matrix(&scf, &eri, n_basis, n_occ).unwrap();
let n_singles = n_occ * (n_basis - n_occ);
for i in 0..n_singles {
assert!(a[[i, i]] >= -1e-8, "A diagonal[{i}] = {}", a[[i, i]]);
}
}
}