Skip to main content

cp2k_rs/
python.rs

1//! Python bindings for CP2K – thin PyO3 wrapper around `crate::worker`.
2//!
3//! All blocking operations release the GIL via `py.detach(...)`, delegating
4//! the actual work to the GIL-free functions in [`crate::worker`].
5//!
6//! # Typical usage
7//!
8//! ```python
9//! import cp2k_rs
10//!
11//! cp2k_rs.init_cp2k(nproc=4)
12//! fe = cp2k_rs.PyForceEnv("input.inp", "output.out")
13//! fe.calc_energy_force()
14//! e = fe.get_potential_energy()
15//! cp2k_rs.finalize_cp2k()
16//! ```
17
18#[cfg(feature = "extended")]
19use numpy::PyArray3;
20#[cfg(feature = "extended")]
21use numpy::ndarray::ShapeBuilder;
22use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArrayDyn};
23use pyo3::exceptions::PyRuntimeError;
24use pyo3::prelude::*;
25
26use crate::worker;
27use crate::worker_protocol::{Command, Payload};
28
29// ─── module registration ─────────────────────────────────────────────────────
30
31/// The Python module `cp2k_rs`.
32#[pymodule]
33fn cp2k_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
34    m.add_function(wrap_pyfunction!(init_cp2k, m)?)?;
35    m.add_function(wrap_pyfunction!(finalize_cp2k, m)?)?;
36    m.add_class::<PyForceEnv>()?;
37    Ok(())
38}
39
40// ─── error conversion ─────────────────────────────────────────────────────────
41
42fn worker_err(e: worker::WorkerError) -> PyErr {
43    PyRuntimeError::new_err(e.to_string())
44}
45
46// ─── IPC (GIL-releasing) ─────────────────────────────────────────────────────
47
48/// Send a command and receive the response. Releases the GIL while blocked.
49pub fn ipc_call(py: Python, command: Command) -> PyResult<Payload> {
50    py.detach(|| worker::ipc_call(command).map_err(worker_err))
51}
52
53// ─── binary discovery (Python-specific extension) ────────────────────────────
54
55/// Look for the worker binary next to the installed `cp2k_rs` Python package.
56fn find_worker_binary_in_package(py: Python) -> Option<std::path::PathBuf> {
57    let dir = py
58        .import("cp2k_rs")
59        .and_then(|m| m.getattr("__file__"))
60        .and_then(|f| f.extract::<String>())
61        .ok()
62        .and_then(|file| {
63            std::path::Path::new(&file)
64                .parent()
65                .map(|p| p.to_path_buf())
66        })?;
67    let candidate = dir.join("cp2k_rs_worker");
68    candidate.exists().then_some(candidate)
69}
70
71// ─── init / finalize ─────────────────────────────────────────────────────────
72
73/// Start the MPI worker and return once the socket is ready.
74///
75/// Parameters
76/// ----------
77/// nproc : int, optional
78///     Number of MPI ranks (default 1). Ignored when `launcher_cmd` is given.
79/// launcher_cmd : list[str], optional
80///     Complete launcher command prefix, e.g. ``["srun", "-n", "8"]``.
81/// env : dict[str, str], optional
82///     Extra environment variables forwarded to the worker.
83/// working_dir : str, optional
84///     Working directory for the worker process.
85/// connect_timeout : float, optional
86///     Seconds to wait for the worker to become ready (default 120).
87#[pyfunction]
88#[pyo3(signature = (nproc=1, launcher_cmd=None, env=None, working_dir=None, connect_timeout=120.0))]
89pub fn init_cp2k(
90    py: Python,
91    nproc: u32,
92    launcher_cmd: Option<Vec<String>>,
93    env: Option<std::collections::HashMap<String, String>>,
94    working_dir: Option<String>,
95    connect_timeout: f64,
96) -> PyResult<()> {
97    // Binary lookup: Python package path (requires GIL) first, then standard locations.
98    let worker_bin = find_worker_binary_in_package(py)
99        .or_else(worker::find_worker_binary)
100        .ok_or_else(|| {
101            PyRuntimeError::new_err(
102                "cp2k_rs_worker binary not found. \
103                 Set CP2K_WORKER_BIN or ensure the binary is on PATH.",
104            )
105        })?;
106
107    // Spawning and waiting for readiness: release the GIL.
108    py.detach(|| {
109        worker::start_worker(
110            worker_bin,
111            Some(nproc),
112            launcher_cmd,
113            env,
114            working_dir,
115            connect_timeout,
116        )
117        .map_err(worker_err)
118    })
119}
120
121/// Shut down the MPI worker and clean up resources.
122#[pyfunction]
123pub fn finalize_cp2k(py: Python) -> PyResult<()> {
124    py.detach(|| worker::stop_worker().map_err(worker_err))
125}
126
127// ─── PyForceEnv ──────────────────────────────────────────────────────────────
128
129/// Python wrapper around a CP2K force environment running inside the MPI worker.
130#[pyclass]
131pub struct PyForceEnv;
132
133#[pymethods]
134impl PyForceEnv {
135    /// Create a new force environment.
136    ///
137    /// Parameters
138    /// ----------
139    /// input_file : str
140    ///     Path to the CP2K input file.
141    /// output_file : str
142    ///     Path for CP2K output.
143    #[new]
144    fn new(py: Python, input_file: String, output_file: String) -> PyResult<Self> {
145        ipc_call(
146            py,
147            Command::InitForceEnv {
148                input: input_file,
149                output: output_file,
150            },
151        )?;
152        Ok(PyForceEnv)
153    }
154
155    // ── calculations ────────────────────────────────────────────────────────
156
157    fn calc_energy_force(&self, py: Python) -> PyResult<()> {
158        ipc_call(py, Command::CalcEnergyForce)?;
159        Ok(())
160    }
161
162    fn calc_energy(&self, py: Python) -> PyResult<()> {
163        ipc_call(py, Command::CalcEnergy)?;
164        Ok(())
165    }
166
167    // ── queries ─────────────────────────────────────────────────────────────
168
169    fn get_natom(&self, py: Python) -> PyResult<usize> {
170        match ipc_call(py, Command::GetNatom)? {
171            Payload::UInt(n) => Ok(n as usize),
172            Payload::Int(n) if n >= 0 => Ok(n as usize),
173            p => Err(unexpected_payload("get_natom", &p)),
174        }
175    }
176
177    fn get_nparticle(&self, py: Python) -> PyResult<usize> {
178        match ipc_call(py, Command::GetNparticle)? {
179            Payload::UInt(n) => Ok(n as usize),
180            Payload::Int(n) if n >= 0 => Ok(n as usize),
181            p => Err(unexpected_payload("get_nparticle", &p)),
182        }
183    }
184
185    fn get_potential_energy(&self, py: Python) -> PyResult<f64> {
186        match ipc_call(py, Command::GetPotentialEnergy)? {
187            Payload::Float(e) => Ok(e),
188            p => Err(unexpected_payload("get_potential_energy", &p)),
189        }
190    }
191
192    fn get_positions<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
193        match ipc_call(py, Command::GetPositions)? {
194            Payload::Array1(v) => Ok(v.into_pyarray(py)),
195            p => Err(unexpected_payload("get_positions", &p)),
196        }
197    }
198
199    fn get_forces<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
200        match ipc_call(py, Command::GetForces)? {
201            Payload::Array1(v) => Ok(v.into_pyarray(py)),
202            p => Err(unexpected_payload("get_forces", &p)),
203        }
204    }
205
206    fn get_cell<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
207        match ipc_call(py, Command::GetCell)? {
208            Payload::Array2 { rows, cols, data } => {
209                let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
210                    .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
211                Ok(arr.into_pyarray(py))
212            }
213            p => Err(unexpected_payload("get_cell", &p)),
214        }
215    }
216
217    fn get_qmmm_cell<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
218        match ipc_call(py, Command::GetQmmmCell)? {
219            Payload::Array2 { rows, cols, data } => {
220                let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
221                    .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
222                Ok(arr.into_pyarray(py))
223            }
224            p => Err(unexpected_payload("get_qmmm_cell", &p)),
225        }
226    }
227
228    // ── setters ─────────────────────────────────────────────────────────────
229
230    fn set_positions(&self, py: Python, positions: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
231        let data: Vec<f64> = positions.as_array().iter().cloned().collect();
232        ipc_call(py, Command::SetPositions { data })?;
233        Ok(())
234    }
235
236    fn set_velocities(&self, py: Python, velocities: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
237        let data: Vec<f64> = velocities.as_array().iter().cloned().collect();
238        ipc_call(py, Command::SetVelocities { data })?;
239        Ok(())
240    }
241
242    fn set_cell(&self, py: Python, cell: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
243        let arr = cell.as_array();
244        if arr.shape() != [3, 3] {
245            return Err(PyRuntimeError::new_err("Cell must be a 3×3 array"));
246        }
247        let data: Vec<f64> = arr.iter().cloned().collect();
248        ipc_call(py, Command::SetCell { data })?;
249        Ok(())
250    }
251
252    // ── active space ─────────────────────────────────────────────────────────
253
254    fn get_mo_count(&self, py: Python) -> PyResult<i32> {
255        match ipc_call(py, Command::GetMoCount)? {
256            Payload::Int(n) => Ok(n as i32),
257            p => Err(unexpected_payload("get_mo_count", &p)),
258        }
259    }
260
261    // ── extended interface ───────────────────────────────────────────────────
262
263    #[cfg(feature = "extended")]
264    fn is_quickstep(&self, py: Python) -> PyResult<bool> {
265        match ipc_call(py, Command::IsQuickstep)? {
266            Payload::Bool(b) => Ok(b),
267            p => Err(unexpected_payload("is_quickstep", &p)),
268        }
269    }
270
271    #[cfg(feature = "extended")]
272    fn get_stress_tensor<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
273        match ipc_call(py, Command::GetStressTensor)? {
274            Payload::Array2 { rows, cols, data } => {
275                let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
276                    .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
277                Ok(arr.into_pyarray(py))
278            }
279            p => Err(unexpected_payload("get_stress_tensor", &p)),
280        }
281    }
282
283    #[cfg(feature = "extended")]
284    fn get_virial_tensor<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
285        match ipc_call(py, Command::GetVirialTensor)? {
286            Payload::Array2 { rows, cols, data } => {
287                let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
288                    .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
289                Ok(arr.into_pyarray(py))
290            }
291            p => Err(unexpected_payload("get_virial_tensor", &p)),
292        }
293    }
294
295    #[cfg(feature = "extended")]
296    fn get_nmo(&self, py: Python, spin: i32) -> PyResult<usize> {
297        match ipc_call(py, Command::GetNmo { spin })? {
298            Payload::UInt(n) => Ok(n as usize),
299            Payload::Int(n) if n >= 0 => Ok(n as usize),
300            p => Err(unexpected_payload("get_nmo", &p)),
301        }
302    }
303
304    #[cfg(feature = "extended")]
305    fn get_eigenvalues<'py>(
306        &self,
307        py: Python<'py>,
308        spin: i32,
309    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
310        match ipc_call(py, Command::GetEigenvalues { spin })? {
311            Payload::Array1(v) => Ok(v.into_pyarray(py)),
312            p => Err(unexpected_payload("get_eigenvalues", &p)),
313        }
314    }
315
316    #[cfg(feature = "extended")]
317    fn get_occupation_numbers<'py>(
318        &self,
319        py: Python<'py>,
320        spin: i32,
321    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
322        match ipc_call(py, Command::GetOccupationNumbers { spin })? {
323            Payload::Array1(v) => Ok(v.into_pyarray(py)),
324            p => Err(unexpected_payload("get_occupation_numbers", &p)),
325        }
326    }
327
328    #[cfg(feature = "extended")]
329    fn get_homo_lumo(&self, py: Python, spin: i32) -> PyResult<(f64, f64, i32, i32)> {
330        match ipc_call(py, Command::GetHomoLumo { spin })? {
331            Payload::HomoLumo {
332                homo,
333                lumo,
334                homo_idx,
335                lumo_idx,
336            } => Ok((homo, lumo, homo_idx, lumo_idx)),
337            p => Err(unexpected_payload("get_homo_lumo", &p)),
338        }
339    }
340
341    #[cfg(feature = "extended")]
342    fn get_band_gap(&self, py: Python, spin: i32) -> PyResult<f64> {
343        match ipc_call(py, Command::GetHomoLumo { spin })? {
344            Payload::HomoLumo { homo, lumo, .. } => Ok((lumo - homo) * 27.211386245988),
345            p => Err(unexpected_payload("get_band_gap", &p)),
346        }
347    }
348
349    #[cfg(feature = "extended")]
350    fn get_mulliken_charges<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
351        match ipc_call(py, Command::GetMullikenCharges)? {
352            Payload::Array1(v) => Ok(v.into_pyarray(py)),
353            p => Err(unexpected_payload("get_mulliken_charges", &p)),
354        }
355    }
356
357    #[cfg(feature = "extended")]
358    fn get_hirshfeld_charges<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
359        match ipc_call(py, Command::GetHirshfeldCharges)? {
360            Payload::Array1(v) => Ok(v.into_pyarray(py)),
361            p => Err(unexpected_payload("get_hirshfeld_charges", &p)),
362        }
363    }
364
365    #[cfg(feature = "extended")]
366    fn get_dipole_moment<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
367        match ipc_call(py, Command::GetDipoleMoment)? {
368            Payload::Array1(v) => Ok(v.into_pyarray(py)),
369            p => Err(unexpected_payload("get_dipole_moment", &p)),
370        }
371    }
372
373    #[cfg(feature = "extended")]
374    fn get_scf_info(&self, py: Python) -> PyResult<(i32, bool, f64)> {
375        match ipc_call(py, Command::GetScfInfo)? {
376            Payload::ScfInfo {
377                nsteps,
378                converged,
379                energy_change,
380            } => Ok((nsteps, converged, energy_change)),
381            p => Err(unexpected_payload("get_scf_info", &p)),
382        }
383    }
384
385    #[cfg(feature = "extended")]
386    fn get_energy_components(&self, py: Python) -> PyResult<(f64, f64, f64, f64, f64)> {
387        match ipc_call(py, Command::GetEnergyComponents)? {
388            Payload::EnergyComponents {
389                e_kin,
390                e_hartree,
391                e_xc,
392                e_core,
393                e_total,
394            } => Ok((e_kin, e_hartree, e_xc, e_core, e_total)),
395            p => Err(unexpected_payload("get_energy_components", &p)),
396        }
397    }
398
399    #[cfg(feature = "extended")]
400    fn get_nelectron(&self, py: Python) -> PyResult<i32> {
401        match ipc_call(py, Command::GetNelectron)? {
402            Payload::Int(n) => Ok(n as i32),
403            p => Err(unexpected_payload("get_nelectron", &p)),
404        }
405    }
406
407    #[cfg(feature = "extended")]
408    fn get_fermi_energy(&self, py: Python) -> PyResult<f64> {
409        match ipc_call(py, Command::GetFermiEnergy)? {
410            Payload::Float(e) => Ok(e),
411            p => Err(unexpected_payload("get_fermi_energy", &p)),
412        }
413    }
414
415    #[cfg(feature = "extended")]
416    fn get_total_spin(&self, py: Python) -> PyResult<f64> {
417        match ipc_call(py, Command::GetTotalSpin)? {
418            Payload::Float(s) => Ok(s),
419            p => Err(unexpected_payload("get_total_spin", &p)),
420        }
421    }
422
423    /// Get grid metadata for the electron density.
424    ///
425    /// Returns a dict with keys "npts" ([nx,ny,nz]), "origin" ([x,y,z] in Bohr),
426    /// and "dh" (3x3 cell increment matrix in Bohr).
427    #[cfg(feature = "extended")]
428    #[pyo3(signature = (spin = 1))]
429    fn get_grid_info(&self, py: Python, spin: i32) -> PyResult<Py<PyAny>> {
430        match ipc_call(py, Command::GetGridInfo { spin })? {
431            Payload::GridInfo { npts, origin, dh } => {
432                let dict = pyo3::types::PyDict::new(py);
433                dict.set_item("npts", npts.to_vec())?;
434                dict.set_item("origin", origin.to_vec())?;
435                let dh_list: Vec<Vec<f64>> = dh.iter().map(|row| row.to_vec()).collect();
436                dict.set_item("dh", dh_list)?;
437                Ok(dict.into_any().unbind())
438            }
439            p => Err(unexpected_payload("get_grid_info", &p)),
440        }
441    }
442
443    /// Get the full electron density on the realspace grid.
444    ///
445    /// Returns a tuple (grid_info_dict, density_array) where density_array is
446    /// a numpy array of shape (nx, ny, nz) in electrons/Bohr^3, Fortran order.
447    #[cfg(feature = "extended")]
448    #[pyo3(signature = (spin = 1))]
449    fn get_electron_density<'py>(
450        &self,
451        py: Python<'py>,
452        spin: i32,
453    ) -> PyResult<(Py<PyAny>, Bound<'py, PyArray3<f64>>)> {
454        let payload = ipc_call(py, Command::GetElectronDensity { spin })?;
455        match payload {
456            Payload::SharedArray3 {
457                shm_name,
458                dims,
459                byte_size,
460            } => {
461                // Read data from shared memory
462                let data = py.detach(|| {
463                    worker::read_shared_array3(&shm_name, dims, byte_size).map_err(worker_err)
464                })?;
465
466                // Build grid info dict (make a separate call for metadata)
467                let info_payload = ipc_call(py, Command::GetGridInfo { spin })?;
468                let info_dict = match info_payload {
469                    Payload::GridInfo { npts, origin, dh } => {
470                        let dict = pyo3::types::PyDict::new(py);
471                        dict.set_item("npts", npts.to_vec())?;
472                        dict.set_item("origin", origin.to_vec())?;
473                        let dh_list: Vec<Vec<f64>> = dh.iter().map(|row| row.to_vec()).collect();
474                        dict.set_item("dh", dh_list)?;
475                        dict.into_any().unbind()
476                    }
477                    _ => {
478                        return Err(PyRuntimeError::new_err(
479                            "Failed to get grid info after density retrieval",
480                        ));
481                    }
482                };
483
484                // Create numpy array in Fortran order
485                let arr =
486                    numpy::ndarray::Array3::from_shape_vec((dims[0], dims[1], dims[2]).f(), data)
487                        .map_err(|e| PyRuntimeError::new_err(format!("Array shape error: {e}")))?;
488                Ok((info_dict, arr.into_pyarray(py)))
489            }
490            p => Err(unexpected_payload("get_electron_density", &p)),
491        }
492    }
493
494    /// Get the dimensions of the MO coefficient matrix.
495    ///
496    /// Returns a tuple (nao, nmo) — number of atomic orbitals and molecular orbitals.
497    #[cfg(feature = "extended")]
498    #[pyo3(signature = (spin = 1))]
499    fn get_mo_coeff_info(&self, py: Python, spin: i32) -> PyResult<(usize, usize)> {
500        match ipc_call(py, Command::GetMoCoeffInfo { spin })? {
501            Payload::MoCoeffInfo { nao, nmo } => Ok((nao, nmo)),
502            p => Err(unexpected_payload("get_mo_coeff_info", &p)),
503        }
504    }
505
506    /// Get the full MO coefficient matrix.
507    ///
508    /// Returns a numpy array of shape (nao, nmo) where column j contains the j-th
509    /// molecular orbital expressed in the AO basis. Uses Fortran (column-major) order.
510    #[cfg(feature = "extended")]
511    #[pyo3(signature = (spin = 1))]
512    fn get_mo_coefficients<'py>(
513        &self,
514        py: Python<'py>,
515        spin: i32,
516    ) -> PyResult<Bound<'py, PyArray2<f64>>> {
517        let payload = ipc_call(py, Command::GetMoCoefficients { spin })?;
518        match payload {
519            Payload::SharedArray2 {
520                shm_name,
521                rows,
522                cols,
523                byte_size,
524            } => {
525                let data = py.detach(|| {
526                    worker::read_shared_array2(&shm_name, byte_size).map_err(worker_err)
527                })?;
528
529                let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols).f(), data)
530                    .map_err(|e| PyRuntimeError::new_err(format!("Array shape error: {e}")))?;
531                Ok(arr.into_pyarray(py))
532            }
533            p => Err(unexpected_payload("get_mo_coefficients", &p)),
534        }
535    }
536}
537
538// ─── helpers ─────────────────────────────────────────────────────────────────
539
540fn unexpected_payload(func: &str, payload: &Payload) -> PyErr {
541    PyRuntimeError::new_err(format!("{func}: unexpected payload variant {:?}", payload))
542}