geometric_pyo3/
engine.rs

1//! Engine corresponds to `geometric.engine.Engine` class in geomeTRIC.
2
3use crate::interface::PyGeomDriver;
4use pyo3::prelude::*;
5use pyo3::types::{PyDict, PyList};
6use pyo3::PyTypeInfo;
7
8/// Mixin class to be mult-inherited together with `geometric.engine.Engine`.
9#[pyclass(subclass)]
10pub struct EngineMixin {
11    driver: Option<PyGeomDriver>,
12}
13
14#[pymethods]
15impl EngineMixin {
16    /// Initialize the EngineMixin class.
17    ///
18    /// Though this function does not do anything, it is intended to be
19    /// inherited by `geometric.engine.Engine`'s initializer. So input
20    /// `_molecule` is actually gracefully initialized.
21    ///
22    /// Please note that `driver` is not initialized here. It should be set
23    /// using the `set_driver` method manually.
24    #[new]
25    pub fn new(_molecule: PyObject) -> PyResult<Self> {
26        Ok(EngineMixin { driver: None })
27    }
28
29    /// Set the driver for the engine.
30    ///
31    /// This driver is used to calculate the energy and gradient of the
32    /// system. This function must be called before using the engine.
33    pub fn set_driver(&mut self, driver: &PyGeomDriver) {
34        self.driver = Some(driver.clone());
35    }
36
37    /// Inherits `geometric.engine.Engine`'s `calc_new` method.
38    pub fn calc_new(&mut self, coords: Vec<f64>, dirname: &str) -> PyResult<PyObject> {
39        // Compute the energy and gradient using the driver.
40        let mut driver = self.driver.as_mut().unwrap().pointer.lock().unwrap();
41        let result = driver.calc_new(&coords, dirname);
42        // Convert the result to a Python object.
43        // Note: that gradient must be converted to numpy flattened array (natom * 3),
44        // list or 2-d array are both incorrect here.
45        Python::with_gil(|py| {
46            let numpy = py.import("numpy")?;
47            let energy = result.energy;
48            let gradient = numpy.call_method1("array", (PyList::new(py, result.gradient)?,))?;
49            let dict = PyDict::new(py);
50            dict.set_item("energy", energy)?;
51            dict.set_item("gradient", gradient)?;
52            Ok(dict.into())
53        })
54    }
55}
56
57/// Get the PyO3 usable geomeTRIC engine class.
58pub fn get_pyo3_engine_cls() -> PyResult<PyObject> {
59    Python::with_gil(|py| {
60        // get the type of base class `geometric.engine.Engine`
61        let base_type = py.import("geometric.engine")?.getattr("Engine")?;
62        // get the type of `EngineMixin` class
63        let engine_mixin_type = EngineMixin::type_object(py);
64
65        // execute and return the following code in Python:
66        // ```python
67        // PyO3Engine = type('PyO3Engine', (EngineMixin, Engine), {})
68        // ```
69        let locals = PyDict::new(py);
70        locals.set_item("Engine", base_type)?;
71        locals.set_item("EngineMixin", engine_mixin_type)?;
72        let pyo3_engine_type =
73            py.eval(c"type('PyO3Engine', (EngineMixin, Engine), {})", None, Some(&locals))?;
74        Ok(pyo3_engine_type.into())
75    })
76}
77
78/// Initialize a geomeTRIC molecule into Python object.
79///
80/// # Arguments
81///
82/// - `elem`: A slice of strings representing the element types.
83/// - `xyzs`: A list of vectors representing the coordinates of the atoms. Each
84///   vector represents one molecule, where its length is (natom * 3), with
85///   dimension of coordinate (3) to be contiguous.
86pub fn init_pyo3_molecule(elem: &[&str], xyzs: &[Vec<f64>]) -> PyResult<PyObject> {
87    Python::with_gil(|py| {
88        // Import the geometric Python module.
89        let molecule_cls = py.import("geometric.molecule")?.getattr("Molecule")?;
90
91        // Create a new instance of the Molecule class
92        let molecule_instance = molecule_cls.call0()?;
93
94        // xyzs must be converted into numpy array of shape (natom, 3), where 1-D array
95        // or python list are both incorrect.
96        let numpy = py.import("numpy")?;
97        let xyzs = xyzs
98            .iter()
99            .map(|xyz| {
100                let arr = numpy.call_method1("array", (PyList::new(py, xyz)?,))?;
101                let arr = arr.call_method1("reshape", (-1, 3))?;
102                Ok(arr)
103            })
104            .collect::<PyResult<Vec<_>>>()?;
105
106        // Set the attributes
107        molecule_instance.setattr("elem", elem)?;
108        molecule_instance.setattr("xyzs", xyzs)?;
109        Ok(molecule_instance.into())
110    })
111}
112
113/// Call `geometric.molecule.build_topology` function to build the topology.
114pub fn molecule_build_topology(
115    molecule: &PyObject,
116    kwargs: Option<&Bound<'_, PyDict>>,
117) -> PyResult<()> {
118    Python::with_gil(|py| {
119        molecule.call_method(py, "build_topology", (), kwargs)?;
120        Ok(())
121    })
122}