#[cfg(feature = "extended")]
use numpy::PyArray3;
#[cfg(feature = "extended")]
use numpy::ndarray::ShapeBuilder;
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArrayDyn};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use crate::worker;
use crate::worker_protocol::{Command, Payload};
#[pymodule]
fn cp2k_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(init_cp2k, m)?)?;
m.add_function(wrap_pyfunction!(finalize_cp2k, m)?)?;
m.add_class::<PyForceEnv>()?;
Ok(())
}
fn worker_err(e: worker::WorkerError) -> PyErr {
PyRuntimeError::new_err(e.to_string())
}
pub fn ipc_call(py: Python, command: Command) -> PyResult<Payload> {
py.detach(|| worker::ipc_call(command).map_err(worker_err))
}
fn find_worker_binary_in_package(py: Python) -> Option<std::path::PathBuf> {
let dir = py
.import("cp2k_rs")
.and_then(|m| m.getattr("__file__"))
.and_then(|f| f.extract::<String>())
.ok()
.and_then(|file| {
std::path::Path::new(&file)
.parent()
.map(|p| p.to_path_buf())
})?;
let candidate = dir.join("cp2k_rs_worker");
candidate.exists().then_some(candidate)
}
#[pyfunction]
#[pyo3(signature = (nproc=1, launcher_cmd=None, env=None, working_dir=None, connect_timeout=120.0))]
pub fn init_cp2k(
py: Python,
nproc: u32,
launcher_cmd: Option<Vec<String>>,
env: Option<std::collections::HashMap<String, String>>,
working_dir: Option<String>,
connect_timeout: f64,
) -> PyResult<()> {
let worker_bin = find_worker_binary_in_package(py)
.or_else(worker::find_worker_binary)
.ok_or_else(|| {
PyRuntimeError::new_err(
"cp2k_rs_worker binary not found. \
Set CP2K_WORKER_BIN or ensure the binary is on PATH.",
)
})?;
py.detach(|| {
worker::start_worker(
worker_bin,
Some(nproc),
launcher_cmd,
env,
working_dir,
connect_timeout,
)
.map_err(worker_err)
})
}
#[pyfunction]
pub fn finalize_cp2k(py: Python) -> PyResult<()> {
py.detach(|| worker::stop_worker().map_err(worker_err))
}
#[pyclass]
pub struct PyForceEnv;
#[pymethods]
impl PyForceEnv {
#[new]
fn new(py: Python, input_file: String, output_file: String) -> PyResult<Self> {
ipc_call(
py,
Command::InitForceEnv {
input: input_file,
output: output_file,
},
)?;
Ok(PyForceEnv)
}
fn calc_energy_force(&self, py: Python) -> PyResult<()> {
ipc_call(py, Command::CalcEnergyForce)?;
Ok(())
}
fn calc_energy(&self, py: Python) -> PyResult<()> {
ipc_call(py, Command::CalcEnergy)?;
Ok(())
}
fn get_natom(&self, py: Python) -> PyResult<usize> {
match ipc_call(py, Command::GetNatom)? {
Payload::UInt(n) => Ok(n as usize),
Payload::Int(n) if n >= 0 => Ok(n as usize),
p => Err(unexpected_payload("get_natom", &p)),
}
}
fn get_nparticle(&self, py: Python) -> PyResult<usize> {
match ipc_call(py, Command::GetNparticle)? {
Payload::UInt(n) => Ok(n as usize),
Payload::Int(n) if n >= 0 => Ok(n as usize),
p => Err(unexpected_payload("get_nparticle", &p)),
}
}
fn get_potential_energy(&self, py: Python) -> PyResult<f64> {
match ipc_call(py, Command::GetPotentialEnergy)? {
Payload::Float(e) => Ok(e),
p => Err(unexpected_payload("get_potential_energy", &p)),
}
}
fn get_positions<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
match ipc_call(py, Command::GetPositions)? {
Payload::Array1(v) => Ok(v.into_pyarray(py)),
p => Err(unexpected_payload("get_positions", &p)),
}
}
fn get_forces<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
match ipc_call(py, Command::GetForces)? {
Payload::Array1(v) => Ok(v.into_pyarray(py)),
p => Err(unexpected_payload("get_forces", &p)),
}
}
fn get_cell<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
match ipc_call(py, Command::GetCell)? {
Payload::Array2 { rows, cols, data } => {
let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
.map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
Ok(arr.into_pyarray(py))
}
p => Err(unexpected_payload("get_cell", &p)),
}
}
fn get_qmmm_cell<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
match ipc_call(py, Command::GetQmmmCell)? {
Payload::Array2 { rows, cols, data } => {
let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
.map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
Ok(arr.into_pyarray(py))
}
p => Err(unexpected_payload("get_qmmm_cell", &p)),
}
}
fn set_positions(&self, py: Python, positions: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
let data: Vec<f64> = positions.as_array().iter().cloned().collect();
ipc_call(py, Command::SetPositions { data })?;
Ok(())
}
fn set_velocities(&self, py: Python, velocities: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
let data: Vec<f64> = velocities.as_array().iter().cloned().collect();
ipc_call(py, Command::SetVelocities { data })?;
Ok(())
}
fn set_cell(&self, py: Python, cell: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
let arr = cell.as_array();
if arr.shape() != [3, 3] {
return Err(PyRuntimeError::new_err("Cell must be a 3×3 array"));
}
let data: Vec<f64> = arr.iter().cloned().collect();
ipc_call(py, Command::SetCell { data })?;
Ok(())
}
fn get_mo_count(&self, py: Python) -> PyResult<i32> {
match ipc_call(py, Command::GetMoCount)? {
Payload::Int(n) => Ok(n as i32),
p => Err(unexpected_payload("get_mo_count", &p)),
}
}
#[cfg(feature = "extended")]
fn is_quickstep(&self, py: Python) -> PyResult<bool> {
match ipc_call(py, Command::IsQuickstep)? {
Payload::Bool(b) => Ok(b),
p => Err(unexpected_payload("is_quickstep", &p)),
}
}
#[cfg(feature = "extended")]
fn get_stress_tensor<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
match ipc_call(py, Command::GetStressTensor)? {
Payload::Array2 { rows, cols, data } => {
let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
.map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
Ok(arr.into_pyarray(py))
}
p => Err(unexpected_payload("get_stress_tensor", &p)),
}
}
#[cfg(feature = "extended")]
fn get_virial_tensor<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
match ipc_call(py, Command::GetVirialTensor)? {
Payload::Array2 { rows, cols, data } => {
let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
.map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
Ok(arr.into_pyarray(py))
}
p => Err(unexpected_payload("get_virial_tensor", &p)),
}
}
#[cfg(feature = "extended")]
fn get_nmo(&self, py: Python, spin: i32) -> PyResult<usize> {
match ipc_call(py, Command::GetNmo { spin })? {
Payload::UInt(n) => Ok(n as usize),
Payload::Int(n) if n >= 0 => Ok(n as usize),
p => Err(unexpected_payload("get_nmo", &p)),
}
}
#[cfg(feature = "extended")]
fn get_eigenvalues<'py>(
&self,
py: Python<'py>,
spin: i32,
) -> PyResult<Bound<'py, PyArray1<f64>>> {
match ipc_call(py, Command::GetEigenvalues { spin })? {
Payload::Array1(v) => Ok(v.into_pyarray(py)),
p => Err(unexpected_payload("get_eigenvalues", &p)),
}
}
#[cfg(feature = "extended")]
fn get_occupation_numbers<'py>(
&self,
py: Python<'py>,
spin: i32,
) -> PyResult<Bound<'py, PyArray1<f64>>> {
match ipc_call(py, Command::GetOccupationNumbers { spin })? {
Payload::Array1(v) => Ok(v.into_pyarray(py)),
p => Err(unexpected_payload("get_occupation_numbers", &p)),
}
}
#[cfg(feature = "extended")]
fn get_homo_lumo(&self, py: Python, spin: i32) -> PyResult<(f64, f64, i32, i32)> {
match ipc_call(py, Command::GetHomoLumo { spin })? {
Payload::HomoLumo {
homo,
lumo,
homo_idx,
lumo_idx,
} => Ok((homo, lumo, homo_idx, lumo_idx)),
p => Err(unexpected_payload("get_homo_lumo", &p)),
}
}
#[cfg(feature = "extended")]
fn get_band_gap(&self, py: Python, spin: i32) -> PyResult<f64> {
match ipc_call(py, Command::GetHomoLumo { spin })? {
Payload::HomoLumo { homo, lumo, .. } => Ok((lumo - homo) * 27.2114),
p => Err(unexpected_payload("get_band_gap", &p)),
}
}
#[cfg(feature = "extended")]
fn get_mulliken_charges<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
match ipc_call(py, Command::GetMullikenCharges)? {
Payload::Array1(v) => Ok(v.into_pyarray(py)),
p => Err(unexpected_payload("get_mulliken_charges", &p)),
}
}
#[cfg(feature = "extended")]
fn get_hirshfeld_charges<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
match ipc_call(py, Command::GetHirshfeldCharges)? {
Payload::Array1(v) => Ok(v.into_pyarray(py)),
p => Err(unexpected_payload("get_hirshfeld_charges", &p)),
}
}
#[cfg(feature = "extended")]
fn get_dipole_moment<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
match ipc_call(py, Command::GetDipoleMoment)? {
Payload::Array1(v) => Ok(v.into_pyarray(py)),
p => Err(unexpected_payload("get_dipole_moment", &p)),
}
}
#[cfg(feature = "extended")]
fn get_scf_info(&self, py: Python) -> PyResult<(i32, bool, f64)> {
match ipc_call(py, Command::GetScfInfo)? {
Payload::ScfInfo {
nsteps,
converged,
energy_change,
} => Ok((nsteps, converged, energy_change)),
p => Err(unexpected_payload("get_scf_info", &p)),
}
}
#[cfg(feature = "extended")]
fn get_energy_components(&self, py: Python) -> PyResult<(f64, f64, f64, f64, f64)> {
match ipc_call(py, Command::GetEnergyComponents)? {
Payload::EnergyComponents {
e_kin,
e_hartree,
e_xc,
e_core,
e_total,
} => Ok((e_kin, e_hartree, e_xc, e_core, e_total)),
p => Err(unexpected_payload("get_energy_components", &p)),
}
}
#[cfg(feature = "extended")]
fn get_nelectron(&self, py: Python) -> PyResult<i32> {
match ipc_call(py, Command::GetNelectron)? {
Payload::Int(n) => Ok(n as i32),
p => Err(unexpected_payload("get_nelectron", &p)),
}
}
#[cfg(feature = "extended")]
fn get_fermi_energy(&self, py: Python) -> PyResult<f64> {
match ipc_call(py, Command::GetFermiEnergy)? {
Payload::Float(e) => Ok(e),
p => Err(unexpected_payload("get_fermi_energy", &p)),
}
}
#[cfg(feature = "extended")]
fn get_total_spin(&self, py: Python) -> PyResult<f64> {
match ipc_call(py, Command::GetTotalSpin)? {
Payload::Float(s) => Ok(s),
p => Err(unexpected_payload("get_total_spin", &p)),
}
}
#[cfg(feature = "extended")]
#[pyo3(signature = (spin = 1))]
fn get_grid_info(&self, py: Python, spin: i32) -> PyResult<Py<PyAny>> {
match ipc_call(py, Command::GetGridInfo { spin })? {
Payload::GridInfo { npts, origin, dh } => {
let dict = pyo3::types::PyDict::new(py);
dict.set_item("npts", npts.to_vec())?;
dict.set_item("origin", origin.to_vec())?;
let dh_list: Vec<Vec<f64>> = dh.iter().map(|row| row.to_vec()).collect();
dict.set_item("dh", dh_list)?;
Ok(dict.into_any().unbind())
}
p => Err(unexpected_payload("get_grid_info", &p)),
}
}
#[cfg(feature = "extended")]
#[pyo3(signature = (spin = 1))]
fn get_electron_density<'py>(
&self,
py: Python<'py>,
spin: i32,
) -> PyResult<(Py<PyAny>, Bound<'py, PyArray3<f64>>)> {
let payload = ipc_call(py, Command::GetElectronDensity { spin })?;
match payload {
Payload::SharedArray3 {
shm_name,
dims,
byte_size,
} => {
let data = py.detach(|| {
worker::read_shared_array3(&shm_name, dims, byte_size).map_err(worker_err)
})?;
let info_payload = ipc_call(py, Command::GetGridInfo { spin })?;
let info_dict = match info_payload {
Payload::GridInfo { npts, origin, dh } => {
let dict = pyo3::types::PyDict::new(py);
dict.set_item("npts", npts.to_vec())?;
dict.set_item("origin", origin.to_vec())?;
let dh_list: Vec<Vec<f64>> = dh.iter().map(|row| row.to_vec()).collect();
dict.set_item("dh", dh_list)?;
dict.into_any().unbind()
}
_ => {
return Err(PyRuntimeError::new_err(
"Failed to get grid info after density retrieval",
));
}
};
let arr =
numpy::ndarray::Array3::from_shape_vec((dims[0], dims[1], dims[2]).f(), data)
.map_err(|e| PyRuntimeError::new_err(format!("Array shape error: {e}")))?;
Ok((info_dict, arr.into_pyarray(py)))
}
p => Err(unexpected_payload("get_electron_density", &p)),
}
}
#[cfg(feature = "extended")]
#[pyo3(signature = (spin = 1))]
fn get_mo_coeff_info(&self, py: Python, spin: i32) -> PyResult<(usize, usize)> {
match ipc_call(py, Command::GetMoCoeffInfo { spin })? {
Payload::MoCoeffInfo { nao, nmo } => Ok((nao, nmo)),
p => Err(unexpected_payload("get_mo_coeff_info", &p)),
}
}
#[cfg(feature = "extended")]
#[pyo3(signature = (spin = 1))]
fn get_mo_coefficients<'py>(
&self,
py: Python<'py>,
spin: i32,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let payload = ipc_call(py, Command::GetMoCoefficients { spin })?;
match payload {
Payload::SharedArray2 {
shm_name,
rows,
cols,
byte_size,
} => {
let data = py.detach(|| {
worker::read_shared_array2(&shm_name, byte_size).map_err(worker_err)
})?;
let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols).f(), data)
.map_err(|e| PyRuntimeError::new_err(format!("Array shape error: {e}")))?;
Ok(arr.into_pyarray(py))
}
p => Err(unexpected_payload("get_mo_coefficients", &p)),
}
}
}
fn unexpected_payload(func: &str, payload: &Payload) -> PyErr {
PyRuntimeError::new_err(format!("{func}: unexpected payload variant {:?}", payload))
}