use crate::ffi;
#[cfg(feature = "extended")]
use crate::ffi_extended;
use ndarray::{Array1, Array2};
use std::ffi::CString;
use std::os::raw::{c_double, c_int};
use thiserror::Error;
#[cfg(feature = "mpi")]
use mpi::{ffi::MPI_Comm_c2f, raw::AsRaw, topology::SimpleCommunicator};
#[derive(Error, Debug)]
pub enum CP2KError {
#[error("CP2K FFI error: {0}")]
FFIError(String),
#[error("Null byte in string: {0}")]
NulError(#[from] std::ffi::NulError),
#[error("UTF-8 conversion error: {0}")]
Utf8Error(#[from] std::string::FromUtf8Error),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("CP2K not initialized")]
NotInitialized,
#[error("CP2K initialization error: {0}")]
InitializationError(String),
#[error("CP2K finalization error: {0}")]
FinalizationError(String),
}
pub type CP2KResult<T> = Result<T, CP2KError>;
pub struct ForceEnv {
id: ffi::force_env_t,
}
impl ForceEnv {
pub fn new(input_file: &str, output_file: &str) -> CP2KResult<Self> {
let input_c = CString::new(input_file)?;
let output_c = CString::new(output_file)?;
let mut id: ffi::force_env_t = 0;
#[cfg(feature = "mpi")]
{
let world = SimpleCommunicator::world();
let fortran_comm = unsafe { MPI_Comm_c2f(world.as_raw()) };
unsafe {
ffi::cp2k_create_force_env_comm(
&mut id as *mut _,
input_c.as_ptr(),
output_c.as_ptr(),
fortran_comm,
);
}
}
#[cfg(not(feature = "mpi"))]
unsafe {
ffi::cp2k_create_force_env(&mut id as *mut _, input_c.as_ptr(), output_c.as_ptr());
}
Ok(ForceEnv { id })
}
#[cfg(feature = "mpi")]
pub fn new_with_mpi(
input_file: &str,
output_file: &str,
comm: &mpi::topology::SimpleCommunicator,
) -> CP2KResult<Self> {
let input_c = CString::new(input_file)?;
let output_c = CString::new(output_file)?;
let mut id: ffi::force_env_t = 0;
let raw_comm = comm.as_raw();
let fortran_comm = unsafe { MPI_Comm_c2f(raw_comm) };
unsafe {
ffi::cp2k_create_force_env_comm(
&mut id as *mut _,
input_c.as_ptr(),
output_c.as_ptr(),
fortran_comm,
);
}
Ok(ForceEnv { id })
}
pub fn set_positions(&mut self, positions: &[f64]) -> CP2KResult<()> {
unsafe {
ffi::cp2k_set_positions(self.id, positions.as_ptr(), positions.len() as c_int);
}
Ok(())
}
pub fn set_velocities(&mut self, velocities: &[f64]) -> CP2KResult<()> {
unsafe {
ffi::cp2k_set_velocities(self.id, velocities.as_ptr(), velocities.len() as c_int);
}
Ok(())
}
pub fn set_cell(&mut self, cell: &[[f64; 3]; 3]) -> CP2KResult<()> {
unsafe {
ffi::cp2k_set_cell(self.id, &cell[0][0] as *const _);
}
Ok(())
}
pub fn get_natom(&self) -> CP2KResult<usize> {
let mut natom: c_int = 0;
unsafe {
ffi::cp2k_get_natom(self.id, &mut natom as *mut _);
}
Ok(natom as usize)
}
pub fn get_nparticle(&self) -> CP2KResult<usize> {
let mut nparticle: c_int = 0;
unsafe {
ffi::cp2k_get_nparticle(self.id, &mut nparticle as *mut _);
}
Ok(nparticle as usize)
}
pub fn get_positions(&self) -> CP2KResult<Array1<f64>> {
let nparticle = self.get_nparticle()?;
let n_el = nparticle * 3;
let mut positions = vec![0.0; n_el];
unsafe {
ffi::cp2k_get_positions(self.id, positions.as_mut_ptr(), n_el as c_int);
}
Ok(Array1::from(positions))
}
pub fn get_forces(&self) -> CP2KResult<Array1<f64>> {
let nparticle = self.get_nparticle()?;
let n_el = nparticle * 3;
let mut forces = vec![0.0; n_el];
unsafe {
ffi::cp2k_get_forces(self.id, forces.as_mut_ptr(), n_el as c_int);
}
Ok(Array1::from(forces))
}
pub fn get_potential_energy(&self) -> CP2KResult<f64> {
let mut energy: c_double = 0.0;
unsafe {
ffi::cp2k_get_potential_energy(self.id, &mut energy as *mut _);
}
Ok(energy)
}
pub fn get_cell(&self) -> CP2KResult<Array2<f64>> {
let mut cell = [[0.0; 3]; 3];
unsafe {
ffi::cp2k_get_cell(self.id, &mut cell[0][0] as *mut _);
}
let flat_cell: Vec<f64> = cell.iter().flatten().copied().collect();
Array2::from_shape_vec((3, 3), flat_cell)
.map_err(|e| CP2KError::FFIError(format!("Array shape error: {e}")))
}
pub fn get_qmmm_cell(&self) -> CP2KResult<Array2<f64>> {
let mut cell = [[0.0; 3]; 3];
unsafe {
ffi::cp2k_get_qmmm_cell(self.id, &mut cell[0][0] as *mut _);
}
let flat_cell: Vec<f64> = cell.iter().flatten().copied().collect();
Array2::from_shape_vec((3, 3), flat_cell)
.map_err(|e| CP2KError::FFIError(format!("Array shape error: {e}")))
}
pub fn calc_energy_force(&mut self) -> CP2KResult<()> {
unsafe {
ffi::cp2k_calc_energy_force(self.id);
}
Ok(())
}
pub fn calc_energy(&mut self) -> CP2KResult<()> {
unsafe {
ffi::cp2k_calc_energy(self.id);
}
Ok(())
}
pub fn get_result(&self, description: &str, n_el: usize) -> CP2KResult<Array1<f64>> {
let desc_c = CString::new(description)?;
let mut result = vec![0.0; n_el];
unsafe {
ffi::cp2k_get_result(self.id, desc_c.as_ptr(), result.as_mut_ptr(), n_el as c_int);
}
Ok(Array1::from(result))
}
pub fn get_mo_count(&self) -> CP2KResult<i32> {
let count = unsafe { ffi::cp2k_active_space_get_mo_count(self.id) };
if count < 0 {
return Err(CP2KError::FFIError("Failed to get MO count".into()));
}
Ok(count)
}
pub fn get_fock_sub(&self) -> CP2KResult<Array2<f64>> {
let mo_count = self.get_mo_count()? as usize;
let buf_len = mo_count * mo_count;
let mut buf = vec![0.0; buf_len];
let nelem = unsafe {
ffi::cp2k_active_space_get_fock_sub(self.id, buf.as_mut_ptr(), buf_len as i64)
};
if nelem < 0 {
return Err(CP2KError::FFIError("Failed to get Fock submatrix".into()));
}
Array2::from_shape_vec((mo_count, mo_count), buf)
.map_err(|e| CP2KError::FFIError(format!("Array shape error: {e}")))
}
pub fn get_eri_nze_count(&self) -> CP2KResult<usize> {
let count = unsafe { ffi::cp2k_active_space_get_eri_nze_count(self.id) };
if count < 0 {
return Err(CP2KError::FFIError(
"Failed to get ERI non-zero element count".into(),
));
}
Ok(count as usize)
}
pub fn get_eri(&self) -> CP2KResult<(Vec<[i32; 4]>, Vec<f64>)> {
let nze_count = self.get_eri_nze_count()?;
let buf_coords_len = 4 * nze_count;
let mut buf_coords = vec![0i32; buf_coords_len];
let mut buf_values = vec![0.0; nze_count];
let nelem = unsafe {
ffi::cp2k_active_space_get_eri(
self.id,
buf_coords.as_mut_ptr(),
buf_coords_len as i64,
buf_values.as_mut_ptr(),
nze_count as i64,
)
};
if nelem < 0 {
return Err(CP2KError::FFIError("Failed to get ERI matrix".into()));
}
let mut coords = Vec::with_capacity(nze_count);
for i in 0..nze_count {
let idx = 4 * i;
coords.push([
buf_coords[idx],
buf_coords[idx + 1],
buf_coords[idx + 2],
buf_coords[idx + 3],
]);
}
Ok((coords, buf_values))
}
#[cfg(feature = "extended")]
pub fn is_quickstep(&self) -> bool {
unsafe { ffi_extended::cp2k_is_qs_env(self.id) != 0 }
}
#[cfg(feature = "extended")]
pub fn get_stress_tensor(&self) -> CP2KResult<Array2<f64>> {
let mut stress = [[0.0; 3]; 3];
unsafe {
ffi_extended::cp2k_get_stress_tensor(self.id, &mut stress[0][0] as *mut _);
}
let flat_stress: Vec<f64> = stress.iter().flatten().copied().collect();
Array2::from_shape_vec((3, 3), flat_stress)
.map_err(|e| CP2KError::FFIError(format!("Array shape error: {}", e)))
}
#[cfg(feature = "extended")]
pub fn get_virial_tensor(&self) -> CP2KResult<Array2<f64>> {
let mut virial = [[0.0; 3]; 3];
unsafe {
ffi_extended::cp2k_get_virial_tensor(self.id, &mut virial[0][0] as *mut _);
}
let flat_virial: Vec<f64> = virial.iter().flatten().copied().collect();
Array2::from_shape_vec((3, 3), flat_virial)
.map_err(|e| CP2KError::FFIError(format!("Array shape error: {}", e)))
}
#[cfg(feature = "extended")]
pub fn get_nmo(&self, spin: i32) -> CP2KResult<usize> {
let nmo = unsafe { ffi_extended::cp2k_get_nmo(self.id, spin as c_int) };
if nmo < 0 {
return Err(CP2KError::FFIError(format!(
"Failed to get number of MOs for spin {}",
spin
)));
}
Ok(nmo as usize)
}
#[cfg(feature = "extended")]
pub fn get_eigenvalues(&self, spin: i32) -> CP2KResult<Array1<f64>> {
let nmo = self.get_nmo(spin)?;
let mut eigenvalues = vec![0.0; nmo];
let n = unsafe {
ffi_extended::cp2k_get_eigenvalues(
self.id,
spin as c_int,
eigenvalues.as_mut_ptr(),
nmo as c_int,
)
};
if n < 0 {
return Err(CP2KError::FFIError(format!(
"Failed to get eigenvalues for spin {}",
spin
)));
}
eigenvalues.truncate(n as usize);
Ok(Array1::from(eigenvalues))
}
#[cfg(feature = "extended")]
pub fn get_occupation_numbers(&self, spin: i32) -> CP2KResult<Array1<f64>> {
let nmo = self.get_nmo(spin)?;
let mut occupations = vec![0.0; nmo];
let n = unsafe {
ffi_extended::cp2k_get_occupation_numbers(
self.id,
spin as c_int,
occupations.as_mut_ptr(),
nmo as c_int,
)
};
if n < 0 {
return Err(CP2KError::FFIError(format!(
"Failed to get occupation numbers for spin {}",
spin
)));
}
occupations.truncate(n as usize);
Ok(Array1::from(occupations))
}
#[cfg(feature = "extended")]
pub fn get_homo_lumo(&self, spin: i32) -> CP2KResult<(f64, f64, i32, i32)> {
let mut homo_energy: c_double = 0.0;
let mut lumo_energy: c_double = 0.0;
let mut homo_index: c_int = 0;
let mut lumo_index: c_int = 0;
let result = unsafe {
ffi_extended::cp2k_get_homo_lumo(
self.id,
spin as c_int,
&mut homo_energy as *mut _,
&mut lumo_energy as *mut _,
&mut homo_index as *mut _,
&mut lumo_index as *mut _,
)
};
if result < 0 {
return Err(CP2KError::FFIError(format!(
"Failed to get HOMO/LUMO for spin {}",
spin
)));
}
Ok((homo_energy, lumo_energy, homo_index, lumo_index))
}
#[cfg(feature = "extended")]
pub fn get_band_gap(&self, spin: i32) -> CP2KResult<f64> {
let (homo, lumo, _, _) = self.get_homo_lumo(spin)?;
const HARTREE_TO_EV: f64 = 27.211386245988;
Ok((lumo - homo) * HARTREE_TO_EV)
}
#[cfg(feature = "extended")]
pub fn get_mulliken_charges(&self) -> CP2KResult<Array1<f64>> {
let natom = self.get_natom()?;
let mut charges = vec![0.0; natom];
let result = unsafe {
ffi_extended::cp2k_get_mulliken_charges(self.id, charges.as_mut_ptr(), natom as c_int)
};
if result < 0 {
return Err(CP2KError::FFIError(
"Failed to get Mulliken charges".to_string(),
));
}
Ok(Array1::from(charges))
}
#[cfg(feature = "extended")]
pub fn get_dipole_moment(&self) -> CP2KResult<Array1<f64>> {
let mut dipole = [0.0; 3];
let result = unsafe { ffi_extended::cp2k_get_dipole_moment(self.id, dipole.as_mut_ptr()) };
if result < 0 {
return Err(CP2KError::FFIError(
"Failed to get dipole moment".to_string(),
));
}
Ok(Array1::from(dipole.to_vec()))
}
#[cfg(feature = "extended")]
pub fn get_scf_info(&self) -> CP2KResult<(i32, bool, f64)> {
let mut niter: c_int = 0;
let mut converged: c_int = 0;
let mut energy_change: c_double = 0.0;
let result = unsafe {
ffi_extended::cp2k_get_scf_info(
self.id,
&mut niter as *mut _,
&mut converged as *mut _,
&mut energy_change as *mut _,
)
};
if result < 0 {
return Err(CP2KError::FFIError("Failed to get SCF info".to_string()));
}
Ok((niter, converged != 0, energy_change))
}
#[cfg(feature = "extended")]
pub fn get_energy_components(&self) -> CP2KResult<(f64, f64, f64, f64, f64)> {
let mut e_kinetic = 0.0;
let mut e_hartree = 0.0;
let mut e_xc = 0.0;
let mut e_core = 0.0;
let mut e_total = 0.0;
let result = unsafe {
ffi_extended::cp2k_get_energy_components(
self.id,
&mut e_kinetic as *mut _,
&mut e_hartree as *mut _,
&mut e_xc as *mut _,
&mut e_core as *mut _,
&mut e_total as *mut _,
)
};
if result < 0 {
return Err(CP2KError::FFIError(
"Failed to get energy components".to_string(),
));
}
Ok((e_kinetic, e_hartree, e_xc, e_core, e_total))
}
#[cfg(feature = "extended")]
pub fn get_nelectron(&self) -> CP2KResult<i32> {
let mut nelectron: i32 = 0;
let result = unsafe { ffi_extended::cp2k_get_nelectron(self.id, &mut nelectron as *mut _) };
if result < 0 {
return Err(CP2KError::FFIError(
"Failed to get number of electrons".to_string(),
));
}
Ok(nelectron)
}
#[cfg(feature = "extended")]
pub fn get_fermi_energy(&self) -> CP2KResult<f64> {
let mut e_fermi = 0.0;
let result =
unsafe { ffi_extended::cp2k_get_fermi_energy(self.id, &mut e_fermi as *mut _) };
if result < 0 {
return Err(CP2KError::FFIError(
"Failed to get Fermi energy (not applicable for this system)".to_string(),
));
}
Ok(e_fermi)
}
#[cfg(feature = "extended")]
pub fn get_hirshfeld_charges(&self) -> CP2KResult<Array1<f64>> {
let natom = self.get_natom()? as i32;
let mut charges = vec![0.0; natom as usize];
let result = unsafe {
ffi_extended::cp2k_get_hirshfeld_charges(self.id, charges.as_mut_ptr(), natom)
};
if result < 0 {
return Err(CP2KError::FFIError(
"Failed to get Hirshfeld charges".to_string(),
));
}
Ok(Array1::from(charges))
}
#[cfg(feature = "extended")]
pub fn get_total_spin(&self) -> CP2KResult<f64> {
let mut total_spin = 0.0;
let result =
unsafe { ffi_extended::cp2k_get_total_spin(self.id, &mut total_spin as *mut _) };
if result < 0 {
return Err(CP2KError::FFIError("Failed to get total spin".to_string()));
}
Ok(total_spin)
}
}
impl Drop for ForceEnv {
fn drop(&mut self) {
unsafe {
ffi::cp2k_destroy_force_env(self.id);
}
}
}