geometric_pyo3/interface.rs
1//! Interface that electronic structure codes should implement.
2
3use std::mem::transmute;
4use std::sync::{Arc, Mutex};
5
6use pyo3::prelude::*;
7
8/// Gradient output from the electronic structure code.
9///
10/// - `energy`: The energy of the system, scalar.
11/// - `gradient`: The gradient of the system, flattened (natom * 3), with
12/// dimension of coordinate (3) to be contiguous.
13pub struct GradOutput {
14 pub energy: f64,
15 pub gradient: Vec<f64>,
16}
17
18/// Trait API to be implemented in electronic structure code for geomeTRIC PyO3
19/// binding.
20pub trait GeomDriverAPI: Send {
21 /// Calculate the energy and gradient of the system.
22 ///
23 /// This trait corresponds to the `calc_new` method in the `Engine` class in
24 /// geomeTRIC.
25 ///
26 /// # Arguments
27 ///
28 /// - `coords` - The coordinates of the system, flattened (natom * 3), with
29 /// dimension of coordinate (3) to be contiguous.
30 /// - `dirname` - The directory to run the calculation in. Can be set to
31 /// dummy if directory is not required for gradient computation.
32 ///
33 /// # Returns
34 ///
35 /// A `GradOutput` struct containing the energy and gradient of the system.
36 fn calc_new(&mut self, coords: &[f64], dirname: &str) -> GradOutput;
37}
38
39/// Python wrapper for the `GeomDriverAPI` trait implementations.
40///
41/// `GeomDriverAPI` is defined as rust trait, which is not directly usable in
42/// Python. This makes the glue between the rust trait and the python class.
43///
44/// # Safety
45///
46/// This struct is marked as `unsafe` because it uses `transmute` to convert
47/// local lifetime to static lifetime.
48///
49/// If your class (to be implemented with trait `GeomDriverAPI`) have lifetime
50/// parameters, then this lifetime will be transmuted to static lifetime when it
51/// is converted to python object. As long as you don't disturb the lifetime of
52/// reference, this transmute should be safe.
53#[pyclass]
54#[derive(Clone)]
55pub struct PyGeomDriver {
56 pub pointer: Arc<Mutex<dyn GeomDriverAPI>>,
57}
58
59impl<T> From<T> for PyGeomDriver
60where
61 T: GeomDriverAPI,
62{
63 fn from(driver: T) -> Self {
64 let a: Arc<Mutex<dyn GeomDriverAPI>> = Arc::new(Mutex::new(driver));
65 // Safety not checked, and should be provided by the caller.
66 // This will convert local lifetime (of `T`) to static lifetime (`'static`) for
67 // python calls.
68 unsafe { transmute(a) }
69 }
70}