geometric_pyo3/
interface.rs

1//! Interface that electronic structure codes should implement.
2
3use std::mem::transmute;
4use std::sync::{Arc, Mutex};
5
6use pyo3::prelude::*;
7
8/// Gradient output from the electronic structure code.
9///
10/// - `energy`: The energy of the system, scalar.
11/// - `gradient`: The gradient of the system, flattened (natom * 3), with
12///   dimension of coordinate (3) to be contiguous.
13pub struct GradOutput {
14    pub energy: f64,
15    pub gradient: Vec<f64>,
16}
17
18/// Trait API to be implemented in electronic structure code for geomeTRIC PyO3
19/// binding.
20pub trait GeomDriverAPI: Send {
21    /// Calculate the energy and gradient of the system.
22    ///
23    /// This trait corresponds to the `calc_new` method in the `Engine` class in
24    /// geomeTRIC.
25    ///
26    /// # Arguments
27    ///
28    /// - `coords` - The coordinates of the system, flattened (natom * 3), with
29    ///   dimension of coordinate (3) to be contiguous.
30    /// - `dirname` - The directory to run the calculation in. Can be set to
31    ///   dummy if directory is not required for gradient computation.
32    ///
33    /// # Returns
34    ///
35    /// A `GradOutput` struct containing the energy and gradient of the system.
36    fn calc_new(&mut self, coords: &[f64], dirname: &str) -> GradOutput;
37}
38
39/// Python wrapper for the `GeomDriverAPI` trait implementations.
40///
41/// `GeomDriverAPI` is defined as rust trait, which is not directly usable in
42/// Python. This makes the glue between the rust trait and the python class.
43///
44/// # Safety
45///
46/// This struct is marked as `unsafe` because it uses `transmute` to convert
47/// local lifetime to static lifetime.
48///
49/// If your class (to be implemented with trait `GeomDriverAPI`) have lifetime
50/// parameters, then this lifetime will be transmuted to static lifetime when it
51/// is converted to python object. As long as you don't disturb the lifetime of
52/// reference, this transmute should be safe.
53#[pyclass]
54#[derive(Clone)]
55pub struct PyGeomDriver {
56    pub pointer: Arc<Mutex<dyn GeomDriverAPI>>,
57}
58
59impl<T> From<T> for PyGeomDriver
60where
61    T: GeomDriverAPI,
62{
63    fn from(driver: T) -> Self {
64        let a: Arc<Mutex<dyn GeomDriverAPI>> = Arc::new(Mutex::new(driver));
65        // Safety not checked, and should be provided by the caller.
66        // This will convert local lifetime (of `T`) to static lifetime (`'static`) for
67        // python calls.
68        unsafe { transmute(a) }
69    }
70}