1use std::os::raw::c_void;
29
30use crate::potential::{PotentialCallback, PotentialImpl, rgpot_potential_t};
31use crate::status::{catch_unwind, rgpot_status_t, set_last_error};
32use crate::types::{rgpot_force_input_t, rgpot_force_out_t};
33
34#[no_mangle]
46pub unsafe extern "C" fn rgpot_potential_new(
47 callback: PotentialCallback,
48 user_data: *mut c_void,
49 free_fn: Option<unsafe extern "C" fn(*mut c_void)>,
50) -> *mut rgpot_potential_t {
51 let pot = PotentialImpl::new(callback, user_data, free_fn);
52 Box::into_raw(Box::new(pot))
53}
54
55#[no_mangle]
65pub unsafe extern "C" fn rgpot_potential_calculate(
66 pot: *const rgpot_potential_t,
67 input: *const rgpot_force_input_t,
68 output: *mut rgpot_force_out_t,
69) -> rgpot_status_t {
70 catch_unwind(std::panic::AssertUnwindSafe(|| {
71 if pot.is_null() {
72 set_last_error("rgpot_potential_calculate: pot is NULL");
73 return rgpot_status_t::RGPOT_INVALID_PARAMETER;
74 }
75 if input.is_null() {
76 set_last_error("rgpot_potential_calculate: input is NULL");
77 return rgpot_status_t::RGPOT_INVALID_PARAMETER;
78 }
79 if output.is_null() {
80 set_last_error("rgpot_potential_calculate: output is NULL");
81 return rgpot_status_t::RGPOT_INVALID_PARAMETER;
82 }
83
84 let pot_ref = unsafe { &*pot };
85 unsafe { pot_ref.calculate(input, output) }
86 }))
87}
88
89#[no_mangle]
94pub unsafe extern "C" fn rgpot_potential_free(pot: *mut rgpot_potential_t) {
95 if !pot.is_null() {
96 drop(unsafe { Box::from_raw(pot) });
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::tensor::{create_owned_f64_tensor, rgpot_tensor_free};
104 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
105
106 unsafe extern "C" fn sum_callback(
108 _ud: *mut c_void,
109 input: *const rgpot_force_input_t,
110 output: *mut rgpot_force_out_t,
111 ) -> rgpot_status_t {
112 let inp = unsafe { &*input };
113 let out = unsafe { &mut *output };
114 let n = unsafe { inp.n_atoms() }.unwrap_or(0);
115 out.energy = n as f64;
116 out.variance = 0.0;
117 out.forces = create_owned_f64_tensor(vec![0.0; n * 3], vec![n as i64, 3]);
118 rgpot_status_t::RGPOT_SUCCESS
119 }
120
121 unsafe extern "C" fn failing_callback(
123 _ud: *mut c_void,
124 _input: *const rgpot_force_input_t,
125 _output: *mut rgpot_force_out_t,
126 ) -> rgpot_status_t {
127 crate::status::set_last_error("deliberately failed");
128 rgpot_status_t::RGPOT_INTERNAL_ERROR
129 }
130
131 unsafe extern "C" fn counting_callback(
133 ud: *mut c_void,
134 _input: *const rgpot_force_input_t,
135 output: *mut rgpot_force_out_t,
136 ) -> rgpot_status_t {
137 let counter = unsafe { &*(ud as *const AtomicU32) };
138 counter.fetch_add(1, Ordering::SeqCst);
139 let out = unsafe { &mut *output };
140 out.energy = counter.load(Ordering::SeqCst) as f64;
141 rgpot_status_t::RGPOT_SUCCESS
142 }
143
144 fn make_test_input() -> (
145 [f64; 6],
146 [i32; 2],
147 [f64; 9],
148 rgpot_force_input_t,
149 ) {
150 let mut pos = [0.0_f64, 0.0, 0.0, 1.0, 0.0, 0.0];
151 let mut atmnrs = [1_i32, 1];
152 let mut box_ = [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0];
153 let input = unsafe {
154 crate::c_api::types::rgpot_force_input_create(
155 2,
156 pos.as_mut_ptr(),
157 atmnrs.as_mut_ptr(),
158 box_.as_mut_ptr(),
159 )
160 };
161 (pos, atmnrs, box_, input)
162 }
163
164 unsafe fn cleanup(input: &mut rgpot_force_input_t, output: &mut rgpot_force_out_t) {
165 unsafe {
166 rgpot_tensor_free(output.forces);
167 crate::c_api::types::rgpot_force_input_free(input);
168 }
169 output.forces = std::ptr::null_mut();
170 }
171
172 #[test]
175 fn new_returns_non_null() {
176 let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
177 assert!(!pot.is_null());
178 unsafe { rgpot_potential_free(pot) };
179 }
180
181 #[test]
182 fn free_null_is_noop() {
183 unsafe { rgpot_potential_free(std::ptr::null_mut()) };
184 }
185
186 #[test]
189 fn calculate_null_pot_returns_invalid_parameter() {
190 let (_pos, _atmnrs, _box_, mut input) = make_test_input();
191 let mut output = rgpot_force_out_t {
192 forces: std::ptr::null_mut(),
193 energy: 0.0,
194 variance: 0.0,
195 };
196
197 let status = unsafe {
198 rgpot_potential_calculate(std::ptr::null(), &input, &mut output)
199 };
200 assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
201 unsafe { crate::c_api::types::rgpot_force_input_free(&mut input) };
202 }
203
204 #[test]
205 fn calculate_null_input_returns_invalid_parameter() {
206 let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
207 let mut output = rgpot_force_out_t {
208 forces: std::ptr::null_mut(),
209 energy: 0.0,
210 variance: 0.0,
211 };
212
213 let status = unsafe {
214 rgpot_potential_calculate(pot, std::ptr::null(), &mut output)
215 };
216 assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
217 unsafe { rgpot_potential_free(pot as *mut _) };
218 }
219
220 #[test]
221 fn calculate_null_output_returns_invalid_parameter() {
222 let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
223 let (_pos, _atmnrs, _box_, mut input) = make_test_input();
224
225 let status = unsafe {
226 rgpot_potential_calculate(pot, &input, std::ptr::null_mut())
227 };
228 assert_eq!(status, rgpot_status_t::RGPOT_INVALID_PARAMETER);
229 unsafe {
230 crate::c_api::types::rgpot_force_input_free(&mut input);
231 rgpot_potential_free(pot as *mut _);
232 }
233 }
234
235 #[test]
238 fn full_lifecycle_new_calculate_free() {
239 let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
240 let (_pos, _atmnrs, _box_, mut input) = make_test_input();
241 let mut output = rgpot_force_out_t {
242 forces: std::ptr::null_mut(),
243 energy: 0.0,
244 variance: 0.0,
245 };
246
247 let status = unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
248 assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
249 assert_eq!(output.energy, 2.0); assert!(!output.forces.is_null());
251
252 unsafe {
253 cleanup(&mut input, &mut output);
254 rgpot_potential_free(pot as *mut _);
255 }
256 }
257
258 #[test]
261 fn callback_error_propagates() {
262 let pot = unsafe {
263 rgpot_potential_new(failing_callback, std::ptr::null_mut(), None)
264 };
265 let (_pos, _atmnrs, _box_, mut input) = make_test_input();
266 let mut output = rgpot_force_out_t {
267 forces: std::ptr::null_mut(),
268 energy: 0.0,
269 variance: 0.0,
270 };
271
272 let status = unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
273 assert_eq!(status, rgpot_status_t::RGPOT_INTERNAL_ERROR);
274
275 let msg = unsafe { std::ffi::CStr::from_ptr(crate::status::rgpot_last_error()) };
276 assert_eq!(msg.to_str().unwrap(), "deliberately failed");
277
278 unsafe {
279 crate::c_api::types::rgpot_force_input_free(&mut input);
280 rgpot_potential_free(pot as *mut _);
281 }
282 }
283
284 #[test]
287 fn user_data_is_forwarded_to_callback() {
288 let counter = AtomicU32::new(0);
289 let pot = unsafe {
290 rgpot_potential_new(
291 counting_callback,
292 &counter as *const _ as *mut c_void,
293 None,
294 )
295 };
296
297 let (_pos, _atmnrs, _box_, mut input) = make_test_input();
298 let mut output = rgpot_force_out_t {
299 forces: std::ptr::null_mut(),
300 energy: 0.0,
301 variance: 0.0,
302 };
303
304 unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
305 assert_eq!(counter.load(Ordering::SeqCst), 1);
306 assert_eq!(output.energy, 1.0);
307
308 unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
309 assert_eq!(counter.load(Ordering::SeqCst), 2);
310 assert_eq!(output.energy, 2.0);
311
312 unsafe {
313 crate::c_api::types::rgpot_force_input_free(&mut input);
314 rgpot_potential_free(pot as *mut _);
315 }
316 }
317
318 static FREE_CALLED: AtomicBool = AtomicBool::new(false);
321
322 unsafe extern "C" fn track_free(_ptr: *mut c_void) {
323 FREE_CALLED.store(true, Ordering::SeqCst);
324 }
325
326 #[test]
327 fn free_fn_is_invoked_on_drop() {
328 FREE_CALLED.store(false, Ordering::SeqCst);
329
330 let mut dummy: u8 = 42;
331 let pot = unsafe {
332 rgpot_potential_new(
333 sum_callback,
334 &mut dummy as *mut u8 as *mut c_void,
335 Some(track_free),
336 )
337 };
338
339 assert!(!FREE_CALLED.load(Ordering::SeqCst));
340 unsafe { rgpot_potential_free(pot) };
341 assert!(FREE_CALLED.load(Ordering::SeqCst));
342 }
343
344 #[test]
347 fn multiple_calculations_same_handle() {
348 let pot = unsafe { rgpot_potential_new(sum_callback, std::ptr::null_mut(), None) };
349
350 for n in 1..=5_usize {
351 let mut pos = vec![0.0_f64; n * 3];
352 let mut atmnrs = vec![1_i32; n];
353 let mut box_ = [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0];
354
355 let mut input = unsafe {
356 crate::c_api::types::rgpot_force_input_create(
357 n,
358 pos.as_mut_ptr(),
359 atmnrs.as_mut_ptr(),
360 box_.as_mut_ptr(),
361 )
362 };
363 let mut output = rgpot_force_out_t {
364 forces: std::ptr::null_mut(),
365 energy: 0.0,
366 variance: 0.0,
367 };
368
369 let status = unsafe { rgpot_potential_calculate(pot, &input, &mut output) };
370 assert_eq!(status, rgpot_status_t::RGPOT_SUCCESS);
371 assert_eq!(output.energy, n as f64);
372
373 unsafe { cleanup(&mut input, &mut output) };
374 }
375
376 unsafe { rgpot_potential_free(pot as *mut _) };
377 }
378}