Skip to main content

diffsol_c/
ode_c.rs

1#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
2use std::ffi::CStr;
3#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
4use std::os::raw::c_char;
5use std::ptr;
6
7use crate::c_api_utils::{valid_f64_ptr, DIFFSOL_BAD_ARG, DIFFSOL_ERR, DIFFSOL_OK};
8use crate::host_array::HostArray;
9#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
10use crate::jit_c::jit_backend_from_i32;
11use crate::linear_solver_type_c::{linear_solver_from_i32, linear_solver_to_i32};
12use crate::matrix_type_c::{matrix_type_from_i32, matrix_type_to_i32};
13use crate::ode::OdeWrapper;
14use crate::ode_solver_type_c::{ode_solver_from_i32, ode_solver_to_i32};
15use crate::scalar_type::ScalarType;
16use crate::solution_wrapper::SolutionWrapper;
17use crate::{c_error, c_invalid_arg};
18
19fn boxed_host_array(array: HostArray) -> *mut HostArray {
20    Box::into_raw(Box::new(array))
21}
22
23fn parse_ode_new_common_args(
24    matrix_type: i32,
25    linear_solver: i32,
26    ode_solver: i32,
27) -> Option<(
28    crate::matrix_type::MatrixType,
29    crate::linear_solver_type::LinearSolverType,
30    crate::ode_solver_type::OdeSolverType,
31)> {
32    let matrix_type = match matrix_type_from_i32(matrix_type) {
33        Some(value) => value,
34        None => {
35            c_invalid_arg!("invalid matrix_type");
36            return None;
37        }
38    };
39    let linear_solver = match linear_solver_from_i32(linear_solver) {
40        Some(value) => value,
41        None => {
42            c_invalid_arg!("invalid linear_solver");
43            return None;
44        }
45    };
46    let ode_solver = match ode_solver_from_i32(ode_solver) {
47        Some(value) => value,
48        None => {
49            c_invalid_arg!("invalid ode_solver");
50            return None;
51        }
52    };
53    Some((matrix_type, linear_solver, ode_solver))
54}
55
56#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
57fn parse_ode_new_jit_args(
58    code: *const c_char,
59    matrix_type: i32,
60    linear_solver: i32,
61    ode_solver: i32,
62) -> Option<(
63    String,
64    crate::matrix_type::MatrixType,
65    crate::linear_solver_type::LinearSolverType,
66    crate::ode_solver_type::OdeSolverType,
67)> {
68    if code.is_null() {
69        c_invalid_arg!("code is null");
70        return None;
71    }
72    let code = unsafe { CStr::from_ptr(code) };
73    let code = match code.to_str() {
74        Ok(value) => value.to_owned(),
75        Err(_) => {
76            c_error!("code is not valid UTF-8");
77            return None;
78        }
79    };
80    let (matrix_type, linear_solver, ode_solver) =
81        parse_ode_new_common_args(matrix_type, linear_solver, ode_solver)?;
82    Some((code, matrix_type, linear_solver, ode_solver))
83}
84
85/// Free a list of host arrays previously returned by this library.
86///
87/// # Safety
88/// `list` must be either null or a pointer returned by this library for a list
89/// of length `len`. Each pointed-to element remains owned separately.
90#[unsafe(no_mangle)]
91pub unsafe extern "C" fn diffsol_host_array_list_free(list: *mut *mut HostArray, len: usize) {
92    if list.is_null() {
93        c_invalid_arg!("host array list is null");
94        return;
95    }
96    unsafe {
97        drop(Box::from_raw(std::ptr::slice_from_raw_parts_mut(list, len)));
98    }
99}
100
101#[cfg(feature = "external")]
102/// Construct an external-backed ODE wrapper.
103///
104/// # Safety
105/// Dependency pointers must be either null with length `0` or point to valid
106/// memory containing `(usize, usize)` pairs for the specified lengths for the
107/// duration of this call.
108#[unsafe(no_mangle)]
109pub unsafe extern "C" fn diffsol_ode_new_external(
110    matrix_type: i32,
111    linear_solver: i32,
112    ode_solver: i32,
113    rhs_state_deps_ptr: *const usize,
114    rhs_state_deps_len: usize,
115    rhs_input_deps_ptr: *const usize,
116    rhs_input_deps_len: usize,
117    mass_state_deps_ptr: *const usize,
118    mass_state_deps_len: usize,
119) -> *mut OdeWrapper {
120    let Some((matrix_type, linear_solver, ode_solver)) =
121        parse_ode_new_common_args(matrix_type, linear_solver, ode_solver)
122    else {
123        return ptr::null_mut();
124    };
125
126    let rhs_state_deps = if !rhs_state_deps_ptr.is_null() && rhs_state_deps_len > 0 {
127        unsafe {
128            let slice = std::slice::from_raw_parts(
129                rhs_state_deps_ptr as *const (usize, usize),
130                rhs_state_deps_len,
131            );
132            slice.to_vec()
133        }
134    } else {
135        Vec::new()
136    };
137
138    let rhs_input_deps = if !rhs_input_deps_ptr.is_null() && rhs_input_deps_len > 0 {
139        unsafe {
140            let slice = std::slice::from_raw_parts(
141                rhs_input_deps_ptr as *const (usize, usize),
142                rhs_input_deps_len,
143            );
144            slice.to_vec()
145        }
146    } else {
147        Vec::new()
148    };
149
150    let mass_state_deps = if !mass_state_deps_ptr.is_null() && mass_state_deps_len > 0 {
151        unsafe {
152            let slice = std::slice::from_raw_parts(
153                mass_state_deps_ptr as *const (usize, usize),
154                mass_state_deps_len,
155            );
156            slice.to_vec()
157        }
158    } else {
159        Vec::new()
160    };
161
162    let scalar_type = ScalarType::F64;
163    match OdeWrapper::new_external(
164        rhs_state_deps,
165        rhs_input_deps,
166        mass_state_deps,
167        scalar_type,
168        matrix_type,
169        linear_solver,
170        ode_solver,
171    ) {
172        Ok(ode) => Box::into_raw(Box::new(ode)),
173        Err(err) => {
174            c_error!(&format!("{}", err));
175            ptr::null_mut()
176        }
177    }
178}
179
180#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
181/// Construct a JIT-backed ODE wrapper from DiffSL source code.
182///
183/// # Safety
184/// `code` must be a valid, null-terminated UTF-8 string for the duration of
185/// this call. The backend and solver enum values must be valid values defined by
186/// this library.
187#[unsafe(no_mangle)]
188pub unsafe extern "C" fn diffsol_ode_new_jit(
189    code: *const c_char,
190    jit_backend: i32,
191    matrix_type: i32,
192    linear_solver: i32,
193    ode_solver: i32,
194) -> *mut OdeWrapper {
195    let Some((code, matrix_type, linear_solver, ode_solver)) =
196        parse_ode_new_jit_args(code, matrix_type, linear_solver, ode_solver)
197    else {
198        return ptr::null_mut();
199    };
200    let jit_backend = match jit_backend_from_i32(jit_backend) {
201        Some(value) => value,
202        None => {
203            c_invalid_arg!("invalid jit_backend_type");
204            return ptr::null_mut();
205        }
206    };
207    let scalar_type = ScalarType::F64;
208    match OdeWrapper::new_jit(
209        &code,
210        jit_backend,
211        scalar_type,
212        matrix_type,
213        linear_solver,
214        ode_solver,
215    ) {
216        Ok(ode) => Box::into_raw(Box::new(ode)),
217        Err(err) => {
218            c_error!(&format!("{}", err));
219            ptr::null_mut()
220        }
221    }
222}
223
224/// Free an ODE wrapper previously returned by this library.
225///
226/// # Safety
227/// `ode` must be either null or a pointer returned by this library that has not
228/// already been freed.
229#[unsafe(no_mangle)]
230pub unsafe extern "C" fn diffsol_ode_free(ode: *mut OdeWrapper) {
231    if ode.is_null() {
232        c_invalid_arg!("ode is null");
233        return;
234    }
235    unsafe {
236        drop(Box::from_raw(ode));
237    }
238}
239
240/// Return a handle to the initial-condition solver options for an ODE.
241///
242/// # Safety
243/// `ode` must be a valid pointer created by this library. `out_options` must be
244/// a valid, writable pointer to receive ownership of the returned options
245/// object.
246#[unsafe(no_mangle)]
247pub unsafe extern "C" fn diffsol_ode_get_ic_options(
248    ode: *const OdeWrapper,
249    out_options: *mut *mut crate::initial_condition_options::InitialConditionSolverOptions,
250) -> i32 {
251    if ode.is_null() || out_options.is_null() {
252        return c_invalid_arg!("invalid arguments to diffsol_ode_get_ic_options");
253    }
254    let ode = unsafe { &*ode };
255    let options = ode.get_ic_options();
256    let boxed = Box::new(options);
257    unsafe {
258        *out_options = Box::into_raw(boxed);
259    }
260    DIFFSOL_OK
261}
262
263/// Return a handle to the ODE solver options for an ODE.
264///
265/// # Safety
266/// `ode` must be a valid pointer created by this library. `out_options` must be
267/// a valid, writable pointer to receive ownership of the returned options
268/// object.
269#[unsafe(no_mangle)]
270pub unsafe extern "C" fn diffsol_ode_get_options(
271    ode: *const OdeWrapper,
272    out_options: *mut *mut crate::ode_options::OdeSolverOptions,
273) -> i32 {
274    if ode.is_null() || out_options.is_null() {
275        return c_invalid_arg!("invalid arguments to diffsol_ode_get_options");
276    }
277    let ode = unsafe { &*ode };
278    let options = ode.get_options();
279    let boxed = Box::new(options);
280    unsafe {
281        *out_options = Box::into_raw(boxed);
282    }
283    DIFFSOL_OK
284}
285
286/// Evaluate the initial condition vector for an ODE.
287///
288/// # Safety
289/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
290/// must be either null with `params_len == 0` or point to `params_len`
291/// readable `f64` values. `out_array` must be a valid, writable pointer.
292#[unsafe(no_mangle)]
293pub unsafe extern "C" fn diffsol_ode_y0(
294    ode: *mut OdeWrapper,
295    params_ptr: *const f64,
296    params_len: usize,
297    out_array: *mut *mut HostArray,
298) -> i32 {
299    if ode.is_null() || out_array.is_null() || !valid_f64_ptr(params_ptr, params_len) {
300        c_invalid_arg!("invalid arguments to diffsol_ode_y0");
301        return DIFFSOL_BAD_ARG;
302    }
303    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
304    let ode = unsafe { &mut *ode };
305    match ode.y0(params) {
306        Ok(array) => {
307            let boxed = boxed_host_array(array);
308            unsafe {
309                *out_array = boxed;
310            }
311            DIFFSOL_OK
312        }
313        Err(err) => {
314            c_error!(&format!("{}", err));
315            DIFFSOL_ERR
316        }
317    }
318}
319
320/// Evaluate the ODE right-hand side at a given time and state.
321///
322/// # Safety
323/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
324/// and `y_ptr` must point to readable `f64` buffers of the specified lengths,
325/// unless the corresponding length is zero. `out_array` must be writable.
326#[unsafe(no_mangle)]
327pub unsafe extern "C" fn diffsol_ode_rhs(
328    ode: *mut OdeWrapper,
329    params_ptr: *const f64,
330    params_len: usize,
331    t: f64,
332    y_ptr: *const f64,
333    y_len: usize,
334    out_array: *mut *mut HostArray,
335) -> i32 {
336    if ode.is_null()
337        || out_array.is_null()
338        || !valid_f64_ptr(params_ptr, params_len)
339        || !valid_f64_ptr(y_ptr, y_len)
340    {
341        c_invalid_arg!("invalid arguments to diffsol_ode_rhs");
342        return DIFFSOL_BAD_ARG;
343    }
344    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
345    let y = HostArray::new_vector(y_ptr as *mut u8, y_len, ScalarType::F64);
346    let ode = unsafe { &mut *ode };
347    match ode.rhs(params, t, y) {
348        Ok(array) => {
349            let boxed = boxed_host_array(array);
350            unsafe {
351                *out_array = boxed;
352            }
353            DIFFSOL_OK
354        }
355        Err(err) => {
356            c_error!(&format!("{}", err));
357            DIFFSOL_ERR
358        }
359    }
360}
361
362/// Evaluate the ODE Jacobian-vector product at a given time and state.
363///
364/// # Safety
365/// `ode` must be a valid mutable pointer created by this library. `params_ptr`,
366/// `y_ptr`, and `v_ptr` must point to readable `f64` buffers of the specified
367/// lengths, unless the corresponding length is zero. `out_array` must be
368/// writable.
369#[unsafe(no_mangle)]
370pub unsafe extern "C" fn diffsol_ode_rhs_jac_mul(
371    ode: *mut OdeWrapper,
372    params_ptr: *const f64,
373    params_len: usize,
374    t: f64,
375    y_ptr: *const f64,
376    y_len: usize,
377    v_ptr: *const f64,
378    v_len: usize,
379    out_array: *mut *mut HostArray,
380) -> i32 {
381    if ode.is_null()
382        || out_array.is_null()
383        || !valid_f64_ptr(params_ptr, params_len)
384        || !valid_f64_ptr(y_ptr, y_len)
385        || !valid_f64_ptr(v_ptr, v_len)
386    {
387        c_invalid_arg!("invalid arguments to diffsol_ode_rhs_jac_mul");
388        return DIFFSOL_BAD_ARG;
389    }
390    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
391    let y = HostArray::new_vector(y_ptr as *mut u8, y_len, ScalarType::F64);
392    let v = HostArray::new_vector(v_ptr as *mut u8, v_len, ScalarType::F64);
393    let ode = unsafe { &mut *ode };
394    match ode.rhs_jac_mul(params, t, y, v) {
395        Ok(array) => {
396            let boxed = boxed_host_array(array);
397            unsafe {
398                *out_array = boxed;
399            }
400            DIFFSOL_OK
401        }
402        Err(err) => {
403            c_error!(&format!("{}", err));
404            DIFFSOL_ERR
405        }
406    }
407}
408
409/// Solve an ODE up to a final time.
410///
411/// # Safety
412/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
413/// must point to `params_len` readable `f64` values unless `params_len == 0`.
414/// `out_solution` must be a valid, writable pointer.
415#[unsafe(no_mangle)]
416pub unsafe extern "C" fn diffsol_ode_solve(
417    ode: *mut OdeWrapper,
418    params_ptr: *const f64,
419    params_len: usize,
420    final_time: f64,
421    out_solution: *mut *mut SolutionWrapper,
422) -> i32 {
423    if ode.is_null() || out_solution.is_null() || !valid_f64_ptr(params_ptr, params_len) {
424        c_invalid_arg!("invalid arguments to diffsol_ode_solve");
425        return DIFFSOL_BAD_ARG;
426    }
427    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
428    let ode = unsafe { &mut *ode };
429    match ode.solve(params, final_time) {
430        Ok(new_solution) => {
431            unsafe {
432                *out_solution = Box::into_raw(Box::new(new_solution));
433            }
434            DIFFSOL_OK
435        }
436        Err(err) => {
437            c_error!(&format!("{}", err));
438            DIFFSOL_ERR
439        }
440    }
441}
442
443/// Solve a hybrid ODE up to a final time, automatically applying resets after roots.
444///
445/// # Safety
446/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
447/// must point to `params_len` readable `f64` values unless `params_len == 0`.
448/// `out_solution` must be a valid, writable pointer.
449#[unsafe(no_mangle)]
450pub unsafe extern "C" fn diffsol_ode_solve_hybrid(
451    ode: *mut OdeWrapper,
452    params_ptr: *const f64,
453    params_len: usize,
454    final_time: f64,
455    out_solution: *mut *mut SolutionWrapper,
456) -> i32 {
457    if ode.is_null() || out_solution.is_null() || !valid_f64_ptr(params_ptr, params_len) {
458        c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid");
459        return DIFFSOL_BAD_ARG;
460    }
461    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
462    let ode = unsafe { &mut *ode };
463    match ode.solve_hybrid(params, final_time) {
464        Ok(new_solution) => {
465            unsafe {
466                *out_solution = Box::into_raw(Box::new(new_solution));
467            }
468            DIFFSOL_OK
469        }
470        Err(err) => {
471            c_error!(&format!("{}", err));
472            DIFFSOL_ERR
473        }
474    }
475}
476
477/// Solve an ODE and sample the solution at requested times.
478///
479/// # Safety
480/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
481/// and `t_eval_ptr` must point to readable `f64` buffers of the specified
482/// lengths, unless the corresponding length is zero. `out_solution` must be writable.
483#[unsafe(no_mangle)]
484pub unsafe extern "C" fn diffsol_ode_solve_dense(
485    ode: *mut OdeWrapper,
486    params_ptr: *const f64,
487    params_len: usize,
488    t_eval_ptr: *const f64,
489    t_eval_len: usize,
490    out_solution: *mut *mut SolutionWrapper,
491) -> i32 {
492    if ode.is_null()
493        || out_solution.is_null()
494        || !valid_f64_ptr(params_ptr, params_len)
495        || !valid_f64_ptr(t_eval_ptr, t_eval_len)
496    {
497        c_invalid_arg!("invalid arguments to diffsol_ode_solve_dense");
498        return DIFFSOL_BAD_ARG;
499    }
500    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
501    let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
502    let ode = unsafe { &mut *ode };
503    match ode.solve_dense(params, t_eval) {
504        Ok(new_solution) => {
505            unsafe {
506                *out_solution = Box::into_raw(Box::new(new_solution));
507            }
508            DIFFSOL_OK
509        }
510        Err(err) => {
511            c_error!(&format!("{}", err));
512            DIFFSOL_ERR
513        }
514    }
515}
516
517/// Solve a hybrid ODE and sample the solution at requested times.
518///
519/// # Safety
520/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
521/// and `t_eval_ptr` must point to readable `f64` buffers of the specified
522/// lengths, unless the corresponding length is zero. `out_solution` must be writable.
523#[unsafe(no_mangle)]
524pub unsafe extern "C" fn diffsol_ode_solve_hybrid_dense(
525    ode: *mut OdeWrapper,
526    params_ptr: *const f64,
527    params_len: usize,
528    t_eval_ptr: *const f64,
529    t_eval_len: usize,
530    out_solution: *mut *mut SolutionWrapper,
531) -> i32 {
532    if ode.is_null()
533        || out_solution.is_null()
534        || !valid_f64_ptr(params_ptr, params_len)
535        || !valid_f64_ptr(t_eval_ptr, t_eval_len)
536    {
537        c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid_dense");
538        return DIFFSOL_BAD_ARG;
539    }
540    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
541    let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
542    let ode = unsafe { &mut *ode };
543    match ode.solve_hybrid_dense(params, t_eval) {
544        Ok(new_solution) => {
545            unsafe {
546                *out_solution = Box::into_raw(Box::new(new_solution));
547            }
548            DIFFSOL_OK
549        }
550        Err(err) => {
551            c_error!(&format!("{}", err));
552            DIFFSOL_ERR
553        }
554    }
555}
556
557/// Solve an ODE and sample forward sensitivities at requested times.
558///
559/// # Safety
560/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
561/// and `t_eval_ptr` must point to readable `f64` buffers of the specified
562/// lengths, unless the corresponding length is zero. `out_solution` must be writable.
563#[unsafe(no_mangle)]
564pub unsafe extern "C" fn diffsol_ode_solve_fwd_sens(
565    ode: *mut OdeWrapper,
566    params_ptr: *const f64,
567    params_len: usize,
568    t_eval_ptr: *const f64,
569    t_eval_len: usize,
570    out_solution: *mut *mut SolutionWrapper,
571) -> i32 {
572    if ode.is_null()
573        || out_solution.is_null()
574        || !valid_f64_ptr(params_ptr, params_len)
575        || !valid_f64_ptr(t_eval_ptr, t_eval_len)
576    {
577        c_invalid_arg!("invalid arguments to diffsol_ode_solve_fwd_sens");
578        return DIFFSOL_BAD_ARG;
579    }
580    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
581    let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
582    let ode = unsafe { &mut *ode };
583    match ode.solve_fwd_sens(params, t_eval) {
584        Ok(new_solution) => {
585            unsafe {
586                *out_solution = Box::into_raw(Box::new(new_solution));
587            }
588            DIFFSOL_OK
589        }
590        Err(err) => {
591            c_error!(&format!("{}", err));
592            DIFFSOL_ERR
593        }
594    }
595}
596
597/// Solve a hybrid ODE with forward sensitivities at requested times.
598///
599/// # Safety
600/// `ode` must be a valid mutable pointer created by this library. `params_ptr`
601/// and `t_eval_ptr` must point to readable `f64` buffers of the specified
602/// lengths, unless the corresponding length is zero. `out_solution` must be writable.
603#[unsafe(no_mangle)]
604pub unsafe extern "C" fn diffsol_ode_solve_hybrid_fwd_sens(
605    ode: *mut OdeWrapper,
606    params_ptr: *const f64,
607    params_len: usize,
608    t_eval_ptr: *const f64,
609    t_eval_len: usize,
610    out_solution: *mut *mut SolutionWrapper,
611) -> i32 {
612    if ode.is_null()
613        || out_solution.is_null()
614        || !valid_f64_ptr(params_ptr, params_len)
615        || !valid_f64_ptr(t_eval_ptr, t_eval_len)
616    {
617        c_invalid_arg!("invalid arguments to diffsol_ode_solve_hybrid_fwd_sens");
618        return DIFFSOL_BAD_ARG;
619    }
620    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
621    let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
622    let ode = unsafe { &mut *ode };
623    match ode.solve_hybrid_fwd_sens(params, t_eval) {
624        Ok(new_solution) => {
625            unsafe {
626                *out_solution = Box::into_raw(Box::new(new_solution));
627            }
628            DIFFSOL_OK
629        }
630        Err(err) => {
631            c_error!(&format!("{}", err));
632            DIFFSOL_ERR
633        }
634    }
635}
636
637/// Solve the sum-of-squares adjoint problem for an ODE.
638///
639/// # Safety
640/// `ode` must be a valid mutable pointer created by this library. `params_ptr`,
641/// `data_ptr`, and `t_eval_ptr` must point to readable buffers matching the
642/// provided dimensions. `out_value` and `out_sens` must be valid, writable
643/// pointers.
644#[unsafe(no_mangle)]
645pub unsafe extern "C" fn diffsol_ode_solve_sum_squares_adj(
646    ode: *mut OdeWrapper,
647    params_ptr: *const f64,
648    params_len: usize,
649    data_ptr: *const f64,
650    data_rows: usize,
651    data_cols: usize,
652    data_row_stride: usize,
653    data_col_stride: usize,
654    t_eval_ptr: *const f64,
655    t_eval_len: usize,
656    out_value: *mut f64,
657    out_sens: *mut *mut HostArray,
658) -> i32 {
659    if ode.is_null()
660        || out_value.is_null()
661        || out_sens.is_null()
662        || data_ptr.is_null()
663        || !valid_f64_ptr(params_ptr, params_len)
664        || !valid_f64_ptr(t_eval_ptr, t_eval_len)
665    {
666        c_invalid_arg!("invalid arguments to diffsol_ode_solve_sum_squares_adj");
667        return DIFFSOL_BAD_ARG;
668    }
669    let params = HostArray::new_vector(params_ptr as *mut u8, params_len, ScalarType::F64);
670    let t_eval = HostArray::new_vector(t_eval_ptr as *mut u8, t_eval_len, ScalarType::F64);
671    let data = HostArray::new_col_major(
672        data_ptr as *mut u8,
673        data_rows,
674        data_cols,
675        data_row_stride as isize,
676        data_col_stride as isize,
677        ScalarType::F64,
678    );
679    let ode = unsafe { &mut *ode };
680    match ode.solve_sum_squares_adj(params, data, t_eval) {
681        Ok((value, sens)) => {
682            let sens_boxed = boxed_host_array(sens);
683            unsafe {
684                *out_value = value;
685                *out_sens = sens_boxed;
686            }
687            DIFFSOL_OK
688        }
689        Err(err) => {
690            c_error!(&format!("{}", err));
691            DIFFSOL_ERR
692        }
693    }
694}
695
696/// Return the matrix type configured for an ODE.
697///
698/// # Safety
699/// `ode` must be a valid pointer created by this library.
700#[unsafe(no_mangle)]
701pub unsafe extern "C" fn diffsol_ode_get_matrix_type(ode: *const OdeWrapper) -> i32 {
702    if ode.is_null() {
703        c_invalid_arg!("ode is null");
704        return -1;
705    }
706    let ode = unsafe { &*ode };
707    match ode.get_matrix_type() {
708        Ok(value) => matrix_type_to_i32(value),
709        Err(err) => {
710            c_error!(&format!("{}", err));
711            -1
712        }
713    }
714}
715
716/// Return the ODE solver enum configured for an ODE.
717///
718/// # Safety
719/// `ode` must be a valid pointer created by this library.
720#[unsafe(no_mangle)]
721pub unsafe extern "C" fn diffsol_ode_get_ode_solver(ode: *const OdeWrapper) -> i32 {
722    if ode.is_null() {
723        c_invalid_arg!("ode is null");
724        return -1;
725    }
726    let ode = unsafe { &*ode };
727    match ode.get_ode_solver() {
728        Ok(value) => ode_solver_to_i32(value),
729        Err(err) => {
730            c_error!(&format!("{}", err));
731            -1
732        }
733    }
734}
735
736/// Set the ODE solver enum for an ODE.
737///
738/// # Safety
739/// `ode` must be a valid mutable pointer created by this library.
740#[unsafe(no_mangle)]
741pub unsafe extern "C" fn diffsol_ode_set_ode_solver(ode: *mut OdeWrapper, value: i32) -> i32 {
742    if ode.is_null() {
743        c_invalid_arg!("ode is null");
744        return DIFFSOL_BAD_ARG;
745    }
746    let value = match ode_solver_from_i32(value) {
747        Some(v) => v,
748        None => {
749            c_invalid_arg!("invalid ode_solver");
750            return DIFFSOL_BAD_ARG;
751        }
752    };
753    let ode = unsafe { &mut *ode };
754    match ode.set_ode_solver(value) {
755        Ok(()) => DIFFSOL_OK,
756        Err(err) => c_error!(&format!("{}", err)),
757    }
758}
759
760/// Return the linear solver enum configured for an ODE.
761///
762/// # Safety
763/// `ode` must be a valid pointer created by this library.
764#[unsafe(no_mangle)]
765pub unsafe extern "C" fn diffsol_ode_get_linear_solver(ode: *const OdeWrapper) -> i32 {
766    if ode.is_null() {
767        c_invalid_arg!("ode is null");
768        return -1;
769    }
770    let ode = unsafe { &*ode };
771    match ode.get_linear_solver() {
772        Ok(value) => linear_solver_to_i32(value),
773        Err(err) => {
774            c_error!(&format!("{}", err));
775            -1
776        }
777    }
778}
779
780/// Set the linear solver enum for an ODE.
781///
782/// # Safety
783/// `ode` must be a valid mutable pointer created by this library.
784#[unsafe(no_mangle)]
785pub unsafe extern "C" fn diffsol_ode_set_linear_solver(ode: *mut OdeWrapper, value: i32) -> i32 {
786    if ode.is_null() {
787        c_invalid_arg!("ode is null");
788        return DIFFSOL_BAD_ARG;
789    }
790    let value = match linear_solver_from_i32(value) {
791        Some(v) => v,
792        None => {
793            c_invalid_arg!("invalid linear_solver");
794            return DIFFSOL_BAD_ARG;
795        }
796    };
797    let ode = unsafe { &mut *ode };
798    match ode.set_linear_solver(value) {
799        Ok(()) => DIFFSOL_OK,
800        Err(err) => c_error!(&format!("{}", err)),
801    }
802}
803
804/// Return the relative tolerance configured for an ODE.
805///
806/// # Safety
807/// `ode` must be a valid pointer created by this library. `out_value` must be a
808/// valid, writable pointer.
809#[unsafe(no_mangle)]
810pub unsafe extern "C" fn diffsol_ode_get_rtol(ode: *const OdeWrapper, out_value: *mut f64) -> i32 {
811    if ode.is_null() || out_value.is_null() {
812        c_invalid_arg!("invalid arguments to diffsol_ode_get_rtol");
813        return DIFFSOL_BAD_ARG;
814    }
815    let ode = unsafe { &*ode };
816    match ode.get_rtol() {
817        Ok(value) => {
818            unsafe {
819                *out_value = value;
820            }
821            DIFFSOL_OK
822        }
823        Err(err) => c_error!(&format!("{}", err)),
824    }
825}
826
827/// Set the relative tolerance for an ODE.
828///
829/// # Safety
830/// `ode` must be a valid mutable pointer created by this library.
831#[unsafe(no_mangle)]
832pub unsafe extern "C" fn diffsol_ode_set_rtol(ode: *mut OdeWrapper, value: f64) -> i32 {
833    if ode.is_null() {
834        c_invalid_arg!("ode is null");
835        return DIFFSOL_BAD_ARG;
836    }
837    let ode = unsafe { &mut *ode };
838    match ode.set_rtol(value) {
839        Ok(()) => DIFFSOL_OK,
840        Err(err) => c_error!(&format!("{}", err)),
841    }
842}
843
844/// Return the absolute tolerance configured for an ODE.
845///
846/// # Safety
847/// `ode` must be a valid pointer created by this library. `out_value` must be a
848/// valid, writable pointer.
849#[unsafe(no_mangle)]
850pub unsafe extern "C" fn diffsol_ode_get_atol(ode: *const OdeWrapper, out_value: *mut f64) -> i32 {
851    if ode.is_null() || out_value.is_null() {
852        c_invalid_arg!("invalid arguments to diffsol_ode_get_atol");
853        return DIFFSOL_BAD_ARG;
854    }
855    let ode = unsafe { &*ode };
856    match ode.get_atol() {
857        Ok(value) => {
858            unsafe {
859                *out_value = value;
860            }
861            DIFFSOL_OK
862        }
863        Err(err) => c_error!(&format!("{}", err)),
864    }
865}
866
867/// Set the absolute tolerance for an ODE.
868///
869/// # Safety
870/// `ode` must be a valid mutable pointer created by this library.
871#[unsafe(no_mangle)]
872pub unsafe extern "C" fn diffsol_ode_set_atol(ode: *mut OdeWrapper, value: f64) -> i32 {
873    if ode.is_null() {
874        c_invalid_arg!("ode is null");
875        return DIFFSOL_BAD_ARG;
876    }
877    let ode = unsafe { &mut *ode };
878    match ode.set_atol(value) {
879        Ok(()) => DIFFSOL_OK,
880        Err(err) => c_error!(&format!("{}", err)),
881    }
882}
883
884#[cfg(all(test, feature = "diffsl-external-f64"))]
885mod tests {
886    use std::ptr;
887
888    use crate::initial_condition_options::InitialConditionSolverOptions;
889    use crate::linear_solver_type::LinearSolverType;
890    use crate::linear_solver_type_c::{
891        diffsol_linear_solver_type_count, diffsol_linear_solver_type_is_valid,
892        diffsol_linear_solver_type_name, linear_solver_to_i32,
893    };
894    use crate::matrix_type::MatrixType;
895    use crate::ode_options::OdeSolverOptions;
896    use crate::ode_options_c::{
897        diffsol_ode_options_free, diffsol_ode_options_get_max_nonlinear_solver_iterations,
898        diffsol_ode_options_get_min_timestep,
899        diffsol_ode_options_set_max_nonlinear_solver_iterations,
900        diffsol_ode_options_set_min_timestep,
901    };
902    use crate::ode_solver_type::OdeSolverType;
903    use crate::ode_solver_type_c::{
904        diffsol_ode_solver_type_count, diffsol_ode_solver_type_is_valid,
905        diffsol_ode_solver_type_name, ode_solver_to_i32,
906    };
907    use crate::scalar_type::ScalarType;
908    use crate::scalar_type_c::{
909        diffsol_scalar_type_count, diffsol_scalar_type_is_valid, diffsol_scalar_type_name,
910        scalar_type_to_i32,
911    };
912    use crate::solution_wrapper_c::{
913        diffsol_solution_wrapper_get_sens, diffsol_solution_wrapper_get_ts,
914        diffsol_solution_wrapper_get_ys,
915    };
916    use crate::test_support::{
917        assert_close, assert_last_error_contains, c_string, clear_last_error, ffi_free_solution,
918        ffi_read_host_array_list_matrices, ffi_read_host_array_matrix, ffi_read_host_array_vector,
919        find_time_window, logistic_state, logistic_state_dr, mass_state_deps, rhs_input_deps,
920        rhs_state_deps, ASSERT_TOL, LOGISTIC_X0,
921    };
922    use crate::{
923        initial_condition_options_c::{
924            diffsol_ic_options_free, diffsol_ic_options_get_max_linesearch_iterations,
925            diffsol_ic_options_get_use_linesearch,
926            diffsol_ic_options_set_max_linesearch_iterations,
927            diffsol_ic_options_set_use_linesearch,
928        },
929        matrix_type_c::{
930            diffsol_matrix_type_count, diffsol_matrix_type_is_valid, diffsol_matrix_type_name,
931            matrix_type_to_i32,
932        },
933    };
934
935    use super::*;
936
937    unsafe fn make_ode_ptr(
938        matrix_type: i32,
939        linear_solver: i32,
940        ode_solver: i32,
941    ) -> *mut OdeWrapper {
942        let rhs_state_deps = rhs_state_deps();
943        let rhs_input_deps = rhs_input_deps();
944        let mass_state_deps = mass_state_deps();
945        unsafe {
946            diffsol_ode_new_external(
947                matrix_type,
948                linear_solver,
949                ode_solver,
950                rhs_state_deps.as_ptr() as *const usize,
951                rhs_state_deps.len(),
952                rhs_input_deps.as_ptr() as *const usize,
953                rhs_input_deps.len(),
954                mass_state_deps.as_ptr() as *const usize,
955                mass_state_deps.len(),
956            )
957        }
958    }
959
960    #[test]
961    fn c_api_reports_enum_metadata() {
962        clear_last_error();
963        unsafe {
964            assert_eq!(diffsol_matrix_type_count(), 3);
965            assert_eq!(diffsol_ode_solver_type_count(), 4);
966            assert_eq!(diffsol_linear_solver_type_count(), 3);
967            assert_eq!(diffsol_scalar_type_count(), 2);
968
969            assert_eq!(
970                c_string(diffsol_matrix_type_name(matrix_type_to_i32(
971                    MatrixType::NalgebraDense
972                ))),
973                "nalgebra_dense"
974            );
975            assert_eq!(
976                c_string(diffsol_ode_solver_type_name(ode_solver_to_i32(
977                    OdeSolverType::Bdf
978                ))),
979                "bdf"
980            );
981            assert_eq!(
982                c_string(diffsol_linear_solver_type_name(linear_solver_to_i32(
983                    LinearSolverType::Default
984                ))),
985                "default"
986            );
987            assert_eq!(
988                c_string(diffsol_scalar_type_name(scalar_type_to_i32(
989                    ScalarType::F64
990                ))),
991                "f64"
992            );
993        }
994    }
995
996    #[test]
997    fn c_api_invalid_enums_set_last_error() {
998        clear_last_error();
999        unsafe {
1000            assert_eq!(diffsol_matrix_type_is_valid(99), 0);
1001            assert_last_error_contains("invalid matrix_type");
1002            clear_last_error();
1003
1004            assert_eq!(diffsol_ode_solver_type_is_valid(99), 0);
1005            assert_last_error_contains("invalid ode_solver_type");
1006            clear_last_error();
1007
1008            assert_eq!(diffsol_linear_solver_type_is_valid(99), 0);
1009            assert_last_error_contains("invalid linear_solver_type");
1010            clear_last_error();
1011
1012            assert_eq!(diffsol_scalar_type_is_valid(99), 0);
1013            assert_last_error_contains("invalid scalar_type");
1014        }
1015    }
1016
1017    #[test]
1018    fn c_api_rejects_invalid_ode_arguments() {
1019        clear_last_error();
1020        unsafe {
1021            let mut out_array = ptr::null_mut();
1022            let status = diffsol_ode_y0(ptr::null_mut(), ptr::null(), 0, &mut out_array);
1023            assert_eq!(status, DIFFSOL_BAD_ARG);
1024            assert!(out_array.is_null());
1025            assert_last_error_contains("invalid arguments to diffsol_ode_y0");
1026            clear_last_error();
1027
1028            let ode = make_ode_ptr(
1029                99,
1030                linear_solver_to_i32(LinearSolverType::Default),
1031                ode_solver_to_i32(OdeSolverType::Bdf),
1032            );
1033            assert!(ode.is_null());
1034            assert_last_error_contains("invalid matrix_type");
1035        }
1036    }
1037
1038    #[test]
1039    fn c_api_full_lifecycle_matches_external_logistic_model() {
1040        clear_last_error();
1041        unsafe {
1042            let ode = make_ode_ptr(
1043                matrix_type_to_i32(MatrixType::NalgebraDense),
1044                linear_solver_to_i32(LinearSolverType::Default),
1045                ode_solver_to_i32(OdeSolverType::Bdf),
1046            );
1047            assert!(!ode.is_null());
1048
1049            assert_eq!(
1050                diffsol_ode_get_matrix_type(ode),
1051                matrix_type_to_i32(MatrixType::NalgebraDense)
1052            );
1053            assert_eq!(
1054                diffsol_ode_get_ode_solver(ode),
1055                ode_solver_to_i32(OdeSolverType::Bdf)
1056            );
1057            assert_eq!(
1058                diffsol_ode_get_linear_solver(ode),
1059                linear_solver_to_i32(LinearSolverType::Default)
1060            );
1061
1062            assert_eq!(
1063                diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1064                DIFFSOL_OK
1065            );
1066            assert_eq!(
1067                diffsol_ode_get_ode_solver(ode),
1068                ode_solver_to_i32(OdeSolverType::Tsit45)
1069            );
1070            assert_eq!(
1071                diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1072                DIFFSOL_OK
1073            );
1074
1075            assert_eq!(diffsol_ode_set_rtol(ode, 1e-8), DIFFSOL_OK);
1076            assert_eq!(diffsol_ode_set_atol(ode, 1e-8), DIFFSOL_OK);
1077            let mut rtol = 0.0;
1078            let mut atol = 0.0;
1079            assert_eq!(diffsol_ode_get_rtol(ode, &mut rtol), DIFFSOL_OK);
1080            assert_eq!(diffsol_ode_get_atol(ode, &mut atol), DIFFSOL_OK);
1081            assert_close(rtol, 1e-8, ASSERT_TOL, "rtol roundtrip");
1082            assert_close(atol, 1e-8, ASSERT_TOL, "atol roundtrip");
1083
1084            let mut ic_options: *mut InitialConditionSolverOptions = ptr::null_mut();
1085            assert_eq!(diffsol_ode_get_ic_options(ode, &mut ic_options), DIFFSOL_OK);
1086            assert!(!ic_options.is_null());
1087            let mut use_linesearch = 0;
1088            let mut max_linesearch_iterations = 0usize;
1089            assert_eq!(
1090                diffsol_ic_options_get_use_linesearch(ic_options, &mut use_linesearch),
1091                DIFFSOL_OK
1092            );
1093            assert_eq!(
1094                diffsol_ic_options_set_use_linesearch(ic_options, 1),
1095                DIFFSOL_OK
1096            );
1097            assert_eq!(
1098                diffsol_ic_options_get_use_linesearch(ic_options, &mut use_linesearch),
1099                DIFFSOL_OK
1100            );
1101            assert_eq!(use_linesearch, 1);
1102            assert_eq!(
1103                diffsol_ic_options_set_max_linesearch_iterations(ic_options, 23),
1104                DIFFSOL_OK
1105            );
1106            assert_eq!(
1107                diffsol_ic_options_get_max_linesearch_iterations(
1108                    ic_options,
1109                    &mut max_linesearch_iterations
1110                ),
1111                DIFFSOL_OK
1112            );
1113            assert_eq!(max_linesearch_iterations, 23);
1114            diffsol_ic_options_free(ic_options);
1115
1116            let mut ode_options: *mut OdeSolverOptions = ptr::null_mut();
1117            assert_eq!(diffsol_ode_get_options(ode, &mut ode_options), DIFFSOL_OK);
1118            assert!(!ode_options.is_null());
1119            let mut max_nonlinear_iterations = 0usize;
1120            let mut min_timestep = 0.0;
1121            assert_eq!(
1122                diffsol_ode_options_set_max_nonlinear_solver_iterations(ode_options, 17),
1123                DIFFSOL_OK
1124            );
1125            assert_eq!(
1126                diffsol_ode_options_get_max_nonlinear_solver_iterations(
1127                    ode_options,
1128                    &mut max_nonlinear_iterations
1129                ),
1130                DIFFSOL_OK
1131            );
1132            assert_eq!(max_nonlinear_iterations, 17);
1133            assert_eq!(
1134                diffsol_ode_options_set_min_timestep(ode_options, 1e-4),
1135                DIFFSOL_OK
1136            );
1137            assert_eq!(
1138                diffsol_ode_options_get_min_timestep(ode_options, &mut min_timestep),
1139                DIFFSOL_OK
1140            );
1141            assert_close(min_timestep, 1e-4, ASSERT_TOL, "min_timestep roundtrip");
1142            diffsol_ode_options_free(ode_options);
1143
1144            let params = [2.0f64];
1145            let y = [0.25f64];
1146            let v = [3.0f64];
1147
1148            let mut y0_ptr = ptr::null_mut();
1149            assert_eq!(
1150                diffsol_ode_y0(ode, params.as_ptr(), params.len(), &mut y0_ptr),
1151                DIFFSOL_OK
1152            );
1153            assert_eq!(ffi_read_host_array_vector(y0_ptr), vec![LOGISTIC_X0]);
1154
1155            let mut rhs_ptr = ptr::null_mut();
1156            assert_eq!(
1157                diffsol_ode_rhs(
1158                    ode,
1159                    params.as_ptr(),
1160                    params.len(),
1161                    0.0,
1162                    y.as_ptr(),
1163                    y.len(),
1164                    &mut rhs_ptr,
1165                ),
1166                DIFFSOL_OK
1167            );
1168            assert_close(
1169                ffi_read_host_array_vector(rhs_ptr)[0],
1170                0.375,
1171                ASSERT_TOL,
1172                "ffi rhs",
1173            );
1174
1175            let mut rhs_jac_mul_ptr = ptr::null_mut();
1176            assert_eq!(
1177                diffsol_ode_rhs_jac_mul(
1178                    ode,
1179                    params.as_ptr(),
1180                    params.len(),
1181                    0.0,
1182                    y.as_ptr(),
1183                    y.len(),
1184                    v.as_ptr(),
1185                    v.len(),
1186                    &mut rhs_jac_mul_ptr,
1187                ),
1188                DIFFSOL_OK
1189            );
1190            assert_close(
1191                ffi_read_host_array_vector(rhs_jac_mul_ptr)[0],
1192                3.0,
1193                ASSERT_TOL,
1194                "ffi rhs_jac_mul",
1195            );
1196
1197            let mut solve_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1198            assert_eq!(
1199                diffsol_ode_solve(
1200                    ode,
1201                    params.as_ptr(),
1202                    params.len(),
1203                    1e-9,
1204                    &mut solve_solution_ptr
1205                ),
1206                DIFFSOL_OK
1207            );
1208            assert!(!solve_solution_ptr.is_null());
1209
1210            let mut solve_ys_ptr = ptr::null_mut();
1211            let mut solve_ts_ptr = ptr::null_mut();
1212            assert_eq!(
1213                diffsol_solution_wrapper_get_ys(solve_solution_ptr, &mut solve_ys_ptr),
1214                DIFFSOL_OK
1215            );
1216            assert_eq!(
1217                diffsol_solution_wrapper_get_ts(solve_solution_ptr, &mut solve_ts_ptr),
1218                DIFFSOL_OK
1219            );
1220            let (solve_rows, solve_cols, solve_ys) = ffi_read_host_array_matrix(solve_ys_ptr);
1221            let solve_ts = ffi_read_host_array_vector(solve_ts_ptr);
1222            assert_eq!(solve_rows, 1);
1223            assert_eq!(solve_cols, solve_ts.len());
1224            assert!(!solve_ts.is_empty());
1225            assert_close(
1226                *solve_ts.last().unwrap(),
1227                1e-9,
1228                ASSERT_TOL,
1229                "ffi solve final time",
1230            );
1231            assert_close(
1232                *solve_ys.last().unwrap(),
1233                logistic_state(LOGISTIC_X0, 2.0, 1e-9),
1234                ASSERT_TOL,
1235                "ffi solve final value",
1236            );
1237            ffi_free_solution(solve_solution_ptr);
1238
1239            let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1240            assert_eq!(
1241                diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1242                DIFFSOL_OK
1243            );
1244
1245            let t_eval = [0.25f64, 0.5f64, 1.0f64];
1246            assert_eq!(
1247                diffsol_ode_solve_dense(
1248                    ode,
1249                    params.as_ptr(),
1250                    params.len(),
1251                    t_eval.as_ptr(),
1252                    t_eval.len(),
1253                    &mut solution_ptr,
1254                ),
1255                DIFFSOL_OK
1256            );
1257            let mut ys_ptr = ptr::null_mut();
1258            let mut ts_ptr = ptr::null_mut();
1259            assert_eq!(
1260                diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
1261                DIFFSOL_OK
1262            );
1263            assert_eq!(
1264                diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
1265                DIFFSOL_OK
1266            );
1267            let (rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
1268            let ts = ffi_read_host_array_vector(ts_ptr);
1269            assert_eq!(rows, 1);
1270            assert_eq!(cols, ts.len());
1271            let start = find_time_window(&ts, &t_eval, ASSERT_TOL);
1272            for (i, &t) in t_eval.iter().enumerate() {
1273                assert_close(ts[start + i], t, ASSERT_TOL, "ffi solution time");
1274                assert_close(
1275                    ys[start + i],
1276                    logistic_state(0.1, 2.0, t),
1277                    5e-4,
1278                    "ffi solution value",
1279                );
1280            }
1281            assert_eq!(
1282                diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1283                DIFFSOL_OK
1284            );
1285
1286            let hybrid_t_eval = [0.5f64, 1.0, 1.25, 1.5, 2.0];
1287            let hybrid_ode = make_ode_ptr(
1288                matrix_type_to_i32(MatrixType::NalgebraDense),
1289                linear_solver_to_i32(LinearSolverType::Default),
1290                ode_solver_to_i32(OdeSolverType::Bdf),
1291            );
1292            assert!(!hybrid_ode.is_null());
1293            let mut hybrid_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1294            assert_eq!(
1295                diffsol_ode_solve_hybrid_dense(
1296                    hybrid_ode,
1297                    params.as_ptr(),
1298                    params.len(),
1299                    hybrid_t_eval.as_ptr(),
1300                    hybrid_t_eval.len(),
1301                    &mut hybrid_solution_ptr,
1302                ),
1303                DIFFSOL_OK
1304            );
1305            let mut hybrid_ys_ptr = ptr::null_mut();
1306            let mut hybrid_ts_ptr = ptr::null_mut();
1307            assert_eq!(
1308                diffsol_solution_wrapper_get_ys(hybrid_solution_ptr, &mut hybrid_ys_ptr),
1309                DIFFSOL_OK
1310            );
1311            assert_eq!(
1312                diffsol_solution_wrapper_get_ts(hybrid_solution_ptr, &mut hybrid_ts_ptr),
1313                DIFFSOL_OK
1314            );
1315            let (hybrid_rows, hybrid_cols, hybrid_ys) = ffi_read_host_array_matrix(hybrid_ys_ptr);
1316            let hybrid_ts = ffi_read_host_array_vector(hybrid_ts_ptr);
1317            assert_eq!(hybrid_rows, 1);
1318            assert_eq!(hybrid_cols, hybrid_t_eval.len());
1319            assert_eq!(hybrid_ts, hybrid_t_eval);
1320            assert_close(
1321                hybrid_ys[0],
1322                logistic_state(LOGISTIC_X0, 2.0, hybrid_t_eval[0]),
1323                5e-4,
1324                "ffi hybrid dense pre-root value",
1325            );
1326            assert_close(
1327                hybrid_ys[1],
1328                logistic_state(LOGISTIC_X0, 2.0, hybrid_t_eval[1]),
1329                5e-4,
1330                "ffi hybrid dense near-root value",
1331            );
1332            for (i, value) in hybrid_ys.iter().enumerate().skip(2) {
1333                assert_close(
1334                    *value,
1335                    1.0,
1336                    5e-4,
1337                    &format!("ffi hybrid dense post-root value[{i}]"),
1338                );
1339            }
1340            ffi_free_solution(hybrid_solution_ptr);
1341            diffsol_ode_free(hybrid_ode);
1342
1343            let analysis_ode = make_ode_ptr(
1344                matrix_type_to_i32(MatrixType::NalgebraDense),
1345                linear_solver_to_i32(LinearSolverType::Default),
1346                ode_solver_to_i32(OdeSolverType::Bdf),
1347            );
1348            assert!(!analysis_ode.is_null());
1349
1350            let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1351            assert_eq!(
1352                diffsol_ode_solve_fwd_sens(
1353                    analysis_ode,
1354                    params.as_ptr(),
1355                    params.len(),
1356                    t_eval.as_ptr(),
1357                    t_eval.len(),
1358                    &mut sens_solution_ptr,
1359                ),
1360                DIFFSOL_OK
1361            );
1362            let mut sens_list = ptr::null_mut();
1363            let mut sens_len = 0usize;
1364            assert_eq!(
1365                diffsol_solution_wrapper_get_sens(sens_solution_ptr, &mut sens_list, &mut sens_len),
1366                DIFFSOL_OK
1367            );
1368            let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
1369            assert_eq!(sens_values.len(), 1);
1370            assert_eq!(sens_values[0].0, 1);
1371            assert_eq!(sens_values[0].1, t_eval.len());
1372            for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate() {
1373                assert_close(
1374                    value,
1375                    logistic_state_dr(LOGISTIC_X0, 2.0, t),
1376                    ASSERT_TOL,
1377                    &format!("ffi sensitivity[{i}]"),
1378                );
1379            }
1380
1381            let adjoint_t_eval = [0.0f64, 0.25f64, 0.5f64, 1.0f64];
1382            let adjoint_data: Vec<f64> = adjoint_t_eval
1383                .iter()
1384                .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1385                .collect();
1386            let mut objective = 0.0;
1387            let mut adjoint_grad_ptr = ptr::null_mut();
1388            assert_eq!(
1389                diffsol_ode_solve_sum_squares_adj(
1390                    analysis_ode,
1391                    params.as_ptr(),
1392                    params.len(),
1393                    adjoint_data.as_ptr(),
1394                    1,
1395                    adjoint_t_eval.len(),
1396                    1,
1397                    1,
1398                    adjoint_t_eval.as_ptr(),
1399                    adjoint_t_eval.len(),
1400                    &mut objective,
1401                    &mut adjoint_grad_ptr,
1402                ),
1403                DIFFSOL_OK
1404            );
1405            assert_close(objective, 0.0, ASSERT_TOL, "ffi adjoint objective");
1406            let grad = ffi_read_host_array_vector(adjoint_grad_ptr);
1407            assert_eq!(grad.len(), 1);
1408            assert_close(grad[0], 0.0, ASSERT_TOL, "ffi adjoint gradient");
1409
1410            ffi_free_solution(sens_solution_ptr);
1411            diffsol_ode_free(analysis_ode);
1412            ffi_free_solution(solution_ptr);
1413            diffsol_ode_free(ode);
1414        }
1415    }
1416}
1417
1418#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
1419mod jit_tests {
1420    use std::ffi::{CStr, CString};
1421    use std::ptr;
1422
1423    use crate::error_c::{diffsol_error_code, diffsol_last_error_message};
1424    use crate::initial_condition_options_c::diffsol_ic_options_free;
1425    use crate::jit::JitBackendType;
1426    use crate::jit_c::jit_backend_to_i32;
1427    use crate::linear_solver_type::LinearSolverType;
1428    use crate::linear_solver_type_c::linear_solver_to_i32;
1429    use crate::matrix_type::MatrixType;
1430    use crate::matrix_type_c::matrix_type_to_i32;
1431    use crate::ode_options_c::diffsol_ode_options_free;
1432    use crate::ode_solver_type::OdeSolverType;
1433    use crate::ode_solver_type_c::ode_solver_to_i32;
1434    #[cfg(feature = "diffsl-llvm")]
1435    use crate::solution_wrapper_c::diffsol_solution_wrapper_get_sens;
1436    use crate::solution_wrapper_c::{
1437        diffsol_solution_wrapper_get_ts, diffsol_solution_wrapper_get_ys,
1438    };
1439    #[cfg(feature = "diffsl-llvm")]
1440    use crate::test_support::ffi_read_host_array_list_matrices;
1441    use crate::test_support::{
1442        assert_close, available_jit_backends, clear_last_error, ffi_free_solution,
1443        ffi_read_host_array_matrix, ffi_read_host_array_vector, find_time_window,
1444        hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code_cstring,
1445        logistic_state, ASSERT_TOL, LOGISTIC_X0,
1446    };
1447    #[cfg(feature = "diffsl-llvm")]
1448    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
1449
1450    use super::*;
1451
1452    unsafe fn make_ode_ptr(
1453        jit_backend: JitBackendType,
1454        matrix_type: i32,
1455        linear_solver: i32,
1456        ode_solver: i32,
1457    ) -> *mut OdeWrapper {
1458        let code = logistic_diffsl_code_cstring();
1459        unsafe {
1460            make_ode_ptr_with_code(
1461                jit_backend,
1462                code.as_ptr(),
1463                matrix_type,
1464                linear_solver,
1465                ode_solver,
1466            )
1467        }
1468    }
1469
1470    unsafe fn make_ode_ptr_with_code(
1471        jit_backend: JitBackendType,
1472        code: *const std::os::raw::c_char,
1473        matrix_type: i32,
1474        linear_solver: i32,
1475        ode_solver: i32,
1476    ) -> *mut OdeWrapper {
1477        unsafe {
1478            diffsol_ode_new_jit(
1479                code,
1480                jit_backend_to_i32(jit_backend),
1481                matrix_type,
1482                linear_solver,
1483                ode_solver,
1484            )
1485        }
1486    }
1487
1488    unsafe fn last_error_message() -> String {
1489        let ptr = unsafe { diffsol_last_error_message() };
1490        assert_eq!(unsafe { diffsol_error_code() }, 1);
1491        assert!(!ptr.is_null());
1492        unsafe { CStr::from_ptr(ptr) }.to_str().unwrap().to_owned()
1493    }
1494
1495    #[test]
1496    fn c_api_full_lifecycle_matches_jit_logistic_model() {
1497        clear_last_error();
1498        for jit_backend in available_jit_backends() {
1499            unsafe {
1500                let ode = make_ode_ptr(
1501                    jit_backend,
1502                    matrix_type_to_i32(MatrixType::NalgebraDense),
1503                    linear_solver_to_i32(LinearSolverType::Default),
1504                    ode_solver_to_i32(OdeSolverType::Bdf),
1505                );
1506                assert!(!ode.is_null());
1507
1508                assert_eq!(
1509                    diffsol_ode_get_matrix_type(ode),
1510                    matrix_type_to_i32(MatrixType::NalgebraDense)
1511                );
1512                assert_eq!(
1513                    diffsol_ode_get_ode_solver(ode),
1514                    ode_solver_to_i32(OdeSolverType::Bdf)
1515                );
1516                assert_eq!(
1517                    diffsol_ode_get_linear_solver(ode),
1518                    linear_solver_to_i32(LinearSolverType::Default)
1519                );
1520
1521                let params = [2.0f64];
1522                let y = [0.25f64];
1523                let v = [3.0f64];
1524
1525                let mut y0_ptr = ptr::null_mut();
1526                assert_eq!(
1527                    diffsol_ode_y0(ode, params.as_ptr(), params.len(), &mut y0_ptr),
1528                    DIFFSOL_OK
1529                );
1530                assert_eq!(ffi_read_host_array_vector(y0_ptr), vec![LOGISTIC_X0]);
1531
1532                let mut rhs_ptr = ptr::null_mut();
1533                assert_eq!(
1534                    diffsol_ode_rhs(
1535                        ode,
1536                        params.as_ptr(),
1537                        params.len(),
1538                        0.0,
1539                        y.as_ptr(),
1540                        y.len(),
1541                        &mut rhs_ptr,
1542                    ),
1543                    DIFFSOL_OK
1544                );
1545                assert_close(
1546                    ffi_read_host_array_vector(rhs_ptr)[0],
1547                    0.375,
1548                    ASSERT_TOL,
1549                    "jit ffi rhs",
1550                );
1551
1552                let mut rhs_jac_mul_ptr = ptr::null_mut();
1553                assert_eq!(
1554                    diffsol_ode_rhs_jac_mul(
1555                        ode,
1556                        params.as_ptr(),
1557                        params.len(),
1558                        0.0,
1559                        y.as_ptr(),
1560                        y.len(),
1561                        v.as_ptr(),
1562                        v.len(),
1563                        &mut rhs_jac_mul_ptr,
1564                    ),
1565                    DIFFSOL_OK
1566                );
1567                assert_close(
1568                    ffi_read_host_array_vector(rhs_jac_mul_ptr)[0],
1569                    3.0,
1570                    ASSERT_TOL,
1571                    "jit ffi rhs_jac_mul",
1572                );
1573
1574                let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1575                let t_eval = [0.25f64, 0.5f64, 1.0f64];
1576                assert_eq!(
1577                    diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1578                    DIFFSOL_OK
1579                );
1580                assert_eq!(
1581                    diffsol_ode_solve_dense(
1582                        ode,
1583                        params.as_ptr(),
1584                        params.len(),
1585                        t_eval.as_ptr(),
1586                        t_eval.len(),
1587                        &mut solution_ptr,
1588                    ),
1589                    DIFFSOL_OK
1590                );
1591                let mut ys_ptr = ptr::null_mut();
1592                let mut ts_ptr = ptr::null_mut();
1593                assert_eq!(
1594                    diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
1595                    DIFFSOL_OK
1596                );
1597                assert_eq!(
1598                    diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
1599                    DIFFSOL_OK
1600                );
1601                let (rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
1602                let ts = ffi_read_host_array_vector(ts_ptr);
1603                assert_eq!(rows, 1);
1604                assert_eq!(cols, ts.len());
1605                let start = find_time_window(&ts, &t_eval, ASSERT_TOL);
1606                for (i, &t) in t_eval.iter().enumerate() {
1607                    assert_close(ts[start + i], t, ASSERT_TOL, "jit ffi solution time");
1608                    assert_close(
1609                        ys[start + i],
1610                        logistic_state(LOGISTIC_X0, 2.0, t),
1611                        5e-4,
1612                        "jit ffi solution value",
1613                    );
1614                }
1615                assert_eq!(
1616                    diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Bdf)),
1617                    DIFFSOL_OK
1618                );
1619
1620                #[cfg(feature = "diffsl-llvm")]
1621                {
1622                    let analysis_code = logistic_diffsl_code_cstring();
1623                    let analysis_ode = make_ode_ptr_with_code(
1624                        JitBackendType::Llvm,
1625                        analysis_code.as_ptr(),
1626                        matrix_type_to_i32(MatrixType::NalgebraDense),
1627                        linear_solver_to_i32(LinearSolverType::Default),
1628                        ode_solver_to_i32(OdeSolverType::Bdf),
1629                    );
1630                    assert!(!analysis_ode.is_null());
1631
1632                    let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1633                    assert_eq!(
1634                        diffsol_ode_solve_fwd_sens(
1635                            analysis_ode,
1636                            params.as_ptr(),
1637                            params.len(),
1638                            t_eval.as_ptr(),
1639                            t_eval.len(),
1640                            &mut sens_solution_ptr,
1641                        ),
1642                        DIFFSOL_OK
1643                    );
1644                    let mut sens_list = ptr::null_mut();
1645                    let mut sens_len = 0usize;
1646                    assert_eq!(
1647                        diffsol_solution_wrapper_get_sens(
1648                            sens_solution_ptr,
1649                            &mut sens_list,
1650                            &mut sens_len
1651                        ),
1652                        DIFFSOL_OK
1653                    );
1654                    let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
1655                    assert_eq!(sens_values.len(), 1);
1656                    assert_eq!(sens_values[0].0, 1);
1657                    assert_eq!(sens_values[0].1, t_eval.len());
1658                    for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate()
1659                    {
1660                        assert_close(
1661                            value,
1662                            logistic_state_dr(LOGISTIC_X0, 2.0, t),
1663                            ASSERT_TOL,
1664                            &format!("jit ffi sensitivity[{i}]"),
1665                        );
1666                    }
1667
1668                    let adjoint_t_eval = [0.0f64, 0.25f64, 0.5f64, 1.0f64];
1669                    let adjoint_data: Vec<f64> = adjoint_t_eval
1670                        .iter()
1671                        .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1672                        .collect();
1673                    let mut objective = 0.0;
1674                    let mut adjoint_grad_ptr = ptr::null_mut();
1675                    assert_eq!(
1676                        diffsol_ode_solve_sum_squares_adj(
1677                            analysis_ode,
1678                            params.as_ptr(),
1679                            params.len(),
1680                            adjoint_data.as_ptr(),
1681                            1,
1682                            adjoint_t_eval.len(),
1683                            1,
1684                            1,
1685                            adjoint_t_eval.as_ptr(),
1686                            adjoint_t_eval.len(),
1687                            &mut objective,
1688                            &mut adjoint_grad_ptr,
1689                        ),
1690                        DIFFSOL_OK
1691                    );
1692                    assert_close(objective, 0.0, ASSERT_TOL, "jit ffi adjoint objective");
1693                    let grad = ffi_read_host_array_vector(adjoint_grad_ptr);
1694                    assert_eq!(grad.len(), 1);
1695                    assert!(
1696                        grad[0].is_finite(),
1697                        "jit ffi adjoint gradient should be finite"
1698                    );
1699
1700                    ffi_free_solution(sens_solution_ptr);
1701                    diffsol_ode_free(analysis_ode);
1702                }
1703                ffi_free_solution(solution_ptr);
1704                diffsol_ode_free(ode);
1705            }
1706        }
1707    }
1708
1709    #[test]
1710    fn c_api_rejects_invalid_jit_arguments() {
1711        unsafe {
1712            clear_last_error();
1713            assert!(diffsol_ode_new_jit(
1714                ptr::null(),
1715                jit_backend_to_i32(available_jit_backends()[0]),
1716                matrix_type_to_i32(MatrixType::NalgebraDense),
1717                linear_solver_to_i32(LinearSolverType::Default),
1718                ode_solver_to_i32(OdeSolverType::Bdf),
1719            )
1720            .is_null());
1721            assert!(last_error_message().contains("code is null"));
1722
1723            clear_last_error();
1724            let invalid_utf8 = CString::from_vec_with_nul(vec![0xff, 0]).unwrap();
1725            assert!(diffsol_ode_new_jit(
1726                invalid_utf8.as_ptr(),
1727                jit_backend_to_i32(available_jit_backends()[0]),
1728                matrix_type_to_i32(MatrixType::NalgebraDense),
1729                linear_solver_to_i32(LinearSolverType::Default),
1730                ode_solver_to_i32(OdeSolverType::Bdf),
1731            )
1732            .is_null());
1733            assert!(last_error_message().contains("valid UTF-8"));
1734
1735            clear_last_error();
1736            let code = logistic_diffsl_code_cstring();
1737            assert!(diffsol_ode_new_jit(
1738                code.as_ptr(),
1739                99,
1740                matrix_type_to_i32(MatrixType::NalgebraDense),
1741                linear_solver_to_i32(LinearSolverType::Default),
1742                ode_solver_to_i32(OdeSolverType::Bdf),
1743            )
1744            .is_null());
1745            assert!(last_error_message().contains("invalid jit_backend_type"));
1746
1747            clear_last_error();
1748            assert!(diffsol_ode_new_jit(
1749                code.as_ptr(),
1750                jit_backend_to_i32(available_jit_backends()[0]),
1751                99,
1752                linear_solver_to_i32(LinearSolverType::Default),
1753                ode_solver_to_i32(OdeSolverType::Bdf),
1754            )
1755            .is_null());
1756            assert!(last_error_message().contains("invalid matrix_type"));
1757
1758            clear_last_error();
1759            assert!(diffsol_ode_new_jit(
1760                code.as_ptr(),
1761                jit_backend_to_i32(available_jit_backends()[0]),
1762                matrix_type_to_i32(MatrixType::NalgebraDense),
1763                99,
1764                ode_solver_to_i32(OdeSolverType::Bdf),
1765            )
1766            .is_null());
1767            assert!(last_error_message().contains("invalid linear_solver"));
1768
1769            clear_last_error();
1770            assert!(diffsol_ode_new_jit(
1771                code.as_ptr(),
1772                jit_backend_to_i32(available_jit_backends()[0]),
1773                matrix_type_to_i32(MatrixType::NalgebraDense),
1774                linear_solver_to_i32(LinearSolverType::Default),
1775                99,
1776            )
1777            .is_null());
1778            assert!(last_error_message().contains("invalid ode_solver"));
1779
1780            clear_last_error();
1781            let invalid_code = CString::new("not valid diffsl").unwrap();
1782            assert!(diffsol_ode_new_jit(
1783                invalid_code.as_ptr(),
1784                jit_backend_to_i32(available_jit_backends()[0]),
1785                matrix_type_to_i32(MatrixType::NalgebraDense),
1786                linear_solver_to_i32(LinearSolverType::Default),
1787                ode_solver_to_i32(OdeSolverType::Bdf),
1788            )
1789            .is_null());
1790            assert!(diffsol_error_code() != 0);
1791
1792            let mut ic_options = ptr::null_mut();
1793            assert_eq!(
1794                diffsol_ode_get_ic_options(ptr::null_mut(), &mut ic_options),
1795                DIFFSOL_BAD_ARG
1796            );
1797            let mut ode_options = ptr::null_mut();
1798            assert_eq!(
1799                diffsol_ode_get_options(ptr::null_mut(), &mut ode_options),
1800                DIFFSOL_BAD_ARG
1801            );
1802
1803            let mut out_array = ptr::null_mut();
1804            assert_eq!(
1805                diffsol_ode_y0(ptr::null_mut(), ptr::null(), 0, &mut out_array),
1806                DIFFSOL_BAD_ARG
1807            );
1808            assert_eq!(
1809                diffsol_ode_rhs(
1810                    ptr::null_mut(),
1811                    ptr::null(),
1812                    0,
1813                    0.0,
1814                    ptr::null(),
1815                    0,
1816                    &mut out_array,
1817                ),
1818                DIFFSOL_BAD_ARG
1819            );
1820            assert_eq!(
1821                diffsol_ode_rhs_jac_mul(
1822                    ptr::null_mut(),
1823                    ptr::null(),
1824                    0,
1825                    0.0,
1826                    ptr::null(),
1827                    0,
1828                    ptr::null(),
1829                    0,
1830                    &mut out_array,
1831                ),
1832                DIFFSOL_BAD_ARG
1833            );
1834
1835            clear_last_error();
1836            diffsol_ode_free(ptr::null_mut());
1837            assert!(last_error_message().contains("ode is null"));
1838
1839            clear_last_error();
1840            diffsol_host_array_list_free(ptr::null_mut(), 0);
1841            assert!(last_error_message().contains("host array list is null"));
1842        }
1843    }
1844
1845    #[test]
1846    fn c_api_jit_wrapper_branches_cover_runtime_success_and_errors() {
1847        for jit_backend in available_jit_backends() {
1848            unsafe {
1849                let ode = make_ode_ptr(
1850                    jit_backend,
1851                    matrix_type_to_i32(MatrixType::NalgebraDense),
1852                    linear_solver_to_i32(LinearSolverType::Default),
1853                    ode_solver_to_i32(OdeSolverType::Bdf),
1854                );
1855                assert!(!ode.is_null());
1856
1857                let mut ic_options = ptr::null_mut();
1858                let mut ode_options = ptr::null_mut();
1859                assert_eq!(diffsol_ode_get_ic_options(ode, &mut ic_options), DIFFSOL_OK);
1860                assert_eq!(diffsol_ode_get_options(ode, &mut ode_options), DIFFSOL_OK);
1861                diffsol_ic_options_free(ic_options);
1862                diffsol_ode_options_free(ode_options);
1863
1864                let mut out_value = 0.0;
1865                assert_eq!(diffsol_ode_get_rtol(ode, &mut out_value), DIFFSOL_OK);
1866                assert_close(out_value, 1e-6, ASSERT_TOL, "jit ffi default rtol");
1867                assert_eq!(diffsol_ode_set_rtol(ode, 1e-4), DIFFSOL_OK);
1868                assert_eq!(diffsol_ode_get_rtol(ode, &mut out_value), DIFFSOL_OK);
1869                assert_close(out_value, 1e-4, ASSERT_TOL, "jit ffi updated rtol");
1870
1871                assert_eq!(diffsol_ode_get_atol(ode, &mut out_value), DIFFSOL_OK);
1872                assert_close(out_value, 1e-6, ASSERT_TOL, "jit ffi default atol");
1873                assert_eq!(diffsol_ode_set_atol(ode, 1e-5), DIFFSOL_OK);
1874                assert_eq!(diffsol_ode_get_atol(ode, &mut out_value), DIFFSOL_OK);
1875                assert_close(out_value, 1e-5, ASSERT_TOL, "jit ffi updated atol");
1876
1877                assert_eq!(
1878                    diffsol_ode_set_linear_solver(ode, linear_solver_to_i32(LinearSolverType::Lu)),
1879                    DIFFSOL_OK
1880                );
1881                assert_eq!(
1882                    diffsol_ode_get_linear_solver(ode),
1883                    linear_solver_to_i32(LinearSolverType::Lu)
1884                );
1885                assert_eq!(
1886                    diffsol_ode_set_ode_solver(ode, ode_solver_to_i32(OdeSolverType::Tsit45)),
1887                    DIFFSOL_OK
1888                );
1889                assert_eq!(
1890                    diffsol_ode_get_ode_solver(ode),
1891                    ode_solver_to_i32(OdeSolverType::Tsit45)
1892                );
1893                assert_eq!(
1894                    diffsol_ode_get_matrix_type(ode),
1895                    matrix_type_to_i32(MatrixType::NalgebraDense)
1896                );
1897
1898                let params = [2.0f64];
1899                let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1900                assert_eq!(
1901                    diffsol_ode_solve(ode, params.as_ptr(), params.len(), 1.0, &mut solution_ptr),
1902                    DIFFSOL_OK
1903                );
1904                ffi_free_solution(solution_ptr);
1905
1906                let t_eval = [0.25f64, 0.5f64, 1.0f64];
1907                let mut dense_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1908                assert_eq!(
1909                    diffsol_ode_solve_dense(
1910                        ode,
1911                        params.as_ptr(),
1912                        params.len(),
1913                        t_eval.as_ptr(),
1914                        t_eval.len(),
1915                        &mut dense_solution_ptr,
1916                    ),
1917                    DIFFSOL_OK
1918                );
1919                ffi_free_solution(dense_solution_ptr);
1920
1921                let no_params: [f64; 0] = [];
1922                let y = [0.25f64];
1923                let v = [3.0f64];
1924                let mut out_array = ptr::null_mut();
1925                assert_eq!(
1926                    diffsol_ode_y0(ode, no_params.as_ptr(), no_params.len(), &mut out_array),
1927                    DIFFSOL_ERR
1928                );
1929                assert_eq!(
1930                    diffsol_ode_rhs(
1931                        ode,
1932                        no_params.as_ptr(),
1933                        no_params.len(),
1934                        0.0,
1935                        y.as_ptr(),
1936                        y.len(),
1937                        &mut out_array,
1938                    ),
1939                    DIFFSOL_ERR
1940                );
1941                assert_eq!(
1942                    diffsol_ode_rhs_jac_mul(
1943                        ode,
1944                        no_params.as_ptr(),
1945                        no_params.len(),
1946                        0.0,
1947                        y.as_ptr(),
1948                        y.len(),
1949                        v.as_ptr(),
1950                        v.len(),
1951                        &mut out_array,
1952                    ),
1953                    DIFFSOL_ERR
1954                );
1955
1956                let mut err_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
1957                assert_eq!(
1958                    diffsol_ode_solve(
1959                        ode,
1960                        no_params.as_ptr(),
1961                        no_params.len(),
1962                        1.0,
1963                        &mut err_solution_ptr,
1964                    ),
1965                    DIFFSOL_ERR
1966                );
1967                assert_eq!(
1968                    diffsol_ode_solve_hybrid(
1969                        ode,
1970                        no_params.as_ptr(),
1971                        no_params.len(),
1972                        1.0,
1973                        &mut err_solution_ptr,
1974                    ),
1975                    DIFFSOL_ERR
1976                );
1977                assert_eq!(
1978                    diffsol_ode_solve_dense(
1979                        ode,
1980                        no_params.as_ptr(),
1981                        no_params.len(),
1982                        t_eval.as_ptr(),
1983                        t_eval.len(),
1984                        &mut err_solution_ptr,
1985                    ),
1986                    DIFFSOL_ERR
1987                );
1988                assert_eq!(
1989                    diffsol_ode_solve_hybrid_dense(
1990                        ode,
1991                        no_params.as_ptr(),
1992                        no_params.len(),
1993                        t_eval.as_ptr(),
1994                        t_eval.len(),
1995                        &mut err_solution_ptr,
1996                    ),
1997                    DIFFSOL_ERR
1998                );
1999
2000                #[cfg(feature = "diffsl-llvm")]
2001                if matches!(jit_backend, JitBackendType::Llvm) {
2002                    assert_eq!(
2003                        diffsol_ode_solve_fwd_sens(
2004                            ode,
2005                            no_params.as_ptr(),
2006                            no_params.len(),
2007                            t_eval.as_ptr(),
2008                            t_eval.len(),
2009                            &mut err_solution_ptr,
2010                        ),
2011                        DIFFSOL_ERR
2012                    );
2013                    assert_eq!(
2014                        diffsol_ode_solve_hybrid_fwd_sens(
2015                            ode,
2016                            no_params.as_ptr(),
2017                            no_params.len(),
2018                            t_eval.as_ptr(),
2019                            t_eval.len(),
2020                            &mut err_solution_ptr,
2021                        ),
2022                        DIFFSOL_ERR
2023                    );
2024
2025                    let adjoint_data: Vec<f64> = t_eval
2026                        .iter()
2027                        .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
2028                        .collect();
2029                    let mut objective = 0.0;
2030                    let mut sens_ptr = ptr::null_mut();
2031                    assert_eq!(
2032                        diffsol_ode_solve_sum_squares_adj(
2033                            ode,
2034                            no_params.as_ptr(),
2035                            no_params.len(),
2036                            adjoint_data.as_ptr(),
2037                            1,
2038                            t_eval.len(),
2039                            1,
2040                            1,
2041                            t_eval.as_ptr(),
2042                            t_eval.len(),
2043                            &mut objective,
2044                            &mut sens_ptr,
2045                        ),
2046                        DIFFSOL_ERR
2047                    );
2048                }
2049
2050                assert_eq!(diffsol_ode_get_matrix_type(ptr::null()), -1);
2051                assert_eq!(diffsol_ode_get_ode_solver(ptr::null()), -1);
2052                assert_eq!(diffsol_ode_get_linear_solver(ptr::null()), -1);
2053                assert_eq!(
2054                    diffsol_ode_set_ode_solver(ptr::null_mut(), 0),
2055                    DIFFSOL_BAD_ARG
2056                );
2057                assert_eq!(
2058                    diffsol_ode_set_linear_solver(ptr::null_mut(), 0),
2059                    DIFFSOL_BAD_ARG
2060                );
2061                assert_eq!(diffsol_ode_set_ode_solver(ode, 99), DIFFSOL_BAD_ARG);
2062                assert_eq!(diffsol_ode_set_linear_solver(ode, 99), DIFFSOL_BAD_ARG);
2063                assert_eq!(
2064                    diffsol_ode_get_rtol(ptr::null(), &mut out_value),
2065                    DIFFSOL_BAD_ARG
2066                );
2067                assert_eq!(diffsol_ode_get_rtol(ode, ptr::null_mut()), DIFFSOL_BAD_ARG);
2068                assert_eq!(diffsol_ode_set_rtol(ptr::null_mut(), 1e-3), DIFFSOL_BAD_ARG);
2069                assert_eq!(
2070                    diffsol_ode_get_atol(ptr::null(), &mut out_value),
2071                    DIFFSOL_BAD_ARG
2072                );
2073                assert_eq!(diffsol_ode_get_atol(ode, ptr::null_mut()), DIFFSOL_BAD_ARG);
2074                assert_eq!(diffsol_ode_set_atol(ptr::null_mut(), 1e-3), DIFFSOL_BAD_ARG);
2075                assert_eq!(
2076                    diffsol_ode_solve(ode, params.as_ptr(), params.len(), 1.0, ptr::null_mut()),
2077                    DIFFSOL_BAD_ARG
2078                );
2079                assert_eq!(
2080                    diffsol_ode_solve_hybrid(
2081                        ode,
2082                        params.as_ptr(),
2083                        params.len(),
2084                        1.0,
2085                        ptr::null_mut(),
2086                    ),
2087                    DIFFSOL_BAD_ARG
2088                );
2089                assert_eq!(
2090                    diffsol_ode_solve_dense(
2091                        ode,
2092                        params.as_ptr(),
2093                        params.len(),
2094                        t_eval.as_ptr(),
2095                        t_eval.len(),
2096                        ptr::null_mut(),
2097                    ),
2098                    DIFFSOL_BAD_ARG
2099                );
2100                assert_eq!(
2101                    diffsol_ode_solve_hybrid_dense(
2102                        ode,
2103                        params.as_ptr(),
2104                        params.len(),
2105                        t_eval.as_ptr(),
2106                        t_eval.len(),
2107                        ptr::null_mut(),
2108                    ),
2109                    DIFFSOL_BAD_ARG
2110                );
2111                #[cfg(feature = "diffsl-llvm")]
2112                if matches!(jit_backend, JitBackendType::Llvm) {
2113                    assert_eq!(
2114                        diffsol_ode_solve_fwd_sens(
2115                            ode,
2116                            params.as_ptr(),
2117                            params.len(),
2118                            t_eval.as_ptr(),
2119                            t_eval.len(),
2120                            ptr::null_mut(),
2121                        ),
2122                        DIFFSOL_BAD_ARG
2123                    );
2124                    assert_eq!(
2125                        diffsol_ode_solve_hybrid_fwd_sens(
2126                            ode,
2127                            params.as_ptr(),
2128                            params.len(),
2129                            t_eval.as_ptr(),
2130                            t_eval.len(),
2131                            ptr::null_mut(),
2132                        ),
2133                        DIFFSOL_BAD_ARG
2134                    );
2135                    let mut objective = 0.0;
2136                    let mut sens_ptr = ptr::null_mut();
2137                    assert_eq!(
2138                        diffsol_ode_solve_sum_squares_adj(
2139                            ode,
2140                            params.as_ptr(),
2141                            params.len(),
2142                            t_eval.as_ptr(),
2143                            1,
2144                            t_eval.len(),
2145                            1,
2146                            1,
2147                            t_eval.as_ptr(),
2148                            t_eval.len(),
2149                            ptr::null_mut(),
2150                            &mut sens_ptr,
2151                        ),
2152                        DIFFSOL_BAD_ARG
2153                    );
2154                    assert_eq!(
2155                        diffsol_ode_solve_sum_squares_adj(
2156                            ode,
2157                            params.as_ptr(),
2158                            params.len(),
2159                            t_eval.as_ptr(),
2160                            1,
2161                            t_eval.len(),
2162                            1,
2163                            1,
2164                            t_eval.as_ptr(),
2165                            t_eval.len(),
2166                            &mut objective,
2167                            ptr::null_mut(),
2168                        ),
2169                        DIFFSOL_BAD_ARG
2170                    );
2171                }
2172
2173                diffsol_ode_free(ode);
2174            }
2175        }
2176    }
2177
2178    #[test]
2179    fn c_api_hybrid_jit_solver_paths_match_expected_values() {
2180        for jit_backend in available_jit_backends() {
2181            unsafe {
2182                let code = CString::new(hybrid_logistic_diffsl_code()).unwrap();
2183                let ode = make_ode_ptr_with_code(
2184                    jit_backend,
2185                    code.as_ptr(),
2186                    matrix_type_to_i32(MatrixType::NalgebraDense),
2187                    linear_solver_to_i32(LinearSolverType::Default),
2188                    ode_solver_to_i32(OdeSolverType::Bdf),
2189                );
2190                assert!(!ode.is_null());
2191
2192                let params = [2.0f64];
2193                let mut solution_ptr: *mut SolutionWrapper = ptr::null_mut();
2194                assert_eq!(
2195                    diffsol_ode_solve_hybrid(
2196                        ode,
2197                        params.as_ptr(),
2198                        params.len(),
2199                        2.0,
2200                        &mut solution_ptr
2201                    ),
2202                    DIFFSOL_OK
2203                );
2204                let mut ys_ptr = ptr::null_mut();
2205                let mut ts_ptr = ptr::null_mut();
2206                assert_eq!(
2207                    diffsol_solution_wrapper_get_ys(solution_ptr, &mut ys_ptr),
2208                    DIFFSOL_OK
2209                );
2210                assert_eq!(
2211                    diffsol_solution_wrapper_get_ts(solution_ptr, &mut ts_ptr),
2212                    DIFFSOL_OK
2213                );
2214                let (_rows, cols, ys) = ffi_read_host_array_matrix(ys_ptr);
2215                let ts = ffi_read_host_array_vector(ts_ptr);
2216                assert!(cols >= 1);
2217                assert_close(*ts.last().unwrap(), 2.0, 5e-4, "jit hybrid solve time");
2218                assert_close(
2219                    *ys.last().unwrap(),
2220                    hybrid_logistic_state(2.0, 2.0),
2221                    5e-4,
2222                    "jit hybrid solve value",
2223                );
2224                ffi_free_solution(solution_ptr);
2225
2226                #[cfg(feature = "diffsl-llvm")]
2227                if matches!(jit_backend, JitBackendType::Llvm) {
2228                    let t_eval = [0.25f64, 0.5f64, 1.0f64];
2229                    let mut sens_solution_ptr: *mut SolutionWrapper = ptr::null_mut();
2230                    assert_eq!(
2231                        diffsol_ode_solve_hybrid_fwd_sens(
2232                            ode,
2233                            params.as_ptr(),
2234                            params.len(),
2235                            t_eval.as_ptr(),
2236                            t_eval.len(),
2237                            &mut sens_solution_ptr,
2238                        ),
2239                        DIFFSOL_OK
2240                    );
2241                    let mut sens_list = ptr::null_mut();
2242                    let mut sens_len = 0usize;
2243                    assert_eq!(
2244                        diffsol_solution_wrapper_get_sens(
2245                            sens_solution_ptr,
2246                            &mut sens_list,
2247                            &mut sens_len
2248                        ),
2249                        DIFFSOL_OK
2250                    );
2251                    let sens_values = ffi_read_host_array_list_matrices(sens_list, sens_len);
2252                    for (i, (&value, &t)) in sens_values[0].2.iter().zip(t_eval.iter()).enumerate()
2253                    {
2254                        assert_close(
2255                            value,
2256                            hybrid_logistic_state_dr(2.0, t),
2257                            5e-4,
2258                            &format!("jit hybrid sensitivity[{i}]"),
2259                        );
2260                    }
2261                    ffi_free_solution(sens_solution_ptr);
2262                }
2263
2264                diffsol_ode_free(ode);
2265            }
2266        }
2267    }
2268}