use std::f64::consts::PI;
use chemx_core::{Molecule, units::BOLTZMANN_HT};
use chemx_linalg::symmetric_eigh;
use crate::frequencies::FrequencyResult;
const KB_J: f64 = 1.380_649e-23; const H_J: f64 = 6.626_070_15e-34; const C_CM: f64 = 2.997_924_58e10; const AMU_KG: f64 = 1.660_539_066_60e-27; const BOHR_M: f64 = 5.291_772_109_03e-11; const EH_J: f64 = 4.359_744_722_207_1e-18; const STD_P_PA: f64 = 101_325.0;
#[derive(Debug, Clone)]
pub struct ThermoResult {
pub temperature: f64,
pub symmetry_number: u32,
pub zpe: f64,
pub thermal_energy_corr: f64,
pub enthalpy_corr: f64,
pub enthalpy: f64,
pub entropy: f64,
pub gibbs: f64,
pub vib_frequencies_cm1: Vec<f64>,
pub moments_of_inertia: Vec<f64>,
pub is_linear: bool,
}
pub fn rrho_thermochemistry(
molecule: &Molecule,
freq_result: &FrequencyResult,
electronic_energy: f64,
temperature: f64,
symmetry_number: u32,
multiplicity: u32,
) -> ThermoResult {
let kt_j = KB_J * temperature; let kt_eh = BOLTZMANN_HT * temperature;
let (moments, is_linear) = principal_moments(molecule);
let total_mass_kg: f64 = molecule
.atoms
.iter()
.map(|a| a.element.mass() * AMU_KG)
.sum();
let q_trans = (2.0 * PI * total_mass_kg * kt_j).powf(1.5) * kt_j / (H_J.powi(3) * STD_P_PA);
let s_trans_j = KB_J * (2.5 + q_trans.ln()); let u_trans_j = 1.5 * kt_j;
let (_q_rot, s_rot_j, u_rot_j) = if is_linear {
let i_si = moments[2] * AMU_KG * BOHR_M * BOHR_M; let q = 8.0 * PI * PI * i_si * kt_j / (H_J * H_J * symmetry_number as f64);
let s = KB_J * (1.0 + q.ln());
let u = kt_j;
(q, s, u)
} else {
let i_si: Vec<f64> = moments
.iter()
.map(|m| m * AMU_KG * BOHR_M * BOHR_M)
.collect();
let prod_i = i_si[0] * i_si[1] * i_si[2];
let q = PI.sqrt() / symmetry_number as f64
* (8.0 * PI * PI * kt_j / (H_J * H_J)).powf(1.5)
* prod_i.sqrt();
let s = KB_J * (1.5 + q.ln());
let u = 1.5 * kt_j;
(q, s, u)
};
let n_trans_rot = if is_linear { 5 } else { 6 };
let vib_cm1: Vec<f64> = freq_result
.frequencies_cm1
.iter()
.skip(n_trans_rot)
.copied()
.filter(|&f| f > 0.0) .collect();
let mut zpe_j = 0.0f64;
let mut u_vib_thermal_j = 0.0f64;
let mut s_vib_j = 0.0f64;
for &nu_cm1 in &vib_cm1 {
let nu_hz = nu_cm1 * C_CM; let x = H_J * nu_hz / kt_j; zpe_j += 0.5 * H_J * nu_hz;
u_vib_thermal_j += H_J * nu_hz / (x.exp() - 1.0);
s_vib_j += KB_J * (x / (x.exp() - 1.0) - (1.0 - (-x).exp()).ln());
}
let s_elec_j = KB_J * (multiplicity as f64).ln();
let to_eh = |j: f64| j / EH_J;
let zpe = to_eh(zpe_j);
let u_vib_thermal = to_eh(u_vib_thermal_j);
let u_rot = to_eh(u_rot_j);
let u_trans = to_eh(u_trans_j);
let s_vib = to_eh(s_vib_j);
let s_rot = to_eh(s_rot_j);
let s_trans = to_eh(s_trans_j);
let s_elec = to_eh(s_elec_j);
let thermal_energy_corr = zpe + u_vib_thermal + u_rot + u_trans;
let enthalpy_corr = thermal_energy_corr + kt_eh; let enthalpy = electronic_energy + enthalpy_corr;
let entropy = s_vib + s_rot + s_trans + s_elec;
let gibbs = enthalpy - temperature * entropy;
ThermoResult {
temperature,
symmetry_number,
zpe,
thermal_energy_corr,
enthalpy_corr,
enthalpy,
entropy,
gibbs,
vib_frequencies_cm1: vib_cm1,
moments_of_inertia: moments,
is_linear,
}
}
fn principal_moments(molecule: &Molecule) -> (Vec<f64>, bool) {
let masses: Vec<f64> = molecule.atoms.iter().map(|a| a.element.mass()).collect();
let total_mass: f64 = masses.iter().sum();
let mut com = [0.0f64; 3];
for (i, atom) in molecule.atoms.iter().enumerate() {
for (k, c) in com.iter_mut().enumerate() {
*c += masses[i] * atom.position[k];
}
}
for c in &mut com {
*c /= total_mass;
}
let mut imat = [[0.0f64; 3]; 3];
for (i, atom) in molecule.atoms.iter().enumerate() {
let r = [
atom.position[0] - com[0],
atom.position[1] - com[1],
atom.position[2] - com[2],
];
let r2 = r[0] * r[0] + r[1] * r[1] + r[2] * r[2];
let m = masses[i];
for a in 0..3 {
for b in 0..3 {
let delta = if a == b { 1.0 } else { 0.0 };
imat[a][b] += m * (delta * r2 - r[a] * r[b]);
}
}
}
let flat: Vec<f64> = (0..3)
.flat_map(|i| (0..3).map(move |j| imat[i][j]))
.collect();
let imat_fmat = chemx_linalg::mat_from_row_major(3, &flat);
let eigh = symmetric_eigh(&imat_fmat);
let moments: Vec<f64> = eigh.values.iter().map(|&v| v.max(0.0)).collect();
let is_linear = moments[0] < 1e-4;
(moments, is_linear)
}