Skip to main content

diffsol_c/
solve.rs

1// Delegate solver types selected at runtime in Host to concrete solver types
2// in Rust.
3
4#[cfg(feature = "diffsl-external-dynamic")]
5use std::path::PathBuf;
6
7use crate::adjoint_checkpoint::AdjointCheckpointWrapper;
8use crate::error::DiffsolRtError;
9use crate::host_array::HostArray;
10use crate::host_array::ToHostArray;
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use crate::jit::JitBackendType;
13#[cfg(feature = "external")]
14use crate::scalar_type::ExternalScalar;
15use crate::scalar_type::Scalar;
16use crate::scalar_type::ScalarType;
17use crate::{
18    generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
19    option_value_from_store, option_value_to_store,
20};
21use crate::{generate_trait_ic_option_accessors, generate_trait_ode_option_accessors};
22use diffsol::ObjectModule;
23use diffsol::OdeBuilder;
24use diffsol::{
25    error::DiffsolError,
26    matrix::{MatrixHost, MatrixRef},
27    CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DenseMatrix, DiffSl,
28    MatrixCommon, NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeSolverProblem, Op, Vector,
29    VectorCommon, VectorHost, VectorRef,
30};
31#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
32use diffsol::{CodegenModuleCompile, CodegenModuleJit};
33use num_traits::{FromPrimitive, ToPrimitive}; // for from_f64 and to_f64
34use paste::paste;
35
36use crate::solve_serialization::{unsupported_serialization_error, SolveSerialization};
37use crate::{
38    linear_solver_type::LinearSolverType, matrix_type::MatrixType, ode_solver_type::OdeSolverType,
39};
40use crate::{
41    matrix_type::MatrixKind,
42    valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
43};
44
45// Each matrix type implements PySolve as bridge between diffsol and Host
46
47use crate::solution::Solution;
48pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolRtError>;
49
50pub(crate) trait Solve {
51    fn matrix_type(&self) -> MatrixType;
52    fn nstates(&self) -> usize;
53    fn nparams(&self) -> usize;
54    fn nout(&self) -> usize;
55    fn has_stop(&self) -> bool;
56
57    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError>;
58
59    fn rhs_jac_mul(
60        &mut self,
61        params: &[f64],
62        t: f64,
63        y: &[f64],
64        v: &[f64],
65    ) -> Result<HostArray, DiffsolRtError>;
66
67    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError>;
68
69    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError>;
70    fn set_rtol(&mut self, rtol: f64);
71    fn rtol(&self) -> f64;
72    fn set_atol(&mut self, atol: f64);
73    fn atol(&self) -> f64;
74    fn set_t0(&mut self, t0: f64);
75    fn t0(&self) -> f64;
76    fn set_h0(&mut self, h0: f64);
77    fn h0(&self) -> f64;
78    fn set_integrate_out(&mut self, integrate_out: bool);
79    fn integrate_out(&self) -> bool;
80    fn set_sens_rtol(&mut self, sens_rtol: Option<f64>);
81    fn sens_rtol(&self) -> Option<f64>;
82    fn set_sens_atol(&mut self, sens_atol: Option<f64>);
83    fn sens_atol(&self) -> Option<f64>;
84    fn set_out_rtol(&mut self, out_rtol: Option<f64>);
85    fn out_rtol(&self) -> Option<f64>;
86    fn set_out_atol(&mut self, out_atol: Option<f64>);
87    fn out_atol(&self) -> Option<f64>;
88    fn set_param_rtol(&mut self, param_rtol: Option<f64>);
89    fn param_rtol(&self) -> Option<f64>;
90    fn set_param_atol(&mut self, param_atol: Option<f64>);
91    fn param_atol(&self) -> Option<f64>;
92    fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
93        unsupported_serialization_error(
94            "ODE serialization is only supported for JIT-backed solvers",
95        )
96    }
97
98    // New API: solution object support
99    fn solve(
100        &mut self,
101        method: OdeSolverType,
102        linear_solver: LinearSolverType,
103        params: &[f64],
104        final_time: f64,
105    ) -> SolveResult;
106
107    fn solve_dense(
108        &mut self,
109        method: OdeSolverType,
110        linear_solver: LinearSolverType,
111        params: &[f64],
112        t_eval: &[f64],
113    ) -> SolveResult;
114
115    fn solve_fwd_sens(
116        &mut self,
117        method: OdeSolverType,
118        linear_solver: LinearSolverType,
119        params: &[f64],
120        t_eval: &[f64],
121    ) -> SolveResult;
122
123    fn solve_continuous_adjoint(
124        &mut self,
125        method: OdeSolverType,
126        linear_solver: LinearSolverType,
127        params: &[f64],
128        final_time: f64,
129    ) -> Result<(HostArray, HostArray), DiffsolRtError>;
130
131    fn solve_adjoint_fwd(
132        &mut self,
133        method: OdeSolverType,
134        linear_solver: LinearSolverType,
135        params: &[f64],
136        t_eval: &[f64],
137    ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError>;
138
139    fn solve_adjoint_bkwd(
140        &mut self,
141        method: OdeSolverType,
142        linear_solver: LinearSolverType,
143        checkpoint: &AdjointCheckpointWrapper,
144        t_eval: &[f64],
145        dgdu_eval: HostArray,
146    ) -> Result<HostArray, DiffsolRtError>;
147
148    generate_trait_ic_option_accessors! {
149        use_linesearch: bool,
150        max_linesearch_iterations: usize,
151        max_newton_iterations: usize,
152        max_linear_solver_setups: usize,
153        step_reduction_factor: f64,
154        armijo_constant: f64
155    }
156    generate_trait_ode_option_accessors! {
157        max_nonlinear_solver_iterations: usize,
158        max_error_test_failures: usize,
159        min_timestep: f64,
160        update_jacobian_after_steps: usize,
161        update_rhs_jacobian_after_steps: usize,
162        threshold_to_update_jacobian: f64,
163        threshold_to_update_rhs_jacobian: f64
164    }
165}
166// Public factory method for generating an instance based on matrix type
167#[cfg(feature = "external")]
168pub(crate) fn solve_factory_external(
169    rhs_state_deps: Vec<(usize, usize)>,
170    rhs_input_deps: Vec<(usize, usize)>,
171    mass_state_deps: Vec<(usize, usize)>,
172    matrix_type: MatrixType,
173    scalar_type: ScalarType,
174) -> Result<Box<dyn Solve>, DiffsolRtError> {
175    let solve: Box<dyn Solve> = match matrix_type {
176        MatrixType::NalgebraDense => match scalar_type {
177            #[cfg(feature = "diffsl-external-f32")]
178            ScalarType::F32 => Box::new(GenericSolve::<
179                diffsol::NalgebraMat<f32>,
180                diffsl::ExternalModule<f32>,
181            >::from_external(
182                rhs_state_deps, rhs_input_deps, mass_state_deps, false
183            )?),
184            #[cfg(feature = "diffsl-external-f64")]
185            ScalarType::F64 => Box::new(GenericSolve::<
186                diffsol::NalgebraMat<f64>,
187                diffsl::ExternalModule<f64>,
188            >::from_external(
189                rhs_state_deps, rhs_input_deps, mass_state_deps, false
190            )?),
191            _ => {
192                return Err(DiffsolRtError::from(DiffsolError::Other(
193                    "Unsupported scalar type for NalgebraDense".to_string(),
194                )));
195            }
196        },
197        MatrixType::FaerDense => match scalar_type {
198            #[cfg(feature = "diffsl-external-f32")]
199            ScalarType::F32 => Box::new(GenericSolve::<
200                diffsol::FaerMat<f32>,
201                diffsl::ExternalModule<f32>,
202            >::from_external(
203                rhs_state_deps, rhs_input_deps, mass_state_deps, false
204            )?),
205            #[cfg(feature = "diffsl-external-f64")]
206            ScalarType::F64 => Box::new(GenericSolve::<
207                diffsol::FaerMat<f64>,
208                diffsl::ExternalModule<f64>,
209            >::from_external(
210                rhs_state_deps, rhs_input_deps, mass_state_deps, false
211            )?),
212            _ => {
213                return Err(DiffsolRtError::from(DiffsolError::Other(
214                    "Unsupported scalar type for FaerDense".to_string(),
215                )));
216            }
217        },
218        MatrixType::FaerSparse => match scalar_type {
219            #[cfg(feature = "diffsl-external-f32")]
220            ScalarType::F32 => Box::new(GenericSolve::<
221                diffsol::FaerSparseMat<f32>,
222                diffsl::ExternalModule<f32>,
223            >::from_external(
224                rhs_state_deps, rhs_input_deps, mass_state_deps, false
225            )?),
226            #[cfg(feature = "diffsl-external-f64")]
227            ScalarType::F64 => Box::new(GenericSolve::<
228                diffsol::FaerSparseMat<f64>,
229                diffsl::ExternalModule<f64>,
230            >::from_external(
231                rhs_state_deps, rhs_input_deps, mass_state_deps, false
232            )?),
233            _ => {
234                return Err(DiffsolRtError::from(DiffsolError::Other(
235                    "Unsupported scalar type for FaerSparse".to_string(),
236                )));
237            }
238        },
239    };
240    Ok(solve)
241}
242
243#[cfg(feature = "diffsl-external-dynamic")]
244pub(crate) fn solve_factory_external_dynamic(
245    path: PathBuf,
246    rhs_state_deps: Vec<(usize, usize)>,
247    rhs_input_deps: Vec<(usize, usize)>,
248    mass_state_deps: Vec<(usize, usize)>,
249    matrix_type: MatrixType,
250    scalar_type: ScalarType,
251) -> Result<Box<dyn Solve>, DiffsolRtError> {
252    let solve: Box<dyn Solve> = match matrix_type {
253        MatrixType::NalgebraDense => match scalar_type {
254            ScalarType::F32 => Box::new(GenericSolve::<
255                diffsol::NalgebraMat<f32>,
256                diffsl::ExternalDynModule<f32>,
257            >::from_external_dynamic(
258                path,
259                rhs_state_deps,
260                rhs_input_deps,
261                mass_state_deps,
262                false,
263            )?),
264            ScalarType::F64 => Box::new(GenericSolve::<
265                diffsol::NalgebraMat<f64>,
266                diffsl::ExternalDynModule<f64>,
267            >::from_external_dynamic(
268                path,
269                rhs_state_deps,
270                rhs_input_deps,
271                mass_state_deps,
272                false,
273            )?),
274        },
275        MatrixType::FaerDense => match scalar_type {
276            ScalarType::F32 => Box::new(GenericSolve::<
277                diffsol::FaerMat<f32>,
278                diffsl::ExternalDynModule<f32>,
279            >::from_external_dynamic(
280                path,
281                rhs_state_deps,
282                rhs_input_deps,
283                mass_state_deps,
284                false,
285            )?),
286            ScalarType::F64 => Box::new(GenericSolve::<
287                diffsol::FaerMat<f64>,
288                diffsl::ExternalDynModule<f64>,
289            >::from_external_dynamic(
290                path,
291                rhs_state_deps,
292                rhs_input_deps,
293                mass_state_deps,
294                false,
295            )?),
296        },
297        MatrixType::FaerSparse => match scalar_type {
298            ScalarType::F32 => Box::new(GenericSolve::<
299                diffsol::FaerSparseMat<f32>,
300                diffsl::ExternalDynModule<f32>,
301            >::from_external_dynamic(
302                path,
303                rhs_state_deps,
304                rhs_input_deps,
305                mass_state_deps,
306                false,
307            )?),
308            ScalarType::F64 => Box::new(GenericSolve::<
309                diffsol::FaerSparseMat<f64>,
310                diffsl::ExternalDynModule<f64>,
311            >::from_external_dynamic(
312                path,
313                rhs_state_deps,
314                rhs_input_deps,
315                mass_state_deps,
316                false,
317            )?),
318        },
319    };
320    Ok(solve)
321}
322
323#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
324pub(crate) fn solve_factory_jit(
325    code: &str,
326    jit_backend: JitBackendType,
327    matrix_type: MatrixType,
328    scalar_type: ScalarType,
329) -> Result<Box<dyn Solve>, DiffsolRtError> {
330    match jit_backend {
331        #[cfg(feature = "diffsl-cranelift")]
332        JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
333            code,
334            matrix_type,
335            scalar_type,
336        ),
337        #[cfg(feature = "diffsl-llvm")]
338        JitBackendType::Llvm => {
339            solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
340        }
341    }
342}
343
344#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
345fn solve_factory_with_jit_backend<CG>(
346    code: &str,
347    matrix_type: MatrixType,
348    scalar_type: ScalarType,
349) -> Result<Box<dyn Solve>, DiffsolRtError>
350where
351    CG: CodegenModule
352        + CodegenModuleJit
353        + CodegenModuleCompile
354        + SolveSerialization<diffsol::NalgebraMat<f32>>
355        + SolveSerialization<diffsol::NalgebraMat<f64>>
356        + SolveSerialization<diffsol::FaerMat<f32>>
357        + SolveSerialization<diffsol::FaerMat<f64>>
358        + SolveSerialization<diffsol::FaerSparseMat<f32>>
359        + SolveSerialization<diffsol::FaerSparseMat<f64>>,
360{
361    let solve: Box<dyn Solve> = match matrix_type {
362        MatrixType::NalgebraDense => match scalar_type {
363            ScalarType::F32 => {
364                let problem =
365                    OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
366                Box::new(GenericSolve { problem })
367            }
368            ScalarType::F64 => {
369                let problem =
370                    OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
371                Box::new(GenericSolve { problem })
372            }
373        },
374        MatrixType::FaerDense => match scalar_type {
375            ScalarType::F32 => {
376                let problem =
377                    OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
378                Box::new(GenericSolve { problem })
379            }
380            ScalarType::F64 => {
381                let problem =
382                    OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
383                Box::new(GenericSolve { problem })
384            }
385        },
386        MatrixType::FaerSparse => match scalar_type {
387            ScalarType::F32 => {
388                let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
389                    .build_from_diffsl::<CG>(code)?;
390                Box::new(GenericSolve { problem })
391            }
392            ScalarType::F64 => {
393                let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
394                    .build_from_diffsl::<CG>(code)?;
395                Box::new(GenericSolve { problem })
396            }
397        },
398    };
399    Ok(solve)
400}
401
402pub(crate) fn solve_factory_from_serialized_diffsl(
403    serialized_diffsl: &[u8],
404    matrix_type: MatrixType,
405    scalar_type: ScalarType,
406) -> Result<Box<dyn Solve>, DiffsolRtError> {
407    let solve: Box<dyn Solve> = match matrix_type {
408        MatrixType::NalgebraDense => match scalar_type {
409            ScalarType::F32 => Box::new(
410                GenericSolve::<diffsol::NalgebraMat<f32>, ObjectModule>::from_serialized_diffsl(
411                    serialized_diffsl,
412                )?,
413            ),
414            ScalarType::F64 => Box::new(
415                GenericSolve::<diffsol::NalgebraMat<f64>, ObjectModule>::from_serialized_diffsl(
416                    serialized_diffsl,
417                )?,
418            ),
419        },
420        MatrixType::FaerDense => match scalar_type {
421            ScalarType::F32 => Box::new(
422                GenericSolve::<diffsol::FaerMat<f32>, ObjectModule>::from_serialized_diffsl(
423                    serialized_diffsl,
424                )?,
425            ),
426            ScalarType::F64 => Box::new(
427                GenericSolve::<diffsol::FaerMat<f64>, ObjectModule>::from_serialized_diffsl(
428                    serialized_diffsl,
429                )?,
430            ),
431        },
432        MatrixType::FaerSparse => match scalar_type {
433            ScalarType::F32 => Box::new(
434                GenericSolve::<diffsol::FaerSparseMat<f32>, ObjectModule>::from_serialized_diffsl(
435                    serialized_diffsl,
436                )?,
437            ),
438            ScalarType::F64 => Box::new(
439                GenericSolve::<diffsol::FaerSparseMat<f64>, ObjectModule>::from_serialized_diffsl(
440                    serialized_diffsl,
441                )?,
442            ),
443        },
444    };
445    Ok(solve)
446}
447
448pub(crate) struct GenericSolve<M, CG>
449where
450    M: MatrixHost<T: Scalar>,
451    M::V: Vector + VectorHost,
452    CG: CodegenModule,
453{
454    problem: OdeSolverProblem<DiffSl<M, CG>>,
455}
456
457impl<M, CG> GenericSolve<M, CG>
458where
459    M: MatrixHost<T: Scalar>,
460    M::V: Vector + VectorHost + DefaultDenseMatrix,
461    CG: CodegenModule,
462{
463    fn from_eqn(eqn: DiffSl<M, CG>) -> Result<Self, DiffsolRtError> {
464        let default_p = vec![0.0; eqn.nparams()];
465        let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
466        Ok(GenericSolve { problem })
467    }
468
469    pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolRtError> {
470        let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
471        let params = M::V::from_slice(&params, M::C::default());
472
473        // Attempt to set problem from params and config
474        let nparams = self.problem.eqn.nparams();
475        if params.len() == nparams {
476            self.problem.eqn.set_params(&params);
477            // Reset model index to 0 for hybrid solves
478            self.problem.eqn.set_model_index(0);
479            Ok(())
480        } else {
481            Err(DiffsolError::Other(format!(
482                "Expecting {} params but got {}",
483                nparams,
484                params.len()
485            ))
486            .into())
487        }
488    }
489
490    pub(crate) fn serialize_eqn(&self) -> Result<Vec<u8>, DiffsolRtError>
491    where
492        DiffSl<M, CG>: serde::Serialize,
493    {
494        serde_json::to_vec(&self.problem.eqn)
495            .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))
496    }
497}
498
499#[cfg(feature = "external")]
500impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
501where
502    M: MatrixHost<T: ExternalScalar>,
503    M::V: Vector + VectorHost + DefaultDenseMatrix,
504{
505    pub fn from_external(
506        rhs_state_deps: Vec<(usize, usize)>,
507        rhs_input_deps: Vec<(usize, usize)>,
508        mass_state_deps: Vec<(usize, usize)>,
509        include_sensitivities: bool,
510    ) -> Result<Self, DiffsolRtError> {
511        let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
512            M::C::default(),
513            rhs_state_deps,
514            rhs_input_deps,
515            mass_state_deps,
516            include_sensitivities,
517        )?;
518        Self::from_eqn(eqn)
519    }
520}
521
522#[cfg(feature = "diffsl-external-dynamic")]
523impl<M> GenericSolve<M, diffsl::ExternalDynModule<M::T>>
524where
525    M: MatrixHost<T: Scalar>,
526    M::V: Vector + VectorHost + DefaultDenseMatrix,
527{
528    pub fn from_external_dynamic(
529        path: impl Into<PathBuf>,
530        rhs_state_deps: Vec<(usize, usize)>,
531        rhs_input_deps: Vec<(usize, usize)>,
532        mass_state_deps: Vec<(usize, usize)>,
533        include_sensitivities: bool,
534    ) -> Result<Self, DiffsolRtError> {
535        let eqn = DiffSl::<M, diffsl::ExternalDynModule<M::T>>::from_external_dynamic(
536            path,
537            M::C::default(),
538            rhs_state_deps,
539            rhs_input_deps,
540            mass_state_deps,
541            include_sensitivities,
542        )?;
543        Self::from_eqn(eqn)
544    }
545}
546
547impl<M> GenericSolve<M, ObjectModule>
548where
549    M: MatrixHost<T: Scalar>,
550    M::V: Vector + VectorHost + DefaultDenseMatrix,
551{
552    pub fn from_serialized_diffsl(serialized_diffsl: &[u8]) -> Result<Self, DiffsolRtError> {
553        let eqn = serde_json::from_slice::<DiffSl<M, ObjectModule>>(serialized_diffsl)
554            .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))?;
555        Self::from_eqn(eqn)
556    }
557}
558
559impl<M, CG> Solve for GenericSolve<M, CG>
560where
561    M: MatrixHost<T: Scalar + ToPrimitive>
562        + DefaultSolver
563        + LuValidator<M>
564        + KluValidator<M>
565        + MatrixKind,
566    CG: CodegenModule,
567    for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
568    for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
569    M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
570    <M::V as DefaultDenseMatrix>::M: Send + Sync,
571    for<'b> &'b M::V: VectorRef<M::V>,
572    for<'b> &'b M: MatrixRef<M>,
573    CG: SolveSerialization<M>,
574{
575    fn matrix_type(&self) -> MatrixType {
576        MatrixType::from_diffsol::<M>()
577    }
578
579    fn nstates(&self) -> usize {
580        self.problem.eqn.nstates()
581    }
582
583    fn nparams(&self) -> usize {
584        self.problem.eqn.nparams()
585    }
586
587    fn nout(&self) -> usize {
588        self.problem.eqn.nout()
589    }
590
591    fn has_stop(&self) -> bool {
592        self.problem.eqn.root().is_some()
593    }
594
595    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError> {
596        validate_linear_solver::<M>(linear_solver)
597    }
598
599    fn set_atol(&mut self, atol: f64) {
600        self.problem.atol.fill(M::T::from_f64(atol).unwrap());
601    }
602
603    fn atol(&self) -> f64 {
604        self.problem.atol[0].to_f64().unwrap()
605    }
606
607    fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
608        CG::serialized_diffsl(self)
609    }
610
611    fn set_rtol(&mut self, rtol: f64) {
612        self.problem.rtol = M::T::from_f64(rtol).unwrap();
613    }
614
615    fn rtol(&self) -> f64 {
616        self.problem.rtol.to_f64().unwrap()
617    }
618
619    fn set_t0(&mut self, t0: f64) {
620        self.problem.t0 = M::T::from_f64(t0).unwrap();
621    }
622
623    fn t0(&self) -> f64 {
624        self.problem.t0.to_f64().unwrap()
625    }
626
627    fn set_h0(&mut self, h0: f64) {
628        self.problem.h0 = M::T::from_f64(h0).unwrap();
629    }
630
631    fn h0(&self) -> f64 {
632        self.problem.h0.to_f64().unwrap()
633    }
634
635    fn set_integrate_out(&mut self, integrate_out: bool) {
636        self.problem.integrate_out = integrate_out;
637    }
638
639    fn integrate_out(&self) -> bool {
640        self.problem.integrate_out
641    }
642
643    fn set_sens_rtol(&mut self, sens_rtol: Option<f64>) {
644        self.problem.sens_rtol = sens_rtol.map(|value| M::T::from_f64(value).unwrap());
645    }
646
647    fn sens_rtol(&self) -> Option<f64> {
648        self.problem.sens_rtol.map(|value| value.to_f64().unwrap())
649    }
650
651    fn set_sens_atol(&mut self, sens_atol: Option<f64>) {
652        self.problem.sens_atol = sens_atol.map(|value| {
653            M::V::from_element(
654                self.problem.eqn.nstates(),
655                M::T::from_f64(value).unwrap(),
656                M::C::default(),
657            )
658        });
659    }
660
661    fn sens_atol(&self) -> Option<f64> {
662        self.problem
663            .sens_atol
664            .as_ref()
665            .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
666    }
667
668    fn set_out_rtol(&mut self, out_rtol: Option<f64>) {
669        self.problem.out_rtol = out_rtol.map(|value| M::T::from_f64(value).unwrap());
670    }
671
672    fn out_rtol(&self) -> Option<f64> {
673        self.problem.out_rtol.map(|value| value.to_f64().unwrap())
674    }
675
676    fn set_out_atol(&mut self, out_atol: Option<f64>) {
677        self.problem.out_atol = out_atol.map(|value| {
678            let len = if self.problem.eqn.out().is_some() {
679                self.problem.eqn.nout()
680            } else {
681                self.problem.eqn.nstates()
682            };
683            M::V::from_element(len, M::T::from_f64(value).unwrap(), M::C::default())
684        });
685    }
686
687    fn out_atol(&self) -> Option<f64> {
688        self.problem
689            .out_atol
690            .as_ref()
691            .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
692    }
693
694    fn set_param_rtol(&mut self, param_rtol: Option<f64>) {
695        self.problem.param_rtol = param_rtol.map(|value| M::T::from_f64(value).unwrap());
696    }
697
698    fn param_rtol(&self) -> Option<f64> {
699        self.problem.param_rtol.map(|value| value.to_f64().unwrap())
700    }
701
702    fn set_param_atol(&mut self, param_atol: Option<f64>) {
703        self.problem.param_atol = param_atol.map(|value| {
704            M::V::from_element(
705                self.problem.eqn.nparams(),
706                M::T::from_f64(value).unwrap(),
707                M::C::default(),
708            )
709        });
710    }
711
712    fn param_atol(&self) -> Option<f64> {
713        self.problem
714            .param_atol
715            .as_ref()
716            .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
717    }
718
719    generate_ic_option_accessors! {
720        use_linesearch: bool,
721        max_linesearch_iterations: usize,
722        max_newton_iterations: usize,
723        max_linear_solver_setups: usize,
724        step_reduction_factor: f64,
725        armijo_constant: f64
726    }
727
728    generate_ode_option_accessors! {
729        max_nonlinear_solver_iterations: usize,
730        max_error_test_failures: usize,
731        min_timestep: f64,
732        update_jacobian_after_steps: usize,
733        update_rhs_jacobian_after_steps: usize,
734        threshold_to_update_jacobian: f64,
735        threshold_to_update_rhs_jacobian: f64
736    }
737
738    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError> {
739        self.setup_problem(params)?;
740        let n = self.problem.eqn.nstates();
741        let mut y0 = M::V::zeros(n, M::C::default());
742        let t0 = self.problem.t0;
743        self.problem.eqn.init().call_inplace(t0, &mut y0);
744        Ok((*y0.inner()).clone().to_host_array())
745    }
746
747    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError> {
748        self.setup_problem(params)?;
749        let n = self.problem.eqn.nstates();
750        let y = y
751            .iter()
752            .map(|&x| M::T::from_f64(x).unwrap())
753            .collect::<Vec<_>>();
754        let y_vec = M::V::from_slice(&y, M::C::default());
755        let mut dydt = M::V::zeros(n, M::C::default());
756        self.problem
757            .eqn
758            .rhs()
759            .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
760        Ok((*dydt.inner()).clone().to_host_array())
761    }
762
763    fn rhs_jac_mul(
764        &mut self,
765
766        params: &[f64],
767        t: f64,
768        y: &[f64],
769        v: &[f64],
770    ) -> Result<HostArray, DiffsolRtError> {
771        self.setup_problem(params)?;
772        let n = self.problem.eqn.nstates();
773        let y = y
774            .iter()
775            .map(|&x| M::T::from_f64(x).unwrap())
776            .collect::<Vec<_>>();
777        let v = v
778            .iter()
779            .map(|&x| M::T::from_f64(x).unwrap())
780            .collect::<Vec<_>>();
781        let y_vec = M::V::from_slice(&y, M::C::default());
782        let v_vec = M::V::from_slice(&v, M::C::default());
783        let mut dydt = M::V::zeros(n, M::C::default());
784        self.problem.eqn.rhs().jac_mul_inplace(
785            &y_vec,
786            M::T::from_f64(t).unwrap(),
787            &v_vec,
788            &mut dydt,
789        );
790        Ok((*dydt.inner()).clone().to_host_array())
791    }
792
793    fn solve(
794        &mut self,
795        method: OdeSolverType,
796        linear_solver: LinearSolverType,
797        params: &[f64],
798        final_time: f64,
799    ) -> SolveResult {
800        self.check(linear_solver)?;
801        self.setup_problem(params)?;
802        let final_time = M::T::from_f64(final_time).unwrap();
803        let soln = match linear_solver {
804            LinearSolverType::Default => {
805                method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
806            }
807            LinearSolverType::Lu => {
808                method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
809            }
810            LinearSolverType::Klu => {
811                method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
812            }
813        };
814        Ok(Box::new(soln?))
815    }
816
817    fn solve_dense(
818        &mut self,
819        method: OdeSolverType,
820        linear_solver: LinearSolverType,
821        params: &[f64],
822        t_eval: &[f64],
823    ) -> SolveResult {
824        self.check(linear_solver)?;
825        self.setup_problem(params)?;
826
827        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
828        let soln =
829            match linear_solver {
830                LinearSolverType::Default => method
831                    .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
832                LinearSolverType::Lu => method
833                    .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
834                LinearSolverType::Klu => method
835                    .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
836            };
837        Ok(Box::new(soln?))
838    }
839
840    fn solve_fwd_sens(
841        &mut self,
842        method: OdeSolverType,
843        linear_solver: LinearSolverType,
844        params: &[f64],
845        t_eval: &[f64],
846    ) -> SolveResult {
847        self.check(linear_solver)?;
848        self.setup_problem(params)?;
849
850        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
851        let soln = match linear_solver {
852            LinearSolverType::Default => {
853                method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
854            }
855            LinearSolverType::Lu => method
856                .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
857            LinearSolverType::Klu => method
858                .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
859        };
860        Ok(Box::new(soln?))
861    }
862
863    fn solve_continuous_adjoint(
864        &mut self,
865        method: OdeSolverType,
866        linear_solver: LinearSolverType,
867        params: &[f64],
868        final_time: f64,
869    ) -> Result<(HostArray, HostArray), DiffsolRtError> {
870        self.check(linear_solver)?;
871        self.setup_problem(params)?;
872
873        let final_time = M::T::from_f64(final_time).unwrap();
874        let integrate_out = self.problem.integrate_out;
875        self.problem.integrate_out = true;
876        let result = match linear_solver {
877            LinearSolverType::Default => method
878                .solve_continuous_adjoint::<M, CG, <M as DefaultSolver>::LS>(
879                    &mut self.problem,
880                    final_time,
881                ),
882            LinearSolverType::Lu => method
883                .solve_continuous_adjoint::<M, CG, <M as LuValidator<M>>::LS>(
884                    &mut self.problem,
885                    final_time,
886                ),
887            LinearSolverType::Klu => method
888                .solve_continuous_adjoint::<M, CG, <M as KluValidator<M>>::LS>(
889                    &mut self.problem,
890                    final_time,
891                ),
892        };
893        self.problem.integrate_out = integrate_out;
894        let (integral, gradient) = result?;
895        Ok((
896            (*integral.inner()).clone().to_host_array(),
897            gradient.to_host_array(),
898        ))
899    }
900
901    fn solve_adjoint_fwd(
902        &mut self,
903        method: OdeSolverType,
904        linear_solver: LinearSolverType,
905        params: &[f64],
906        t_eval: &[f64],
907    ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError> {
908        self.check(linear_solver)?;
909        self.setup_problem(params)?;
910
911        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
912        let (soln, checkpoint) = match linear_solver {
913            LinearSolverType::Default => method
914                .solve_adjoint_fwd::<M, CG, <M as DefaultSolver>::LS>(
915                    &mut self.problem,
916                    &t_eval,
917                    params,
918                    linear_solver,
919                ),
920            LinearSolverType::Lu => method.solve_adjoint_fwd::<M, CG, <M as LuValidator<M>>::LS>(
921                &mut self.problem,
922                &t_eval,
923                params,
924                linear_solver,
925            ),
926            LinearSolverType::Klu => method.solve_adjoint_fwd::<M, CG, <M as KluValidator<M>>::LS>(
927                &mut self.problem,
928                &t_eval,
929                params,
930                linear_solver,
931            ),
932        }?;
933        Ok((Box::new(soln), AdjointCheckpointWrapper::new(checkpoint)))
934    }
935
936    fn solve_adjoint_bkwd(
937        &mut self,
938        method: OdeSolverType,
939        linear_solver: LinearSolverType,
940        checkpoint: &AdjointCheckpointWrapper,
941        t_eval: &[f64],
942        dgdu_eval: HostArray,
943    ) -> Result<HostArray, DiffsolRtError> {
944        self.check(linear_solver)?;
945        let checkpoint = checkpoint.guard()?;
946        self.setup_problem(checkpoint.params())?;
947
948        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
949        let dgdu_eval = host_array_to_dense_matrix::<M>(dgdu_eval)?;
950        if dgdu_eval.nrows() != self.problem.eqn.nout() {
951            return Err(DiffsolError::Other(format!(
952                "Expected dgdu_eval to have {} rows, got {}",
953                self.problem.eqn.nout(),
954                dgdu_eval.nrows()
955            ))
956            .into());
957        }
958        if dgdu_eval.ncols() != t_eval.len() {
959            return Err(DiffsolError::Other(format!(
960                "Expected dgdu_eval to have {} columns, got {}",
961                t_eval.len(),
962                dgdu_eval.ncols()
963            ))
964            .into());
965        }
966
967        let gradient = match linear_solver {
968            LinearSolverType::Default => method
969                .solve_adjoint_bkwd::<M, CG, <M as DefaultSolver>::LS>(
970                    &mut self.problem,
971                    checkpoint.as_ref(),
972                    &dgdu_eval,
973                    &t_eval,
974                ),
975            LinearSolverType::Lu => method.solve_adjoint_bkwd::<M, CG, <M as LuValidator<M>>::LS>(
976                &mut self.problem,
977                checkpoint.as_ref(),
978                &dgdu_eval,
979                &t_eval,
980            ),
981            LinearSolverType::Klu => method
982                .solve_adjoint_bkwd::<M, CG, <M as KluValidator<M>>::LS>(
983                    &mut self.problem,
984                    checkpoint.as_ref(),
985                    &dgdu_eval,
986                    &t_eval,
987                ),
988        }?;
989        Ok(gradient.to_host_array())
990    }
991}
992
993fn host_array_to_dense_matrix<M>(
994    array: HostArray,
995) -> Result<<M::V as DefaultDenseMatrix>::M, DiffsolRtError>
996where
997    M: MatrixHost<T: Scalar>,
998    M::V: DefaultDenseMatrix,
999{
1000    let view = array.as_array::<M::T>()?;
1001    let mut values = Vec::with_capacity(view.nrows() * view.ncols());
1002    for col in 0..view.ncols() {
1003        for row in 0..view.nrows() {
1004            values.push(view[(row, col)]);
1005        }
1006    }
1007    Ok(<M::V as DefaultDenseMatrix>::M::from_vec(
1008        view.nrows(),
1009        view.ncols(),
1010        values,
1011        M::C::default(),
1012    ))
1013}
1014
1015#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
1016mod tests {
1017    use diffsol::MatrixCommon;
1018    use diffsol::{
1019        CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
1020    };
1021
1022    #[cfg(feature = "diffsl-llvm")]
1023    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
1024    use crate::{
1025        host_array::FromHostArray,
1026        linear_solver_type::LinearSolverType,
1027        matrix_type::MatrixType,
1028        ode_solver_type::OdeSolverType,
1029        scalar_type::ScalarType,
1030        test_support::{
1031            assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
1032            logistic_state, LOGISTIC_X0,
1033        },
1034    };
1035
1036    use super::{solve_factory_with_jit_backend, GenericSolve, Solve, SolveSerialization};
1037
1038    fn make_generic_solve<T, CG>() -> GenericSolve<diffsol::NalgebraMat<T>, CG>
1039    where
1040        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1041        CG: diffsol::CodegenModule
1042            + CodegenModuleJit
1043            + CodegenModuleCompile
1044            + SolveSerialization<diffsol::NalgebraMat<T>>,
1045        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1046        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1047        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1048    {
1049        let problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1050            .build_from_diffsl::<CG>(logistic_diffsl_code())
1051            .unwrap();
1052        GenericSolve { problem }
1053    }
1054
1055    fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
1056    where
1057        CG: diffsol::CodegenModule
1058            + CodegenModuleJit
1059            + CodegenModuleCompile
1060            + SolveSerialization<diffsol::NalgebraMat<f32>>
1061            + SolveSerialization<diffsol::NalgebraMat<f64>>
1062            + SolveSerialization<diffsol::FaerMat<f32>>
1063            + SolveSerialization<diffsol::FaerMat<f64>>
1064            + SolveSerialization<diffsol::FaerSparseMat<f32>>
1065            + SolveSerialization<diffsol::FaerSparseMat<f64>>,
1066    {
1067        for matrix_type in [
1068            MatrixType::NalgebraDense,
1069            MatrixType::FaerDense,
1070            MatrixType::FaerSparse,
1071        ] {
1072            for scalar_type in [ScalarType::F32, ScalarType::F64] {
1073                assert!(solve_factory_with_jit_backend::<CG>(
1074                    logistic_diffsl_code(),
1075                    matrix_type,
1076                    scalar_type,
1077                )
1078                .is_ok());
1079            }
1080        }
1081    }
1082
1083    fn assert_solve_metadata_and_helpers<T, CG>(tol: f64)
1084    where
1085        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1086        CG: diffsol::CodegenModule
1087            + CodegenModuleJit
1088            + CodegenModuleCompile
1089            + SolveSerialization<diffsol::NalgebraMat<T>>,
1090        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1091        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1092        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1093        GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1094    {
1095        let mut solve = make_generic_solve::<T, CG>();
1096        assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
1097        assert_eq!(solve.nstates(), 1);
1098        assert_eq!(solve.nparams(), 1);
1099        assert_eq!(solve.nout(), 1);
1100        assert!(!solve.has_stop());
1101        assert!(solve.check(LinearSolverType::Default).is_ok());
1102        assert!(solve.check(LinearSolverType::Lu).is_ok());
1103        assert!(solve.check(LinearSolverType::Klu).is_err());
1104
1105        solve.set_atol(1e-5);
1106        solve.set_rtol(1e-4);
1107        assert_close(solve.atol(), 1e-5, tol, "solve atol");
1108        assert_close(solve.rtol(), 1e-4, tol, "solve rtol");
1109        solve.set_t0(0.125);
1110        solve.set_h0(0.25);
1111        solve.set_integrate_out(true);
1112        assert_close(solve.t0(), 0.125, tol, "solve t0");
1113        assert_close(solve.h0(), 0.25, tol, "solve h0");
1114        assert!(solve.integrate_out());
1115
1116        solve.set_sens_rtol(Some(1e-3));
1117        solve.set_sens_atol(Some(1e-4));
1118        solve.set_out_rtol(Some(2e-3));
1119        solve.set_out_atol(Some(2e-4));
1120        solve.set_param_rtol(Some(3e-3));
1121        solve.set_param_atol(Some(3e-4));
1122        assert_close(solve.sens_rtol().unwrap(), 1e-3, tol, "solve sens rtol");
1123        assert_close(solve.sens_atol().unwrap(), 1e-4, tol, "solve sens atol");
1124        assert_close(solve.out_rtol().unwrap(), 2e-3, tol, "solve out rtol");
1125        assert_close(solve.out_atol().unwrap(), 2e-4, tol, "solve out atol");
1126        assert_close(solve.param_rtol().unwrap(), 3e-3, tol, "solve param rtol");
1127        assert_close(solve.param_atol().unwrap(), 3e-4, tol, "solve param atol");
1128
1129        solve.set_sens_rtol(None);
1130        solve.set_sens_atol(None);
1131        solve.set_out_rtol(None);
1132        solve.set_out_atol(None);
1133        solve.set_param_rtol(None);
1134        solve.set_param_atol(None);
1135        assert_eq!(solve.sens_rtol(), None);
1136        assert_eq!(solve.sens_atol(), None);
1137        assert_eq!(solve.out_rtol(), None);
1138        assert_eq!(solve.out_atol(), None);
1139        assert_eq!(solve.param_rtol(), None);
1140        assert_eq!(solve.param_atol(), None);
1141
1142        let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
1143        assert_eq!(y0.len(), 1);
1144        assert_close(y0[0], LOGISTIC_X0, tol, "y0[0]");
1145
1146        let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
1147        assert_eq!(rhs.len(), 1);
1148        assert_close(rhs[0], 0.375, tol, "solve rhs");
1149
1150        let jac_mul =
1151            Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
1152                .unwrap();
1153        assert_eq!(jac_mul.len(), 1);
1154        assert_close(jac_mul[0], 3.0, tol, "solve rhs jac mul");
1155    }
1156
1157    fn assert_solve_runtime_paths<T, CG>(tol: f64)
1158    where
1159        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1160        CG: diffsol::CodegenModule
1161            + CodegenModuleJit
1162            + CodegenModuleCompile
1163            + SolveSerialization<diffsol::NalgebraMat<T>>,
1164        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1165        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1166        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1167        GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1168    {
1169        let mut solve = make_generic_solve::<T, CG>();
1170        let soln = solve
1171            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
1172            .unwrap();
1173        let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
1174        let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
1175        assert_close(*ts.last().unwrap(), 1.0, tol, "solve final time");
1176        assert_close(
1177            ys[0][ts.len() - 1],
1178            logistic_state(LOGISTIC_X0, 2.0, 1.0),
1179            tol,
1180            "solve final value",
1181        );
1182
1183        let mut solve = make_generic_solve::<T, CG>();
1184        let dense = solve
1185            .solve_dense(
1186                OdeSolverType::Tsit45,
1187                LinearSolverType::Lu,
1188                &[2.0],
1189                &[0.25, 0.5, 1.0],
1190            )
1191            .unwrap();
1192        let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
1193        let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
1194        assert_eq!(ts, vec![0.25, 0.5, 1.0]);
1195        for (i, &t) in ts.iter().enumerate() {
1196            assert_close(
1197                ys[0][i],
1198                logistic_state(LOGISTIC_X0, 2.0, t),
1199                tol,
1200                &format!("solve_dense[{i}]"),
1201            );
1202        }
1203
1204        let mut solve = make_generic_solve::<T, CG>();
1205        let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
1206            Ok(_) => panic!("expected parameter count mismatch"),
1207            Err(err) => err,
1208        };
1209        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1210
1211        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1212            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1213            .unwrap();
1214        let mut hybrid_solve = GenericSolve {
1215            problem: hybrid_problem,
1216        };
1217        let hybrid = hybrid_solve
1218            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
1219            .unwrap();
1220        let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
1221        let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
1222        assert_close(*hybrid_ts.last().unwrap(), 2.0, tol, "solve final time");
1223        assert_close(
1224            hybrid_ys[0][hybrid_ts.len() - 1],
1225            hybrid_logistic_state(2.0, 2.0),
1226            tol,
1227            "solve final value",
1228        );
1229
1230        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1231            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1232            .unwrap();
1233        let mut hybrid_solve = GenericSolve {
1234            problem: hybrid_problem,
1235        };
1236        let hybrid_dense = hybrid_solve
1237            .solve_dense(
1238                OdeSolverType::Tsit45,
1239                LinearSolverType::Lu,
1240                &[2.0],
1241                &[0.5, 1.0, 1.5, 2.0],
1242            )
1243            .unwrap();
1244        let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
1245        let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
1246        assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
1247        for (i, &t) in hybrid_dense_ts.iter().enumerate() {
1248            assert_close(
1249                hybrid_dense_ys[0][i],
1250                hybrid_logistic_state(2.0, t),
1251                tol,
1252                &format!("solve_dense[{i}]"),
1253            );
1254        }
1255    }
1256
1257    #[cfg(feature = "diffsl-llvm")]
1258    fn assert_solve_sensitivity_paths() {
1259        let t_eval = [0.25, 0.5, 1.0];
1260
1261        let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1262        let sens = solve
1263            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1264            .unwrap();
1265        let sens_values = sens.get_sens();
1266        assert_eq!(sens_values.len(), 1);
1267        let sens_matrix =
1268            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1269        for (i, &t) in t_eval.iter().enumerate() {
1270            assert_close(
1271                sens_matrix[0][i],
1272                logistic_state_dr(LOGISTIC_X0, 2.0, t),
1273                5e-4,
1274                &format!("solve_fwd_sens[{i}]"),
1275            );
1276        }
1277
1278        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
1279            .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
1280            .unwrap();
1281        let mut solve = GenericSolve {
1282            problem: hybrid_problem,
1283        };
1284        let hybrid_sens = solve
1285            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1286            .unwrap();
1287        let sens_values = hybrid_sens.get_sens();
1288        let sens_matrix =
1289            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1290        for (i, &t) in t_eval.iter().enumerate() {
1291            assert_close(
1292                sens_matrix[0][i],
1293                hybrid_logistic_state_dr(2.0, t),
1294                5e-4,
1295                &format!("solve_fwd_sens[{i}]"),
1296            );
1297        }
1298    }
1299
1300    #[cfg(feature = "diffsl-cranelift")]
1301    #[test]
1302    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
1303        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
1304    }
1305
1306    #[cfg(feature = "diffsl-cranelift")]
1307    #[test]
1308    fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
1309        assert_solve_metadata_and_helpers::<f64, diffsol::CraneliftJitModule>(1e-12);
1310        assert_solve_runtime_paths::<f64, diffsol::CraneliftJitModule>(5e-4);
1311        assert_solve_metadata_and_helpers::<f32, diffsol::CraneliftJitModule>(5e-3);
1312        assert_solve_runtime_paths::<f32, diffsol::CraneliftJitModule>(5e-3);
1313    }
1314
1315    #[cfg(feature = "diffsl-llvm")]
1316    #[test]
1317    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
1318        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
1319    }
1320
1321    #[cfg(feature = "diffsl-llvm")]
1322    #[test]
1323    fn solve_trait_helpers_and_runtime_paths_for_llvm() {
1324        assert_solve_metadata_and_helpers::<f64, diffsol::LlvmModule>(1e-12);
1325        assert_solve_runtime_paths::<f64, diffsol::LlvmModule>(5e-4);
1326        assert_solve_metadata_and_helpers::<f32, diffsol::LlvmModule>(5e-3);
1327        assert_solve_runtime_paths::<f32, diffsol::LlvmModule>(5e-3);
1328    }
1329
1330    #[cfg(feature = "diffsl-llvm")]
1331    #[test]
1332    fn solve_trait_sensitivity_paths_for_llvm() {
1333        assert_solve_sensitivity_paths();
1334    }
1335
1336    #[cfg(feature = "diffsl-cranelift")]
1337    #[test]
1338    fn setup_problem_validates_parameter_count_for_cranelift() {
1339        let mut solve = make_generic_solve::<f64, diffsol::CraneliftJitModule>();
1340        let err = solve.setup_problem(&[]).unwrap_err();
1341        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1342
1343        solve.setup_problem(&[2.0]).unwrap();
1344        let mut params = solve.problem.context().vector_zeros(1);
1345        solve.problem.eqn.get_params(&mut params);
1346        assert_eq!(params.get_index(0), 2.0);
1347    }
1348
1349    #[cfg(feature = "diffsl-llvm")]
1350    #[test]
1351    fn setup_problem_validates_parameter_count_for_llvm() {
1352        let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1353        let err = solve.setup_problem(&[]).unwrap_err();
1354        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1355
1356        solve.setup_problem(&[2.0]).unwrap();
1357        let mut params = solve.problem.context().vector_zeros(1);
1358        solve.problem.eqn.get_params(&mut params);
1359        assert_eq!(params.get_index(0), 2.0);
1360    }
1361}