1use std::os::raw::c_void;
31
32use crate::status::rgpot_status_t;
33use crate::types::{rgpot_force_input_t, rgpot_force_out_t};
34
35pub type PotentialCallback = unsafe extern "C" fn(
44 user_data: *mut c_void,
45 input: *const rgpot_force_input_t,
46 output: *mut rgpot_force_out_t,
47) -> rgpot_status_t;
48
49pub type FreeFn = unsafe extern "C" fn(*mut c_void);
51
52pub struct PotentialImpl {
54 pub(crate) callback: PotentialCallback,
55 pub(crate) user_data: *mut c_void,
56 pub(crate) free_fn: Option<FreeFn>,
57}
58
59unsafe impl Send for PotentialImpl {}
62
63impl PotentialImpl {
64 pub fn new(
66 callback: PotentialCallback,
67 user_data: *mut c_void,
68 free_fn: Option<FreeFn>,
69 ) -> Self {
70 Self {
71 callback,
72 user_data,
73 free_fn,
74 }
75 }
76
77 pub unsafe fn calculate(
83 &self,
84 input: *const rgpot_force_input_t,
85 output: *mut rgpot_force_out_t,
86 ) -> rgpot_status_t {
87 (self.callback)(self.user_data, input, output)
88 }
89}
90
91impl Drop for PotentialImpl {
92 fn drop(&mut self) {
93 if let Some(free) = self.free_fn {
94 if !self.user_data.is_null() {
95 unsafe { free(self.user_data) };
96 }
97 }
98 }
99}
100
101pub type rgpot_potential_t = PotentialImpl;
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::tensor::{
110 create_owned_f64_tensor, rgpot_tensor_cpu_f64_2d, rgpot_tensor_cpu_f64_matrix3,
111 rgpot_tensor_cpu_i32_1d, rgpot_tensor_free,
112 };
113 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
114
115 struct TestIO {
117 pos: Vec<f64>,
118 atmnrs: Vec<i32>,
119 box_: [f64; 9],
120 }
121
122 impl TestIO {
123 fn new(n: usize) -> Self {
124 Self {
125 pos: vec![0.0; n * 3],
126 atmnrs: vec![1; n],
127 box_: [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0],
128 }
129 }
130
131 fn make_input(&mut self) -> rgpot_force_input_t {
132 let n = self.atmnrs.len();
133 rgpot_force_input_t {
134 positions: unsafe {
135 rgpot_tensor_cpu_f64_2d(self.pos.as_mut_ptr(), n as i64, 3)
136 },
137 atomic_numbers: unsafe {
138 rgpot_tensor_cpu_i32_1d(self.atmnrs.as_mut_ptr(), n as i64)
139 },
140 box_matrix: unsafe {
141 rgpot_tensor_cpu_f64_matrix3(self.box_.as_mut_ptr())
142 },
143 }
144 }
145
146 unsafe fn free_input(&self, input: &rgpot_force_input_t) {
147 unsafe {
148 rgpot_tensor_free(input.positions);
149 rgpot_tensor_free(input.atomic_numbers);
150 rgpot_tensor_free(input.box_matrix);
151 }
152 }
153 }
154
155 unsafe extern "C" fn mock_callback(
157 _user_data: *mut c_void,
158 input: *const rgpot_force_input_t,
159 output: *mut rgpot_force_out_t,
160 ) -> rgpot_status_t {
161 let inp = unsafe { &*input };
162 let out = unsafe { &mut *output };
163 let n = unsafe { inp.n_atoms() }.unwrap_or(0);
164 out.energy = n as f64;
165 out.variance = 0.0;
166 out.forces = create_owned_f64_tensor(vec![0.0; n * 3], vec![n as i64, 3]);
167 rgpot_status_t::RGPOT_SUCCESS
168 }
169
170 unsafe extern "C" fn error_callback(
171 _ud: *mut c_void,
172 _input: *const rgpot_force_input_t,
173 _output: *mut rgpot_force_out_t,
174 ) -> rgpot_status_t {
175 rgpot_status_t::RGPOT_INVALID_PARAMETER
176 }
177
178 unsafe extern "C" fn writing_callback(
179 _ud: *mut c_void,
180 input: *const rgpot_force_input_t,
181 output: *mut rgpot_force_out_t,
182 ) -> rgpot_status_t {
183 let inp = unsafe { &*input };
184 let out = unsafe { &mut *output };
185 let n = unsafe { inp.n_atoms() }.unwrap_or(0);
186 let forces: Vec<f64> = (1..=(n * 3)).map(|i| i as f64).collect();
187 out.forces = create_owned_f64_tensor(forces, vec![n as i64, 3]);
188 out.energy = -42.0;
189 out.variance = 0.5;
190 rgpot_status_t::RGPOT_SUCCESS
191 }
192
193 #[test]
194 fn test_potential_callback() {
195 let pot = PotentialImpl::new(mock_callback, std::ptr::null_mut(), None);
196
197 let mut io = TestIO::new(3);
198 let input = io.make_input();
199 let mut output = rgpot_force_out_t {
200 forces: std::ptr::null_mut(),
201 energy: 0.0,
202 variance: 0.0,
203 };
204
205 let status = unsafe { pot.calculate(&input, &mut output) };
206 assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
207 assert_eq!(output.energy, 3.0);
208
209 unsafe {
210 rgpot_tensor_free(output.forces);
211 io.free_input(&input);
212 }
213 }
214
215 #[test]
216 fn callback_error_status_is_returned() {
217 let pot = PotentialImpl::new(error_callback, std::ptr::null_mut(), None);
218 let mut io = TestIO::new(1);
219 let input = io.make_input();
220 let mut output = rgpot_force_out_t {
221 forces: std::ptr::null_mut(),
222 energy: 0.0,
223 variance: 0.0,
224 };
225 let status = unsafe { pot.calculate(&input, &mut output) };
226 assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
227 unsafe { io.free_input(&input) };
228 }
229
230 #[test]
231 fn callback_writes_forces_energy_variance() {
232 let pot = PotentialImpl::new(writing_callback, std::ptr::null_mut(), None);
233 let mut io = TestIO::new(2);
234 let input = io.make_input();
235 let mut output = rgpot_force_out_t {
236 forces: std::ptr::null_mut(),
237 energy: 0.0,
238 variance: 0.0,
239 };
240
241 let status = unsafe { pot.calculate(&input, &mut output) };
242 assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
243 assert_eq!(output.energy, -42.0);
244 assert_eq!(output.variance, 0.5);
245
246 assert!(!output.forces.is_null());
248 let ft = unsafe { &(*output.forces).dl_tensor };
249 let forces = unsafe { std::slice::from_raw_parts(ft.data as *const f64, 6) };
250 assert_eq!(forces, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
251
252 unsafe {
253 rgpot_tensor_free(output.forces);
254 io.free_input(&input);
255 }
256 }
257
258 static DROP_CALLED: AtomicBool = AtomicBool::new(false);
259
260 unsafe extern "C" fn track_drop(ptr: *mut c_void) {
261 DROP_CALLED.store(true, Ordering::SeqCst);
262 let val = unsafe { *(ptr as *const u64) };
263 assert_eq!(val, 0xDEAD_BEEF);
264 }
265
266 #[test]
267 fn drop_calls_free_fn_with_user_data() {
268 DROP_CALLED.store(false, Ordering::SeqCst);
269 let mut sentinel: u64 = 0xDEAD_BEEF;
270 {
271 let _pot = PotentialImpl::new(
272 mock_callback,
273 &mut sentinel as *mut u64 as *mut c_void,
274 Some(track_drop),
275 );
276 assert!(!DROP_CALLED.load(Ordering::SeqCst));
277 }
278 assert!(DROP_CALLED.load(Ordering::SeqCst));
279 }
280
281 #[test]
282 fn drop_skips_free_fn_when_user_data_is_null() {
283 let _pot = PotentialImpl::new(mock_callback, std::ptr::null_mut(), Some(track_drop));
284 }
285
286 #[test]
287 fn drop_without_free_fn_is_safe() {
288 let _pot = PotentialImpl::new(mock_callback, std::ptr::null_mut(), None);
289 }
290
291 #[test]
292 fn user_data_is_passed_through() {
293 static CALL_COUNT: AtomicU32 = AtomicU32::new(0);
294
295 unsafe extern "C" fn count_cb(
296 ud: *mut c_void,
297 _input: *const rgpot_force_input_t,
298 output: *mut rgpot_force_out_t,
299 ) -> rgpot_status_t {
300 let ctr = unsafe { &*(ud as *const AtomicU32) };
301 ctr.fetch_add(1, Ordering::SeqCst);
302 unsafe { (*output).energy = ctr.load(Ordering::SeqCst) as f64 };
303 rgpot_status_t::RGPOT_SUCCESS
304 }
305
306 CALL_COUNT.store(0, Ordering::SeqCst);
307 let pot = PotentialImpl::new(
308 count_cb,
309 &CALL_COUNT as *const _ as *mut c_void,
310 None,
311 );
312
313 let mut io = TestIO::new(1);
314 let input = io.make_input();
315 let mut output = rgpot_force_out_t {
316 forces: std::ptr::null_mut(),
317 energy: 0.0,
318 variance: 0.0,
319 };
320
321 unsafe { pot.calculate(&input, &mut output) };
322 assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1);
323 assert_eq!(output.energy, 1.0);
324
325 unsafe { pot.calculate(&input, &mut output) };
326 assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 2);
327 assert_eq!(output.energy, 2.0);
328
329 unsafe { io.free_input(&input) };
330 }
331
332 #[test]
333 fn potential_impl_is_send() {
334 fn assert_send<T: Send>() {}
335 assert_send::<PotentialImpl>();
336 }
337}