use serde::{Deserialize, Serialize};
use crate::eht;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MeshMethod {
Eht,
Pm3,
Xtb,
Hf3c,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrbitalMeshResult {
pub mesh: eht::IsosurfaceMesh,
pub grid: eht::VolumetricGrid,
pub method: MeshMethod,
pub mo_index: usize,
pub homo_index: usize,
pub orbital_energies: Vec<f64>,
pub gap: f64,
}
fn compute_eht_mesh(
elements: &[u8],
positions: &[[f64; 3]],
mo_index: usize,
spacing: f64,
padding: f64,
isovalue: f32,
) -> Result<OrbitalMeshResult, String> {
let result = eht::solve_eht(elements, positions, None)?;
let basis = eht::basis::build_basis(elements, positions);
let grid = eht::evaluate_orbital_on_grid(
&basis,
&result.coefficients,
mo_index,
positions,
spacing,
padding,
);
let mesh = eht::marching_cubes(&grid, isovalue);
Ok(OrbitalMeshResult {
mesh,
grid,
method: MeshMethod::Eht,
mo_index,
homo_index: result.homo_index,
orbital_energies: result.energies.clone(),
gap: result.gap,
})
}
fn compute_pm3_mesh(
elements: &[u8],
positions: &[[f64; 3]],
mo_index: usize,
spacing: f64,
padding: f64,
isovalue: f32,
) -> Result<OrbitalMeshResult, String> {
let pm3 = crate::pm3::solve_pm3(elements, positions)?;
let eht_result = eht::solve_eht(elements, positions, None)?;
let basis = eht::basis::build_basis(elements, positions);
let grid = eht::evaluate_orbital_on_grid(
&basis,
&eht_result.coefficients,
mo_index,
positions,
spacing,
padding,
);
let mesh = eht::marching_cubes(&grid, isovalue);
let homo_idx = if pm3.n_electrons > 0 {
pm3.n_electrons / 2 - 1
} else {
0
};
Ok(OrbitalMeshResult {
mesh,
grid,
method: MeshMethod::Pm3,
mo_index,
homo_index: homo_idx,
orbital_energies: pm3.orbital_energies,
gap: pm3.gap,
})
}
fn compute_xtb_mesh(
elements: &[u8],
positions: &[[f64; 3]],
mo_index: usize,
spacing: f64,
padding: f64,
isovalue: f32,
) -> Result<OrbitalMeshResult, String> {
let xtb = crate::xtb::solve_xtb(elements, positions)?;
let eht_result = eht::solve_eht(elements, positions, None)?;
let basis = eht::basis::build_basis(elements, positions);
let grid = eht::evaluate_orbital_on_grid(
&basis,
&eht_result.coefficients,
mo_index,
positions,
spacing,
padding,
);
let mesh = eht::marching_cubes(&grid, isovalue);
let homo_idx = if xtb.n_electrons > 0 {
xtb.n_electrons / 2 - 1
} else {
0
};
Ok(OrbitalMeshResult {
mesh,
grid,
method: MeshMethod::Xtb,
mo_index,
homo_index: homo_idx,
orbital_energies: xtb.orbital_energies,
gap: xtb.gap,
})
}
fn compute_hf3c_mesh(
elements: &[u8],
positions: &[[f64; 3]],
mo_index: usize,
spacing: f64,
padding: f64,
isovalue: f32,
) -> Result<OrbitalMeshResult, String> {
let config = crate::hf::HfConfig::default();
let hf = crate::hf::api::solve_hf3c(elements, positions, &config)?;
let eht_result = eht::solve_eht(elements, positions, None)?;
let basis = eht::basis::build_basis(elements, positions);
let grid = eht::evaluate_orbital_on_grid(
&basis,
&eht_result.coefficients,
mo_index,
positions,
spacing,
padding,
);
let mesh = eht::marching_cubes(&grid, isovalue);
let n_electrons: usize = elements.iter().map(|&z| z as usize).sum();
let homo_idx = if n_electrons > 0 {
n_electrons / 2 - 1
} else {
0
};
let gap = if hf.orbital_energies.len() > homo_idx + 1 {
hf.orbital_energies[homo_idx + 1] - hf.orbital_energies[homo_idx]
} else {
0.0
};
Ok(OrbitalMeshResult {
mesh,
grid,
method: MeshMethod::Hf3c,
mo_index,
homo_index: homo_idx,
orbital_energies: hf.orbital_energies,
gap,
})
}
pub fn compute_orbital_mesh(
elements: &[u8],
positions: &[[f64; 3]],
method: MeshMethod,
mo_index: usize,
spacing: f64,
padding: f64,
isovalue: f32,
) -> Result<OrbitalMeshResult, String> {
match method {
MeshMethod::Eht => {
compute_eht_mesh(elements, positions, mo_index, spacing, padding, isovalue)
}
MeshMethod::Pm3 => {
compute_pm3_mesh(elements, positions, mo_index, spacing, padding, isovalue)
}
MeshMethod::Xtb => {
compute_xtb_mesh(elements, positions, mo_index, spacing, padding, isovalue)
}
MeshMethod::Hf3c => {
compute_hf3c_mesh(elements, positions, mo_index, spacing, padding, isovalue)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn water_positions() -> (Vec<u8>, Vec<[f64; 3]>) {
(
vec![8, 1, 1],
vec![[0.0, 0.0, 0.0], [0.0, 0.757, 0.587], [0.0, -0.757, 0.587]],
)
}
#[test]
fn test_eht_mesh_water() {
let (elements, positions) = water_positions();
let result =
compute_orbital_mesh(&elements, &positions, MeshMethod::Eht, 0, 0.4, 3.0, 0.02);
assert!(result.is_ok());
let r = result.unwrap();
assert!(r.grid.num_points() > 0);
assert_eq!(r.method, MeshMethod::Eht);
}
#[test]
fn test_pm3_mesh_water() {
let (elements, positions) = water_positions();
let result =
compute_orbital_mesh(&elements, &positions, MeshMethod::Pm3, 0, 0.4, 3.0, 0.02);
assert!(result.is_ok());
let r = result.unwrap();
assert!(r.grid.num_points() > 0);
assert_eq!(r.method, MeshMethod::Pm3);
}
#[test]
fn test_xtb_mesh_water() {
let (elements, positions) = water_positions();
let result =
compute_orbital_mesh(&elements, &positions, MeshMethod::Xtb, 0, 0.4, 3.0, 0.02);
assert!(result.is_ok());
let r = result.unwrap();
assert!(r.grid.num_points() > 0);
assert_eq!(r.method, MeshMethod::Xtb);
}
#[test]
fn test_hf3c_mesh_water() {
let (elements, positions) = water_positions();
let result =
compute_orbital_mesh(&elements, &positions, MeshMethod::Hf3c, 0, 0.4, 3.0, 0.02);
assert!(result.is_ok());
let r = result.unwrap();
assert!(r.grid.num_points() > 0);
assert_eq!(r.method, MeshMethod::Hf3c);
}
}