1#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
2use std::ffi::CStr;
3#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
4use std::os::raw::c_char;
5use std::ptr;
6
7use crate::c_api_utils::{DIFFSOL_BAD_ARG, DIFFSOL_ERR, DIFFSOL_OK, valid_f64_ptr};
8use crate::host_array::HostArray;
9#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
10use crate::jit_c::jit_backend_from_i32;
11use crate::linear_solver_type_c::{linear_solver_from_i32, linear_solver_to_i32};
12use crate::matrix_type_c::{matrix_type_from_i32, matrix_type_to_i32};
13use crate::ode::OdeWrapper;
14use crate::ode_solver_type_c::{ode_solver_from_i32, ode_solver_to_i32};
15use crate::scalar_type::ScalarType;
16use crate::solution_wrapper::SolutionWrapper;
17use crate::{c_error, c_invalid_arg};
18
19fn boxed_host_array(array: HostArray) -> *mut HostArray {
20 Box::into_raw(Box::new(array))
21}
22
23fn parse_ode_new_common_args(
24 matrix_type: i32,
25 linear_solver: i32,
26 ode_solver: i32,
27) -> Option<(
28 crate::matrix_type::MatrixType,
29 crate::linear_solver_type::LinearSolverType,
30 crate::ode_solver_type::OdeSolverType,
31)> {
32 let matrix_type = match matrix_type_from_i32(matrix_type) {
33 Some(value) => value,
34 None => {
35 c_invalid_arg!("invalid matrix_type");
36 return None;
37 }
38 };
39 let linear_solver = match linear_solver_from_i32(linear_solver) {
40 Some(value) => value,
41 None => {
42 c_invalid_arg!("invalid linear_solver");
43 return None;
44 }
45 };
46 let ode_solver = match ode_solver_from_i32(ode_solver) {
47 Some(value) => value,
48 None => {
49 c_invalid_arg!("invalid ode_solver");
50 return None;
51 }
52 };
53 Some((matrix_type, linear_solver, ode_solver))
54}
55
56#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
57fn parse_ode_new_jit_args(
58 code: *const c_char,
59 matrix_type: i32,
60 linear_solver: i32,
61 ode_solver: i32,
62) -> Option<(
63 String,
64 crate::matrix_type::MatrixType,
65 crate::linear_solver_type::LinearSolverType,
66 crate::ode_solver_type::OdeSolverType,
67)> {
68 if code.is_null() {
69 c_invalid_arg!("code is null");
70 return None;
71 }
72 let code = unsafe { CStr::from_ptr(code) };
73 let code = match code.to_str() {
74 Ok(value) => value.to_owned(),
75 Err(_) => {
76 c_error!("code is not valid UTF-8");
77 return None;
78 }
79 };
80 let (matrix_type, linear_solver, ode_solver) =
81 parse_ode_new_common_args(matrix_type, linear_solver, ode_solver)?;
82 Some((code, matrix_type, linear_solver, ode_solver))
83}
84
85#[unsafe(no_mangle)]
91pub unsafe extern "C" fn diffsol_host_array_list_free(list: *mut *mut HostArray, len: usize) {
92 if list.is_null() {
93 c_invalid_arg!("host array list is null");
94 return;
95 }
96 unsafe {
97 drop(Vec::from_raw_parts(list, len, len));
98 }
99}
100
101#[cfg(feature = "external")]
102#[unsafe(no_mangle)]
109pub unsafe extern "C" fn diffsol_ode_new_external(
110 matrix_type: i32,
111 linear_solver: i32,
112 ode_solver: i32,
113 rhs_state_deps_ptr: *const usize,
114 rhs_state_deps_len: usize,
115 rhs_input_deps_ptr: *const usize,
116 rhs_input_deps_len: usize,
117 mass_state_deps_ptr: *const usize,
118 mass_state_deps_len: usize,
119) -> *mut OdeWrapper {
120 let Some((matrix_type, linear_solver, ode_solver)) =
121 parse_ode_new_common_args(matrix_type, linear_solver, ode_solver)
122 else {
123 return ptr::null_mut();
124 };
125
126 let rhs_state_deps = if !rhs_state_deps_ptr.is_null() && rhs_state_deps_len > 0 {
127 unsafe {
128 let slice = std::slice::from_raw_parts(
129 rhs_state_deps_ptr as *const (usize, usize),
130 rhs_state_deps_len,
131 );
132 slice.to_vec()
133 }
134 } else {
135 Vec::new()
136 };
137
138 let rhs_input_deps = if !rhs_input_deps_ptr.is_null() && rhs_input_deps_len > 0 {
139 unsafe {
140 let slice = std::slice::from_raw_parts(
141 rhs_input_deps_ptr as *const (usize, usize),
142 rhs_input_deps_len,
143 );
144 slice.to_vec()
145 }
146 } else {
147 Vec::new()
148 };
149
150 let mass_state_deps = if !mass_state_deps_ptr.is_null() && mass_state_deps_len > 0 {
151 unsafe {
152 let slice = std::slice::from_raw_parts(
153 mass_state_deps_ptr as *const (usize, usize),
154 mass_state_deps_len,
155 );
156 slice.to_vec()
157 }
158 } else {
159 Vec::new()
160 };
161
162 let scalar_type = ScalarType::F64;
163 match OdeWrapper::new_external(
164 rhs_state_deps,
165 rhs_input_deps,
166 mass_state_deps,
167 scalar_type,
168 matrix_type,
169 linear_solver,
170 ode_solver,
171 ) {
172 Ok(ode) => Box::into_raw(Box::new(ode)),
173 Err(err) => {
174 c_error!(&format!("{}", err));
175 ptr::null_mut()
176 }
177 }
178}
179
180#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
181#[unsafe(no_mangle)]
188pub unsafe extern "C" fn diffsol_ode_new_jit(
189 code: *const c_char,
190 jit_backend: i32,
191 matrix_type: i32,
192 linear_solver: i32,
193 ode_solver: i32,
194) -> *mut OdeWrapper {
195 let Some((code, matrix_type, linear_solver, ode_solver)) =
196 parse_ode_new_jit_args(code, matrix_type, linear_solver, ode_solver)
197 else {
198 return ptr::null_mut();
199 };
200 let jit_backend = match jit_backend_from_i32(jit_backend) {
201 Some(value) => value,
202 None => {
203 c_invalid_arg!("invalid jit_backend_type");
204 return ptr::null_mut();
205 }
206 };
207 let scalar_type = ScalarType::F64;
208 match OdeWrapper::new_jit(
209 &code,
210 jit_backend,
211 scalar_type,
212 matrix_type,
213 linear_solver,
214 ode_solver,
215 ) {
216 Ok(ode) => Box::into_raw(Box::new(ode)),
217 Err(err) => {
218 c_error!(&format!("{}", err));
219 ptr::null_mut()
220 }
221 }
222}
223
224#[unsafe(no_mangle)]
230pub unsafe extern "C" fn diffsol_ode_free(ode: *mut OdeWrapper) {
231 if ode.is_null() {
232 c_invalid_arg!("ode is null");
233 return;
234 }
235 unsafe {
236 drop(Box::from_raw(ode));
237 }
238}
239
240#[unsafe(no_mangle)]
247pub unsafe extern "C" fn diffsol_ode_get_ic_options(
248 ode: *const OdeWrapper,
249 out_options: *mut *mut crate::initial_condition_options::InitialConditionSolverOptions,
250) -> i32 {
251 if ode.is_null() || out_options.is_null() {
252 return c_invalid_arg!("invalid arguments to diffsol_ode_get_ic_options");
253 }
254 let ode = unsafe { &*ode };
255 let options = ode.get_ic_options();
256 let boxed = Box::new(options);
257 unsafe {
258 *out_options = Box::into_raw(boxed);
259 }
260 DIFFSOL_OK
261}
262
263#[unsafe(no_mangle)]
270pub unsafe extern "C" fn diffsol_ode_get_options(
271 ode: *const OdeWrapper,
272 out_options: *mut *mut crate::ode_options::OdeSolverOptions,
273) -> i32 {
274 if ode.is_null() || out_options.is_null() {
275 return c_invalid_arg!("invalid arguments to diffsol_ode_get_options");
276 }
277 let ode = unsafe { &*ode };
278 let options = ode.get_options();
279 let boxed = Box::new(options);
280 unsafe {
281 *out_options = Box::into_raw(boxed);
282 }
283 DIFFSOL_OK
284}
285
286#[unsafe(no_mangle)]
293pub unsafe extern "C" fn diffsol_ode_y0(
294 ode: *mut OdeWrapper,
295 params_ptr: *const f64,
296 params_len: usize,
297 out_array: *mut *mut HostArray,
298) -> i32 {
299 if ode.is_null() || out_array.is_null() || !valid_f64_ptr(params_ptr, params_len) {
300 c_invalid_arg!("invalid arguments to diffsol_ode_y0");
301 return DIFFSOL_BAD_ARG;
302 }
303 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
304 let ode = unsafe { &mut *ode };
305 match ode.y0(params) {
306 Ok(array) => {
307 let boxed = boxed_host_array(array);
308 unsafe {
309 *out_array = boxed;
310 }
311 DIFFSOL_OK
312 }
313 Err(err) => {
314 c_error!(&format!("{}", err));
315 DIFFSOL_ERR
316 }
317 }
318}
319
320#[unsafe(no_mangle)]
327pub unsafe extern "C" fn diffsol_ode_rhs(
328 ode: *mut OdeWrapper,
329 params_ptr: *const f64,
330 params_len: usize,
331 t: f64,
332 y_ptr: *const f64,
333 y_len: usize,
334 out_array: *mut *mut HostArray,
335) -> i32 {
336 if ode.is_null()
337 || out_array.is_null()
338 || !valid_f64_ptr(params_ptr, params_len)
339 || !valid_f64_ptr(y_ptr, y_len)
340 {
341 c_invalid_arg!("invalid arguments to diffsol_ode_rhs");
342 return DIFFSOL_BAD_ARG;
343 }
344 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
345 let y = HostArray::new_vector(y_ptr as *mut u8, y_len, ScalarType::F64);
346 let ode = unsafe { &mut *ode };
347 match ode.rhs(params, t, y) {
348 Ok(array) => {
349 let boxed = boxed_host_array(array);
350 unsafe {
351 *out_array = boxed;
352 }
353 DIFFSOL_OK
354 }
355 Err(err) => {
356 c_error!(&format!("{}", err));
357 DIFFSOL_ERR
358 }
359 }
360}
361
362#[unsafe(no_mangle)]
370pub unsafe extern "C" fn diffsol_ode_rhs_jac_mul(
371 ode: *mut OdeWrapper,
372 params_ptr: *const f64,
373 params_len: usize,
374 t: f64,
375 y_ptr: *const f64,
376 y_len: usize,
377 v_ptr: *const f64,
378 v_len: usize,
379 out_array: *mut *mut HostArray,
380) -> i32 {
381 if ode.is_null()
382 || out_array.is_null()
383 || !valid_f64_ptr(params_ptr, params_len)
384 || !valid_f64_ptr(y_ptr, y_len)
385 || !valid_f64_ptr(v_ptr, v_len)
386 {
387 c_invalid_arg!("invalid arguments to diffsol_ode_rhs_jac_mul");
388 return DIFFSOL_BAD_ARG;
389 }
390 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
391 let y = HostArray::new_vector(y_ptr as *mut u8, y_len, ScalarType::F64);
392 let v = HostArray::new_vector(v_ptr as *mut u8, v_len, ScalarType::F64);
393 let ode = unsafe { &mut *ode };
394 match ode.rhs_jac_mul(params, t, y, v) {
395 Ok(array) => {
396 let boxed = boxed_host_array(array);
397 unsafe {
398 *out_array = boxed;
399 }
400 DIFFSOL_OK
401 }
402 Err(err) => {
403 c_error!(&format!("{}", err));
404 DIFFSOL_ERR
405 }
406 }
407}
408
409#[unsafe(no_mangle)]
416pub unsafe extern "C" fn diffsol_ode_solve(
417 ode: *mut OdeWrapper,
418 params_ptr: *const f64,
419 params_len: usize,
420 final_time: f64,
421 out_solution: *mut *mut SolutionWrapper,
422) -> i32 {
423 if ode.is_null() || out_solution.is_null() || !valid_f64_ptr(params_ptr, params_len) {
424 c_invalid_arg!("invalid arguments to diffsol_ode_solve");
425 return DIFFSOL_BAD_ARG;
426 }
427 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
428 let ode = unsafe { &mut *ode };
429 match ode.solve(params, final_time) {
430 Ok(new_solution) => {
431 unsafe {
432 *out_solution = Box::into_raw(Box::new(new_solution));
433 }
434 DIFFSOL_OK
435 }
436 Err(err) => {
437 c_error!(&format!("{}", err));
438 DIFFSOL_ERR
439 }
440 }
441}
442
443#[unsafe(no_mangle)]
450pub unsafe extern "C" fn diffsol_ode_solve_hybrid(
451 ode: *mut OdeWrapper,
452 params_ptr: *const f64,
453 params_len: usize,
454 final_time: f64,
455 out_solution: *mut *mut SolutionWrapper,
456) -> i32 {
457 if ode.is_null() || out_solution.is_null() || !valid_f64_ptr(params_ptr, params_len) {
458 c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid");
459 return DIFFSOL_BAD_ARG;
460 }
461 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
462 let ode = unsafe { &mut *ode };
463 match ode.solve_hybrid(params, final_time) {
464 Ok(new_solution) => {
465 unsafe {
466 *out_solution = Box::into_raw(Box::new(new_solution));
467 }
468 DIFFSOL_OK
469 }
470 Err(err) => {
471 c_error!(&format!("{}", err));
472 DIFFSOL_ERR
473 }
474 }
475}
476
477#[unsafe(no_mangle)]
484pub unsafe extern "C" fn diffsol_ode_solve_dense(
485 ode: *mut OdeWrapper,
486 params_ptr: *const f64,
487 params_len: usize,
488 t_eval_ptr: *const f64,
489 t_eval_len: usize,
490 out_solution: *mut *mut SolutionWrapper,
491) -> i32 {
492 if ode.is_null()
493 || out_solution.is_null()
494 || !valid_f64_ptr(params_ptr, params_len)
495 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
496 {
497 c_invalid_arg!("invalid arguments to diffsol_ode_solve_dense");
498 return DIFFSOL_BAD_ARG;
499 }
500 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
501 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
502 let ode = unsafe { &mut *ode };
503 match ode.solve_dense(params, t_eval) {
504 Ok(new_solution) => {
505 unsafe {
506 *out_solution = Box::into_raw(Box::new(new_solution));
507 }
508 DIFFSOL_OK
509 }
510 Err(err) => {
511 c_error!(&format!("{}", err));
512 DIFFSOL_ERR
513 }
514 }
515}
516
517#[unsafe(no_mangle)]
524pub unsafe extern "C" fn diffsol_ode_solve_hybrid_dense(
525 ode: *mut OdeWrapper,
526 params_ptr: *const f64,
527 params_len: usize,
528 t_eval_ptr: *const f64,
529 t_eval_len: usize,
530 out_solution: *mut *mut SolutionWrapper,
531) -> i32 {
532 if ode.is_null()
533 || out_solution.is_null()
534 || !valid_f64_ptr(params_ptr, params_len)
535 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
536 {
537 c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid_dense");
538 return DIFFSOL_BAD_ARG;
539 }
540 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
541 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
542 let ode = unsafe { &mut *ode };
543 match ode.solve_hybrid_dense(params, t_eval) {
544 Ok(new_solution) => {
545 unsafe {
546 *out_solution = Box::into_raw(Box::new(new_solution));
547 }
548 DIFFSOL_OK
549 }
550 Err(err) => {
551 c_error!(&format!("{}", err));
552 DIFFSOL_ERR
553 }
554 }
555}
556
557#[unsafe(no_mangle)]
564pub unsafe extern "C" fn diffsol_ode_solve_fwd_sens(
565 ode: *mut OdeWrapper,
566 params_ptr: *const f64,
567 params_len: usize,
568 t_eval_ptr: *const f64,
569 t_eval_len: usize,
570 out_solution: *mut *mut SolutionWrapper,
571) -> i32 {
572 if ode.is_null()
573 || out_solution.is_null()
574 || !valid_f64_ptr(params_ptr, params_len)
575 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
576 {
577 c_invalid_arg!("invalid arguments to diffsol_ode_solve_fwd_sens");
578 return DIFFSOL_BAD_ARG;
579 }
580 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
581 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
582 let ode = unsafe { &mut *ode };
583 match ode.solve_fwd_sens(params, t_eval) {
584 Ok(new_solution) => {
585 unsafe {
586 *out_solution = Box::into_raw(Box::new(new_solution));
587 }
588 DIFFSOL_OK
589 }
590 Err(err) => {
591 c_error!(&format!("{}", err));
592 DIFFSOL_ERR
593 }
594 }
595}
596
597#[unsafe(no_mangle)]
604pub unsafe extern "C" fn diffsol_ode_solve_hybrid_fwd_sens(
605 ode: *mut OdeWrapper,
606 params_ptr: *const f64,
607 params_len: usize,
608 t_eval_ptr: *const f64,
609 t_eval_len: usize,
610 out_solution: *mut *mut SolutionWrapper,
611) -> i32 {
612 if ode.is_null()
613 || out_solution.is_null()
614 || !valid_f64_ptr(params_ptr, params_len)
615 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
616 {
617 c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid_fwd_sens");
618 return DIFFSOL_BAD_ARG;
619 }
620 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
621 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
622 let ode = unsafe { &mut *ode };
623 match ode.solve_hybrid_fwd_sens(params, t_eval) {
624 Ok(new_solution) => {
625 unsafe {
626 *out_solution = Box::into_raw(Box::new(new_solution));
627 }
628 DIFFSOL_OK
629 }
630 Err(err) => {
631 c_error!(&format!("{}", err));
632 DIFFSOL_ERR
633 }
634 }
635}
636
637#[unsafe(no_mangle)]
645pub unsafe extern "C" fn diffsol_ode_solve_sum_squares_adj(
646 ode: *mut OdeWrapper,
647 params_ptr: *const f64,
648 params_len: usize,
649 data_ptr: *const f64,
650 data_rows: usize,
651 data_cols: usize,
652 data_row_stride: usize,
653 data_col_stride: usize,
654 t_eval_ptr: *const f64,
655 t_eval_len: usize,
656 out_value: *mut f64,
657 out_sens: *mut *mut HostArray,
658) -> i32 {
659 if ode.is_null()
660 || out_value.is_null()
661 || out_sens.is_null()
662 || data_ptr.is_null()
663 || !valid_f64_ptr(params_ptr, params_len)
664 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
665 {
666 c_invalid_arg!("invalid arguments to diffsol_ode_solve_sum_squares_adj");
667 return DIFFSOL_BAD_ARG;
668 }
669 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
670 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
671 let data = HostArray::new_col_major(
672 data_ptr as *mut u8,
673 data_rows,
674 data_cols,
675 data_row_stride as isize,
676 data_col_stride as isize,
677 ScalarType::F64,
678 );
679 let ode = unsafe { &mut *ode };
680 match ode.solve_sum_squares_adj(params, data, t_eval) {
681 Ok((value, sens)) => {
682 let sens_boxed = boxed_host_array(sens);
683 unsafe {
684 *out_value = value;
685 *out_sens = sens_boxed;
686 }
687 DIFFSOL_OK
688 }
689 Err(err) => {
690 c_error!(&format!("{}", err));
691 DIFFSOL_ERR
692 }
693 }
694}
695
696#[unsafe(no_mangle)]
701pub unsafe extern "C" fn diffsol_ode_get_matrix_type(ode: *const OdeWrapper) -> i32 {
702 if ode.is_null() {
703 c_invalid_arg!("ode is null");
704 return -1;
705 }
706 let ode = unsafe { &*ode };
707 match ode.get_matrix_type() {
708 Ok(value) => matrix_type_to_i32(value),
709 Err(err) => {
710 c_error!(&format!("{}", err));
711 -1
712 }
713 }
714}
715
716#[unsafe(no_mangle)]
721pub unsafe extern "C" fn diffsol_ode_get_ode_solver(ode: *const OdeWrapper) -> i32 {
722 if ode.is_null() {
723 c_invalid_arg!("ode is null");
724 return -1;
725 }
726 let ode = unsafe { &*ode };
727 match ode.get_ode_solver() {
728 Ok(value) => ode_solver_to_i32(value),
729 Err(err) => {
730 c_error!(&format!("{}", err));
731 -1
732 }
733 }
734}
735
736#[unsafe(no_mangle)]
741pub unsafe extern "C" fn diffsol_ode_set_ode_solver(ode: *mut OdeWrapper, value: i32) -> i32 {
742 if ode.is_null() {
743 c_invalid_arg!("ode is null");
744 return DIFFSOL_BAD_ARG;
745 }
746 let value = match ode_solver_from_i32(value) {
747 Some(v) => v,
748 None => {
749 c_invalid_arg!("invalid ode_solver");
750 return DIFFSOL_BAD_ARG;
751 }
752 };
753 let ode = unsafe { &mut *ode };
754 match ode.set_ode_solver(value) {
755 Ok(()) => DIFFSOL_OK,
756 Err(err) => c_error!(&format!("{}", err)),
757 }
758}
759
760#[unsafe(no_mangle)]
765pub unsafe extern "C" fn diffsol_ode_get_linear_solver(ode: *const OdeWrapper) -> i32 {
766 if ode.is_null() {
767 c_invalid_arg!("ode is null");
768 return -1;
769 }
770 let ode = unsafe { &*ode };
771 match ode.get_linear_solver() {
772 Ok(value) => linear_solver_to_i32(value),
773 Err(err) => {
774 c_error!(&format!("{}", err));
775 -1
776 }
777 }
778}
779
780#[unsafe(no_mangle)]
785pub unsafe extern "C" fn diffsol_ode_set_linear_solver(ode: *mut OdeWrapper, value: i32) -> i32 {
786 if ode.is_null() {
787 c_invalid_arg!("ode is null");
788 return DIFFSOL_BAD_ARG;
789 }
790 let value = match linear_solver_from_i32(value) {
791 Some(v) => v,
792 None => {
793 c_invalid_arg!("invalid linear_solver");
794 return DIFFSOL_BAD_ARG;
795 }
796 };
797 let ode = unsafe { &mut *ode };
798 match ode.set_linear_solver(value) {
799 Ok(()) => DIFFSOL_OK,
800 Err(err) => c_error!(&format!("{}", err)),
801 }
802}
803
804#[unsafe(no_mangle)]
810pub unsafe extern "C" fn diffsol_ode_get_rtol(ode: *const OdeWrapper, out_value: *mut f64) -> i32 {
811 if ode.is_null() || out_value.is_null() {
812 c_invalid_arg!("invalid arguments to diffsol_ode_get_rtol");
813 return DIFFSOL_BAD_ARG;
814 }
815 let ode = unsafe { &*ode };
816 match ode.get_rtol() {
817 Ok(value) => {
818 unsafe {
819 *out_value = value;
820 }
821 DIFFSOL_OK
822 }
823 Err(err) => c_error!(&format!("{}", err)),
824 }
825}
826
827#[unsafe(no_mangle)]
832pub unsafe extern "C" fn diffsol_ode_set_rtol(ode: *mut OdeWrapper, value: f64) -> i32 {
833 if ode.is_null() {
834 c_invalid_arg!("ode is null");
835 return DIFFSOL_BAD_ARG;
836 }
837 let ode = unsafe { &mut *ode };
838 match ode.set_rtol(value) {
839 Ok(()) => DIFFSOL_OK,
840 Err(err) => c_error!(&format!("{}", err)),
841 }
842}
843
844#[unsafe(no_mangle)]
850pub unsafe extern "C" fn diffsol_ode_get_atol(ode: *const OdeWrapper, out_value: *mut f64) -> i32 {
851 if ode.is_null() || out_value.is_null() {
852 c_invalid_arg!("invalid arguments to diffsol_ode_get_atol");
853 return DIFFSOL_BAD_ARG;
854 }
855 let ode = unsafe { &*ode };
856 match ode.get_atol() {
857 Ok(value) => {
858 unsafe {
859 *out_value = value;
860 }
861 DIFFSOL_OK
862 }
863 Err(err) => c_error!(&format!("{}", err)),
864 }
865}
866
867#[unsafe(no_mangle)]
872pub unsafe extern "C" fn diffsol_ode_set_atol(ode: *mut OdeWrapper, value: f64) -> i32 {
873 if ode.is_null() {
874 c_invalid_arg!("ode is null");
875 return DIFFSOL_BAD_ARG;
876 }
877 let ode = unsafe { &mut *ode };
878 match ode.set_atol(value) {
879 Ok(()) => DIFFSOL_OK,
880 Err(err) => c_error!(&format!("{}", err)),
881 }
882}
883
884#[cfg(all(test, feature = "diffsl-external-f64"))]
885mod tests {
886 use std::ptr;
887
888 use crate::initial_condition_options::InitialConditionSolverOptions;
889 use crate::linear_solver_type::LinearSolverType;
890 use crate::linear_solver_type_c::{
891 diffsol_linear_solver_type_count, diffsol_linear_solver_type_is_valid,
892 diffsol_linear_solver_type_name, linear_solver_to_i32,
893 };
894 use crate::matrix_type::MatrixType;
895 use crate::ode_options::OdeSolverOptions;
896 use crate::ode_options_c::{
897 diffsol_ode_options_free, diffsol_ode_options_get_max_nonlinear_solver_iterations,
898 diffsol_ode_options_get_min_timestep,
899 diffsol_ode_options_set_max_nonlinear_solver_iterations,
900 diffsol_ode_options_set_min_timestep,
901 };
902 use crate::ode_solver_type::OdeSolverType;
903 use crate::ode_solver_type_c::{
904 diffsol_ode_solver_type_count, diffsol_ode_solver_type_is_valid,
905 diffsol_ode_solver_type_name, ode_solver_to_i32,
906 };
907 use crate::scalar_type::ScalarType;
908 use crate::scalar_type_c::{
909 diffsol_scalar_type_count, diffsol_scalar_type_is_valid, diffsol_scalar_type_name,
910 scalar_type_to_i32,
911 };
912 use crate::solution_wrapper_c::{
913 diffsol_solution_wrapper_get_sens, diffsol_solution_wrapper_get_ts,
914 diffsol_solution_wrapper_get_ys,
915 };
916 use crate::test_support::{
917 ASSERT_TOL, LOGISTIC_X0, assert_close, assert_last_error_contains, c_string,
918 clear_last_error, ffi_free_solution, ffi_read_host_array_list_matrices,
919 ffi_read_host_array_matrix, ffi_read_host_array_vector, find_time_window,
920 logistic_integral, logistic_state, logistic_state_dr, mass_state_deps, rhs_input_deps,
921 rhs_state_deps,
922 };
923 use crate::{
924 initial_condition_options_c::{
925 diffsol_ic_options_free, diffsol_ic_options_get_max_linesearch_iterations,
926 diffsol_ic_options_get_use_linesearch,
927 diffsol_ic_options_set_max_linesearch_iterations,
928 diffsol_ic_options_set_use_linesearch,
929 },
930 matrix_type_c::{
931 diffsol_matrix_type_count, diffsol_matrix_type_is_valid, diffsol_matrix_type_name,
932 matrix_type_to_i32,
933 },
934 };
935
936 use super::*;
937
938 unsafe fn make_ode_ptr(
939 matrix_type: i32,
940 linear_solver: i32,
941 ode_solver: i32,
942 ) -> *mut OdeWrapper {
943 let rhs_state_deps = rhs_state_deps();
944 let rhs_input_deps = rhs_input_deps();
945 let mass_state_deps = mass_state_deps();
946 unsafe {
947 diffsol_ode_new_external(
948 matrix_type,
949 linear_solver,
950 ode_solver,
951 rhs_state_deps.as_ptr() as *const usize,
952 rhs_state_deps.len(),
953 rhs_input_deps.as_ptr() as *const usize,
954 rhs_input_deps.len(),
955 mass_state_deps.as_ptr() as *const usize,
956 mass_state_deps.len(),
957 )
958 }
959 }
960
961 #[test]
962 fn c_api_reports_enum_metadata() {
963 clear_last_error();
964 unsafe {
965 assert_eq!(diffsol_matrix_type_count(), 3);
966 assert_eq!(diffsol_ode_solver_type_count(), 4);
967 assert_eq!(diffsol_linear_solver_type_count(), 3);
968 assert_eq!(diffsol_scalar_type_count(), 2);
969
970 assert_eq!(
971 c_string(diffsol_matrix_type_name(matrix_type_to_i32(
972 MatrixType::NalgebraDense
973 ))),
974 "nalgebra_dense"
975 );
976 assert_eq!(
977 c_string(diffsol_ode_solver_type_name(ode_solver_to_i32(
978 OdeSolverType::Bdf
979 ))),
980 "bdf"
981 );
982 assert_eq!(
983 c_string(diffsol_linear_solver_type_name(linear_solver_to_i32(
984 LinearSolverType::Default
985 ))),
986 "default"
987 );
988 assert_eq!(
989 c_string(diffsol_scalar_type_name(scalar_type_to_i32(
990 ScalarType::F64
991 ))),
992 "f64"
993 );
994 }
995 }
996
997 #[test]
998 fn c_api_invalid_enums_set_last_error() {
999 clear_last_error();
1000 unsafe {
1001 assert_eq!(diffsol_matrix_type_is_valid(99), 0);
1002 assert_last_error_contains("invalid matrix_type");
1003 clear_last_error();
1004
1005 assert_eq!(diffsol_ode_solver_type_is_valid(99), 0);
1006 assert_last_error_contains("invalid ode_solver_type");
1007 clear_last_error();
1008
1009 assert_eq!(diffsol_linear_solver_type_is_valid(99), 0);
1010 assert_last_error_contains("invalid linear_solver_type");
1011 clear_last_error();
1012
1013 assert_eq!(diffsol_scalar_type_is_valid(99), 0);
1014 assert_last_error_contains("invalid scalar_type");
1015 }
1016 }
1017
1018 #[test]
1019 fn c_api_rejects_invalid_ode_arguments() {
1020 clear_last_error();
1021 unsafe {
1022 let mut out_array = ptr::null_mut();
1023 let status = diffsol_ode_y0(ptr::null_mut(), ptr::null(), 0, &mut out_array);
1024 assert_eq!(status, DIFFSOL_BAD_ARG);
1025 assert!(out_array.is_null());
1026 assert_last_error_contains("invalid arguments to diffsol_ode_y0");
1027 clear_last_error();
1028
1029 let ode = make_ode_ptr(
1030 99,
1031 linear_solver_to_i32(LinearSolverType::Default),
1032 ode_solver_to_i32(OdeSolverType::Bdf),
1033 );
1034 assert!(ode.is_null());
1035 assert_last_error_contains("invalid matrix_type");
1036 }
1037 }
1038
1039 #[test]
1040 fn c_api_full_lifecycle_matches_external_logistic_model() {
1041 clear_last_error();
1042 unsafe {
1043 let ode = make_ode_ptr(
1044 matrix_type_to_i32(MatrixType::NalgebraDense),
1045 linear_solver_to_i32(LinearSolverType::Default),
1046 ode_solver_to_i32(OdeSolverType::Bdf),
1047 );
1048 assert!(!ode.is_null());
1049
1050 assert_eq!(
1051 diffsol_ode_get_matrix_type(ode),
1052 matrix_type_to_i32(MatrixType::NalgebraDense)
1053 );
1054 assert_eq!(
1055 diffsol_ode_get_ode_solver(ode),
1056 ode_solver_to_i32(OdeSolverType::Bdf)
1057 );
1058 assert_eq!(
1059 diffsol_ode_get_linear_solver(ode),
1060 linear_solver_to_i32(LinearSolverType::Default)
1061 );
1062
1063 assert_eq!(
1064 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1065 DIFFSOL_OK
1066 );
1067 assert_eq!(
1068 diffsol_ode_get_ode_solver(ode),
1069 ode_solver_to_i32(OdeSolverType::Tsit45)
1070 );
1071 assert_eq!(
1072 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1073 DIFFSOL_OK
1074 );
1075
1076 assert_eq!(diffsol_ode_set_rtol(ode, 1e-8), DIFFSOL_OK);
1077 assert_eq!(diffsol_ode_set_atol(ode, 1e-8), DIFFSOL_OK);
1078 let mut rtol = 0.0;
1079 let mut atol = 0.0;
1080 assert_eq!(diffsol_ode_get_rtol(ode, &mut rtol), DIFFSOL_OK);
1081 assert_eq!(diffsol_ode_get_atol(ode, &mut atol), DIFFSOL_OK);
1082 assert_close(rtol, 1e-8, ASSERT_TOL, "rtol roundtrip");
1083 assert_close(atol, 1e-8, ASSERT_TOL, "atol roundtrip");
1084
1085 let mut ic_options: *mut InitialConditionSolverOptions = ptr::null_mut();
1086 assert_eq!(diffsol_ode_get_ic_options(ode, &mut ic_options), DIFFSOL_OK);
1087 assert!(!ic_options.is_null());
1088 let mut use_linesearch = 0;
1089 let mut max_linesearch_iterations = 0usize;
1090 assert_eq!(
1091 diffsol_ic_options_get_use_linesearch(ic_options, &mut use_linesearch),
1092 DIFFSOL_OK
1093 );
1094 assert_eq!(
1095 diffsol_ic_options_set_use_linesearch(ic_options, 1),
1096 DIFFSOL_OK
1097 );
1098 assert_eq!(
1099 diffsol_ic_options_get_use_linesearch(ic_options, &mut use_linesearch),
1100 DIFFSOL_OK
1101 );
1102 assert_eq!(use_linesearch, 1);
1103 assert_eq!(
1104 diffsol_ic_options_set_max_linesearch_iterations(ic_options, 23),
1105 DIFFSOL_OK
1106 );
1107 assert_eq!(
1108 diffsol_ic_options_get_max_linesearch_iterations(
1109 ic_options,
1110 &mut max_linesearch_iterations
1111 ),
1112 DIFFSOL_OK
1113 );
1114 assert_eq!(max_linesearch_iterations, 23);
1115 diffsol_ic_options_free(ic_options);
1116
1117 let mut ode_options: *mut OdeSolverOptions = ptr::null_mut();
1118 assert_eq!(diffsol_ode_get_options(ode, &mut ode_options), DIFFSOL_OK);
1119 assert!(!ode_options.is_null());
1120 let mut max_nonlinear_iterations = 0usize;
1121 let mut min_timestep = 0.0;
1122 assert_eq!(
1123 diffsol_ode_options_set_max_nonlinear_solver_iterations(ode_options, 17),
1124 DIFFSOL_OK
1125 );
1126 assert_eq!(
1127 diffsol_ode_options_get_max_nonlinear_solver_iterations(
1128 ode_options,
1129 &mut max_nonlinear_iterations
1130 ),
1131 DIFFSOL_OK
1132 );
1133 assert_eq!(max_nonlinear_iterations, 17);
1134 assert_eq!(
1135 diffsol_ode_options_set_min_timestep(ode_options, 1e-4),
1136 DIFFSOL_OK
1137 );
1138 assert_eq!(
1139 diffsol_ode_options_get_min_timestep(ode_options, &mut min_timestep),
1140 DIFFSOL_OK
1141 );
1142 assert_close(min_timestep, 1e-4, ASSERT_TOL, "min_timestep roundtrip");
1143 diffsol_ode_options_free(ode_options);
1144
1145 let params = [2.0f64];
1146 let y = [0.25f64];
1147 let v = [3.0f64];
1148
1149 let mut y0_ptr = ptr::null_mut();
1150 assert_eq!(
1151 diffsol_ode_y0(ode, params.as_ptr(), params.len(), &mut y0_ptr),
1152 DIFFSOL_OK
1153 );
1154 assert_eq!(ffi_read_host_array_vector(y0_ptr), vec![LOGISTIC_X0]);
1155
1156 let mut rhs_ptr = ptr::null_mut();
1157 assert_eq!(
1158 diffsol_ode_rhs(
1159 ode,
1160 params.as_ptr(),
1161 params.len(),
1162 0.0,
1163 y.as_ptr(),
1164 y.len(),
1165 &mut rhs_ptr,
1166 ),
1167 DIFFSOL_OK
1168 );
1169 assert_close(
1170 ffi_read_host_array_vector(rhs_ptr)[0],
1171 0.375,
1172 ASSERT_TOL,
1173 "ffi rhs",
1174 );
1175
1176 let mut rhs_jac_mul_ptr = ptr::null_mut();
1177 assert_eq!(
1178 diffsol_ode_rhs_jac_mul(
1179 ode,
1180 params.as_ptr(),
1181 params.len(),
1182 0.0,
1183 y.as_ptr(),
1184 y.len(),
1185 v.as_ptr(),
1186 v.len(),
1187 &mut rhs_jac_mul_ptr,
1188 ),
1189 DIFFSOL_OK
1190 );
1191 assert_close(
1192 ffi_read_host_array_vector(rhs_jac_mul_ptr)[0],
1193 3.0,
1194 ASSERT_TOL,
1195 "ffi rhs_jac_mul",
1196 );
1197
1198 let mut solve_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1199 assert_eq!(
1200 diffsol_ode_solve(
1201 ode,
1202 params.as_ptr(),
1203 params.len(),
1204 1e-9,
1205 &mut solve_solution_ptr
1206 ),
1207 DIFFSOL_OK
1208 );
1209 assert!(!solve_solution_ptr.is_null());
1210
1211 let mut solve_ys_ptr = ptr::null_mut();
1212 let mut solve_ts_ptr = ptr::null_mut();
1213 assert_eq!(
1214 diffsol_solution_wrapper_get_ys(solve_solution_ptr, &mut solve_ys_ptr),
1215 DIFFSOL_OK
1216 );
1217 assert_eq!(
1218 diffsol_solution_wrapper_get_ts(solve_solution_ptr, &mut solve_ts_ptr),
1219 DIFFSOL_OK
1220 );
1221 let (solve_rows, solve_cols, solve_ys) = ffi_read_host_array_matrix(solve_ys_ptr);
1222 let solve_ts = ffi_read_host_array_vector(solve_ts_ptr);
1223 assert_eq!(solve_rows, 1);
1224 assert_eq!(solve_cols, solve_ts.len());
1225 assert!(!solve_ts.is_empty());
1226 assert_close(
1227 *solve_ts.last().unwrap(),
1228 1e-9,
1229 ASSERT_TOL,
1230 "ffi solve final time",
1231 );
1232 assert_close(
1233 *solve_ys.last().unwrap(),
1234 logistic_state(LOGISTIC_X0, 2.0, 1e-9),
1235 ASSERT_TOL,
1236 "ffi solve final value",
1237 );
1238 ffi_free_solution(solve_solution_ptr);
1239
1240 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1241 assert_eq!(
1242 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1243 DIFFSOL_OK
1244 );
1245
1246 let t_eval = [0.25f64, 0.5f64, 1.0f64];
1247 assert_eq!(
1248 diffsol_ode_solve_dense(
1249 ode,
1250 params.as_ptr(),
1251 params.len(),
1252 t_eval.as_ptr(),
1253 t_eval.len(),
1254 &mut solution_ptr,
1255 ),
1256 DIFFSOL_OK
1257 );
1258 let mut ys_ptr = ptr::null_mut();
1259 let mut ts_ptr = ptr::null_mut();
1260 assert_eq!(
1261 diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
1262 DIFFSOL_OK
1263 );
1264 assert_eq!(
1265 diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
1266 DIFFSOL_OK
1267 );
1268 let (rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
1269 let ts = ffi_read_host_array_vector(ts_ptr);
1270 assert_eq!(rows, 1);
1271 assert_eq!(cols, ts.len());
1272 let start = find_time_window(&ts, &t_eval, ASSERT_TOL);
1273 for (i, &t) in t_eval.iter().enumerate() {
1274 assert_close(ts[start + i], t, ASSERT_TOL, "ffi solution time");
1275 assert_close(
1276 ys[start + i],
1277 logistic_state(0.1, 2.0, t),
1278 5e-4,
1279 "ffi solution value",
1280 );
1281 }
1282 assert_eq!(
1283 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1284 DIFFSOL_OK
1285 );
1286
1287 let hybrid_t_eval = [0.5f64, 1.0, 1.25, 1.5, 2.0];
1288 let hybrid_ode = make_ode_ptr(
1289 matrix_type_to_i32(MatrixType::NalgebraDense),
1290 linear_solver_to_i32(LinearSolverType::Default),
1291 ode_solver_to_i32(OdeSolverType::Bdf),
1292 );
1293 assert!(!hybrid_ode.is_null());
1294 let mut hybrid_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1295 assert_eq!(
1296 diffsol_ode_solve_hybrid_dense(
1297 hybrid_ode,
1298 params.as_ptr(),
1299 params.len(),
1300 hybrid_t_eval.as_ptr(),
1301 hybrid_t_eval.len(),
1302 &mut hybrid_solution_ptr,
1303 ),
1304 DIFFSOL_OK
1305 );
1306 let mut hybrid_ys_ptr = ptr::null_mut();
1307 let mut hybrid_ts_ptr = ptr::null_mut();
1308 assert_eq!(
1309 diffsol_solution_wrapper_get_ys(hybrid_solution_ptr, &mut hybrid_ys_ptr),
1310 DIFFSOL_OK
1311 );
1312 assert_eq!(
1313 diffsol_solution_wrapper_get_ts(hybrid_solution_ptr, &mut hybrid_ts_ptr),
1314 DIFFSOL_OK
1315 );
1316 let (hybrid_rows, hybrid_cols, hybrid_ys) = ffi_read_host_array_matrix(hybrid_ys_ptr);
1317 let hybrid_ts = ffi_read_host_array_vector(hybrid_ts_ptr);
1318 assert_eq!(hybrid_rows, 1);
1319 assert_eq!(hybrid_cols, hybrid_t_eval.len());
1320 assert_eq!(hybrid_ts, hybrid_t_eval);
1321 assert_close(
1322 hybrid_ys[0],
1323 logistic_state(LOGISTIC_X0, 2.0, hybrid_t_eval[0]),
1324 5e-4,
1325 "ffi hybrid dense pre-root value",
1326 );
1327 assert_close(
1328 hybrid_ys[1],
1329 logistic_state(LOGISTIC_X0, 2.0, hybrid_t_eval[1]),
1330 5e-4,
1331 "ffi hybrid dense near-root value",
1332 );
1333 for (i, value) in hybrid_ys.iter().enumerate().skip(2) {
1334 assert_close(
1335 *value,
1336 1.0,
1337 5e-4,
1338 &format!("ffi hybrid dense post-root value[{i}]"),
1339 );
1340 }
1341 ffi_free_solution(hybrid_solution_ptr);
1342 diffsol_ode_free(hybrid_ode);
1343
1344 let analysis_ode = make_ode_ptr(
1345 matrix_type_to_i32(MatrixType::NalgebraDense),
1346 linear_solver_to_i32(LinearSolverType::Default),
1347 ode_solver_to_i32(OdeSolverType::Bdf),
1348 );
1349 assert!(!analysis_ode.is_null());
1350
1351 let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1352 assert_eq!(
1353 diffsol_ode_solve_fwd_sens(
1354 analysis_ode,
1355 params.as_ptr(),
1356 params.len(),
1357 t_eval.as_ptr(),
1358 t_eval.len(),
1359 &mut sens_solution_ptr,
1360 ),
1361 DIFFSOL_OK
1362 );
1363 let mut sens_list = ptr::null_mut();
1364 let mut sens_len = 0usize;
1365 assert_eq!(
1366 diffsol_solution_wrapper_get_sens(sens_solution_ptr, &mut sens_list, &mut sens_len),
1367 DIFFSOL_OK
1368 );
1369 let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
1370 assert_eq!(sens_values.len(), 1);
1371 assert_eq!(sens_values[0].0, 1);
1372 assert_eq!(sens_values[0].1, t_eval.len());
1373 for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate() {
1374 assert_close(
1375 value,
1376 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1377 ASSERT_TOL,
1378 &format!("ffi sensitivity[{i}]"),
1379 );
1380 }
1381
1382 let adjoint_t_eval = [0.0f64, 0.25f64, 0.5f64, 1.0f64];
1383 let adjoint_data: Vec<f64> = adjoint_t_eval
1384 .iter()
1385 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
1386 .collect();
1387 let mut objective = 0.0;
1388 let mut adjoint_grad_ptr = ptr::null_mut();
1389 assert_eq!(
1390 diffsol_ode_solve_sum_squares_adj(
1391 analysis_ode,
1392 params.as_ptr(),
1393 params.len(),
1394 adjoint_data.as_ptr(),
1395 1,
1396 adjoint_t_eval.len(),
1397 1,
1398 1,
1399 adjoint_t_eval.as_ptr(),
1400 adjoint_t_eval.len(),
1401 &mut objective,
1402 &mut adjoint_grad_ptr,
1403 ),
1404 DIFFSOL_OK
1405 );
1406 assert_close(objective, 0.0, ASSERT_TOL, "ffi adjoint objective");
1407 let grad = ffi_read_host_array_vector(adjoint_grad_ptr);
1408 assert_eq!(grad.len(), 1);
1409 assert_close(grad[0], 0.0, ASSERT_TOL, "ffi adjoint gradient");
1410
1411 ffi_free_solution(sens_solution_ptr);
1412 diffsol_ode_free(analysis_ode);
1413 ffi_free_solution(solution_ptr);
1414 diffsol_ode_free(ode);
1415 }
1416 }
1417}
1418
1419#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
1420mod jit_tests {
1421 use std::ffi::{CStr, CString};
1422 use std::ptr;
1423
1424 use crate::error_c::{diffsol_error_code, diffsol_last_error_message};
1425 use crate::initial_condition_options_c::diffsol_ic_options_free;
1426 use crate::jit::JitBackendType;
1427 use crate::jit_c::jit_backend_to_i32;
1428 use crate::linear_solver_type::LinearSolverType;
1429 use crate::linear_solver_type_c::linear_solver_to_i32;
1430 use crate::matrix_type::MatrixType;
1431 use crate::matrix_type_c::matrix_type_to_i32;
1432 use crate::ode_options_c::diffsol_ode_options_free;
1433 use crate::ode_solver_type::OdeSolverType;
1434 use crate::ode_solver_type_c::ode_solver_to_i32;
1435 #[cfg(feature = "diffsl-llvm")]
1436 use crate::solution_wrapper_c::diffsol_solution_wrapper_get_sens;
1437 use crate::solution_wrapper_c::{
1438 diffsol_solution_wrapper_get_ts, diffsol_solution_wrapper_get_ys,
1439 };
1440 #[cfg(feature = "diffsl-llvm")]
1441 use crate::test_support::ffi_read_host_array_list_matrices;
1442 use crate::test_support::{
1443 ASSERT_TOL, LOGISTIC_X0, assert_close, available_jit_backends, clear_last_error,
1444 ffi_free_solution, ffi_read_host_array_matrix, ffi_read_host_array_vector,
1445 find_time_window, hybrid_logistic_diffsl_code, hybrid_logistic_state,
1446 logistic_diffsl_code_cstring, logistic_state,
1447 };
1448 #[cfg(feature = "diffsl-llvm")]
1449 use crate::test_support::{hybrid_logistic_state_dr, logistic_integral, logistic_state_dr};
1450
1451 use super::*;
1452
1453 unsafe fn make_ode_ptr(
1454 jit_backend: JitBackendType,
1455 matrix_type: i32,
1456 linear_solver: i32,
1457 ode_solver: i32,
1458 ) -> *mut OdeWrapper {
1459 let code = logistic_diffsl_code_cstring();
1460 unsafe {
1461 make_ode_ptr_with_code(
1462 jit_backend,
1463 code.as_ptr(),
1464 matrix_type,
1465 linear_solver,
1466 ode_solver,
1467 )
1468 }
1469 }
1470
1471 unsafe fn make_ode_ptr_with_code(
1472 jit_backend: JitBackendType,
1473 code: *const std::os::raw::c_char,
1474 matrix_type: i32,
1475 linear_solver: i32,
1476 ode_solver: i32,
1477 ) -> *mut OdeWrapper {
1478 unsafe {
1479 diffsol_ode_new_jit(
1480 code,
1481 jit_backend_to_i32(jit_backend),
1482 matrix_type,
1483 linear_solver,
1484 ode_solver,
1485 )
1486 }
1487 }
1488
1489 unsafe fn last_error_message() -> String {
1490 let ptr = unsafe { diffsol_last_error_message() };
1491 assert_eq!(unsafe { diffsol_error_code() }, 1);
1492 assert!(!ptr.is_null());
1493 unsafe { CStr::from_ptr(ptr) }.to_str().unwrap().to_owned()
1494 }
1495
1496 #[test]
1497 fn c_api_full_lifecycle_matches_jit_logistic_model() {
1498 clear_last_error();
1499 for jit_backend in available_jit_backends() {
1500 unsafe {
1501 let ode = make_ode_ptr(
1502 jit_backend,
1503 matrix_type_to_i32(MatrixType::NalgebraDense),
1504 linear_solver_to_i32(LinearSolverType::Default),
1505 ode_solver_to_i32(OdeSolverType::Bdf),
1506 );
1507 assert!(!ode.is_null());
1508
1509 assert_eq!(
1510 diffsol_ode_get_matrix_type(ode),
1511 matrix_type_to_i32(MatrixType::NalgebraDense)
1512 );
1513 assert_eq!(
1514 diffsol_ode_get_ode_solver(ode),
1515 ode_solver_to_i32(OdeSolverType::Bdf)
1516 );
1517 assert_eq!(
1518 diffsol_ode_get_linear_solver(ode),
1519 linear_solver_to_i32(LinearSolverType::Default)
1520 );
1521
1522 let params = [2.0f64];
1523 let y = [0.25f64];
1524 let v = [3.0f64];
1525
1526 let mut y0_ptr = ptr::null_mut();
1527 assert_eq!(
1528 diffsol_ode_y0(ode, params.as_ptr(), params.len(), &mut y0_ptr),
1529 DIFFSOL_OK
1530 );
1531 assert_eq!(ffi_read_host_array_vector(y0_ptr), vec![LOGISTIC_X0]);
1532
1533 let mut rhs_ptr = ptr::null_mut();
1534 assert_eq!(
1535 diffsol_ode_rhs(
1536 ode,
1537 params.as_ptr(),
1538 params.len(),
1539 0.0,
1540 y.as_ptr(),
1541 y.len(),
1542 &mut rhs_ptr,
1543 ),
1544 DIFFSOL_OK
1545 );
1546 assert_close(
1547 ffi_read_host_array_vector(rhs_ptr)[0],
1548 0.375,
1549 ASSERT_TOL,
1550 "jit ffi rhs",
1551 );
1552
1553 let mut rhs_jac_mul_ptr = ptr::null_mut();
1554 assert_eq!(
1555 diffsol_ode_rhs_jac_mul(
1556 ode,
1557 params.as_ptr(),
1558 params.len(),
1559 0.0,
1560 y.as_ptr(),
1561 y.len(),
1562 v.as_ptr(),
1563 v.len(),
1564 &mut rhs_jac_mul_ptr,
1565 ),
1566 DIFFSOL_OK
1567 );
1568 assert_close(
1569 ffi_read_host_array_vector(rhs_jac_mul_ptr)[0],
1570 3.0,
1571 ASSERT_TOL,
1572 "jit ffi rhs_jac_mul",
1573 );
1574
1575 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1576 let t_eval = [0.25f64, 0.5f64, 1.0f64];
1577 assert_eq!(
1578 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1579 DIFFSOL_OK
1580 );
1581 assert_eq!(
1582 diffsol_ode_solve_dense(
1583 ode,
1584 params.as_ptr(),
1585 params.len(),
1586 t_eval.as_ptr(),
1587 t_eval.len(),
1588 &mut solution_ptr,
1589 ),
1590 DIFFSOL_OK
1591 );
1592 let mut ys_ptr = ptr::null_mut();
1593 let mut ts_ptr = ptr::null_mut();
1594 assert_eq!(
1595 diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
1596 DIFFSOL_OK
1597 );
1598 assert_eq!(
1599 diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
1600 DIFFSOL_OK
1601 );
1602 let (rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
1603 let ts = ffi_read_host_array_vector(ts_ptr);
1604 assert_eq!(rows, 1);
1605 assert_eq!(cols, ts.len());
1606 let start = find_time_window(&ts, &t_eval, ASSERT_TOL);
1607 for (i, &t) in t_eval.iter().enumerate() {
1608 assert_close(ts[start + i], t, ASSERT_TOL, "jit ffi solution time");
1609 assert_close(
1610 ys[start + i],
1611 logistic_state(LOGISTIC_X0, 2.0, t),
1612 5e-4,
1613 "jit ffi solution value",
1614 );
1615 }
1616 assert_eq!(
1617 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1618 DIFFSOL_OK
1619 );
1620
1621 #[cfg(feature = "diffsl-llvm")]
1622 {
1623 let analysis_code = logistic_diffsl_code_cstring();
1624 let analysis_ode = make_ode_ptr_with_code(
1625 JitBackendType::Llvm,
1626 analysis_code.as_ptr(),
1627 matrix_type_to_i32(MatrixType::NalgebraDense),
1628 linear_solver_to_i32(LinearSolverType::Default),
1629 ode_solver_to_i32(OdeSolverType::Bdf),
1630 );
1631 assert!(!analysis_ode.is_null());
1632
1633 let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1634 assert_eq!(
1635 diffsol_ode_solve_fwd_sens(
1636 analysis_ode,
1637 params.as_ptr(),
1638 params.len(),
1639 t_eval.as_ptr(),
1640 t_eval.len(),
1641 &mut sens_solution_ptr,
1642 ),
1643 DIFFSOL_OK
1644 );
1645 let mut sens_list = ptr::null_mut();
1646 let mut sens_len = 0usize;
1647 assert_eq!(
1648 diffsol_solution_wrapper_get_sens(
1649 sens_solution_ptr,
1650 &mut sens_list,
1651 &mut sens_len
1652 ),
1653 DIFFSOL_OK
1654 );
1655 let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
1656 assert_eq!(sens_values.len(), 1);
1657 assert_eq!(sens_values[0].0, 1);
1658 assert_eq!(sens_values[0].1, t_eval.len());
1659 for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate()
1660 {
1661 assert_close(
1662 value,
1663 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1664 ASSERT_TOL,
1665 &format!("jit ffi sensitivity[{i}]"),
1666 );
1667 }
1668
1669 let adjoint_t_eval = [0.0f64, 0.25f64, 0.5f64, 1.0f64];
1670 let adjoint_data: Vec<f64> = adjoint_t_eval
1671 .iter()
1672 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
1673 .collect();
1674 let mut objective = 0.0;
1675 let mut adjoint_grad_ptr = ptr::null_mut();
1676 assert_eq!(
1677 diffsol_ode_solve_sum_squares_adj(
1678 analysis_ode,
1679 params.as_ptr(),
1680 params.len(),
1681 adjoint_data.as_ptr(),
1682 1,
1683 adjoint_t_eval.len(),
1684 1,
1685 1,
1686 adjoint_t_eval.as_ptr(),
1687 adjoint_t_eval.len(),
1688 &mut objective,
1689 &mut adjoint_grad_ptr,
1690 ),
1691 DIFFSOL_OK
1692 );
1693 assert_close(objective, 0.0, ASSERT_TOL, "jit ffi adjoint objective");
1694 let grad = ffi_read_host_array_vector(adjoint_grad_ptr);
1695 assert_eq!(grad.len(), 1);
1696 assert!(
1697 grad[0].is_finite(),
1698 "jit ffi adjoint gradient should be finite"
1699 );
1700
1701 ffi_free_solution(sens_solution_ptr);
1702 diffsol_ode_free(analysis_ode);
1703 }
1704 ffi_free_solution(solution_ptr);
1705 diffsol_ode_free(ode);
1706 }
1707 }
1708 }
1709
1710 #[test]
1711 fn c_api_rejects_invalid_jit_arguments() {
1712 unsafe {
1713 clear_last_error();
1714 assert!(
1715 diffsol_ode_new_jit(
1716 ptr::null(),
1717 jit_backend_to_i32(available_jit_backends()[0]),
1718 matrix_type_to_i32(MatrixType::NalgebraDense),
1719 linear_solver_to_i32(LinearSolverType::Default),
1720 ode_solver_to_i32(OdeSolverType::Bdf),
1721 )
1722 .is_null()
1723 );
1724 assert!(last_error_message().contains("code is null"));
1725
1726 clear_last_error();
1727 let invalid_utf8 = CString::from_vec_with_nul(vec![0xff, 0]).unwrap();
1728 assert!(
1729 diffsol_ode_new_jit(
1730 invalid_utf8.as_ptr(),
1731 jit_backend_to_i32(available_jit_backends()[0]),
1732 matrix_type_to_i32(MatrixType::NalgebraDense),
1733 linear_solver_to_i32(LinearSolverType::Default),
1734 ode_solver_to_i32(OdeSolverType::Bdf),
1735 )
1736 .is_null()
1737 );
1738 assert!(last_error_message().contains("valid UTF-8"));
1739
1740 clear_last_error();
1741 let code = logistic_diffsl_code_cstring();
1742 assert!(
1743 diffsol_ode_new_jit(
1744 code.as_ptr(),
1745 99,
1746 matrix_type_to_i32(MatrixType::NalgebraDense),
1747 linear_solver_to_i32(LinearSolverType::Default),
1748 ode_solver_to_i32(OdeSolverType::Bdf),
1749 )
1750 .is_null()
1751 );
1752 assert!(last_error_message().contains("invalid jit_backend_type"));
1753
1754 clear_last_error();
1755 assert!(
1756 diffsol_ode_new_jit(
1757 code.as_ptr(),
1758 jit_backend_to_i32(available_jit_backends()[0]),
1759 99,
1760 linear_solver_to_i32(LinearSolverType::Default),
1761 ode_solver_to_i32(OdeSolverType::Bdf),
1762 )
1763 .is_null()
1764 );
1765 assert!(last_error_message().contains("invalid matrix_type"));
1766
1767 clear_last_error();
1768 assert!(
1769 diffsol_ode_new_jit(
1770 code.as_ptr(),
1771 jit_backend_to_i32(available_jit_backends()[0]),
1772 matrix_type_to_i32(MatrixType::NalgebraDense),
1773 99,
1774 ode_solver_to_i32(OdeSolverType::Bdf),
1775 )
1776 .is_null()
1777 );
1778 assert!(last_error_message().contains("invalid linear_solver"));
1779
1780 clear_last_error();
1781 assert!(
1782 diffsol_ode_new_jit(
1783 code.as_ptr(),
1784 jit_backend_to_i32(available_jit_backends()[0]),
1785 matrix_type_to_i32(MatrixType::NalgebraDense),
1786 linear_solver_to_i32(LinearSolverType::Default),
1787 99,
1788 )
1789 .is_null()
1790 );
1791 assert!(last_error_message().contains("invalid ode_solver"));
1792
1793 clear_last_error();
1794 let invalid_code = CString::new("not valid diffsl").unwrap();
1795 assert!(
1796 diffsol_ode_new_jit(
1797 invalid_code.as_ptr(),
1798 jit_backend_to_i32(available_jit_backends()[0]),
1799 matrix_type_to_i32(MatrixType::NalgebraDense),
1800 linear_solver_to_i32(LinearSolverType::Default),
1801 ode_solver_to_i32(OdeSolverType::Bdf),
1802 )
1803 .is_null()
1804 );
1805 assert!(diffsol_error_code() != 0);
1806
1807 let mut ic_options = ptr::null_mut();
1808 assert_eq!(
1809 diffsol_ode_get_ic_options(ptr::null_mut(), &mut ic_options),
1810 DIFFSOL_BAD_ARG
1811 );
1812 let mut ode_options = ptr::null_mut();
1813 assert_eq!(
1814 diffsol_ode_get_options(ptr::null_mut(), &mut ode_options),
1815 DIFFSOL_BAD_ARG
1816 );
1817
1818 let mut out_array = ptr::null_mut();
1819 assert_eq!(
1820 diffsol_ode_y0(ptr::null_mut(), ptr::null(), 0, &mut out_array),
1821 DIFFSOL_BAD_ARG
1822 );
1823 assert_eq!(
1824 diffsol_ode_rhs(
1825 ptr::null_mut(),
1826 ptr::null(),
1827 0,
1828 0.0,
1829 ptr::null(),
1830 0,
1831 &mut out_array,
1832 ),
1833 DIFFSOL_BAD_ARG
1834 );
1835 assert_eq!(
1836 diffsol_ode_rhs_jac_mul(
1837 ptr::null_mut(),
1838 ptr::null(),
1839 0,
1840 0.0,
1841 ptr::null(),
1842 0,
1843 ptr::null(),
1844 0,
1845 &mut out_array,
1846 ),
1847 DIFFSOL_BAD_ARG
1848 );
1849
1850 clear_last_error();
1851 diffsol_ode_free(ptr::null_mut());
1852 assert!(last_error_message().contains("ode is null"));
1853
1854 clear_last_error();
1855 diffsol_host_array_list_free(ptr::null_mut(), 0);
1856 assert!(last_error_message().contains("host array list is null"));
1857 }
1858 }
1859
1860 #[test]
1861 fn c_api_jit_wrapper_branches_cover_runtime_success_and_errors() {
1862 for jit_backend in available_jit_backends() {
1863 unsafe {
1864 let ode = make_ode_ptr(
1865 jit_backend,
1866 matrix_type_to_i32(MatrixType::NalgebraDense),
1867 linear_solver_to_i32(LinearSolverType::Default),
1868 ode_solver_to_i32(OdeSolverType::Bdf),
1869 );
1870 assert!(!ode.is_null());
1871
1872 let mut ic_options = ptr::null_mut();
1873 let mut ode_options = ptr::null_mut();
1874 assert_eq!(diffsol_ode_get_ic_options(ode, &mut ic_options), DIFFSOL_OK);
1875 assert_eq!(diffsol_ode_get_options(ode, &mut ode_options), DIFFSOL_OK);
1876 diffsol_ic_options_free(ic_options);
1877 diffsol_ode_options_free(ode_options);
1878
1879 let mut out_value = 0.0;
1880 assert_eq!(diffsol_ode_get_rtol(ode, &mut out_value), DIFFSOL_OK);
1881 assert_close(out_value, 1e-6, ASSERT_TOL, "jit ffi default rtol");
1882 assert_eq!(diffsol_ode_set_rtol(ode, 1e-4), DIFFSOL_OK);
1883 assert_eq!(diffsol_ode_get_rtol(ode, &mut out_value), DIFFSOL_OK);
1884 assert_close(out_value, 1e-4, ASSERT_TOL, "jit ffi updated rtol");
1885
1886 assert_eq!(diffsol_ode_get_atol(ode, &mut out_value), DIFFSOL_OK);
1887 assert_close(out_value, 1e-6, ASSERT_TOL, "jit ffi default atol");
1888 assert_eq!(diffsol_ode_set_atol(ode, 1e-5), DIFFSOL_OK);
1889 assert_eq!(diffsol_ode_get_atol(ode, &mut out_value), DIFFSOL_OK);
1890 assert_close(out_value, 1e-5, ASSERT_TOL, "jit ffi updated atol");
1891
1892 assert_eq!(
1893 diffsol_ode_set_linear_solver(ode, linear_solver_to_i32(LinearSolverType::Lu)),
1894 DIFFSOL_OK
1895 );
1896 assert_eq!(
1897 diffsol_ode_get_linear_solver(ode),
1898 linear_solver_to_i32(LinearSolverType::Lu)
1899 );
1900 assert_eq!(
1901 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1902 DIFFSOL_OK
1903 );
1904 assert_eq!(
1905 diffsol_ode_get_ode_solver(ode),
1906 ode_solver_to_i32(OdeSolverType::Tsit45)
1907 );
1908 assert_eq!(
1909 diffsol_ode_get_matrix_type(ode),
1910 matrix_type_to_i32(MatrixType::NalgebraDense)
1911 );
1912
1913 let params = [2.0f64];
1914 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1915 assert_eq!(
1916 diffsol_ode_solve(ode, params.as_ptr(), params.len(), 1.0, &mut solution_ptr),
1917 DIFFSOL_OK
1918 );
1919 ffi_free_solution(solution_ptr);
1920
1921 let t_eval = [0.25f64, 0.5f64, 1.0f64];
1922 let mut dense_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1923 assert_eq!(
1924 diffsol_ode_solve_dense(
1925 ode,
1926 params.as_ptr(),
1927 params.len(),
1928 t_eval.as_ptr(),
1929 t_eval.len(),
1930 &mut dense_solution_ptr,
1931 ),
1932 DIFFSOL_OK
1933 );
1934 ffi_free_solution(dense_solution_ptr);
1935
1936 let no_params: [f64; 0] = [];
1937 let y = [0.25f64];
1938 let v = [3.0f64];
1939 let mut out_array = ptr::null_mut();
1940 assert_eq!(
1941 diffsol_ode_y0(ode, no_params.as_ptr(), no_params.len(), &mut out_array),
1942 DIFFSOL_ERR
1943 );
1944 assert_eq!(
1945 diffsol_ode_rhs(
1946 ode,
1947 no_params.as_ptr(),
1948 no_params.len(),
1949 0.0,
1950 y.as_ptr(),
1951 y.len(),
1952 &mut out_array,
1953 ),
1954 DIFFSOL_ERR
1955 );
1956 assert_eq!(
1957 diffsol_ode_rhs_jac_mul(
1958 ode,
1959 no_params.as_ptr(),
1960 no_params.len(),
1961 0.0,
1962 y.as_ptr(),
1963 y.len(),
1964 v.as_ptr(),
1965 v.len(),
1966 &mut out_array,
1967 ),
1968 DIFFSOL_ERR
1969 );
1970
1971 let mut err_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1972 assert_eq!(
1973 diffsol_ode_solve(
1974 ode,
1975 no_params.as_ptr(),
1976 no_params.len(),
1977 1.0,
1978 &mut err_solution_ptr,
1979 ),
1980 DIFFSOL_ERR
1981 );
1982 assert_eq!(
1983 diffsol_ode_solve_hybrid(
1984 ode,
1985 no_params.as_ptr(),
1986 no_params.len(),
1987 1.0,
1988 &mut err_solution_ptr,
1989 ),
1990 DIFFSOL_ERR
1991 );
1992 assert_eq!(
1993 diffsol_ode_solve_dense(
1994 ode,
1995 no_params.as_ptr(),
1996 no_params.len(),
1997 t_eval.as_ptr(),
1998 t_eval.len(),
1999 &mut err_solution_ptr,
2000 ),
2001 DIFFSOL_ERR
2002 );
2003 assert_eq!(
2004 diffsol_ode_solve_hybrid_dense(
2005 ode,
2006 no_params.as_ptr(),
2007 no_params.len(),
2008 t_eval.as_ptr(),
2009 t_eval.len(),
2010 &mut err_solution_ptr,
2011 ),
2012 DIFFSOL_ERR
2013 );
2014
2015 #[cfg(feature = "diffsl-llvm")]
2016 if matches!(jit_backend, JitBackendType::Llvm) {
2017 assert_eq!(
2018 diffsol_ode_solve_fwd_sens(
2019 ode,
2020 no_params.as_ptr(),
2021 no_params.len(),
2022 t_eval.as_ptr(),
2023 t_eval.len(),
2024 &mut err_solution_ptr,
2025 ),
2026 DIFFSOL_ERR
2027 );
2028 assert_eq!(
2029 diffsol_ode_solve_hybrid_fwd_sens(
2030 ode,
2031 no_params.as_ptr(),
2032 no_params.len(),
2033 t_eval.as_ptr(),
2034 t_eval.len(),
2035 &mut err_solution_ptr,
2036 ),
2037 DIFFSOL_ERR
2038 );
2039
2040 let adjoint_data: Vec<f64> = t_eval
2041 .iter()
2042 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
2043 .collect();
2044 let mut objective = 0.0;
2045 let mut sens_ptr = ptr::null_mut();
2046 assert_eq!(
2047 diffsol_ode_solve_sum_squares_adj(
2048 ode,
2049 no_params.as_ptr(),
2050 no_params.len(),
2051 adjoint_data.as_ptr(),
2052 1,
2053 t_eval.len(),
2054 1,
2055 1,
2056 t_eval.as_ptr(),
2057 t_eval.len(),
2058 &mut objective,
2059 &mut sens_ptr,
2060 ),
2061 DIFFSOL_ERR
2062 );
2063 }
2064
2065 assert_eq!(diffsol_ode_get_matrix_type(ptr::null()), -1);
2066 assert_eq!(diffsol_ode_get_ode_solver(ptr::null()), -1);
2067 assert_eq!(diffsol_ode_get_linear_solver(ptr::null()), -1);
2068 assert_eq!(
2069 diffsol_ode_set_ode_solver(ptr::null_mut(), 0),
2070 DIFFSOL_BAD_ARG
2071 );
2072 assert_eq!(
2073 diffsol_ode_set_linear_solver(ptr::null_mut(), 0),
2074 DIFFSOL_BAD_ARG
2075 );
2076 assert_eq!(diffsol_ode_set_ode_solver(ode, 99), DIFFSOL_BAD_ARG);
2077 assert_eq!(diffsol_ode_set_linear_solver(ode, 99), DIFFSOL_BAD_ARG);
2078 assert_eq!(
2079 diffsol_ode_get_rtol(ptr::null(), &mut out_value),
2080 DIFFSOL_BAD_ARG
2081 );
2082 assert_eq!(diffsol_ode_get_rtol(ode, ptr::null_mut()), DIFFSOL_BAD_ARG);
2083 assert_eq!(diffsol_ode_set_rtol(ptr::null_mut(), 1e-3), DIFFSOL_BAD_ARG);
2084 assert_eq!(
2085 diffsol_ode_get_atol(ptr::null(), &mut out_value),
2086 DIFFSOL_BAD_ARG
2087 );
2088 assert_eq!(diffsol_ode_get_atol(ode, ptr::null_mut()), DIFFSOL_BAD_ARG);
2089 assert_eq!(diffsol_ode_set_atol(ptr::null_mut(), 1e-3), DIFFSOL_BAD_ARG);
2090 assert_eq!(
2091 diffsol_ode_solve(ode, params.as_ptr(), params.len(), 1.0, ptr::null_mut()),
2092 DIFFSOL_BAD_ARG
2093 );
2094 assert_eq!(
2095 diffsol_ode_solve_hybrid(
2096 ode,
2097 params.as_ptr(),
2098 params.len(),
2099 1.0,
2100 ptr::null_mut(),
2101 ),
2102 DIFFSOL_BAD_ARG
2103 );
2104 assert_eq!(
2105 diffsol_ode_solve_dense(
2106 ode,
2107 params.as_ptr(),
2108 params.len(),
2109 t_eval.as_ptr(),
2110 t_eval.len(),
2111 ptr::null_mut(),
2112 ),
2113 DIFFSOL_BAD_ARG
2114 );
2115 assert_eq!(
2116 diffsol_ode_solve_hybrid_dense(
2117 ode,
2118 params.as_ptr(),
2119 params.len(),
2120 t_eval.as_ptr(),
2121 t_eval.len(),
2122 ptr::null_mut(),
2123 ),
2124 DIFFSOL_BAD_ARG
2125 );
2126 #[cfg(feature = "diffsl-llvm")]
2127 if matches!(jit_backend, JitBackendType::Llvm) {
2128 assert_eq!(
2129 diffsol_ode_solve_fwd_sens(
2130 ode,
2131 params.as_ptr(),
2132 params.len(),
2133 t_eval.as_ptr(),
2134 t_eval.len(),
2135 ptr::null_mut(),
2136 ),
2137 DIFFSOL_BAD_ARG
2138 );
2139 assert_eq!(
2140 diffsol_ode_solve_hybrid_fwd_sens(
2141 ode,
2142 params.as_ptr(),
2143 params.len(),
2144 t_eval.as_ptr(),
2145 t_eval.len(),
2146 ptr::null_mut(),
2147 ),
2148 DIFFSOL_BAD_ARG
2149 );
2150 let mut objective = 0.0;
2151 let mut sens_ptr = ptr::null_mut();
2152 assert_eq!(
2153 diffsol_ode_solve_sum_squares_adj(
2154 ode,
2155 params.as_ptr(),
2156 params.len(),
2157 t_eval.as_ptr(),
2158 1,
2159 t_eval.len(),
2160 1,
2161 1,
2162 t_eval.as_ptr(),
2163 t_eval.len(),
2164 ptr::null_mut(),
2165 &mut sens_ptr,
2166 ),
2167 DIFFSOL_BAD_ARG
2168 );
2169 assert_eq!(
2170 diffsol_ode_solve_sum_squares_adj(
2171 ode,
2172 params.as_ptr(),
2173 params.len(),
2174 t_eval.as_ptr(),
2175 1,
2176 t_eval.len(),
2177 1,
2178 1,
2179 t_eval.as_ptr(),
2180 t_eval.len(),
2181 &mut objective,
2182 ptr::null_mut(),
2183 ),
2184 DIFFSOL_BAD_ARG
2185 );
2186 }
2187
2188 diffsol_ode_free(ode);
2189 }
2190 }
2191 }
2192
2193 #[test]
2194 fn c_api_hybrid_jit_solver_paths_match_expected_values() {
2195 for jit_backend in available_jit_backends() {
2196 unsafe {
2197 let code = CString::new(hybrid_logistic_diffsl_code()).unwrap();
2198 let ode = make_ode_ptr_with_code(
2199 jit_backend,
2200 code.as_ptr(),
2201 matrix_type_to_i32(MatrixType::NalgebraDense),
2202 linear_solver_to_i32(LinearSolverType::Default),
2203 ode_solver_to_i32(OdeSolverType::Bdf),
2204 );
2205 assert!(!ode.is_null());
2206
2207 let params = [2.0f64];
2208 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
2209 assert_eq!(
2210 diffsol_ode_solve_hybrid(
2211 ode,
2212 params.as_ptr(),
2213 params.len(),
2214 2.0,
2215 &mut solution_ptr
2216 ),
2217 DIFFSOL_OK
2218 );
2219 let mut ys_ptr = ptr::null_mut();
2220 let mut ts_ptr = ptr::null_mut();
2221 assert_eq!(
2222 diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
2223 DIFFSOL_OK
2224 );
2225 assert_eq!(
2226 diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
2227 DIFFSOL_OK
2228 );
2229 let (_rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
2230 let ts = ffi_read_host_array_vector(ts_ptr);
2231 assert!(cols >= 1);
2232 assert_close(*ts.last().unwrap(), 2.0, 5e-4, "jit hybrid solve time");
2233 assert_close(
2234 *ys.last().unwrap(),
2235 hybrid_logistic_state(2.0, 2.0),
2236 5e-4,
2237 "jit hybrid solve value",
2238 );
2239 ffi_free_solution(solution_ptr);
2240
2241 #[cfg(feature = "diffsl-llvm")]
2242 if matches!(jit_backend, JitBackendType::Llvm) {
2243 let t_eval = [0.25f64, 0.5f64, 1.0f64];
2244 let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
2245 assert_eq!(
2246 diffsol_ode_solve_hybrid_fwd_sens(
2247 ode,
2248 params.as_ptr(),
2249 params.len(),
2250 t_eval.as_ptr(),
2251 t_eval.len(),
2252 &mut sens_solution_ptr,
2253 ),
2254 DIFFSOL_OK
2255 );
2256 let mut sens_list = ptr::null_mut();
2257 let mut sens_len = 0usize;
2258 assert_eq!(
2259 diffsol_solution_wrapper_get_sens(
2260 sens_solution_ptr,
2261 &mut sens_list,
2262 &mut sens_len
2263 ),
2264 DIFFSOL_OK
2265 );
2266 let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
2267 for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate()
2268 {
2269 assert_close(
2270 value,
2271 hybrid_logistic_state_dr(2.0, t),
2272 5e-4,
2273 &format!("jit hybrid sensitivity[{i}]"),
2274 );
2275 }
2276 ffi_free_solution(sens_solution_ptr);
2277 }
2278
2279 diffsol_ode_free(ode);
2280 }
2281 }
2282 }
2283}