Skip to main content

rgpot_core/
potential.rs

1// MIT License
2// Copyright 2023--present rgpot developers
3
4//! Callback-based potential dispatch.
5//!
6//! This module defines [`PotentialImpl`], an opaque handle that wraps a C
7//! function pointer callback together with a `void* user_data` and an optional
8//! destructor. This is the core abstraction that lets existing C++ potentials
9//! (LJ, CuH2, or any future implementation) plug into the Rust infrastructure
10//! without the Rust side knowing the concrete type.
11//!
12//! ## How it Works
13//!
14//! 1. The C++ side creates a potential object (e.g., `LJPot`).
15//! 2. A trampoline function with the [`PotentialCallback`] signature is
16//!    registered, casting `user_data` back to the concrete type and calling
17//!    `forceImpl`.
18//! 3. The Rust core dispatches through the function pointer, receiving results
19//!    via the [`rgpot_force_out_t`] output struct.
20//!
21//! ## Lifetime Contract
22//!
23//! - The `user_data` pointer is borrowed by `PotentialImpl`. The caller must
24//!   keep the underlying object alive for the lifetime of the handle.
25//! - If a `free_fn` is provided, it is called on drop when `user_data` is
26//!   non-null, transferring ownership to `PotentialImpl`.
27//! - The handle is exposed to C as `rgpot_potential_t` — an opaque pointer
28//!   managed via `rgpot_potential_new` / `rgpot_potential_free`.
29
30use std::os::raw::c_void;
31
32use crate::status::rgpot_status_t;
33use crate::types::{rgpot_force_input_t, rgpot_force_out_t};
34
35/// Function pointer type for a potential energy calculation.
36///
37/// The callback receives:
38/// - `user_data`: opaque pointer to the C++ object (e.g. `LJPot*`)
39/// - `input`: the atomic configuration (DLPack tensors)
40/// - `output`: the buffer for results (callback sets `forces` tensor)
41///
42/// Returns `RGPOT_SUCCESS` on success, or an error status code.
43pub type PotentialCallback = unsafe extern "C" fn(
44    user_data: *mut c_void,
45    input: *const rgpot_force_input_t,
46    output: *mut rgpot_force_out_t,
47) -> rgpot_status_t;
48
49/// Destructor for the user_data pointer.
50pub type FreeFn = unsafe extern "C" fn(*mut c_void);
51
52/// Opaque potential handle wrapping a callback + user data.
53pub struct PotentialImpl {
54    pub(crate) callback: PotentialCallback,
55    pub(crate) user_data: *mut c_void,
56    pub(crate) free_fn: Option<FreeFn>,
57}
58
59// PotentialImpl stores a raw pointer but we guarantee exclusive access
60// through the opaque handle pattern.
61unsafe impl Send for PotentialImpl {}
62
63impl PotentialImpl {
64    /// Create a new potential from a callback, user data, and optional destructor.
65    pub fn new(
66        callback: PotentialCallback,
67        user_data: *mut c_void,
68        free_fn: Option<FreeFn>,
69    ) -> Self {
70        Self {
71            callback,
72            user_data,
73            free_fn,
74        }
75    }
76
77    /// Invoke the underlying callback.
78    ///
79    /// # Safety
80    /// The caller must ensure `input` and `output` point to valid, properly
81    /// sized structures.
82    pub unsafe fn calculate(
83        &self,
84        input: *const rgpot_force_input_t,
85        output: *mut rgpot_force_out_t,
86    ) -> rgpot_status_t {
87        (self.callback)(self.user_data, input, output)
88    }
89}
90
91impl Drop for PotentialImpl {
92    fn drop(&mut self) {
93        if let Some(free) = self.free_fn {
94            if !self.user_data.is_null() {
95                unsafe { free(self.user_data) };
96            }
97        }
98    }
99}
100
101/// Opaque handle exposed to C as `rgpot_potential_t`.
102///
103/// This is a type alias used by cbindgen to generate a forward declaration.
104pub type rgpot_potential_t = PotentialImpl;
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::tensor::{
110        create_owned_f64_tensor, rgpot_tensor_cpu_f64_2d, rgpot_tensor_cpu_f64_matrix3,
111        rgpot_tensor_cpu_i32_1d, rgpot_tensor_free,
112    };
113    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
114
115    /// Helper: create input tensors from stack arrays.  Returns (input, cleanup_fn).
116    struct TestIO {
117        pos: Vec<f64>,
118        atmnrs: Vec<i32>,
119        box_: [f64; 9],
120    }
121
122    impl TestIO {
123        fn new(n: usize) -> Self {
124            Self {
125                pos: vec![0.0; n * 3],
126                atmnrs: vec![1; n],
127                box_: [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0],
128            }
129        }
130
131        fn make_input(&mut self) -> rgpot_force_input_t {
132            let n = self.atmnrs.len();
133            rgpot_force_input_t {
134                positions: unsafe {
135                    rgpot_tensor_cpu_f64_2d(self.pos.as_mut_ptr(), n as i64, 3)
136                },
137                atomic_numbers: unsafe {
138                    rgpot_tensor_cpu_i32_1d(self.atmnrs.as_mut_ptr(), n as i64)
139                },
140                box_matrix: unsafe {
141                    rgpot_tensor_cpu_f64_matrix3(self.box_.as_mut_ptr())
142                },
143            }
144        }
145
146        unsafe fn free_input(&self, input: &rgpot_force_input_t) {
147            unsafe {
148                rgpot_tensor_free(input.positions);
149                rgpot_tensor_free(input.atomic_numbers);
150                rgpot_tensor_free(input.box_matrix);
151            }
152        }
153    }
154
155    /// Mock callback: sets energy = n_atoms, creates owning forces tensor.
156    unsafe extern "C" fn mock_callback(
157        _user_data: *mut c_void,
158        input: *const rgpot_force_input_t,
159        output: *mut rgpot_force_out_t,
160    ) -> rgpot_status_t {
161        let inp = unsafe { &*input };
162        let out = unsafe { &mut *output };
163        let n = unsafe { inp.n_atoms() }.unwrap_or(0);
164        out.energy = n as f64;
165        out.variance = 0.0;
166        out.forces = create_owned_f64_tensor(vec![0.0; n * 3], vec![n as i64, 3]);
167        rgpot_status_t::RGPOT_SUCCESS
168    }
169
170    unsafe extern "C" fn error_callback(
171        _ud: *mut c_void,
172        _input: *const rgpot_force_input_t,
173        _output: *mut rgpot_force_out_t,
174    ) -> rgpot_status_t {
175        rgpot_status_t::RGPOT_INVALID_PARAMETER
176    }
177
178    unsafe extern "C" fn writing_callback(
179        _ud: *mut c_void,
180        input: *const rgpot_force_input_t,
181        output: *mut rgpot_force_out_t,
182    ) -> rgpot_status_t {
183        let inp = unsafe { &*input };
184        let out = unsafe { &mut *output };
185        let n = unsafe { inp.n_atoms() }.unwrap_or(0);
186        let forces: Vec<f64> = (1..=(n * 3)).map(|i| i as f64).collect();
187        out.forces = create_owned_f64_tensor(forces, vec![n as i64, 3]);
188        out.energy = -42.0;
189        out.variance = 0.5;
190        rgpot_status_t::RGPOT_SUCCESS
191    }
192
193    #[test]
194    fn test_potential_callback() {
195        let pot = PotentialImpl::new(mock_callback, std::ptr::null_mut(), None);
196
197        let mut io = TestIO::new(3);
198        let input = io.make_input();
199        let mut output = rgpot_force_out_t {
200            forces: std::ptr::null_mut(),
201            energy: 0.0,
202            variance: 0.0,
203        };
204
205        let status = unsafe { pot.calculate(&input, &mut output) };
206        assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
207        assert_eq!(output.energy, 3.0);
208
209        unsafe {
210            rgpot_tensor_free(output.forces);
211            io.free_input(&input);
212        }
213    }
214
215    #[test]
216    fn callback_error_status_is_returned() {
217        let pot = PotentialImpl::new(error_callback, std::ptr::null_mut(), None);
218        let mut io = TestIO::new(1);
219        let input = io.make_input();
220        let mut output = rgpot_force_out_t {
221            forces: std::ptr::null_mut(),
222            energy: 0.0,
223            variance: 0.0,
224        };
225        let status = unsafe { pot.calculate(&input, &mut output) };
226        assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
227        unsafe { io.free_input(&input) };
228    }
229
230    #[test]
231    fn callback_writes_forces_energy_variance() {
232        let pot = PotentialImpl::new(writing_callback, std::ptr::null_mut(), None);
233        let mut io = TestIO::new(2);
234        let input = io.make_input();
235        let mut output = rgpot_force_out_t {
236            forces: std::ptr::null_mut(),
237            energy: 0.0,
238            variance: 0.0,
239        };
240
241        let status = unsafe { pot.calculate(&input, &mut output) };
242        assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
243        assert_eq!(output.energy, -42.0);
244        assert_eq!(output.variance, 0.5);
245
246        // Read forces from the DLPack tensor
247        assert!(!output.forces.is_null());
248        let ft = unsafe { &(*output.forces).dl_tensor };
249        let forces = unsafe { std::slice::from_raw_parts(ft.data as *const f64, 6) };
250        assert_eq!(forces, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
251
252        unsafe {
253            rgpot_tensor_free(output.forces);
254            io.free_input(&input);
255        }
256    }
257
258    static DROP_CALLED: AtomicBool = AtomicBool::new(false);
259
260    unsafe extern "C" fn track_drop(ptr: *mut c_void) {
261        DROP_CALLED.store(true, Ordering::SeqCst);
262        let val = unsafe { *(ptr as *const u64) };
263        assert_eq!(val, 0xDEAD_BEEF);
264    }
265
266    #[test]
267    fn drop_calls_free_fn_with_user_data() {
268        DROP_CALLED.store(false, Ordering::SeqCst);
269        let mut sentinel: u64 = 0xDEAD_BEEF;
270        {
271            let _pot = PotentialImpl::new(
272                mock_callback,
273                &mut sentinel as *mut u64 as *mut c_void,
274                Some(track_drop),
275            );
276            assert!(!DROP_CALLED.load(Ordering::SeqCst));
277        }
278        assert!(DROP_CALLED.load(Ordering::SeqCst));
279    }
280
281    #[test]
282    fn drop_skips_free_fn_when_user_data_is_null() {
283        let _pot = PotentialImpl::new(mock_callback, std::ptr::null_mut(), Some(track_drop));
284    }
285
286    #[test]
287    fn drop_without_free_fn_is_safe() {
288        let _pot = PotentialImpl::new(mock_callback, std::ptr::null_mut(), None);
289    }
290
291    #[test]
292    fn user_data_is_passed_through() {
293        static CALL_COUNT: AtomicU32 = AtomicU32::new(0);
294
295        unsafe extern "C" fn count_cb(
296            ud: *mut c_void,
297            _input: *const rgpot_force_input_t,
298            output: *mut rgpot_force_out_t,
299        ) -> rgpot_status_t {
300            let ctr = unsafe { &*(ud as *const AtomicU32) };
301            ctr.fetch_add(1, Ordering::SeqCst);
302            unsafe { (*output).energy = ctr.load(Ordering::SeqCst) as f64 };
303            rgpot_status_t::RGPOT_SUCCESS
304        }
305
306        CALL_COUNT.store(0, Ordering::SeqCst);
307        let pot = PotentialImpl::new(
308            count_cb,
309            &CALL_COUNT as *const _ as *mut c_void,
310            None,
311        );
312
313        let mut io = TestIO::new(1);
314        let input = io.make_input();
315        let mut output = rgpot_force_out_t {
316            forces: std::ptr::null_mut(),
317            energy: 0.0,
318            variance: 0.0,
319        };
320
321        unsafe { pot.calculate(&input, &mut output) };
322        assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1);
323        assert_eq!(output.energy, 1.0);
324
325        unsafe { pot.calculate(&input, &mut output) };
326        assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 2);
327        assert_eq!(output.energy, 2.0);
328
329        unsafe { io.free_input(&input) };
330    }
331
332    #[test]
333    fn potential_impl_is_send() {
334        fn assert_send<T: Send>() {}
335        assert_send::<PotentialImpl>();
336    }
337}