1#[cfg(feature = "extended")]
19use numpy::PyArray3;
20#[cfg(feature = "extended")]
21use numpy::ndarray::ShapeBuilder;
22use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArrayDyn};
23use pyo3::exceptions::PyRuntimeError;
24use pyo3::prelude::*;
25
26use crate::worker;
27use crate::worker_protocol::{Command, Payload};
28
29#[pymodule]
33fn cp2k_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
34 m.add_function(wrap_pyfunction!(init_cp2k, m)?)?;
35 m.add_function(wrap_pyfunction!(finalize_cp2k, m)?)?;
36 m.add_class::<PyForceEnv>()?;
37 Ok(())
38}
39
40fn worker_err(e: worker::WorkerError) -> PyErr {
43 PyRuntimeError::new_err(e.to_string())
44}
45
46pub fn ipc_call(py: Python, command: Command) -> PyResult<Payload> {
50 py.detach(|| worker::ipc_call(command).map_err(worker_err))
51}
52
53fn find_worker_binary_in_package(py: Python) -> Option<std::path::PathBuf> {
57 let dir = py
58 .import("cp2k_rs")
59 .and_then(|m| m.getattr("__file__"))
60 .and_then(|f| f.extract::<String>())
61 .ok()
62 .and_then(|file| {
63 std::path::Path::new(&file)
64 .parent()
65 .map(|p| p.to_path_buf())
66 })?;
67 let candidate = dir.join("cp2k_rs_worker");
68 candidate.exists().then_some(candidate)
69}
70
71#[pyfunction]
88#[pyo3(signature = (nproc=1, launcher_cmd=None, env=None, working_dir=None, connect_timeout=120.0))]
89pub fn init_cp2k(
90 py: Python,
91 nproc: u32,
92 launcher_cmd: Option<Vec<String>>,
93 env: Option<std::collections::HashMap<String, String>>,
94 working_dir: Option<String>,
95 connect_timeout: f64,
96) -> PyResult<()> {
97 let worker_bin = find_worker_binary_in_package(py)
99 .or_else(worker::find_worker_binary)
100 .ok_or_else(|| {
101 PyRuntimeError::new_err(
102 "cp2k_rs_worker binary not found. \
103 Set CP2K_WORKER_BIN or ensure the binary is on PATH.",
104 )
105 })?;
106
107 py.detach(|| {
109 worker::start_worker(
110 worker_bin,
111 Some(nproc),
112 launcher_cmd,
113 env,
114 working_dir,
115 connect_timeout,
116 )
117 .map_err(worker_err)
118 })
119}
120
121#[pyfunction]
123pub fn finalize_cp2k(py: Python) -> PyResult<()> {
124 py.detach(|| worker::stop_worker().map_err(worker_err))
125}
126
127#[pyclass]
131pub struct PyForceEnv;
132
133#[pymethods]
134impl PyForceEnv {
135 #[new]
144 fn new(py: Python, input_file: String, output_file: String) -> PyResult<Self> {
145 ipc_call(
146 py,
147 Command::InitForceEnv {
148 input: input_file,
149 output: output_file,
150 },
151 )?;
152 Ok(PyForceEnv)
153 }
154
155 fn calc_energy_force(&self, py: Python) -> PyResult<()> {
158 ipc_call(py, Command::CalcEnergyForce)?;
159 Ok(())
160 }
161
162 fn calc_energy(&self, py: Python) -> PyResult<()> {
163 ipc_call(py, Command::CalcEnergy)?;
164 Ok(())
165 }
166
167 fn get_natom(&self, py: Python) -> PyResult<usize> {
170 match ipc_call(py, Command::GetNatom)? {
171 Payload::UInt(n) => Ok(n as usize),
172 Payload::Int(n) if n >= 0 => Ok(n as usize),
173 p => Err(unexpected_payload("get_natom", &p)),
174 }
175 }
176
177 fn get_nparticle(&self, py: Python) -> PyResult<usize> {
178 match ipc_call(py, Command::GetNparticle)? {
179 Payload::UInt(n) => Ok(n as usize),
180 Payload::Int(n) if n >= 0 => Ok(n as usize),
181 p => Err(unexpected_payload("get_nparticle", &p)),
182 }
183 }
184
185 fn get_potential_energy(&self, py: Python) -> PyResult<f64> {
186 match ipc_call(py, Command::GetPotentialEnergy)? {
187 Payload::Float(e) => Ok(e),
188 p => Err(unexpected_payload("get_potential_energy", &p)),
189 }
190 }
191
192 fn get_positions<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
193 match ipc_call(py, Command::GetPositions)? {
194 Payload::Array1(v) => Ok(v.into_pyarray(py)),
195 p => Err(unexpected_payload("get_positions", &p)),
196 }
197 }
198
199 fn get_forces<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
200 match ipc_call(py, Command::GetForces)? {
201 Payload::Array1(v) => Ok(v.into_pyarray(py)),
202 p => Err(unexpected_payload("get_forces", &p)),
203 }
204 }
205
206 fn get_cell<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
207 match ipc_call(py, Command::GetCell)? {
208 Payload::Array2 { rows, cols, data } => {
209 let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
210 .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
211 Ok(arr.into_pyarray(py))
212 }
213 p => Err(unexpected_payload("get_cell", &p)),
214 }
215 }
216
217 fn get_qmmm_cell<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
218 match ipc_call(py, Command::GetQmmmCell)? {
219 Payload::Array2 { rows, cols, data } => {
220 let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
221 .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
222 Ok(arr.into_pyarray(py))
223 }
224 p => Err(unexpected_payload("get_qmmm_cell", &p)),
225 }
226 }
227
228 fn set_positions(&self, py: Python, positions: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
231 let data: Vec<f64> = positions.as_array().iter().cloned().collect();
232 ipc_call(py, Command::SetPositions { data })?;
233 Ok(())
234 }
235
236 fn set_velocities(&self, py: Python, velocities: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
237 let data: Vec<f64> = velocities.as_array().iter().cloned().collect();
238 ipc_call(py, Command::SetVelocities { data })?;
239 Ok(())
240 }
241
242 fn set_cell(&self, py: Python, cell: PyReadonlyArrayDyn<f64>) -> PyResult<()> {
243 let arr = cell.as_array();
244 if arr.shape() != [3, 3] {
245 return Err(PyRuntimeError::new_err("Cell must be a 3×3 array"));
246 }
247 let data: Vec<f64> = arr.iter().cloned().collect();
248 ipc_call(py, Command::SetCell { data })?;
249 Ok(())
250 }
251
252 fn get_mo_count(&self, py: Python) -> PyResult<i32> {
255 match ipc_call(py, Command::GetMoCount)? {
256 Payload::Int(n) => Ok(n as i32),
257 p => Err(unexpected_payload("get_mo_count", &p)),
258 }
259 }
260
261 #[cfg(feature = "extended")]
264 fn is_quickstep(&self, py: Python) -> PyResult<bool> {
265 match ipc_call(py, Command::IsQuickstep)? {
266 Payload::Bool(b) => Ok(b),
267 p => Err(unexpected_payload("is_quickstep", &p)),
268 }
269 }
270
271 #[cfg(feature = "extended")]
272 fn get_stress_tensor<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
273 match ipc_call(py, Command::GetStressTensor)? {
274 Payload::Array2 { rows, cols, data } => {
275 let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
276 .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
277 Ok(arr.into_pyarray(py))
278 }
279 p => Err(unexpected_payload("get_stress_tensor", &p)),
280 }
281 }
282
283 #[cfg(feature = "extended")]
284 fn get_virial_tensor<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
285 match ipc_call(py, Command::GetVirialTensor)? {
286 Payload::Array2 { rows, cols, data } => {
287 let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols), data)
288 .map_err(|e| PyRuntimeError::new_err(format!("{e}")))?;
289 Ok(arr.into_pyarray(py))
290 }
291 p => Err(unexpected_payload("get_virial_tensor", &p)),
292 }
293 }
294
295 #[cfg(feature = "extended")]
296 fn get_nmo(&self, py: Python, spin: i32) -> PyResult<usize> {
297 match ipc_call(py, Command::GetNmo { spin })? {
298 Payload::UInt(n) => Ok(n as usize),
299 Payload::Int(n) if n >= 0 => Ok(n as usize),
300 p => Err(unexpected_payload("get_nmo", &p)),
301 }
302 }
303
304 #[cfg(feature = "extended")]
305 fn get_eigenvalues<'py>(
306 &self,
307 py: Python<'py>,
308 spin: i32,
309 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
310 match ipc_call(py, Command::GetEigenvalues { spin })? {
311 Payload::Array1(v) => Ok(v.into_pyarray(py)),
312 p => Err(unexpected_payload("get_eigenvalues", &p)),
313 }
314 }
315
316 #[cfg(feature = "extended")]
317 fn get_occupation_numbers<'py>(
318 &self,
319 py: Python<'py>,
320 spin: i32,
321 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
322 match ipc_call(py, Command::GetOccupationNumbers { spin })? {
323 Payload::Array1(v) => Ok(v.into_pyarray(py)),
324 p => Err(unexpected_payload("get_occupation_numbers", &p)),
325 }
326 }
327
328 #[cfg(feature = "extended")]
329 fn get_homo_lumo(&self, py: Python, spin: i32) -> PyResult<(f64, f64, i32, i32)> {
330 match ipc_call(py, Command::GetHomoLumo { spin })? {
331 Payload::HomoLumo {
332 homo,
333 lumo,
334 homo_idx,
335 lumo_idx,
336 } => Ok((homo, lumo, homo_idx, lumo_idx)),
337 p => Err(unexpected_payload("get_homo_lumo", &p)),
338 }
339 }
340
341 #[cfg(feature = "extended")]
342 fn get_band_gap(&self, py: Python, spin: i32) -> PyResult<f64> {
343 match ipc_call(py, Command::GetHomoLumo { spin })? {
344 Payload::HomoLumo { homo, lumo, .. } => Ok((lumo - homo) * 27.211386245988),
345 p => Err(unexpected_payload("get_band_gap", &p)),
346 }
347 }
348
349 #[cfg(feature = "extended")]
350 fn get_mulliken_charges<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
351 match ipc_call(py, Command::GetMullikenCharges)? {
352 Payload::Array1(v) => Ok(v.into_pyarray(py)),
353 p => Err(unexpected_payload("get_mulliken_charges", &p)),
354 }
355 }
356
357 #[cfg(feature = "extended")]
358 fn get_hirshfeld_charges<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
359 match ipc_call(py, Command::GetHirshfeldCharges)? {
360 Payload::Array1(v) => Ok(v.into_pyarray(py)),
361 p => Err(unexpected_payload("get_hirshfeld_charges", &p)),
362 }
363 }
364
365 #[cfg(feature = "extended")]
366 fn get_dipole_moment<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
367 match ipc_call(py, Command::GetDipoleMoment)? {
368 Payload::Array1(v) => Ok(v.into_pyarray(py)),
369 p => Err(unexpected_payload("get_dipole_moment", &p)),
370 }
371 }
372
373 #[cfg(feature = "extended")]
374 fn get_scf_info(&self, py: Python) -> PyResult<(i32, bool, f64)> {
375 match ipc_call(py, Command::GetScfInfo)? {
376 Payload::ScfInfo {
377 nsteps,
378 converged,
379 energy_change,
380 } => Ok((nsteps, converged, energy_change)),
381 p => Err(unexpected_payload("get_scf_info", &p)),
382 }
383 }
384
385 #[cfg(feature = "extended")]
386 fn get_energy_components(&self, py: Python) -> PyResult<(f64, f64, f64, f64, f64)> {
387 match ipc_call(py, Command::GetEnergyComponents)? {
388 Payload::EnergyComponents {
389 e_kin,
390 e_hartree,
391 e_xc,
392 e_core,
393 e_total,
394 } => Ok((e_kin, e_hartree, e_xc, e_core, e_total)),
395 p => Err(unexpected_payload("get_energy_components", &p)),
396 }
397 }
398
399 #[cfg(feature = "extended")]
400 fn get_nelectron(&self, py: Python) -> PyResult<i32> {
401 match ipc_call(py, Command::GetNelectron)? {
402 Payload::Int(n) => Ok(n as i32),
403 p => Err(unexpected_payload("get_nelectron", &p)),
404 }
405 }
406
407 #[cfg(feature = "extended")]
408 fn get_fermi_energy(&self, py: Python) -> PyResult<f64> {
409 match ipc_call(py, Command::GetFermiEnergy)? {
410 Payload::Float(e) => Ok(e),
411 p => Err(unexpected_payload("get_fermi_energy", &p)),
412 }
413 }
414
415 #[cfg(feature = "extended")]
416 fn get_total_spin(&self, py: Python) -> PyResult<f64> {
417 match ipc_call(py, Command::GetTotalSpin)? {
418 Payload::Float(s) => Ok(s),
419 p => Err(unexpected_payload("get_total_spin", &p)),
420 }
421 }
422
423 #[cfg(feature = "extended")]
428 #[pyo3(signature = (spin = 1))]
429 fn get_grid_info(&self, py: Python, spin: i32) -> PyResult<Py<PyAny>> {
430 match ipc_call(py, Command::GetGridInfo { spin })? {
431 Payload::GridInfo { npts, origin, dh } => {
432 let dict = pyo3::types::PyDict::new(py);
433 dict.set_item("npts", npts.to_vec())?;
434 dict.set_item("origin", origin.to_vec())?;
435 let dh_list: Vec<Vec<f64>> = dh.iter().map(|row| row.to_vec()).collect();
436 dict.set_item("dh", dh_list)?;
437 Ok(dict.into_any().unbind())
438 }
439 p => Err(unexpected_payload("get_grid_info", &p)),
440 }
441 }
442
443 #[cfg(feature = "extended")]
448 #[pyo3(signature = (spin = 1))]
449 fn get_electron_density<'py>(
450 &self,
451 py: Python<'py>,
452 spin: i32,
453 ) -> PyResult<(Py<PyAny>, Bound<'py, PyArray3<f64>>)> {
454 let payload = ipc_call(py, Command::GetElectronDensity { spin })?;
455 match payload {
456 Payload::SharedArray3 {
457 shm_name,
458 dims,
459 byte_size,
460 } => {
461 let data = py.detach(|| {
463 worker::read_shared_array3(&shm_name, dims, byte_size).map_err(worker_err)
464 })?;
465
466 let info_payload = ipc_call(py, Command::GetGridInfo { spin })?;
468 let info_dict = match info_payload {
469 Payload::GridInfo { npts, origin, dh } => {
470 let dict = pyo3::types::PyDict::new(py);
471 dict.set_item("npts", npts.to_vec())?;
472 dict.set_item("origin", origin.to_vec())?;
473 let dh_list: Vec<Vec<f64>> = dh.iter().map(|row| row.to_vec()).collect();
474 dict.set_item("dh", dh_list)?;
475 dict.into_any().unbind()
476 }
477 _ => {
478 return Err(PyRuntimeError::new_err(
479 "Failed to get grid info after density retrieval",
480 ));
481 }
482 };
483
484 let arr =
486 numpy::ndarray::Array3::from_shape_vec((dims[0], dims[1], dims[2]).f(), data)
487 .map_err(|e| PyRuntimeError::new_err(format!("Array shape error: {e}")))?;
488 Ok((info_dict, arr.into_pyarray(py)))
489 }
490 p => Err(unexpected_payload("get_electron_density", &p)),
491 }
492 }
493
494 #[cfg(feature = "extended")]
498 #[pyo3(signature = (spin = 1))]
499 fn get_mo_coeff_info(&self, py: Python, spin: i32) -> PyResult<(usize, usize)> {
500 match ipc_call(py, Command::GetMoCoeffInfo { spin })? {
501 Payload::MoCoeffInfo { nao, nmo } => Ok((nao, nmo)),
502 p => Err(unexpected_payload("get_mo_coeff_info", &p)),
503 }
504 }
505
506 #[cfg(feature = "extended")]
511 #[pyo3(signature = (spin = 1))]
512 fn get_mo_coefficients<'py>(
513 &self,
514 py: Python<'py>,
515 spin: i32,
516 ) -> PyResult<Bound<'py, PyArray2<f64>>> {
517 let payload = ipc_call(py, Command::GetMoCoefficients { spin })?;
518 match payload {
519 Payload::SharedArray2 {
520 shm_name,
521 rows,
522 cols,
523 byte_size,
524 } => {
525 let data = py.detach(|| {
526 worker::read_shared_array2(&shm_name, byte_size).map_err(worker_err)
527 })?;
528
529 let arr = numpy::ndarray::Array2::from_shape_vec((rows, cols).f(), data)
530 .map_err(|e| PyRuntimeError::new_err(format!("Array shape error: {e}")))?;
531 Ok(arr.into_pyarray(py))
532 }
533 p => Err(unexpected_payload("get_mo_coefficients", &p)),
534 }
535 }
536}
537
538fn unexpected_payload(func: &str, payload: &Payload) -> PyErr {
541 PyRuntimeError::new_err(format!("{func}: unexpected payload variant {:?}", payload))
542}