1#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
2use std::ffi::CStr;
3#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
4use std::os::raw::c_char;
5use std::ptr;
6
7use crate::c_api_utils::{valid_f64_ptr, DIFFSOL_BAD_ARG, DIFFSOL_ERR, DIFFSOL_OK};
8use crate::host_array::HostArray;
9#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
10use crate::jit_c::jit_backend_from_i32;
11use crate::linear_solver_type_c::{linear_solver_from_i32, linear_solver_to_i32};
12use crate::matrix_type_c::{matrix_type_from_i32, matrix_type_to_i32};
13use crate::ode::OdeWrapper;
14use crate::ode_solver_type_c::{ode_solver_from_i32, ode_solver_to_i32};
15use crate::scalar_type::ScalarType;
16use crate::solution_wrapper::SolutionWrapper;
17use crate::{c_error, c_invalid_arg};
18
19fn boxed_host_array(array: HostArray) -> *mut HostArray {
20 Box::into_raw(Box::new(array))
21}
22
23fn parse_ode_new_common_args(
24 matrix_type: i32,
25 linear_solver: i32,
26 ode_solver: i32,
27) -> Option<(
28 crate::matrix_type::MatrixType,
29 crate::linear_solver_type::LinearSolverType,
30 crate::ode_solver_type::OdeSolverType,
31)> {
32 let matrix_type = match matrix_type_from_i32(matrix_type) {
33 Some(value) => value,
34 None => {
35 c_invalid_arg!("invalid matrix_type");
36 return None;
37 }
38 };
39 let linear_solver = match linear_solver_from_i32(linear_solver) {
40 Some(value) => value,
41 None => {
42 c_invalid_arg!("invalid linear_solver");
43 return None;
44 }
45 };
46 let ode_solver = match ode_solver_from_i32(ode_solver) {
47 Some(value) => value,
48 None => {
49 c_invalid_arg!("invalid ode_solver");
50 return None;
51 }
52 };
53 Some((matrix_type, linear_solver, ode_solver))
54}
55
56#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
57fn parse_ode_new_jit_args(
58 code: *const c_char,
59 matrix_type: i32,
60 linear_solver: i32,
61 ode_solver: i32,
62) -> Option<(
63 String,
64 crate::matrix_type::MatrixType,
65 crate::linear_solver_type::LinearSolverType,
66 crate::ode_solver_type::OdeSolverType,
67)> {
68 if code.is_null() {
69 c_invalid_arg!("code is null");
70 return None;
71 }
72 let code = unsafe { CStr::from_ptr(code) };
73 let code = match code.to_str() {
74 Ok(value) => value.to_owned(),
75 Err(_) => {
76 c_error!("code is not valid UTF-8");
77 return None;
78 }
79 };
80 let (matrix_type, linear_solver, ode_solver) =
81 parse_ode_new_common_args(matrix_type, linear_solver, ode_solver)?;
82 Some((code, matrix_type, linear_solver, ode_solver))
83}
84
85#[unsafe(no_mangle)]
91pub unsafe extern "C" fn diffsol_host_array_list_free(list: *mut *mut HostArray, len: usize) {
92 if list.is_null() {
93 c_invalid_arg!("host array list is null");
94 return;
95 }
96 unsafe {
97 drop(Box::from_raw(std::ptr::slice_from_raw_parts_mut(list, len)));
98 }
99}
100
101#[cfg(feature = "external")]
102#[unsafe(no_mangle)]
109pub unsafe extern "C" fn diffsol_ode_new_external(
110 matrix_type: i32,
111 linear_solver: i32,
112 ode_solver: i32,
113 rhs_state_deps_ptr: *const usize,
114 rhs_state_deps_len: usize,
115 rhs_input_deps_ptr: *const usize,
116 rhs_input_deps_len: usize,
117 mass_state_deps_ptr: *const usize,
118 mass_state_deps_len: usize,
119) -> *mut OdeWrapper {
120 let Some((matrix_type, linear_solver, ode_solver)) =
121 parse_ode_new_common_args(matrix_type, linear_solver, ode_solver)
122 else {
123 return ptr::null_mut();
124 };
125
126 let rhs_state_deps = if !rhs_state_deps_ptr.is_null() && rhs_state_deps_len > 0 {
127 unsafe {
128 let slice = std::slice::from_raw_parts(
129 rhs_state_deps_ptr as *const (usize, usize),
130 rhs_state_deps_len,
131 );
132 slice.to_vec()
133 }
134 } else {
135 Vec::new()
136 };
137
138 let rhs_input_deps = if !rhs_input_deps_ptr.is_null() && rhs_input_deps_len > 0 {
139 unsafe {
140 let slice = std::slice::from_raw_parts(
141 rhs_input_deps_ptr as *const (usize, usize),
142 rhs_input_deps_len,
143 );
144 slice.to_vec()
145 }
146 } else {
147 Vec::new()
148 };
149
150 let mass_state_deps = if !mass_state_deps_ptr.is_null() && mass_state_deps_len > 0 {
151 unsafe {
152 let slice = std::slice::from_raw_parts(
153 mass_state_deps_ptr as *const (usize, usize),
154 mass_state_deps_len,
155 );
156 slice.to_vec()
157 }
158 } else {
159 Vec::new()
160 };
161
162 let scalar_type = ScalarType::F64;
163 match OdeWrapper::new_external(
164 rhs_state_deps,
165 rhs_input_deps,
166 mass_state_deps,
167 scalar_type,
168 matrix_type,
169 linear_solver,
170 ode_solver,
171 ) {
172 Ok(ode) => Box::into_raw(Box::new(ode)),
173 Err(err) => {
174 c_error!(&format!("{}", err));
175 ptr::null_mut()
176 }
177 }
178}
179
180#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
181#[unsafe(no_mangle)]
188pub unsafe extern "C" fn diffsol_ode_new_jit(
189 code: *const c_char,
190 jit_backend: i32,
191 matrix_type: i32,
192 linear_solver: i32,
193 ode_solver: i32,
194) -> *mut OdeWrapper {
195 let Some((code, matrix_type, linear_solver, ode_solver)) =
196 parse_ode_new_jit_args(code, matrix_type, linear_solver, ode_solver)
197 else {
198 return ptr::null_mut();
199 };
200 let jit_backend = match jit_backend_from_i32(jit_backend) {
201 Some(value) => value,
202 None => {
203 c_invalid_arg!("invalid jit_backend_type");
204 return ptr::null_mut();
205 }
206 };
207 let scalar_type = ScalarType::F64;
208 match OdeWrapper::new_jit(
209 &code,
210 jit_backend,
211 scalar_type,
212 matrix_type,
213 linear_solver,
214 ode_solver,
215 ) {
216 Ok(ode) => Box::into_raw(Box::new(ode)),
217 Err(err) => {
218 c_error!(&format!("{}", err));
219 ptr::null_mut()
220 }
221 }
222}
223
224#[unsafe(no_mangle)]
230pub unsafe extern "C" fn diffsol_ode_free(ode: *mut OdeWrapper) {
231 if ode.is_null() {
232 c_invalid_arg!("ode is null");
233 return;
234 }
235 unsafe {
236 drop(Box::from_raw(ode));
237 }
238}
239
240#[unsafe(no_mangle)]
247pub unsafe extern "C" fn diffsol_ode_get_ic_options(
248 ode: *const OdeWrapper,
249 out_options: *mut *mut crate::initial_condition_options::InitialConditionSolverOptions,
250) -> i32 {
251 if ode.is_null() || out_options.is_null() {
252 return c_invalid_arg!("invalid arguments to diffsol_ode_get_ic_options");
253 }
254 let ode = unsafe { &*ode };
255 let options = ode.get_ic_options();
256 let boxed = Box::new(options);
257 unsafe {
258 *out_options = Box::into_raw(boxed);
259 }
260 DIFFSOL_OK
261}
262
263#[unsafe(no_mangle)]
270pub unsafe extern "C" fn diffsol_ode_get_options(
271 ode: *const OdeWrapper,
272 out_options: *mut *mut crate::ode_options::OdeSolverOptions,
273) -> i32 {
274 if ode.is_null() || out_options.is_null() {
275 return c_invalid_arg!("invalid arguments to diffsol_ode_get_options");
276 }
277 let ode = unsafe { &*ode };
278 let options = ode.get_options();
279 let boxed = Box::new(options);
280 unsafe {
281 *out_options = Box::into_raw(boxed);
282 }
283 DIFFSOL_OK
284}
285
286#[unsafe(no_mangle)]
293pub unsafe extern "C" fn diffsol_ode_y0(
294 ode: *mut OdeWrapper,
295 params_ptr: *const f64,
296 params_len: usize,
297 out_array: *mut *mut HostArray,
298) -> i32 {
299 if ode.is_null() || out_array.is_null() || !valid_f64_ptr(params_ptr, params_len) {
300 c_invalid_arg!("invalid arguments to diffsol_ode_y0");
301 return DIFFSOL_BAD_ARG;
302 }
303 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
304 let ode = unsafe { &mut *ode };
305 match ode.y0(params) {
306 Ok(array) => {
307 let boxed = boxed_host_array(array);
308 unsafe {
309 *out_array = boxed;
310 }
311 DIFFSOL_OK
312 }
313 Err(err) => {
314 c_error!(&format!("{}", err));
315 DIFFSOL_ERR
316 }
317 }
318}
319
320#[unsafe(no_mangle)]
327pub unsafe extern "C" fn diffsol_ode_rhs(
328 ode: *mut OdeWrapper,
329 params_ptr: *const f64,
330 params_len: usize,
331 t: f64,
332 y_ptr: *const f64,
333 y_len: usize,
334 out_array: *mut *mut HostArray,
335) -> i32 {
336 if ode.is_null()
337 || out_array.is_null()
338 || !valid_f64_ptr(params_ptr, params_len)
339 || !valid_f64_ptr(y_ptr, y_len)
340 {
341 c_invalid_arg!("invalid arguments to diffsol_ode_rhs");
342 return DIFFSOL_BAD_ARG;
343 }
344 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
345 let y = HostArray::new_vector(y_ptr as *mut u8, y_len, ScalarType::F64);
346 let ode = unsafe { &mut *ode };
347 match ode.rhs(params, t, y) {
348 Ok(array) => {
349 let boxed = boxed_host_array(array);
350 unsafe {
351 *out_array = boxed;
352 }
353 DIFFSOL_OK
354 }
355 Err(err) => {
356 c_error!(&format!("{}", err));
357 DIFFSOL_ERR
358 }
359 }
360}
361
362#[unsafe(no_mangle)]
370pub unsafe extern "C" fn diffsol_ode_rhs_jac_mul(
371 ode: *mut OdeWrapper,
372 params_ptr: *const f64,
373 params_len: usize,
374 t: f64,
375 y_ptr: *const f64,
376 y_len: usize,
377 v_ptr: *const f64,
378 v_len: usize,
379 out_array: *mut *mut HostArray,
380) -> i32 {
381 if ode.is_null()
382 || out_array.is_null()
383 || !valid_f64_ptr(params_ptr, params_len)
384 || !valid_f64_ptr(y_ptr, y_len)
385 || !valid_f64_ptr(v_ptr, v_len)
386 {
387 c_invalid_arg!("invalid arguments to diffsol_ode_rhs_jac_mul");
388 return DIFFSOL_BAD_ARG;
389 }
390 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
391 let y = HostArray::new_vector(y_ptr as *mut u8, y_len, ScalarType::F64);
392 let v = HostArray::new_vector(v_ptr as *mut u8, v_len, ScalarType::F64);
393 let ode = unsafe { &mut *ode };
394 match ode.rhs_jac_mul(params, t, y, v) {
395 Ok(array) => {
396 let boxed = boxed_host_array(array);
397 unsafe {
398 *out_array = boxed;
399 }
400 DIFFSOL_OK
401 }
402 Err(err) => {
403 c_error!(&format!("{}", err));
404 DIFFSOL_ERR
405 }
406 }
407}
408
409#[unsafe(no_mangle)]
416pub unsafe extern "C" fn diffsol_ode_solve(
417 ode: *mut OdeWrapper,
418 params_ptr: *const f64,
419 params_len: usize,
420 final_time: f64,
421 out_solution: *mut *mut SolutionWrapper,
422) -> i32 {
423 if ode.is_null() || out_solution.is_null() || !valid_f64_ptr(params_ptr, params_len) {
424 c_invalid_arg!("invalid arguments to diffsol_ode_solve");
425 return DIFFSOL_BAD_ARG;
426 }
427 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
428 let ode = unsafe { &mut *ode };
429 match ode.solve(params, final_time) {
430 Ok(new_solution) => {
431 unsafe {
432 *out_solution = Box::into_raw(Box::new(new_solution));
433 }
434 DIFFSOL_OK
435 }
436 Err(err) => {
437 c_error!(&format!("{}", err));
438 DIFFSOL_ERR
439 }
440 }
441}
442
443#[unsafe(no_mangle)]
450pub unsafe extern "C" fn diffsol_ode_solve_hybrid(
451 ode: *mut OdeWrapper,
452 params_ptr: *const f64,
453 params_len: usize,
454 final_time: f64,
455 out_solution: *mut *mut SolutionWrapper,
456) -> i32 {
457 if ode.is_null() || out_solution.is_null() || !valid_f64_ptr(params_ptr, params_len) {
458 c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid");
459 return DIFFSOL_BAD_ARG;
460 }
461 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
462 let ode = unsafe { &mut *ode };
463 match ode.solve_hybrid(params, final_time) {
464 Ok(new_solution) => {
465 unsafe {
466 *out_solution = Box::into_raw(Box::new(new_solution));
467 }
468 DIFFSOL_OK
469 }
470 Err(err) => {
471 c_error!(&format!("{}", err));
472 DIFFSOL_ERR
473 }
474 }
475}
476
477#[unsafe(no_mangle)]
484pub unsafe extern "C" fn diffsol_ode_solve_dense(
485 ode: *mut OdeWrapper,
486 params_ptr: *const f64,
487 params_len: usize,
488 t_eval_ptr: *const f64,
489 t_eval_len: usize,
490 out_solution: *mut *mut SolutionWrapper,
491) -> i32 {
492 if ode.is_null()
493 || out_solution.is_null()
494 || !valid_f64_ptr(params_ptr, params_len)
495 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
496 {
497 c_invalid_arg!("invalid arguments to diffsol_ode_solve_dense");
498 return DIFFSOL_BAD_ARG;
499 }
500 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
501 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
502 let ode = unsafe { &mut *ode };
503 match ode.solve_dense(params, t_eval) {
504 Ok(new_solution) => {
505 unsafe {
506 *out_solution = Box::into_raw(Box::new(new_solution));
507 }
508 DIFFSOL_OK
509 }
510 Err(err) => {
511 c_error!(&format!("{}", err));
512 DIFFSOL_ERR
513 }
514 }
515}
516
517#[unsafe(no_mangle)]
524pub unsafe extern "C" fn diffsol_ode_solve_hybrid_dense(
525 ode: *mut OdeWrapper,
526 params_ptr: *const f64,
527 params_len: usize,
528 t_eval_ptr: *const f64,
529 t_eval_len: usize,
530 out_solution: *mut *mut SolutionWrapper,
531) -> i32 {
532 if ode.is_null()
533 || out_solution.is_null()
534 || !valid_f64_ptr(params_ptr, params_len)
535 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
536 {
537 c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid_dense");
538 return DIFFSOL_BAD_ARG;
539 }
540 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
541 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
542 let ode = unsafe { &mut *ode };
543 match ode.solve_hybrid_dense(params, t_eval) {
544 Ok(new_solution) => {
545 unsafe {
546 *out_solution = Box::into_raw(Box::new(new_solution));
547 }
548 DIFFSOL_OK
549 }
550 Err(err) => {
551 c_error!(&format!("{}", err));
552 DIFFSOL_ERR
553 }
554 }
555}
556
557#[unsafe(no_mangle)]
564pub unsafe extern "C" fn diffsol_ode_solve_fwd_sens(
565 ode: *mut OdeWrapper,
566 params_ptr: *const f64,
567 params_len: usize,
568 t_eval_ptr: *const f64,
569 t_eval_len: usize,
570 out_solution: *mut *mut SolutionWrapper,
571) -> i32 {
572 if ode.is_null()
573 || out_solution.is_null()
574 || !valid_f64_ptr(params_ptr, params_len)
575 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
576 {
577 c_invalid_arg!("invalid arguments to diffsol_ode_solve_fwd_sens");
578 return DIFFSOL_BAD_ARG;
579 }
580 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
581 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
582 let ode = unsafe { &mut *ode };
583 match ode.solve_fwd_sens(params, t_eval) {
584 Ok(new_solution) => {
585 unsafe {
586 *out_solution = Box::into_raw(Box::new(new_solution));
587 }
588 DIFFSOL_OK
589 }
590 Err(err) => {
591 c_error!(&format!("{}", err));
592 DIFFSOL_ERR
593 }
594 }
595}
596
597#[unsafe(no_mangle)]
604pub unsafe extern "C" fn diffsol_ode_solve_hybrid_fwd_sens(
605 ode: *mut OdeWrapper,
606 params_ptr: *const f64,
607 params_len: usize,
608 t_eval_ptr: *const f64,
609 t_eval_len: usize,
610 out_solution: *mut *mut SolutionWrapper,
611) -> i32 {
612 if ode.is_null()
613 || out_solution.is_null()
614 || !valid_f64_ptr(params_ptr, params_len)
615 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
616 {
617 c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid_fwd_sens");
618 return DIFFSOL_BAD_ARG;
619 }
620 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
621 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
622 let ode = unsafe { &mut *ode };
623 match ode.solve_hybrid_fwd_sens(params, t_eval) {
624 Ok(new_solution) => {
625 unsafe {
626 *out_solution = Box::into_raw(Box::new(new_solution));
627 }
628 DIFFSOL_OK
629 }
630 Err(err) => {
631 c_error!(&format!("{}", err));
632 DIFFSOL_ERR
633 }
634 }
635}
636
637#[unsafe(no_mangle)]
645pub unsafe extern "C" fn diffsol_ode_solve_sum_squares_adj(
646 ode: *mut OdeWrapper,
647 params_ptr: *const f64,
648 params_len: usize,
649 data_ptr: *const f64,
650 data_rows: usize,
651 data_cols: usize,
652 data_row_stride: usize,
653 data_col_stride: usize,
654 t_eval_ptr: *const f64,
655 t_eval_len: usize,
656 out_value: *mut f64,
657 out_sens: *mut *mut HostArray,
658) -> i32 {
659 if ode.is_null()
660 || out_value.is_null()
661 || out_sens.is_null()
662 || data_ptr.is_null()
663 || !valid_f64_ptr(params_ptr, params_len)
664 || !valid_f64_ptr(t_eval_ptr, t_eval_len)
665 {
666 c_invalid_arg!("invalid arguments to diffsol_ode_solve_sum_squares_adj");
667 return DIFFSOL_BAD_ARG;
668 }
669 let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
670 let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
671 let data = HostArray::new_col_major(
672 data_ptr as *mut u8,
673 data_rows,
674 data_cols,
675 data_row_stride as isize,
676 data_col_stride as isize,
677 ScalarType::F64,
678 );
679 let ode = unsafe { &mut *ode };
680 match ode.solve_sum_squares_adj(params, data, t_eval) {
681 Ok((value, sens)) => {
682 let sens_boxed = boxed_host_array(sens);
683 unsafe {
684 *out_value = value;
685 *out_sens = sens_boxed;
686 }
687 DIFFSOL_OK
688 }
689 Err(err) => {
690 c_error!(&format!("{}", err));
691 DIFFSOL_ERR
692 }
693 }
694}
695
696#[unsafe(no_mangle)]
701pub unsafe extern "C" fn diffsol_ode_get_matrix_type(ode: *const OdeWrapper) -> i32 {
702 if ode.is_null() {
703 c_invalid_arg!("ode is null");
704 return -1;
705 }
706 let ode = unsafe { &*ode };
707 match ode.get_matrix_type() {
708 Ok(value) => matrix_type_to_i32(value),
709 Err(err) => {
710 c_error!(&format!("{}", err));
711 -1
712 }
713 }
714}
715
716#[unsafe(no_mangle)]
721pub unsafe extern "C" fn diffsol_ode_get_ode_solver(ode: *const OdeWrapper) -> i32 {
722 if ode.is_null() {
723 c_invalid_arg!("ode is null");
724 return -1;
725 }
726 let ode = unsafe { &*ode };
727 match ode.get_ode_solver() {
728 Ok(value) => ode_solver_to_i32(value),
729 Err(err) => {
730 c_error!(&format!("{}", err));
731 -1
732 }
733 }
734}
735
736#[unsafe(no_mangle)]
741pub unsafe extern "C" fn diffsol_ode_set_ode_solver(ode: *mut OdeWrapper, value: i32) -> i32 {
742 if ode.is_null() {
743 c_invalid_arg!("ode is null");
744 return DIFFSOL_BAD_ARG;
745 }
746 let value = match ode_solver_from_i32(value) {
747 Some(v) => v,
748 None => {
749 c_invalid_arg!("invalid ode_solver");
750 return DIFFSOL_BAD_ARG;
751 }
752 };
753 let ode = unsafe { &mut *ode };
754 match ode.set_ode_solver(value) {
755 Ok(()) => DIFFSOL_OK,
756 Err(err) => c_error!(&format!("{}", err)),
757 }
758}
759
760#[unsafe(no_mangle)]
765pub unsafe extern "C" fn diffsol_ode_get_linear_solver(ode: *const OdeWrapper) -> i32 {
766 if ode.is_null() {
767 c_invalid_arg!("ode is null");
768 return -1;
769 }
770 let ode = unsafe { &*ode };
771 match ode.get_linear_solver() {
772 Ok(value) => linear_solver_to_i32(value),
773 Err(err) => {
774 c_error!(&format!("{}", err));
775 -1
776 }
777 }
778}
779
780#[unsafe(no_mangle)]
785pub unsafe extern "C" fn diffsol_ode_set_linear_solver(ode: *mut OdeWrapper, value: i32) -> i32 {
786 if ode.is_null() {
787 c_invalid_arg!("ode is null");
788 return DIFFSOL_BAD_ARG;
789 }
790 let value = match linear_solver_from_i32(value) {
791 Some(v) => v,
792 None => {
793 c_invalid_arg!("invalid linear_solver");
794 return DIFFSOL_BAD_ARG;
795 }
796 };
797 let ode = unsafe { &mut *ode };
798 match ode.set_linear_solver(value) {
799 Ok(()) => DIFFSOL_OK,
800 Err(err) => c_error!(&format!("{}", err)),
801 }
802}
803
804#[unsafe(no_mangle)]
810pub unsafe extern "C" fn diffsol_ode_get_rtol(ode: *const OdeWrapper, out_value: *mut f64) -> i32 {
811 if ode.is_null() || out_value.is_null() {
812 c_invalid_arg!("invalid arguments to diffsol_ode_get_rtol");
813 return DIFFSOL_BAD_ARG;
814 }
815 let ode = unsafe { &*ode };
816 match ode.get_rtol() {
817 Ok(value) => {
818 unsafe {
819 *out_value = value;
820 }
821 DIFFSOL_OK
822 }
823 Err(err) => c_error!(&format!("{}", err)),
824 }
825}
826
827#[unsafe(no_mangle)]
832pub unsafe extern "C" fn diffsol_ode_set_rtol(ode: *mut OdeWrapper, value: f64) -> i32 {
833 if ode.is_null() {
834 c_invalid_arg!("ode is null");
835 return DIFFSOL_BAD_ARG;
836 }
837 let ode = unsafe { &mut *ode };
838 match ode.set_rtol(value) {
839 Ok(()) => DIFFSOL_OK,
840 Err(err) => c_error!(&format!("{}", err)),
841 }
842}
843
844#[unsafe(no_mangle)]
850pub unsafe extern "C" fn diffsol_ode_get_atol(ode: *const OdeWrapper, out_value: *mut f64) -> i32 {
851 if ode.is_null() || out_value.is_null() {
852 c_invalid_arg!("invalid arguments to diffsol_ode_get_atol");
853 return DIFFSOL_BAD_ARG;
854 }
855 let ode = unsafe { &*ode };
856 match ode.get_atol() {
857 Ok(value) => {
858 unsafe {
859 *out_value = value;
860 }
861 DIFFSOL_OK
862 }
863 Err(err) => c_error!(&format!("{}", err)),
864 }
865}
866
867#[unsafe(no_mangle)]
872pub unsafe extern "C" fn diffsol_ode_set_atol(ode: *mut OdeWrapper, value: f64) -> i32 {
873 if ode.is_null() {
874 c_invalid_arg!("ode is null");
875 return DIFFSOL_BAD_ARG;
876 }
877 let ode = unsafe { &mut *ode };
878 match ode.set_atol(value) {
879 Ok(()) => DIFFSOL_OK,
880 Err(err) => c_error!(&format!("{}", err)),
881 }
882}
883
884#[cfg(all(test, feature = "diffsl-external-f64"))]
885mod tests {
886 use std::ptr;
887
888 use crate::initial_condition_options::InitialConditionSolverOptions;
889 use crate::linear_solver_type::LinearSolverType;
890 use crate::linear_solver_type_c::{
891 diffsol_linear_solver_type_count, diffsol_linear_solver_type_is_valid,
892 diffsol_linear_solver_type_name, linear_solver_to_i32,
893 };
894 use crate::matrix_type::MatrixType;
895 use crate::ode_options::OdeSolverOptions;
896 use crate::ode_options_c::{
897 diffsol_ode_options_free, diffsol_ode_options_get_max_nonlinear_solver_iterations,
898 diffsol_ode_options_get_min_timestep,
899 diffsol_ode_options_set_max_nonlinear_solver_iterations,
900 diffsol_ode_options_set_min_timestep,
901 };
902 use crate::ode_solver_type::OdeSolverType;
903 use crate::ode_solver_type_c::{
904 diffsol_ode_solver_type_count, diffsol_ode_solver_type_is_valid,
905 diffsol_ode_solver_type_name, ode_solver_to_i32,
906 };
907 use crate::scalar_type::ScalarType;
908 use crate::scalar_type_c::{
909 diffsol_scalar_type_count, diffsol_scalar_type_is_valid, diffsol_scalar_type_name,
910 scalar_type_to_i32,
911 };
912 use crate::solution_wrapper_c::{
913 diffsol_solution_wrapper_get_sens, diffsol_solution_wrapper_get_ts,
914 diffsol_solution_wrapper_get_ys,
915 };
916 use crate::test_support::{
917 assert_close, assert_last_error_contains, c_string, clear_last_error, ffi_free_solution,
918 ffi_read_host_array_list_matrices, ffi_read_host_array_matrix, ffi_read_host_array_vector,
919 find_time_window, logistic_state, logistic_state_dr, mass_state_deps, rhs_input_deps,
920 rhs_state_deps, ASSERT_TOL, LOGISTIC_X0,
921 };
922 use crate::{
923 initial_condition_options_c::{
924 diffsol_ic_options_free, diffsol_ic_options_get_max_linesearch_iterations,
925 diffsol_ic_options_get_use_linesearch,
926 diffsol_ic_options_set_max_linesearch_iterations,
927 diffsol_ic_options_set_use_linesearch,
928 },
929 matrix_type_c::{
930 diffsol_matrix_type_count, diffsol_matrix_type_is_valid, diffsol_matrix_type_name,
931 matrix_type_to_i32,
932 },
933 };
934
935 use super::*;
936
937 unsafe fn make_ode_ptr(
938 matrix_type: i32,
939 linear_solver: i32,
940 ode_solver: i32,
941 ) -> *mut OdeWrapper {
942 let rhs_state_deps = rhs_state_deps();
943 let rhs_input_deps = rhs_input_deps();
944 let mass_state_deps = mass_state_deps();
945 unsafe {
946 diffsol_ode_new_external(
947 matrix_type,
948 linear_solver,
949 ode_solver,
950 rhs_state_deps.as_ptr() as *const usize,
951 rhs_state_deps.len(),
952 rhs_input_deps.as_ptr() as *const usize,
953 rhs_input_deps.len(),
954 mass_state_deps.as_ptr() as *const usize,
955 mass_state_deps.len(),
956 )
957 }
958 }
959
960 #[test]
961 fn c_api_reports_enum_metadata() {
962 clear_last_error();
963 unsafe {
964 assert_eq!(diffsol_matrix_type_count(), 3);
965 assert_eq!(diffsol_ode_solver_type_count(), 4);
966 assert_eq!(diffsol_linear_solver_type_count(), 3);
967 assert_eq!(diffsol_scalar_type_count(), 2);
968
969 assert_eq!(
970 c_string(diffsol_matrix_type_name(matrix_type_to_i32(
971 MatrixType::NalgebraDense
972 ))),
973 "nalgebra_dense"
974 );
975 assert_eq!(
976 c_string(diffsol_ode_solver_type_name(ode_solver_to_i32(
977 OdeSolverType::Bdf
978 ))),
979 "bdf"
980 );
981 assert_eq!(
982 c_string(diffsol_linear_solver_type_name(linear_solver_to_i32(
983 LinearSolverType::Default
984 ))),
985 "default"
986 );
987 assert_eq!(
988 c_string(diffsol_scalar_type_name(scalar_type_to_i32(
989 ScalarType::F64
990 ))),
991 "f64"
992 );
993 }
994 }
995
996 #[test]
997 fn c_api_invalid_enums_set_last_error() {
998 clear_last_error();
999 unsafe {
1000 assert_eq!(diffsol_matrix_type_is_valid(99), 0);
1001 assert_last_error_contains("invalid matrix_type");
1002 clear_last_error();
1003
1004 assert_eq!(diffsol_ode_solver_type_is_valid(99), 0);
1005 assert_last_error_contains("invalid ode_solver_type");
1006 clear_last_error();
1007
1008 assert_eq!(diffsol_linear_solver_type_is_valid(99), 0);
1009 assert_last_error_contains("invalid linear_solver_type");
1010 clear_last_error();
1011
1012 assert_eq!(diffsol_scalar_type_is_valid(99), 0);
1013 assert_last_error_contains("invalid scalar_type");
1014 }
1015 }
1016
1017 #[test]
1018 fn c_api_rejects_invalid_ode_arguments() {
1019 clear_last_error();
1020 unsafe {
1021 let mut out_array = ptr::null_mut();
1022 let status = diffsol_ode_y0(ptr::null_mut(), ptr::null(), 0, &mut out_array);
1023 assert_eq!(status, DIFFSOL_BAD_ARG);
1024 assert!(out_array.is_null());
1025 assert_last_error_contains("invalid arguments to diffsol_ode_y0");
1026 clear_last_error();
1027
1028 let ode = make_ode_ptr(
1029 99,
1030 linear_solver_to_i32(LinearSolverType::Default),
1031 ode_solver_to_i32(OdeSolverType::Bdf),
1032 );
1033 assert!(ode.is_null());
1034 assert_last_error_contains("invalid matrix_type");
1035 }
1036 }
1037
1038 #[test]
1039 fn c_api_full_lifecycle_matches_external_logistic_model() {
1040 clear_last_error();
1041 unsafe {
1042 let ode = make_ode_ptr(
1043 matrix_type_to_i32(MatrixType::NalgebraDense),
1044 linear_solver_to_i32(LinearSolverType::Default),
1045 ode_solver_to_i32(OdeSolverType::Bdf),
1046 );
1047 assert!(!ode.is_null());
1048
1049 assert_eq!(
1050 diffsol_ode_get_matrix_type(ode),
1051 matrix_type_to_i32(MatrixType::NalgebraDense)
1052 );
1053 assert_eq!(
1054 diffsol_ode_get_ode_solver(ode),
1055 ode_solver_to_i32(OdeSolverType::Bdf)
1056 );
1057 assert_eq!(
1058 diffsol_ode_get_linear_solver(ode),
1059 linear_solver_to_i32(LinearSolverType::Default)
1060 );
1061
1062 assert_eq!(
1063 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1064 DIFFSOL_OK
1065 );
1066 assert_eq!(
1067 diffsol_ode_get_ode_solver(ode),
1068 ode_solver_to_i32(OdeSolverType::Tsit45)
1069 );
1070 assert_eq!(
1071 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1072 DIFFSOL_OK
1073 );
1074
1075 assert_eq!(diffsol_ode_set_rtol(ode, 1e-8), DIFFSOL_OK);
1076 assert_eq!(diffsol_ode_set_atol(ode, 1e-8), DIFFSOL_OK);
1077 let mut rtol = 0.0;
1078 let mut atol = 0.0;
1079 assert_eq!(diffsol_ode_get_rtol(ode, &mut rtol), DIFFSOL_OK);
1080 assert_eq!(diffsol_ode_get_atol(ode, &mut atol), DIFFSOL_OK);
1081 assert_close(rtol, 1e-8, ASSERT_TOL, "rtol roundtrip");
1082 assert_close(atol, 1e-8, ASSERT_TOL, "atol roundtrip");
1083
1084 let mut ic_options: *mut InitialConditionSolverOptions = ptr::null_mut();
1085 assert_eq!(diffsol_ode_get_ic_options(ode, &mut ic_options), DIFFSOL_OK);
1086 assert!(!ic_options.is_null());
1087 let mut use_linesearch = 0;
1088 let mut max_linesearch_iterations = 0usize;
1089 assert_eq!(
1090 diffsol_ic_options_get_use_linesearch(ic_options, &mut use_linesearch),
1091 DIFFSOL_OK
1092 );
1093 assert_eq!(
1094 diffsol_ic_options_set_use_linesearch(ic_options, 1),
1095 DIFFSOL_OK
1096 );
1097 assert_eq!(
1098 diffsol_ic_options_get_use_linesearch(ic_options, &mut use_linesearch),
1099 DIFFSOL_OK
1100 );
1101 assert_eq!(use_linesearch, 1);
1102 assert_eq!(
1103 diffsol_ic_options_set_max_linesearch_iterations(ic_options, 23),
1104 DIFFSOL_OK
1105 );
1106 assert_eq!(
1107 diffsol_ic_options_get_max_linesearch_iterations(
1108 ic_options,
1109 &mut max_linesearch_iterations
1110 ),
1111 DIFFSOL_OK
1112 );
1113 assert_eq!(max_linesearch_iterations, 23);
1114 diffsol_ic_options_free(ic_options);
1115
1116 let mut ode_options: *mut OdeSolverOptions = ptr::null_mut();
1117 assert_eq!(diffsol_ode_get_options(ode, &mut ode_options), DIFFSOL_OK);
1118 assert!(!ode_options.is_null());
1119 let mut max_nonlinear_iterations = 0usize;
1120 let mut min_timestep = 0.0;
1121 assert_eq!(
1122 diffsol_ode_options_set_max_nonlinear_solver_iterations(ode_options, 17),
1123 DIFFSOL_OK
1124 );
1125 assert_eq!(
1126 diffsol_ode_options_get_max_nonlinear_solver_iterations(
1127 ode_options,
1128 &mut max_nonlinear_iterations
1129 ),
1130 DIFFSOL_OK
1131 );
1132 assert_eq!(max_nonlinear_iterations, 17);
1133 assert_eq!(
1134 diffsol_ode_options_set_min_timestep(ode_options, 1e-4),
1135 DIFFSOL_OK
1136 );
1137 assert_eq!(
1138 diffsol_ode_options_get_min_timestep(ode_options, &mut min_timestep),
1139 DIFFSOL_OK
1140 );
1141 assert_close(min_timestep, 1e-4, ASSERT_TOL, "min_timestep roundtrip");
1142 diffsol_ode_options_free(ode_options);
1143
1144 let params = [2.0f64];
1145 let y = [0.25f64];
1146 let v = [3.0f64];
1147
1148 let mut y0_ptr = ptr::null_mut();
1149 assert_eq!(
1150 diffsol_ode_y0(ode, params.as_ptr(), params.len(), &mut y0_ptr),
1151 DIFFSOL_OK
1152 );
1153 assert_eq!(ffi_read_host_array_vector(y0_ptr), vec![LOGISTIC_X0]);
1154
1155 let mut rhs_ptr = ptr::null_mut();
1156 assert_eq!(
1157 diffsol_ode_rhs(
1158 ode,
1159 params.as_ptr(),
1160 params.len(),
1161 0.0,
1162 y.as_ptr(),
1163 y.len(),
1164 &mut rhs_ptr,
1165 ),
1166 DIFFSOL_OK
1167 );
1168 assert_close(
1169 ffi_read_host_array_vector(rhs_ptr)[0],
1170 0.375,
1171 ASSERT_TOL,
1172 "ffi rhs",
1173 );
1174
1175 let mut rhs_jac_mul_ptr = ptr::null_mut();
1176 assert_eq!(
1177 diffsol_ode_rhs_jac_mul(
1178 ode,
1179 params.as_ptr(),
1180 params.len(),
1181 0.0,
1182 y.as_ptr(),
1183 y.len(),
1184 v.as_ptr(),
1185 v.len(),
1186 &mut rhs_jac_mul_ptr,
1187 ),
1188 DIFFSOL_OK
1189 );
1190 assert_close(
1191 ffi_read_host_array_vector(rhs_jac_mul_ptr)[0],
1192 3.0,
1193 ASSERT_TOL,
1194 "ffi rhs_jac_mul",
1195 );
1196
1197 let mut solve_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1198 assert_eq!(
1199 diffsol_ode_solve(
1200 ode,
1201 params.as_ptr(),
1202 params.len(),
1203 1e-9,
1204 &mut solve_solution_ptr
1205 ),
1206 DIFFSOL_OK
1207 );
1208 assert!(!solve_solution_ptr.is_null());
1209
1210 let mut solve_ys_ptr = ptr::null_mut();
1211 let mut solve_ts_ptr = ptr::null_mut();
1212 assert_eq!(
1213 diffsol_solution_wrapper_get_ys(solve_solution_ptr, &mut solve_ys_ptr),
1214 DIFFSOL_OK
1215 );
1216 assert_eq!(
1217 diffsol_solution_wrapper_get_ts(solve_solution_ptr, &mut solve_ts_ptr),
1218 DIFFSOL_OK
1219 );
1220 let (solve_rows, solve_cols, solve_ys) = ffi_read_host_array_matrix(solve_ys_ptr);
1221 let solve_ts = ffi_read_host_array_vector(solve_ts_ptr);
1222 assert_eq!(solve_rows, 1);
1223 assert_eq!(solve_cols, solve_ts.len());
1224 assert!(!solve_ts.is_empty());
1225 assert_close(
1226 *solve_ts.last().unwrap(),
1227 1e-9,
1228 ASSERT_TOL,
1229 "ffi solve final time",
1230 );
1231 assert_close(
1232 *solve_ys.last().unwrap(),
1233 logistic_state(LOGISTIC_X0, 2.0, 1e-9),
1234 ASSERT_TOL,
1235 "ffi solve final value",
1236 );
1237 ffi_free_solution(solve_solution_ptr);
1238
1239 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1240 assert_eq!(
1241 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1242 DIFFSOL_OK
1243 );
1244
1245 let t_eval = [0.25f64, 0.5f64, 1.0f64];
1246 assert_eq!(
1247 diffsol_ode_solve_dense(
1248 ode,
1249 params.as_ptr(),
1250 params.len(),
1251 t_eval.as_ptr(),
1252 t_eval.len(),
1253 &mut solution_ptr,
1254 ),
1255 DIFFSOL_OK
1256 );
1257 let mut ys_ptr = ptr::null_mut();
1258 let mut ts_ptr = ptr::null_mut();
1259 assert_eq!(
1260 diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
1261 DIFFSOL_OK
1262 );
1263 assert_eq!(
1264 diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
1265 DIFFSOL_OK
1266 );
1267 let (rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
1268 let ts = ffi_read_host_array_vector(ts_ptr);
1269 assert_eq!(rows, 1);
1270 assert_eq!(cols, ts.len());
1271 let start = find_time_window(&ts, &t_eval, ASSERT_TOL);
1272 for (i, &t) in t_eval.iter().enumerate() {
1273 assert_close(ts[start + i], t, ASSERT_TOL, "ffi solution time");
1274 assert_close(
1275 ys[start + i],
1276 logistic_state(0.1, 2.0, t),
1277 5e-4,
1278 "ffi solution value",
1279 );
1280 }
1281 assert_eq!(
1282 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1283 DIFFSOL_OK
1284 );
1285
1286 let hybrid_t_eval = [0.5f64, 1.0, 1.25, 1.5, 2.0];
1287 let hybrid_ode = make_ode_ptr(
1288 matrix_type_to_i32(MatrixType::NalgebraDense),
1289 linear_solver_to_i32(LinearSolverType::Default),
1290 ode_solver_to_i32(OdeSolverType::Bdf),
1291 );
1292 assert!(!hybrid_ode.is_null());
1293 let mut hybrid_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1294 assert_eq!(
1295 diffsol_ode_solve_hybrid_dense(
1296 hybrid_ode,
1297 params.as_ptr(),
1298 params.len(),
1299 hybrid_t_eval.as_ptr(),
1300 hybrid_t_eval.len(),
1301 &mut hybrid_solution_ptr,
1302 ),
1303 DIFFSOL_OK
1304 );
1305 let mut hybrid_ys_ptr = ptr::null_mut();
1306 let mut hybrid_ts_ptr = ptr::null_mut();
1307 assert_eq!(
1308 diffsol_solution_wrapper_get_ys(hybrid_solution_ptr, &mut hybrid_ys_ptr),
1309 DIFFSOL_OK
1310 );
1311 assert_eq!(
1312 diffsol_solution_wrapper_get_ts(hybrid_solution_ptr, &mut hybrid_ts_ptr),
1313 DIFFSOL_OK
1314 );
1315 let (hybrid_rows, hybrid_cols, hybrid_ys) = ffi_read_host_array_matrix(hybrid_ys_ptr);
1316 let hybrid_ts = ffi_read_host_array_vector(hybrid_ts_ptr);
1317 assert_eq!(hybrid_rows, 1);
1318 assert_eq!(hybrid_cols, hybrid_t_eval.len());
1319 assert_eq!(hybrid_ts, hybrid_t_eval);
1320 assert_close(
1321 hybrid_ys[0],
1322 logistic_state(LOGISTIC_X0, 2.0, hybrid_t_eval[0]),
1323 5e-4,
1324 "ffi hybrid dense pre-root value",
1325 );
1326 assert_close(
1327 hybrid_ys[1],
1328 logistic_state(LOGISTIC_X0, 2.0, hybrid_t_eval[1]),
1329 5e-4,
1330 "ffi hybrid dense near-root value",
1331 );
1332 for (i, value) in hybrid_ys.iter().enumerate().skip(2) {
1333 assert_close(
1334 *value,
1335 1.0,
1336 5e-4,
1337 &format!("ffi hybrid dense post-root value[{i}]"),
1338 );
1339 }
1340 ffi_free_solution(hybrid_solution_ptr);
1341 diffsol_ode_free(hybrid_ode);
1342
1343 let analysis_ode = make_ode_ptr(
1344 matrix_type_to_i32(MatrixType::NalgebraDense),
1345 linear_solver_to_i32(LinearSolverType::Default),
1346 ode_solver_to_i32(OdeSolverType::Bdf),
1347 );
1348 assert!(!analysis_ode.is_null());
1349
1350 let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1351 assert_eq!(
1352 diffsol_ode_solve_fwd_sens(
1353 analysis_ode,
1354 params.as_ptr(),
1355 params.len(),
1356 t_eval.as_ptr(),
1357 t_eval.len(),
1358 &mut sens_solution_ptr,
1359 ),
1360 DIFFSOL_OK
1361 );
1362 let mut sens_list = ptr::null_mut();
1363 let mut sens_len = 0usize;
1364 assert_eq!(
1365 diffsol_solution_wrapper_get_sens(sens_solution_ptr, &mut sens_list, &mut sens_len),
1366 DIFFSOL_OK
1367 );
1368 let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
1369 assert_eq!(sens_values.len(), 1);
1370 assert_eq!(sens_values[0].0, 1);
1371 assert_eq!(sens_values[0].1, t_eval.len());
1372 for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate() {
1373 assert_close(
1374 value,
1375 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1376 ASSERT_TOL,
1377 &format!("ffi sensitivity[{i}]"),
1378 );
1379 }
1380
1381 let adjoint_t_eval = [0.0f64, 0.25f64, 0.5f64, 1.0f64];
1382 let adjoint_data: Vec<f64> = adjoint_t_eval
1383 .iter()
1384 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1385 .collect();
1386 let mut objective = 0.0;
1387 let mut adjoint_grad_ptr = ptr::null_mut();
1388 assert_eq!(
1389 diffsol_ode_solve_sum_squares_adj(
1390 analysis_ode,
1391 params.as_ptr(),
1392 params.len(),
1393 adjoint_data.as_ptr(),
1394 1,
1395 adjoint_t_eval.len(),
1396 1,
1397 1,
1398 adjoint_t_eval.as_ptr(),
1399 adjoint_t_eval.len(),
1400 &mut objective,
1401 &mut adjoint_grad_ptr,
1402 ),
1403 DIFFSOL_OK
1404 );
1405 assert_close(objective, 0.0, ASSERT_TOL, "ffi adjoint objective");
1406 let grad = ffi_read_host_array_vector(adjoint_grad_ptr);
1407 assert_eq!(grad.len(), 1);
1408 assert_close(grad[0], 0.0, ASSERT_TOL, "ffi adjoint gradient");
1409
1410 ffi_free_solution(sens_solution_ptr);
1411 diffsol_ode_free(analysis_ode);
1412 ffi_free_solution(solution_ptr);
1413 diffsol_ode_free(ode);
1414 }
1415 }
1416}
1417
1418#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
1419mod jit_tests {
1420 use std::ffi::{CStr, CString};
1421 use std::ptr;
1422
1423 use crate::error_c::{diffsol_error_code, diffsol_last_error_message};
1424 use crate::initial_condition_options_c::diffsol_ic_options_free;
1425 use crate::jit::JitBackendType;
1426 use crate::jit_c::jit_backend_to_i32;
1427 use crate::linear_solver_type::LinearSolverType;
1428 use crate::linear_solver_type_c::linear_solver_to_i32;
1429 use crate::matrix_type::MatrixType;
1430 use crate::matrix_type_c::matrix_type_to_i32;
1431 use crate::ode_options_c::diffsol_ode_options_free;
1432 use crate::ode_solver_type::OdeSolverType;
1433 use crate::ode_solver_type_c::ode_solver_to_i32;
1434 #[cfg(feature = "diffsl-llvm")]
1435 use crate::solution_wrapper_c::diffsol_solution_wrapper_get_sens;
1436 use crate::solution_wrapper_c::{
1437 diffsol_solution_wrapper_get_ts, diffsol_solution_wrapper_get_ys,
1438 };
1439 #[cfg(feature = "diffsl-llvm")]
1440 use crate::test_support::ffi_read_host_array_list_matrices;
1441 use crate::test_support::{
1442 assert_close, available_jit_backends, clear_last_error, ffi_free_solution,
1443 ffi_read_host_array_matrix, ffi_read_host_array_vector, find_time_window,
1444 hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code_cstring,
1445 logistic_state, ASSERT_TOL, LOGISTIC_X0,
1446 };
1447 #[cfg(feature = "diffsl-llvm")]
1448 use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
1449
1450 use super::*;
1451
1452 unsafe fn make_ode_ptr(
1453 jit_backend: JitBackendType,
1454 matrix_type: i32,
1455 linear_solver: i32,
1456 ode_solver: i32,
1457 ) -> *mut OdeWrapper {
1458 let code = logistic_diffsl_code_cstring();
1459 unsafe {
1460 make_ode_ptr_with_code(
1461 jit_backend,
1462 code.as_ptr(),
1463 matrix_type,
1464 linear_solver,
1465 ode_solver,
1466 )
1467 }
1468 }
1469
1470 unsafe fn make_ode_ptr_with_code(
1471 jit_backend: JitBackendType,
1472 code: *const std::os::raw::c_char,
1473 matrix_type: i32,
1474 linear_solver: i32,
1475 ode_solver: i32,
1476 ) -> *mut OdeWrapper {
1477 unsafe {
1478 diffsol_ode_new_jit(
1479 code,
1480 jit_backend_to_i32(jit_backend),
1481 matrix_type,
1482 linear_solver,
1483 ode_solver,
1484 )
1485 }
1486 }
1487
1488 unsafe fn last_error_message() -> String {
1489 let ptr = unsafe { diffsol_last_error_message() };
1490 assert_eq!(unsafe { diffsol_error_code() }, 1);
1491 assert!(!ptr.is_null());
1492 unsafe { CStr::from_ptr(ptr) }.to_str().unwrap().to_owned()
1493 }
1494
1495 #[test]
1496 fn c_api_full_lifecycle_matches_jit_logistic_model() {
1497 clear_last_error();
1498 for jit_backend in available_jit_backends() {
1499 unsafe {
1500 let ode = make_ode_ptr(
1501 jit_backend,
1502 matrix_type_to_i32(MatrixType::NalgebraDense),
1503 linear_solver_to_i32(LinearSolverType::Default),
1504 ode_solver_to_i32(OdeSolverType::Bdf),
1505 );
1506 assert!(!ode.is_null());
1507
1508 assert_eq!(
1509 diffsol_ode_get_matrix_type(ode),
1510 matrix_type_to_i32(MatrixType::NalgebraDense)
1511 );
1512 assert_eq!(
1513 diffsol_ode_get_ode_solver(ode),
1514 ode_solver_to_i32(OdeSolverType::Bdf)
1515 );
1516 assert_eq!(
1517 diffsol_ode_get_linear_solver(ode),
1518 linear_solver_to_i32(LinearSolverType::Default)
1519 );
1520
1521 let params = [2.0f64];
1522 let y = [0.25f64];
1523 let v = [3.0f64];
1524
1525 let mut y0_ptr = ptr::null_mut();
1526 assert_eq!(
1527 diffsol_ode_y0(ode, params.as_ptr(), params.len(), &mut y0_ptr),
1528 DIFFSOL_OK
1529 );
1530 assert_eq!(ffi_read_host_array_vector(y0_ptr), vec![LOGISTIC_X0]);
1531
1532 let mut rhs_ptr = ptr::null_mut();
1533 assert_eq!(
1534 diffsol_ode_rhs(
1535 ode,
1536 params.as_ptr(),
1537 params.len(),
1538 0.0,
1539 y.as_ptr(),
1540 y.len(),
1541 &mut rhs_ptr,
1542 ),
1543 DIFFSOL_OK
1544 );
1545 assert_close(
1546 ffi_read_host_array_vector(rhs_ptr)[0],
1547 0.375,
1548 ASSERT_TOL,
1549 "jit ffi rhs",
1550 );
1551
1552 let mut rhs_jac_mul_ptr = ptr::null_mut();
1553 assert_eq!(
1554 diffsol_ode_rhs_jac_mul(
1555 ode,
1556 params.as_ptr(),
1557 params.len(),
1558 0.0,
1559 y.as_ptr(),
1560 y.len(),
1561 v.as_ptr(),
1562 v.len(),
1563 &mut rhs_jac_mul_ptr,
1564 ),
1565 DIFFSOL_OK
1566 );
1567 assert_close(
1568 ffi_read_host_array_vector(rhs_jac_mul_ptr)[0],
1569 3.0,
1570 ASSERT_TOL,
1571 "jit ffi rhs_jac_mul",
1572 );
1573
1574 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1575 let t_eval = [0.25f64, 0.5f64, 1.0f64];
1576 assert_eq!(
1577 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1578 DIFFSOL_OK
1579 );
1580 assert_eq!(
1581 diffsol_ode_solve_dense(
1582 ode,
1583 params.as_ptr(),
1584 params.len(),
1585 t_eval.as_ptr(),
1586 t_eval.len(),
1587 &mut solution_ptr,
1588 ),
1589 DIFFSOL_OK
1590 );
1591 let mut ys_ptr = ptr::null_mut();
1592 let mut ts_ptr = ptr::null_mut();
1593 assert_eq!(
1594 diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
1595 DIFFSOL_OK
1596 );
1597 assert_eq!(
1598 diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
1599 DIFFSOL_OK
1600 );
1601 let (rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
1602 let ts = ffi_read_host_array_vector(ts_ptr);
1603 assert_eq!(rows, 1);
1604 assert_eq!(cols, ts.len());
1605 let start = find_time_window(&ts, &t_eval, ASSERT_TOL);
1606 for (i, &t) in t_eval.iter().enumerate() {
1607 assert_close(ts[start + i], t, ASSERT_TOL, "jit ffi solution time");
1608 assert_close(
1609 ys[start + i],
1610 logistic_state(LOGISTIC_X0, 2.0, t),
1611 5e-4,
1612 "jit ffi solution value",
1613 );
1614 }
1615 assert_eq!(
1616 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1617 DIFFSOL_OK
1618 );
1619
1620 #[cfg(feature = "diffsl-llvm")]
1621 {
1622 let analysis_code = logistic_diffsl_code_cstring();
1623 let analysis_ode = make_ode_ptr_with_code(
1624 JitBackendType::Llvm,
1625 analysis_code.as_ptr(),
1626 matrix_type_to_i32(MatrixType::NalgebraDense),
1627 linear_solver_to_i32(LinearSolverType::Default),
1628 ode_solver_to_i32(OdeSolverType::Bdf),
1629 );
1630 assert!(!analysis_ode.is_null());
1631
1632 let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1633 assert_eq!(
1634 diffsol_ode_solve_fwd_sens(
1635 analysis_ode,
1636 params.as_ptr(),
1637 params.len(),
1638 t_eval.as_ptr(),
1639 t_eval.len(),
1640 &mut sens_solution_ptr,
1641 ),
1642 DIFFSOL_OK
1643 );
1644 let mut sens_list = ptr::null_mut();
1645 let mut sens_len = 0usize;
1646 assert_eq!(
1647 diffsol_solution_wrapper_get_sens(
1648 sens_solution_ptr,
1649 &mut sens_list,
1650 &mut sens_len
1651 ),
1652 DIFFSOL_OK
1653 );
1654 let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
1655 assert_eq!(sens_values.len(), 1);
1656 assert_eq!(sens_values[0].0, 1);
1657 assert_eq!(sens_values[0].1, t_eval.len());
1658 for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate()
1659 {
1660 assert_close(
1661 value,
1662 logistic_state_dr(LOGISTIC_X0, 2.0, t),
1663 ASSERT_TOL,
1664 &format!("jit ffi sensitivity[{i}]"),
1665 );
1666 }
1667
1668 let adjoint_t_eval = [0.0f64, 0.25f64, 0.5f64, 1.0f64];
1669 let adjoint_data: Vec<f64> = adjoint_t_eval
1670 .iter()
1671 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1672 .collect();
1673 let mut objective = 0.0;
1674 let mut adjoint_grad_ptr = ptr::null_mut();
1675 assert_eq!(
1676 diffsol_ode_solve_sum_squares_adj(
1677 analysis_ode,
1678 params.as_ptr(),
1679 params.len(),
1680 adjoint_data.as_ptr(),
1681 1,
1682 adjoint_t_eval.len(),
1683 1,
1684 1,
1685 adjoint_t_eval.as_ptr(),
1686 adjoint_t_eval.len(),
1687 &mut objective,
1688 &mut adjoint_grad_ptr,
1689 ),
1690 DIFFSOL_OK
1691 );
1692 assert_close(objective, 0.0, ASSERT_TOL, "jit ffi adjoint objective");
1693 let grad = ffi_read_host_array_vector(adjoint_grad_ptr);
1694 assert_eq!(grad.len(), 1);
1695 assert!(
1696 grad[0].is_finite(),
1697 "jit ffi adjoint gradient should be finite"
1698 );
1699
1700 ffi_free_solution(sens_solution_ptr);
1701 diffsol_ode_free(analysis_ode);
1702 }
1703 ffi_free_solution(solution_ptr);
1704 diffsol_ode_free(ode);
1705 }
1706 }
1707 }
1708
1709 #[test]
1710 fn c_api_rejects_invalid_jit_arguments() {
1711 unsafe {
1712 clear_last_error();
1713 assert!(diffsol_ode_new_jit(
1714 ptr::null(),
1715 jit_backend_to_i32(available_jit_backends()[0]),
1716 matrix_type_to_i32(MatrixType::NalgebraDense),
1717 linear_solver_to_i32(LinearSolverType::Default),
1718 ode_solver_to_i32(OdeSolverType::Bdf),
1719 )
1720 .is_null());
1721 assert!(last_error_message().contains("code is null"));
1722
1723 clear_last_error();
1724 let invalid_utf8 = CString::from_vec_with_nul(vec![0xff, 0]).unwrap();
1725 assert!(diffsol_ode_new_jit(
1726 invalid_utf8.as_ptr(),
1727 jit_backend_to_i32(available_jit_backends()[0]),
1728 matrix_type_to_i32(MatrixType::NalgebraDense),
1729 linear_solver_to_i32(LinearSolverType::Default),
1730 ode_solver_to_i32(OdeSolverType::Bdf),
1731 )
1732 .is_null());
1733 assert!(last_error_message().contains("valid UTF-8"));
1734
1735 clear_last_error();
1736 let code = logistic_diffsl_code_cstring();
1737 assert!(diffsol_ode_new_jit(
1738 code.as_ptr(),
1739 99,
1740 matrix_type_to_i32(MatrixType::NalgebraDense),
1741 linear_solver_to_i32(LinearSolverType::Default),
1742 ode_solver_to_i32(OdeSolverType::Bdf),
1743 )
1744 .is_null());
1745 assert!(last_error_message().contains("invalid jit_backend_type"));
1746
1747 clear_last_error();
1748 assert!(diffsol_ode_new_jit(
1749 code.as_ptr(),
1750 jit_backend_to_i32(available_jit_backends()[0]),
1751 99,
1752 linear_solver_to_i32(LinearSolverType::Default),
1753 ode_solver_to_i32(OdeSolverType::Bdf),
1754 )
1755 .is_null());
1756 assert!(last_error_message().contains("invalid matrix_type"));
1757
1758 clear_last_error();
1759 assert!(diffsol_ode_new_jit(
1760 code.as_ptr(),
1761 jit_backend_to_i32(available_jit_backends()[0]),
1762 matrix_type_to_i32(MatrixType::NalgebraDense),
1763 99,
1764 ode_solver_to_i32(OdeSolverType::Bdf),
1765 )
1766 .is_null());
1767 assert!(last_error_message().contains("invalid linear_solver"));
1768
1769 clear_last_error();
1770 assert!(diffsol_ode_new_jit(
1771 code.as_ptr(),
1772 jit_backend_to_i32(available_jit_backends()[0]),
1773 matrix_type_to_i32(MatrixType::NalgebraDense),
1774 linear_solver_to_i32(LinearSolverType::Default),
1775 99,
1776 )
1777 .is_null());
1778 assert!(last_error_message().contains("invalid ode_solver"));
1779
1780 clear_last_error();
1781 let invalid_code = CString::new("not valid diffsl").unwrap();
1782 assert!(diffsol_ode_new_jit(
1783 invalid_code.as_ptr(),
1784 jit_backend_to_i32(available_jit_backends()[0]),
1785 matrix_type_to_i32(MatrixType::NalgebraDense),
1786 linear_solver_to_i32(LinearSolverType::Default),
1787 ode_solver_to_i32(OdeSolverType::Bdf),
1788 )
1789 .is_null());
1790 assert!(diffsol_error_code() != 0);
1791
1792 let mut ic_options = ptr::null_mut();
1793 assert_eq!(
1794 diffsol_ode_get_ic_options(ptr::null_mut(), &mut ic_options),
1795 DIFFSOL_BAD_ARG
1796 );
1797 let mut ode_options = ptr::null_mut();
1798 assert_eq!(
1799 diffsol_ode_get_options(ptr::null_mut(), &mut ode_options),
1800 DIFFSOL_BAD_ARG
1801 );
1802
1803 let mut out_array = ptr::null_mut();
1804 assert_eq!(
1805 diffsol_ode_y0(ptr::null_mut(), ptr::null(), 0, &mut out_array),
1806 DIFFSOL_BAD_ARG
1807 );
1808 assert_eq!(
1809 diffsol_ode_rhs(
1810 ptr::null_mut(),
1811 ptr::null(),
1812 0,
1813 0.0,
1814 ptr::null(),
1815 0,
1816 &mut out_array,
1817 ),
1818 DIFFSOL_BAD_ARG
1819 );
1820 assert_eq!(
1821 diffsol_ode_rhs_jac_mul(
1822 ptr::null_mut(),
1823 ptr::null(),
1824 0,
1825 0.0,
1826 ptr::null(),
1827 0,
1828 ptr::null(),
1829 0,
1830 &mut out_array,
1831 ),
1832 DIFFSOL_BAD_ARG
1833 );
1834
1835 clear_last_error();
1836 diffsol_ode_free(ptr::null_mut());
1837 assert!(last_error_message().contains("ode is null"));
1838
1839 clear_last_error();
1840 diffsol_host_array_list_free(ptr::null_mut(), 0);
1841 assert!(last_error_message().contains("host array list is null"));
1842 }
1843 }
1844
1845 #[test]
1846 fn c_api_jit_wrapper_branches_cover_runtime_success_and_errors() {
1847 for jit_backend in available_jit_backends() {
1848 unsafe {
1849 let ode = make_ode_ptr(
1850 jit_backend,
1851 matrix_type_to_i32(MatrixType::NalgebraDense),
1852 linear_solver_to_i32(LinearSolverType::Default),
1853 ode_solver_to_i32(OdeSolverType::Bdf),
1854 );
1855 assert!(!ode.is_null());
1856
1857 let mut ic_options = ptr::null_mut();
1858 let mut ode_options = ptr::null_mut();
1859 assert_eq!(diffsol_ode_get_ic_options(ode, &mut ic_options), DIFFSOL_OK);
1860 assert_eq!(diffsol_ode_get_options(ode, &mut ode_options), DIFFSOL_OK);
1861 diffsol_ic_options_free(ic_options);
1862 diffsol_ode_options_free(ode_options);
1863
1864 let mut out_value = 0.0;
1865 assert_eq!(diffsol_ode_get_rtol(ode, &mut out_value), DIFFSOL_OK);
1866 assert_close(out_value, 1e-6, ASSERT_TOL, "jit ffi default rtol");
1867 assert_eq!(diffsol_ode_set_rtol(ode, 1e-4), DIFFSOL_OK);
1868 assert_eq!(diffsol_ode_get_rtol(ode, &mut out_value), DIFFSOL_OK);
1869 assert_close(out_value, 1e-4, ASSERT_TOL, "jit ffi updated rtol");
1870
1871 assert_eq!(diffsol_ode_get_atol(ode, &mut out_value), DIFFSOL_OK);
1872 assert_close(out_value, 1e-6, ASSERT_TOL, "jit ffi default atol");
1873 assert_eq!(diffsol_ode_set_atol(ode, 1e-5), DIFFSOL_OK);
1874 assert_eq!(diffsol_ode_get_atol(ode, &mut out_value), DIFFSOL_OK);
1875 assert_close(out_value, 1e-5, ASSERT_TOL, "jit ffi updated atol");
1876
1877 assert_eq!(
1878 diffsol_ode_set_linear_solver(ode, linear_solver_to_i32(LinearSolverType::Lu)),
1879 DIFFSOL_OK
1880 );
1881 assert_eq!(
1882 diffsol_ode_get_linear_solver(ode),
1883 linear_solver_to_i32(LinearSolverType::Lu)
1884 );
1885 assert_eq!(
1886 diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1887 DIFFSOL_OK
1888 );
1889 assert_eq!(
1890 diffsol_ode_get_ode_solver(ode),
1891 ode_solver_to_i32(OdeSolverType::Tsit45)
1892 );
1893 assert_eq!(
1894 diffsol_ode_get_matrix_type(ode),
1895 matrix_type_to_i32(MatrixType::NalgebraDense)
1896 );
1897
1898 let params = [2.0f64];
1899 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1900 assert_eq!(
1901 diffsol_ode_solve(ode, params.as_ptr(), params.len(), 1.0, &mut solution_ptr),
1902 DIFFSOL_OK
1903 );
1904 ffi_free_solution(solution_ptr);
1905
1906 let t_eval = [0.25f64, 0.5f64, 1.0f64];
1907 let mut dense_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1908 assert_eq!(
1909 diffsol_ode_solve_dense(
1910 ode,
1911 params.as_ptr(),
1912 params.len(),
1913 t_eval.as_ptr(),
1914 t_eval.len(),
1915 &mut dense_solution_ptr,
1916 ),
1917 DIFFSOL_OK
1918 );
1919 ffi_free_solution(dense_solution_ptr);
1920
1921 let no_params: [f64; 0] = [];
1922 let y = [0.25f64];
1923 let v = [3.0f64];
1924 let mut out_array = ptr::null_mut();
1925 assert_eq!(
1926 diffsol_ode_y0(ode, no_params.as_ptr(), no_params.len(), &mut out_array),
1927 DIFFSOL_ERR
1928 );
1929 assert_eq!(
1930 diffsol_ode_rhs(
1931 ode,
1932 no_params.as_ptr(),
1933 no_params.len(),
1934 0.0,
1935 y.as_ptr(),
1936 y.len(),
1937 &mut out_array,
1938 ),
1939 DIFFSOL_ERR
1940 );
1941 assert_eq!(
1942 diffsol_ode_rhs_jac_mul(
1943 ode,
1944 no_params.as_ptr(),
1945 no_params.len(),
1946 0.0,
1947 y.as_ptr(),
1948 y.len(),
1949 v.as_ptr(),
1950 v.len(),
1951 &mut out_array,
1952 ),
1953 DIFFSOL_ERR
1954 );
1955
1956 let mut err_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1957 assert_eq!(
1958 diffsol_ode_solve(
1959 ode,
1960 no_params.as_ptr(),
1961 no_params.len(),
1962 1.0,
1963 &mut err_solution_ptr,
1964 ),
1965 DIFFSOL_ERR
1966 );
1967 assert_eq!(
1968 diffsol_ode_solve_hybrid(
1969 ode,
1970 no_params.as_ptr(),
1971 no_params.len(),
1972 1.0,
1973 &mut err_solution_ptr,
1974 ),
1975 DIFFSOL_ERR
1976 );
1977 assert_eq!(
1978 diffsol_ode_solve_dense(
1979 ode,
1980 no_params.as_ptr(),
1981 no_params.len(),
1982 t_eval.as_ptr(),
1983 t_eval.len(),
1984 &mut err_solution_ptr,
1985 ),
1986 DIFFSOL_ERR
1987 );
1988 assert_eq!(
1989 diffsol_ode_solve_hybrid_dense(
1990 ode,
1991 no_params.as_ptr(),
1992 no_params.len(),
1993 t_eval.as_ptr(),
1994 t_eval.len(),
1995 &mut err_solution_ptr,
1996 ),
1997 DIFFSOL_ERR
1998 );
1999
2000 #[cfg(feature = "diffsl-llvm")]
2001 if matches!(jit_backend, JitBackendType::Llvm) {
2002 assert_eq!(
2003 diffsol_ode_solve_fwd_sens(
2004 ode,
2005 no_params.as_ptr(),
2006 no_params.len(),
2007 t_eval.as_ptr(),
2008 t_eval.len(),
2009 &mut err_solution_ptr,
2010 ),
2011 DIFFSOL_ERR
2012 );
2013 assert_eq!(
2014 diffsol_ode_solve_hybrid_fwd_sens(
2015 ode,
2016 no_params.as_ptr(),
2017 no_params.len(),
2018 t_eval.as_ptr(),
2019 t_eval.len(),
2020 &mut err_solution_ptr,
2021 ),
2022 DIFFSOL_ERR
2023 );
2024
2025 let adjoint_data: Vec<f64> = t_eval
2026 .iter()
2027 .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
2028 .collect();
2029 let mut objective = 0.0;
2030 let mut sens_ptr = ptr::null_mut();
2031 assert_eq!(
2032 diffsol_ode_solve_sum_squares_adj(
2033 ode,
2034 no_params.as_ptr(),
2035 no_params.len(),
2036 adjoint_data.as_ptr(),
2037 1,
2038 t_eval.len(),
2039 1,
2040 1,
2041 t_eval.as_ptr(),
2042 t_eval.len(),
2043 &mut objective,
2044 &mut sens_ptr,
2045 ),
2046 DIFFSOL_ERR
2047 );
2048 }
2049
2050 assert_eq!(diffsol_ode_get_matrix_type(ptr::null()), -1);
2051 assert_eq!(diffsol_ode_get_ode_solver(ptr::null()), -1);
2052 assert_eq!(diffsol_ode_get_linear_solver(ptr::null()), -1);
2053 assert_eq!(
2054 diffsol_ode_set_ode_solver(ptr::null_mut(), 0),
2055 DIFFSOL_BAD_ARG
2056 );
2057 assert_eq!(
2058 diffsol_ode_set_linear_solver(ptr::null_mut(), 0),
2059 DIFFSOL_BAD_ARG
2060 );
2061 assert_eq!(diffsol_ode_set_ode_solver(ode, 99), DIFFSOL_BAD_ARG);
2062 assert_eq!(diffsol_ode_set_linear_solver(ode, 99), DIFFSOL_BAD_ARG);
2063 assert_eq!(
2064 diffsol_ode_get_rtol(ptr::null(), &mut out_value),
2065 DIFFSOL_BAD_ARG
2066 );
2067 assert_eq!(diffsol_ode_get_rtol(ode, ptr::null_mut()), DIFFSOL_BAD_ARG);
2068 assert_eq!(diffsol_ode_set_rtol(ptr::null_mut(), 1e-3), DIFFSOL_BAD_ARG);
2069 assert_eq!(
2070 diffsol_ode_get_atol(ptr::null(), &mut out_value),
2071 DIFFSOL_BAD_ARG
2072 );
2073 assert_eq!(diffsol_ode_get_atol(ode, ptr::null_mut()), DIFFSOL_BAD_ARG);
2074 assert_eq!(diffsol_ode_set_atol(ptr::null_mut(), 1e-3), DIFFSOL_BAD_ARG);
2075 assert_eq!(
2076 diffsol_ode_solve(ode, params.as_ptr(), params.len(), 1.0, ptr::null_mut()),
2077 DIFFSOL_BAD_ARG
2078 );
2079 assert_eq!(
2080 diffsol_ode_solve_hybrid(
2081 ode,
2082 params.as_ptr(),
2083 params.len(),
2084 1.0,
2085 ptr::null_mut(),
2086 ),
2087 DIFFSOL_BAD_ARG
2088 );
2089 assert_eq!(
2090 diffsol_ode_solve_dense(
2091 ode,
2092 params.as_ptr(),
2093 params.len(),
2094 t_eval.as_ptr(),
2095 t_eval.len(),
2096 ptr::null_mut(),
2097 ),
2098 DIFFSOL_BAD_ARG
2099 );
2100 assert_eq!(
2101 diffsol_ode_solve_hybrid_dense(
2102 ode,
2103 params.as_ptr(),
2104 params.len(),
2105 t_eval.as_ptr(),
2106 t_eval.len(),
2107 ptr::null_mut(),
2108 ),
2109 DIFFSOL_BAD_ARG
2110 );
2111 #[cfg(feature = "diffsl-llvm")]
2112 if matches!(jit_backend, JitBackendType::Llvm) {
2113 assert_eq!(
2114 diffsol_ode_solve_fwd_sens(
2115 ode,
2116 params.as_ptr(),
2117 params.len(),
2118 t_eval.as_ptr(),
2119 t_eval.len(),
2120 ptr::null_mut(),
2121 ),
2122 DIFFSOL_BAD_ARG
2123 );
2124 assert_eq!(
2125 diffsol_ode_solve_hybrid_fwd_sens(
2126 ode,
2127 params.as_ptr(),
2128 params.len(),
2129 t_eval.as_ptr(),
2130 t_eval.len(),
2131 ptr::null_mut(),
2132 ),
2133 DIFFSOL_BAD_ARG
2134 );
2135 let mut objective = 0.0;
2136 let mut sens_ptr = ptr::null_mut();
2137 assert_eq!(
2138 diffsol_ode_solve_sum_squares_adj(
2139 ode,
2140 params.as_ptr(),
2141 params.len(),
2142 t_eval.as_ptr(),
2143 1,
2144 t_eval.len(),
2145 1,
2146 1,
2147 t_eval.as_ptr(),
2148 t_eval.len(),
2149 ptr::null_mut(),
2150 &mut sens_ptr,
2151 ),
2152 DIFFSOL_BAD_ARG
2153 );
2154 assert_eq!(
2155 diffsol_ode_solve_sum_squares_adj(
2156 ode,
2157 params.as_ptr(),
2158 params.len(),
2159 t_eval.as_ptr(),
2160 1,
2161 t_eval.len(),
2162 1,
2163 1,
2164 t_eval.as_ptr(),
2165 t_eval.len(),
2166 &mut objective,
2167 ptr::null_mut(),
2168 ),
2169 DIFFSOL_BAD_ARG
2170 );
2171 }
2172
2173 diffsol_ode_free(ode);
2174 }
2175 }
2176 }
2177
2178 #[test]
2179 fn c_api_hybrid_jit_solver_paths_match_expected_values() {
2180 for jit_backend in available_jit_backends() {
2181 unsafe {
2182 let code = CString::new(hybrid_logistic_diffsl_code()).unwrap();
2183 let ode = make_ode_ptr_with_code(
2184 jit_backend,
2185 code.as_ptr(),
2186 matrix_type_to_i32(MatrixType::NalgebraDense),
2187 linear_solver_to_i32(LinearSolverType::Default),
2188 ode_solver_to_i32(OdeSolverType::Bdf),
2189 );
2190 assert!(!ode.is_null());
2191
2192 let params = [2.0f64];
2193 let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
2194 assert_eq!(
2195 diffsol_ode_solve_hybrid(
2196 ode,
2197 params.as_ptr(),
2198 params.len(),
2199 2.0,
2200 &mut solution_ptr
2201 ),
2202 DIFFSOL_OK
2203 );
2204 let mut ys_ptr = ptr::null_mut();
2205 let mut ts_ptr = ptr::null_mut();
2206 assert_eq!(
2207 diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
2208 DIFFSOL_OK
2209 );
2210 assert_eq!(
2211 diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
2212 DIFFSOL_OK
2213 );
2214 let (_rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
2215 let ts = ffi_read_host_array_vector(ts_ptr);
2216 assert!(cols >= 1);
2217 assert_close(*ts.last().unwrap(), 2.0, 5e-4, "jit hybrid solve time");
2218 assert_close(
2219 *ys.last().unwrap(),
2220 hybrid_logistic_state(2.0, 2.0),
2221 5e-4,
2222 "jit hybrid solve value",
2223 );
2224 ffi_free_solution(solution_ptr);
2225
2226 #[cfg(feature = "diffsl-llvm")]
2227 if matches!(jit_backend, JitBackendType::Llvm) {
2228 let t_eval = [0.25f64, 0.5f64, 1.0f64];
2229 let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
2230 assert_eq!(
2231 diffsol_ode_solve_hybrid_fwd_sens(
2232 ode,
2233 params.as_ptr(),
2234 params.len(),
2235 t_eval.as_ptr(),
2236 t_eval.len(),
2237 &mut sens_solution_ptr,
2238 ),
2239 DIFFSOL_OK
2240 );
2241 let mut sens_list = ptr::null_mut();
2242 let mut sens_len = 0usize;
2243 assert_eq!(
2244 diffsol_solution_wrapper_get_sens(
2245 sens_solution_ptr,
2246 &mut sens_list,
2247 &mut sens_len
2248 ),
2249 DIFFSOL_OK
2250 );
2251 let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
2252 for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate()
2253 {
2254 assert_close(
2255 value,
2256 hybrid_logistic_state_dr(2.0, t),
2257 5e-4,
2258 &format!("jit hybrid sensitivity[{i}]"),
2259 );
2260 }
2261 ffi_free_solution(sens_solution_ptr);
2262 }
2263
2264 diffsol_ode_free(ode);
2265 }
2266 }
2267 }
2268}