Skip to main content

rgpot_core/c_api/
potential.rs

1// MIT License
2// Copyright 2023--present rgpot developers
3
4//! C API for the potential handle lifecycle: create, calculate, free.
5//!
6//! These three functions form the core public interface for potential energy
7//! calculations. The typical usage pattern from C/C++ is:
8//!
9//! ```c
10//! // 1. Create a handle from a callback
11//! rgpot_potential_t *pot = rgpot_potential_new(my_callback, my_data, NULL);
12//!
13//! // 2. Prepare input/output
14//! rgpot_force_input_t  input  = rgpot_force_input_create(n, pos, types, box);
15//! rgpot_force_out_t    output = rgpot_force_out_create();
16//!
17//! // 3. Calculate
18//! rgpot_status_t s = rgpot_potential_calculate(pot, &input, &output);
19//! if (s != RGPOT_SUCCESS) { /* handle error */ }
20//! // output.forces is now a DLPack tensor — use it, then free:
21//! rgpot_tensor_free(output.forces);
22//!
23//! // 4. Clean up
24//! rgpot_force_input_free(&input);
25//! rgpot_potential_free(pot);
26//! ```
27
28use std::os::raw::c_void;
29
30use crate::potential::{PotentialCallback, PotentialImpl, rgpot_potential_t};
31use crate::status::{catch_unwind, rgpot_status_t, set_last_error};
32use crate::types::{rgpot_force_input_t, rgpot_force_out_t};
33
34/// Create a new potential handle from a callback function pointer.
35///
36/// - `callback`: the function that performs the force/energy calculation.
37/// - `user_data`: opaque pointer forwarded to every callback invocation
38///   (typically a pointer to the C++ potential object).
39/// - `free_fn`: optional destructor for `user_data`. Pass `NULL` if the
40///   caller manages the lifetime externally.
41///
42/// Returns a heap-allocated `rgpot_potential_t*`, or `NULL` on failure.
43/// The caller must eventually pass the returned pointer to
44/// `rgpot_potential_free`.
45#[no_mangle]
46pub unsafe extern "C" fn rgpot_potential_new(
47    callback: PotentialCallback,
48    user_data: *mut c_void,
49    free_fn: Option<unsafe extern "C" fn(*mut c_void)>,
50) -> *mut rgpot_potential_t {
51    let pot = PotentialImpl::new(callback, user_data, free_fn);
52    Box::into_raw(Box::new(pot))
53}
54
55/// Perform a force/energy calculation using the potential handle.
56///
57/// - `pot`: a valid handle obtained from `rgpot_potential_new`.
58/// - `input`: pointer to the input configuration (DLPack tensors).
59/// - `output`: pointer to the output struct. The callback sets `output->forces`
60///   to a callee-allocated DLPack tensor.
61///
62/// Returns `RGPOT_SUCCESS` on success, or an error status code.
63/// On error, call `rgpot_last_error()` for details.
64#[no_mangle]
65pub unsafe extern "C" fn rgpot_potential_calculate(
66    pot: *const rgpot_potential_t,
67    input: *const rgpot_force_input_t,
68    output: *mut rgpot_force_out_t,
69) -> rgpot_status_t {
70    catch_unwind(std::panic::AssertUnwindSafe(|| {
71        if pot.is_null() {
72            set_last_error("rgpot_potential_calculate: pot is NULL");
73            return rgpot_status_t::RGPOT_INVALID_PARAMETER;
74        }
75        if input.is_null() {
76            set_last_error("rgpot_potential_calculate: input is NULL");
77            return rgpot_status_t::RGPOT_INVALID_PARAMETER;
78        }
79        if output.is_null() {
80            set_last_error("rgpot_potential_calculate: output is NULL");
81            return rgpot_status_t::RGPOT_INVALID_PARAMETER;
82        }
83
84        let pot_ref = unsafe { &*pot };
85        unsafe { pot_ref.calculate(input, output) }
86    }))
87}
88
89/// Free a potential handle previously obtained from `rgpot_potential_new`.
90///
91/// If `pot` is `NULL`, this function is a no-op.
92/// After this call, `pot` must not be used again.
93#[no_mangle]
94pub unsafe extern "C" fn rgpot_potential_free(pot: *mut rgpot_potential_t) {
95    if !pot.is_null() {
96        drop(unsafe { Box::from_raw(pot) });
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::tensor::{create_owned_f64_tensor, rgpot_tensor_free};
104    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
105
106    /// A trivial callback that sets energy = n_atoms and returns success.
107    unsafe extern "C" fn sum_callback(
108        _ud: *mut c_void,
109        input: *const rgpot_force_input_t,
110        output: *mut rgpot_force_out_t,
111    ) -> rgpot_status_t {
112        let inp = unsafe { &*input };
113        let out = unsafe { &mut *output };
114        let n = unsafe { inp.n_atoms() }.unwrap_or(0);
115        out.energy = n as f64;
116        out.variance = 0.0;
117        out.forces = create_owned_f64_tensor(vec![0.0; n * 3], vec![n as i64, 3]);
118        rgpot_status_t::RGPOT_SUCCESS
119    }
120
121    /// A callback that always returns an error.
122    unsafe extern "C" fn failing_callback(
123        _ud: *mut c_void,
124        _input: *const rgpot_force_input_t,
125        _output: *mut rgpot_force_out_t,
126    ) -> rgpot_status_t {
127        crate::status::set_last_error("deliberately failed");
128        rgpot_status_t::RGPOT_INTERNAL_ERROR
129    }
130
131    /// A callback that reads user_data as a counter and increments it.
132    unsafe extern "C" fn counting_callback(
133        ud: *mut c_void,
134        _input: *const rgpot_force_input_t,
135        output: *mut rgpot_force_out_t,
136    ) -> rgpot_status_t {
137        let counter = unsafe { &*(ud as *const AtomicU32) };
138        counter.fetch_add(1, Ordering::SeqCst);
139        let out = unsafe { &mut *output };
140        out.energy = counter.load(Ordering::SeqCst) as f64;
141        rgpot_status_t::RGPOT_SUCCESS
142    }
143
144    fn make_test_input() -> (
145        [f64; 6],
146        [i32; 2],
147        [f64; 9],
148        rgpot_force_input_t,
149    ) {
150        let mut pos = [0.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0];
151        let mut atmnrs = [1_i32, 1];
152        let mut box_ = [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0];
153        let input = unsafe {
154            crate::c_api::types::rgpot_force_input_create(
155                2,
156                pos.as_mut_ptr(),
157                atmnrs.as_mut_ptr(),
158                box_.as_mut_ptr(),
159            )
160        };
161        (pos, atmnrs, box_, input)
162    }
163
164    unsafe fn cleanup(input: &mut rgpot_force_input_t, output: &mut rgpot_force_out_t) {
165        unsafe {
166            rgpot_tensor_free(output.forces);
167            crate::c_api::types::rgpot_force_input_free(input);
168        }
169        output.forces = std::ptr::null_mut();
170    }
171
172    // --- rgpot_potential_new / rgpot_potential_free ---
173
174    #[test]
175    fn new_returns_non_null() {
176        let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
177        assert!(!pot.is_null());
178        unsafe { rgpot_potential_free(pot) };
179    }
180
181    #[test]
182    fn free_null_is_noop() {
183        unsafe { rgpot_potential_free(std::ptr::null_mut()) };
184    }
185
186    // --- rgpot_potential_calculate: null argument handling ---
187
188    #[test]
189    fn calculate_null_pot_returns_invalid_parameter() {
190        let (_pos, _atmnrs, _box_, mut input) = make_test_input();
191        let mut output = rgpot_force_out_t {
192            forces: std::ptr::null_mut(),
193            energy: 0.0,
194            variance: 0.0,
195        };
196
197        let status = unsafe {
198            rgpot_potential_calculate(std::ptr::null(), &input, &mut output)
199        };
200        assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
201        unsafe { crate::c_api::types::rgpot_force_input_free(&mut input) };
202    }
203
204    #[test]
205    fn calculate_null_input_returns_invalid_parameter() {
206        let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
207        let mut output = rgpot_force_out_t {
208            forces: std::ptr::null_mut(),
209            energy: 0.0,
210            variance: 0.0,
211        };
212
213        let status = unsafe {
214            rgpot_potential_calculate(pot, std::ptr::null(), &mut output)
215        };
216        assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
217        unsafe { rgpot_potential_free(pot as *mut _) };
218    }
219
220    #[test]
221    fn calculate_null_output_returns_invalid_parameter() {
222        let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
223        let (_pos, _atmnrs, _box_, mut input) = make_test_input();
224
225        let status = unsafe {
226            rgpot_potential_calculate(pot, &input, std::ptr::null_mut())
227        };
228        assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
229        unsafe {
230            crate::c_api::types::rgpot_force_input_free(&mut input);
231            rgpot_potential_free(pot as *mut _);
232        }
233    }
234
235    // --- rgpot_potential_calculate: success path ---
236
237    #[test]
238    fn full_lifecycle_new_calculate_free() {
239        let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
240        let (_pos, _atmnrs, _box_, mut input) = make_test_input();
241        let mut output = rgpot_force_out_t {
242            forces: std::ptr::null_mut(),
243            energy: 0.0,
244            variance: 0.0,
245        };
246
247        let status = unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
248        assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
249        assert_eq!(output.energy, 2.0); // n_atoms = 2
250        assert!(!output.forces.is_null());
251
252        unsafe {
253            cleanup(&mut input, &mut output);
254            rgpot_potential_free(pot as *mut _);
255        }
256    }
257
258    // --- Error propagation from callback ---
259
260    #[test]
261    fn callback_error_propagates() {
262        let pot = unsafe {
263            rgpot_potential_new(failing_callback, std::ptr::null_mut(), None)
264        };
265        let (_pos, _atmnrs, _box_, mut input) = make_test_input();
266        let mut output = rgpot_force_out_t {
267            forces: std::ptr::null_mut(),
268            energy: 0.0,
269            variance: 0.0,
270        };
271
272        let status = unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
273        assert_eq!(status, rgpot_status_t::RGPOT_INTERNAL_ERROR);
274
275        let msg = unsafe { std::ffi::CStr::from_ptr(crate::status::rgpot_last_error()) };
276        assert_eq!(msg.to_str().unwrap(), "deliberately failed");
277
278        unsafe {
279            crate::c_api::types::rgpot_force_input_free(&mut input);
280            rgpot_potential_free(pot as *mut _);
281        }
282    }
283
284    // --- user_data passthrough ---
285
286    #[test]
287    fn user_data_is_forwarded_to_callback() {
288        let counter = AtomicU32::new(0);
289        let pot = unsafe {
290            rgpot_potential_new(
291                counting_callback,
292                &counter as *const _ as *mut c_void,
293                None,
294            )
295        };
296
297        let (_pos, _atmnrs, _box_, mut input) = make_test_input();
298        let mut output = rgpot_force_out_t {
299            forces: std::ptr::null_mut(),
300            energy: 0.0,
301            variance: 0.0,
302        };
303
304        unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
305        assert_eq!(counter.load(Ordering::SeqCst), 1);
306        assert_eq!(output.energy, 1.0);
307
308        unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
309        assert_eq!(counter.load(Ordering::SeqCst), 2);
310        assert_eq!(output.energy, 2.0);
311
312        unsafe {
313            crate::c_api::types::rgpot_force_input_free(&mut input);
314            rgpot_potential_free(pot as *mut _);
315        }
316    }
317
318    // --- free_fn is called on drop ---
319
320    static FREE_CALLED: AtomicBool = AtomicBool::new(false);
321
322    unsafe extern "C" fn track_free(_ptr: *mut c_void) {
323        FREE_CALLED.store(true, Ordering::SeqCst);
324    }
325
326    #[test]
327    fn free_fn_is_invoked_on_drop() {
328        FREE_CALLED.store(false, Ordering::SeqCst);
329
330        let mut dummy: u8 = 42;
331        let pot = unsafe {
332            rgpot_potential_new(
333                sum_callback,
334                &mut dummy as *mut u8 as *mut c_void,
335                Some(track_free),
336            )
337        };
338
339        assert!(!FREE_CALLED.load(Ordering::SeqCst));
340        unsafe { rgpot_potential_free(pot) };
341        assert!(FREE_CALLED.load(Ordering::SeqCst));
342    }
343
344    // --- Multiple sequential calculations on the same handle ---
345
346    #[test]
347    fn multiple_calculations_same_handle() {
348        let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
349
350        for n in 1..=5_usize {
351            let mut pos = vec![0.0_f64; n * 3];
352            let mut atmnrs = vec![1_i32; n];
353            let mut box_ = [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0];
354
355            let mut input = unsafe {
356                crate::c_api::types::rgpot_force_input_create(
357                    n,
358                    pos.as_mut_ptr(),
359                    atmnrs.as_mut_ptr(),
360                    box_.as_mut_ptr(),
361                )
362            };
363            let mut output = rgpot_force_out_t {
364                forces: std::ptr::null_mut(),
365                energy: 0.0,
366                variance: 0.0,
367            };
368
369            let status = unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
370            assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
371            assert_eq!(output.energy, n as f64);
372
373            unsafe { cleanup(&mut input, &mut output) };
374        }
375
376        unsafe { rgpot_potential_free(pot as *mut _) };
377    }
378}