Skip to main content

diffsol_c/
solve.rs

1// Delegate solver types selected at runtime in Host to concrete solver types
2// in Rust.
3
4#[cfg(feature = "diffsl-external-dynamic")]
5use std::path::PathBuf;
6
7use crate::adjoint_checkpoint::AdjointCheckpointWrapper;
8use crate::error::DiffsolRtError;
9use crate::host_array::HostArray;
10use crate::host_array::ToHostArray;
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use crate::jit::JitBackendType;
13#[cfg(feature = "external")]
14use crate::scalar_type::ExternalScalar;
15use crate::scalar_type::Scalar;
16use crate::scalar_type::ScalarType;
17use crate::{
18    generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
19    option_value_from_store, option_value_to_store,
20};
21use crate::{generate_trait_ic_option_accessors, generate_trait_ode_option_accessors};
22use diffsol::ObjectModule;
23use diffsol::OdeBuilder;
24use diffsol::{
25    error::DiffsolError,
26    matrix::{MatrixHost, MatrixRef},
27    CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DenseMatrix, DiffSl,
28    MatrixCommon, NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeSolverProblem, Op, Vector,
29    VectorCommon, VectorHost, VectorRef,
30};
31#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
32use diffsol::{CodegenModuleCompile, CodegenModuleJit};
33use num_traits::{FromPrimitive, ToPrimitive}; // for from_f64 and to_f64
34use paste::paste;
35
36use crate::solve_serialization::{unsupported_serialization_error, SolveSerialization};
37use crate::{
38    linear_solver_type::LinearSolverType, matrix_type::MatrixType, ode_solver_type::OdeSolverType,
39};
40use crate::{
41    matrix_type::MatrixKind,
42    valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
43};
44
45// Each matrix type implements PySolve as bridge between diffsol and Host
46
47use crate::solution::Solution;
48pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolRtError>;
49
50pub(crate) trait Solve {
51    fn matrix_type(&self) -> MatrixType;
52    fn nstates(&self) -> usize;
53    fn nparams(&self) -> usize;
54    fn nout(&self) -> usize;
55    fn has_stop(&self) -> bool;
56
57    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError>;
58
59    fn rhs_jac_mul(
60        &mut self,
61        params: &[f64],
62        t: f64,
63        y: &[f64],
64        v: &[f64],
65    ) -> Result<HostArray, DiffsolRtError>;
66
67    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError>;
68
69    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError>;
70    fn set_rtol(&mut self, rtol: f64);
71    fn rtol(&self) -> f64;
72    fn set_atol(&mut self, atol: f64);
73    fn atol(&self) -> f64;
74    fn set_t0(&mut self, t0: f64);
75    fn t0(&self) -> f64;
76    fn set_h0(&mut self, h0: f64);
77    fn h0(&self) -> f64;
78    fn set_integrate_out(&mut self, integrate_out: bool);
79    fn integrate_out(&self) -> bool;
80    fn set_sens_rtol(&mut self, sens_rtol: Option<f64>);
81    fn sens_rtol(&self) -> Option<f64>;
82    fn set_sens_atol(&mut self, sens_atol: Option<f64>);
83    fn sens_atol(&self) -> Option<f64>;
84    fn set_out_rtol(&mut self, out_rtol: Option<f64>);
85    fn out_rtol(&self) -> Option<f64>;
86    fn set_out_atol(&mut self, out_atol: Option<f64>);
87    fn out_atol(&self) -> Option<f64>;
88    fn set_param_rtol(&mut self, param_rtol: Option<f64>);
89    fn param_rtol(&self) -> Option<f64>;
90    fn set_param_atol(&mut self, param_atol: Option<f64>);
91    fn param_atol(&self) -> Option<f64>;
92    fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
93        unsupported_serialization_error(
94            "ODE serialization is only supported for JIT-backed solvers",
95        )
96    }
97
98    // New API: solution object support
99    fn solve(
100        &mut self,
101        method: OdeSolverType,
102        linear_solver: LinearSolverType,
103        params: &[f64],
104        final_time: f64,
105    ) -> SolveResult;
106
107    fn solve_dense(
108        &mut self,
109        method: OdeSolverType,
110        linear_solver: LinearSolverType,
111        params: &[f64],
112        t_eval: &[f64],
113    ) -> SolveResult;
114
115    fn solve_fwd_sens(
116        &mut self,
117        method: OdeSolverType,
118        linear_solver: LinearSolverType,
119        params: &[f64],
120        t_eval: &[f64],
121    ) -> SolveResult;
122
123    fn solve_continuous_adjoint(
124        &mut self,
125        method: OdeSolverType,
126        linear_solver: LinearSolverType,
127        params: &[f64],
128        final_time: f64,
129    ) -> Result<(HostArray, HostArray), DiffsolRtError>;
130
131    fn solve_adjoint_fwd(
132        &mut self,
133        method: OdeSolverType,
134        linear_solver: LinearSolverType,
135        params: &[f64],
136        t_eval: &[f64],
137    ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError>;
138
139    fn solve_adjoint_bkwd(
140        &mut self,
141        method: OdeSolverType,
142        linear_solver: LinearSolverType,
143        checkpoint: &AdjointCheckpointWrapper,
144        t_eval: &[f64],
145        dgdu_eval: HostArray,
146    ) -> Result<HostArray, DiffsolRtError>;
147
148    generate_trait_ic_option_accessors! {
149        use_linesearch: bool,
150        max_linesearch_iterations: usize,
151        max_newton_iterations: usize,
152        max_linear_solver_setups: usize,
153        step_reduction_factor: f64,
154        armijo_constant: f64
155    }
156    generate_trait_ode_option_accessors! {
157        max_nonlinear_solver_iterations: usize,
158        max_error_test_failures: usize,
159        min_timestep: f64,
160        update_jacobian_after_steps: usize,
161        update_rhs_jacobian_after_steps: usize,
162        threshold_to_update_jacobian: f64,
163        threshold_to_update_rhs_jacobian: f64
164    }
165}
166// Public factory method for generating an instance based on matrix type
167#[cfg(feature = "external")]
168pub(crate) fn solve_factory_external(
169    rhs_state_deps: Vec<(usize, usize)>,
170    rhs_input_deps: Vec<(usize, usize)>,
171    mass_state_deps: Vec<(usize, usize)>,
172    matrix_type: MatrixType,
173    scalar_type: ScalarType,
174) -> Result<Box<dyn Solve>, DiffsolRtError> {
175    let solve: Box<dyn Solve> = match matrix_type {
176        MatrixType::NalgebraDense => match scalar_type {
177            #[cfg(feature = "diffsl-external-f32")]
178            ScalarType::F32 => Box::new(GenericSolve::<
179                diffsol::NalgebraMat<f32>,
180                diffsl::ExternalModule<f32>,
181            >::from_external(
182                rhs_state_deps, rhs_input_deps, mass_state_deps, false
183            )?),
184            #[cfg(feature = "diffsl-external-f64")]
185            ScalarType::F64 => Box::new(GenericSolve::<
186                diffsol::NalgebraMat<f64>,
187                diffsl::ExternalModule<f64>,
188            >::from_external(
189                rhs_state_deps, rhs_input_deps, mass_state_deps, false
190            )?),
191            _ => {
192                return Err(DiffsolRtError::from(DiffsolError::Other(
193                    "Unsupported scalar type for NalgebraDense".to_string(),
194                )));
195            }
196        },
197        MatrixType::FaerDense => match scalar_type {
198            #[cfg(feature = "diffsl-external-f32")]
199            ScalarType::F32 => Box::new(GenericSolve::<
200                diffsol::FaerMat<f32>,
201                diffsl::ExternalModule<f32>,
202            >::from_external(
203                rhs_state_deps, rhs_input_deps, mass_state_deps, false
204            )?),
205            #[cfg(feature = "diffsl-external-f64")]
206            ScalarType::F64 => Box::new(GenericSolve::<
207                diffsol::FaerMat<f64>,
208                diffsl::ExternalModule<f64>,
209            >::from_external(
210                rhs_state_deps, rhs_input_deps, mass_state_deps, false
211            )?),
212            _ => {
213                return Err(DiffsolRtError::from(DiffsolError::Other(
214                    "Unsupported scalar type for FaerDense".to_string(),
215                )));
216            }
217        },
218        MatrixType::FaerSparse => match scalar_type {
219            #[cfg(feature = "diffsl-external-f32")]
220            ScalarType::F32 => Box::new(GenericSolve::<
221                diffsol::FaerSparseMat<f32>,
222                diffsl::ExternalModule<f32>,
223            >::from_external(
224                rhs_state_deps, rhs_input_deps, mass_state_deps, false
225            )?),
226            #[cfg(feature = "diffsl-external-f64")]
227            ScalarType::F64 => Box::new(GenericSolve::<
228                diffsol::FaerSparseMat<f64>,
229                diffsl::ExternalModule<f64>,
230            >::from_external(
231                rhs_state_deps, rhs_input_deps, mass_state_deps, false
232            )?),
233            _ => {
234                return Err(DiffsolRtError::from(DiffsolError::Other(
235                    "Unsupported scalar type for FaerSparse".to_string(),
236                )));
237            }
238        },
239    };
240    Ok(solve)
241}
242
243#[cfg(feature = "diffsl-external-dynamic")]
244pub(crate) fn solve_factory_external_dynamic(
245    path: PathBuf,
246    rhs_state_deps: Vec<(usize, usize)>,
247    rhs_input_deps: Vec<(usize, usize)>,
248    mass_state_deps: Vec<(usize, usize)>,
249    matrix_type: MatrixType,
250    scalar_type: ScalarType,
251) -> Result<Box<dyn Solve>, DiffsolRtError> {
252    let solve: Box<dyn Solve> = match matrix_type {
253        MatrixType::NalgebraDense => match scalar_type {
254            ScalarType::F32 => Box::new(GenericSolve::<
255                diffsol::NalgebraMat<f32>,
256                diffsl::ExternalDynModule<f32>,
257            >::from_external_dynamic(
258                path,
259                rhs_state_deps,
260                rhs_input_deps,
261                mass_state_deps,
262                false,
263            )?),
264            ScalarType::F64 => Box::new(GenericSolve::<
265                diffsol::NalgebraMat<f64>,
266                diffsl::ExternalDynModule<f64>,
267            >::from_external_dynamic(
268                path,
269                rhs_state_deps,
270                rhs_input_deps,
271                mass_state_deps,
272                false,
273            )?),
274        },
275        MatrixType::FaerDense => match scalar_type {
276            ScalarType::F32 => Box::new(GenericSolve::<
277                diffsol::FaerMat<f32>,
278                diffsl::ExternalDynModule<f32>,
279            >::from_external_dynamic(
280                path,
281                rhs_state_deps,
282                rhs_input_deps,
283                mass_state_deps,
284                false,
285            )?),
286            ScalarType::F64 => Box::new(GenericSolve::<
287                diffsol::FaerMat<f64>,
288                diffsl::ExternalDynModule<f64>,
289            >::from_external_dynamic(
290                path,
291                rhs_state_deps,
292                rhs_input_deps,
293                mass_state_deps,
294                false,
295            )?),
296        },
297        MatrixType::FaerSparse => match scalar_type {
298            ScalarType::F32 => Box::new(GenericSolve::<
299                diffsol::FaerSparseMat<f32>,
300                diffsl::ExternalDynModule<f32>,
301            >::from_external_dynamic(
302                path,
303                rhs_state_deps,
304                rhs_input_deps,
305                mass_state_deps,
306                false,
307            )?),
308            ScalarType::F64 => Box::new(GenericSolve::<
309                diffsol::FaerSparseMat<f64>,
310                diffsl::ExternalDynModule<f64>,
311            >::from_external_dynamic(
312                path,
313                rhs_state_deps,
314                rhs_input_deps,
315                mass_state_deps,
316                false,
317            )?),
318        },
319    };
320    Ok(solve)
321}
322
323#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
324pub(crate) fn solve_factory_jit(
325    code: &str,
326    jit_backend: JitBackendType,
327    matrix_type: MatrixType,
328    scalar_type: ScalarType,
329) -> Result<Box<dyn Solve>, DiffsolRtError> {
330    match jit_backend {
331        #[cfg(feature = "diffsl-cranelift")]
332        JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
333            code,
334            matrix_type,
335            scalar_type,
336        ),
337        #[cfg(feature = "diffsl-llvm")]
338        JitBackendType::Llvm => {
339            solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
340        }
341    }
342}
343
344#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
345fn solve_factory_with_jit_backend<CG>(
346    code: &str,
347    matrix_type: MatrixType,
348    scalar_type: ScalarType,
349) -> Result<Box<dyn Solve>, DiffsolRtError>
350where
351    CG: CodegenModule
352        + CodegenModuleJit
353        + CodegenModuleCompile
354        + SolveSerialization<diffsol::NalgebraMat<f32>>
355        + SolveSerialization<diffsol::NalgebraMat<f64>>
356        + SolveSerialization<diffsol::FaerMat<f32>>
357        + SolveSerialization<diffsol::FaerMat<f64>>
358        + SolveSerialization<diffsol::FaerSparseMat<f32>>
359        + SolveSerialization<diffsol::FaerSparseMat<f64>>,
360{
361    let solve: Box<dyn Solve> = match matrix_type {
362        MatrixType::NalgebraDense => match scalar_type {
363            ScalarType::F32 => {
364                let problem =
365                    OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
366                Box::new(GenericSolve { problem })
367            }
368            ScalarType::F64 => {
369                let problem =
370                    OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
371                Box::new(GenericSolve { problem })
372            }
373        },
374        MatrixType::FaerDense => match scalar_type {
375            ScalarType::F32 => {
376                let problem =
377                    OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
378                Box::new(GenericSolve { problem })
379            }
380            ScalarType::F64 => {
381                let problem =
382                    OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
383                Box::new(GenericSolve { problem })
384            }
385        },
386        MatrixType::FaerSparse => match scalar_type {
387            ScalarType::F32 => {
388                let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
389                    .build_from_diffsl::<CG>(code)?;
390                Box::new(GenericSolve { problem })
391            }
392            ScalarType::F64 => {
393                let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
394                    .build_from_diffsl::<CG>(code)?;
395                Box::new(GenericSolve { problem })
396            }
397        },
398    };
399    Ok(solve)
400}
401
402pub(crate) fn solve_factory_from_serialized_diffsl(
403    serialized_diffsl: &[u8],
404    matrix_type: MatrixType,
405    scalar_type: ScalarType,
406) -> Result<Box<dyn Solve>, DiffsolRtError> {
407    let solve: Box<dyn Solve> = match matrix_type {
408        MatrixType::NalgebraDense => match scalar_type {
409            ScalarType::F32 => Box::new(
410                GenericSolve::<diffsol::NalgebraMat<f32>, ObjectModule>::from_serialized_diffsl(
411                    serialized_diffsl,
412                )?,
413            ),
414            ScalarType::F64 => Box::new(
415                GenericSolve::<diffsol::NalgebraMat<f64>, ObjectModule>::from_serialized_diffsl(
416                    serialized_diffsl,
417                )?,
418            ),
419        },
420        MatrixType::FaerDense => match scalar_type {
421            ScalarType::F32 => Box::new(
422                GenericSolve::<diffsol::FaerMat<f32>, ObjectModule>::from_serialized_diffsl(
423                    serialized_diffsl,
424                )?,
425            ),
426            ScalarType::F64 => Box::new(
427                GenericSolve::<diffsol::FaerMat<f64>, ObjectModule>::from_serialized_diffsl(
428                    serialized_diffsl,
429                )?,
430            ),
431        },
432        MatrixType::FaerSparse => match scalar_type {
433            ScalarType::F32 => Box::new(
434                GenericSolve::<diffsol::FaerSparseMat<f32>, ObjectModule>::from_serialized_diffsl(
435                    serialized_diffsl,
436                )?,
437            ),
438            ScalarType::F64 => Box::new(
439                GenericSolve::<diffsol::FaerSparseMat<f64>, ObjectModule>::from_serialized_diffsl(
440                    serialized_diffsl,
441                )?,
442            ),
443        },
444    };
445    Ok(solve)
446}
447
448pub(crate) struct GenericSolve<M, CG>
449where
450    M: MatrixHost<T: Scalar>,
451    M::V: Vector + VectorHost,
452    CG: CodegenModule,
453{
454    problem: OdeSolverProblem<DiffSl<M, CG>>,
455}
456
457impl<M, CG> GenericSolve<M, CG>
458where
459    M: MatrixHost<T: Scalar>,
460    M::V: Vector + VectorHost + DefaultDenseMatrix,
461    CG: CodegenModule,
462{
463    fn from_eqn(eqn: DiffSl<M, CG>) -> Result<Self, DiffsolRtError> {
464        let default_p = vec![0.0; eqn.nparams()];
465        let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
466        Ok(GenericSolve { problem })
467    }
468
469    pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolRtError> {
470        let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
471        let params = M::V::from_slice(&params, M::C::default());
472
473        // Attempt to set problem from params and config
474        let nparams = self.problem.eqn.nparams();
475        if params.len() == nparams {
476            self.problem.eqn.set_params(&params);
477            Ok(())
478        } else {
479            Err(DiffsolError::Other(format!(
480                "Expecting {} params but got {}",
481                nparams,
482                params.len()
483            ))
484            .into())
485        }
486    }
487
488    pub(crate) fn serialize_eqn(&self) -> Result<Vec<u8>, DiffsolRtError>
489    where
490        DiffSl<M, CG>: serde::Serialize,
491    {
492        serde_json::to_vec(&self.problem.eqn)
493            .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))
494    }
495}
496
497#[cfg(feature = "external")]
498impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
499where
500    M: MatrixHost<T: ExternalScalar>,
501    M::V: Vector + VectorHost + DefaultDenseMatrix,
502{
503    pub fn from_external(
504        rhs_state_deps: Vec<(usize, usize)>,
505        rhs_input_deps: Vec<(usize, usize)>,
506        mass_state_deps: Vec<(usize, usize)>,
507        include_sensitivities: bool,
508    ) -> Result<Self, DiffsolRtError> {
509        let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
510            M::C::default(),
511            rhs_state_deps,
512            rhs_input_deps,
513            mass_state_deps,
514            include_sensitivities,
515        )?;
516        Self::from_eqn(eqn)
517    }
518}
519
520#[cfg(feature = "diffsl-external-dynamic")]
521impl<M> GenericSolve<M, diffsl::ExternalDynModule<M::T>>
522where
523    M: MatrixHost<T: Scalar>,
524    M::V: Vector + VectorHost + DefaultDenseMatrix,
525{
526    pub fn from_external_dynamic(
527        path: impl Into<PathBuf>,
528        rhs_state_deps: Vec<(usize, usize)>,
529        rhs_input_deps: Vec<(usize, usize)>,
530        mass_state_deps: Vec<(usize, usize)>,
531        include_sensitivities: bool,
532    ) -> Result<Self, DiffsolRtError> {
533        let eqn = DiffSl::<M, diffsl::ExternalDynModule<M::T>>::from_external_dynamic(
534            path,
535            M::C::default(),
536            rhs_state_deps,
537            rhs_input_deps,
538            mass_state_deps,
539            include_sensitivities,
540        )?;
541        Self::from_eqn(eqn)
542    }
543}
544
545impl<M> GenericSolve<M, ObjectModule>
546where
547    M: MatrixHost<T: Scalar>,
548    M::V: Vector + VectorHost + DefaultDenseMatrix,
549{
550    pub fn from_serialized_diffsl(serialized_diffsl: &[u8]) -> Result<Self, DiffsolRtError> {
551        let eqn = serde_json::from_slice::<DiffSl<M, ObjectModule>>(serialized_diffsl)
552            .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))?;
553        Self::from_eqn(eqn)
554    }
555}
556
557impl<M, CG> Solve for GenericSolve<M, CG>
558where
559    M: MatrixHost<T: Scalar + ToPrimitive>
560        + DefaultSolver
561        + LuValidator<M>
562        + KluValidator<M>
563        + MatrixKind,
564    CG: CodegenModule,
565    for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
566    for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
567    M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
568    <M::V as DefaultDenseMatrix>::M: Send + Sync,
569    for<'b> &'b M::V: VectorRef<M::V>,
570    for<'b> &'b M: MatrixRef<M>,
571    CG: SolveSerialization<M>,
572{
573    fn matrix_type(&self) -> MatrixType {
574        MatrixType::from_diffsol::<M>()
575    }
576
577    fn nstates(&self) -> usize {
578        self.problem.eqn.nstates()
579    }
580
581    fn nparams(&self) -> usize {
582        self.problem.eqn.nparams()
583    }
584
585    fn nout(&self) -> usize {
586        self.problem.eqn.nout()
587    }
588
589    fn has_stop(&self) -> bool {
590        self.problem.eqn.root().is_some()
591    }
592
593    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError> {
594        validate_linear_solver::<M>(linear_solver)
595    }
596
597    fn set_atol(&mut self, atol: f64) {
598        self.problem.atol.fill(M::T::from_f64(atol).unwrap());
599    }
600
601    fn atol(&self) -> f64 {
602        self.problem.atol[0].to_f64().unwrap()
603    }
604
605    fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
606        CG::serialized_diffsl(self)
607    }
608
609    fn set_rtol(&mut self, rtol: f64) {
610        self.problem.rtol = M::T::from_f64(rtol).unwrap();
611    }
612
613    fn rtol(&self) -> f64 {
614        self.problem.rtol.to_f64().unwrap()
615    }
616
617    fn set_t0(&mut self, t0: f64) {
618        self.problem.t0 = M::T::from_f64(t0).unwrap();
619    }
620
621    fn t0(&self) -> f64 {
622        self.problem.t0.to_f64().unwrap()
623    }
624
625    fn set_h0(&mut self, h0: f64) {
626        self.problem.h0 = M::T::from_f64(h0).unwrap();
627    }
628
629    fn h0(&self) -> f64 {
630        self.problem.h0.to_f64().unwrap()
631    }
632
633    fn set_integrate_out(&mut self, integrate_out: bool) {
634        self.problem.integrate_out = integrate_out;
635    }
636
637    fn integrate_out(&self) -> bool {
638        self.problem.integrate_out
639    }
640
641    fn set_sens_rtol(&mut self, sens_rtol: Option<f64>) {
642        self.problem.sens_rtol = sens_rtol.map(|value| M::T::from_f64(value).unwrap());
643    }
644
645    fn sens_rtol(&self) -> Option<f64> {
646        self.problem.sens_rtol.map(|value| value.to_f64().unwrap())
647    }
648
649    fn set_sens_atol(&mut self, sens_atol: Option<f64>) {
650        self.problem.sens_atol = sens_atol.map(|value| {
651            M::V::from_element(
652                self.problem.eqn.nstates(),
653                M::T::from_f64(value).unwrap(),
654                M::C::default(),
655            )
656        });
657    }
658
659    fn sens_atol(&self) -> Option<f64> {
660        self.problem
661            .sens_atol
662            .as_ref()
663            .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
664    }
665
666    fn set_out_rtol(&mut self, out_rtol: Option<f64>) {
667        self.problem.out_rtol = out_rtol.map(|value| M::T::from_f64(value).unwrap());
668    }
669
670    fn out_rtol(&self) -> Option<f64> {
671        self.problem.out_rtol.map(|value| value.to_f64().unwrap())
672    }
673
674    fn set_out_atol(&mut self, out_atol: Option<f64>) {
675        self.problem.out_atol = out_atol.map(|value| {
676            let len = if self.problem.eqn.out().is_some() {
677                self.problem.eqn.nout()
678            } else {
679                self.problem.eqn.nstates()
680            };
681            M::V::from_element(len, M::T::from_f64(value).unwrap(), M::C::default())
682        });
683    }
684
685    fn out_atol(&self) -> Option<f64> {
686        self.problem
687            .out_atol
688            .as_ref()
689            .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
690    }
691
692    fn set_param_rtol(&mut self, param_rtol: Option<f64>) {
693        self.problem.param_rtol = param_rtol.map(|value| M::T::from_f64(value).unwrap());
694    }
695
696    fn param_rtol(&self) -> Option<f64> {
697        self.problem.param_rtol.map(|value| value.to_f64().unwrap())
698    }
699
700    fn set_param_atol(&mut self, param_atol: Option<f64>) {
701        self.problem.param_atol = param_atol.map(|value| {
702            M::V::from_element(
703                self.problem.eqn.nparams(),
704                M::T::from_f64(value).unwrap(),
705                M::C::default(),
706            )
707        });
708    }
709
710    fn param_atol(&self) -> Option<f64> {
711        self.problem
712            .param_atol
713            .as_ref()
714            .and_then(|value| (value.len() > 0).then(|| value.get_index(0).to_f64().unwrap()))
715    }
716
717    generate_ic_option_accessors! {
718        use_linesearch: bool,
719        max_linesearch_iterations: usize,
720        max_newton_iterations: usize,
721        max_linear_solver_setups: usize,
722        step_reduction_factor: f64,
723        armijo_constant: f64
724    }
725
726    generate_ode_option_accessors! {
727        max_nonlinear_solver_iterations: usize,
728        max_error_test_failures: usize,
729        min_timestep: f64,
730        update_jacobian_after_steps: usize,
731        update_rhs_jacobian_after_steps: usize,
732        threshold_to_update_jacobian: f64,
733        threshold_to_update_rhs_jacobian: f64
734    }
735
736    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError> {
737        self.setup_problem(params)?;
738        let n = self.problem.eqn.nstates();
739        let mut y0 = M::V::zeros(n, M::C::default());
740        let t0 = self.problem.t0;
741        self.problem.eqn.init().call_inplace(t0, &mut y0);
742        Ok((*y0.inner()).clone().to_host_array())
743    }
744
745    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError> {
746        self.setup_problem(params)?;
747        let n = self.problem.eqn.nstates();
748        let y = y
749            .iter()
750            .map(|&x| M::T::from_f64(x).unwrap())
751            .collect::<Vec<_>>();
752        let y_vec = M::V::from_slice(&y, M::C::default());
753        let mut dydt = M::V::zeros(n, M::C::default());
754        self.problem
755            .eqn
756            .rhs()
757            .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
758        Ok((*dydt.inner()).clone().to_host_array())
759    }
760
761    fn rhs_jac_mul(
762        &mut self,
763
764        params: &[f64],
765        t: f64,
766        y: &[f64],
767        v: &[f64],
768    ) -> Result<HostArray, DiffsolRtError> {
769        self.setup_problem(params)?;
770        let n = self.problem.eqn.nstates();
771        let y = y
772            .iter()
773            .map(|&x| M::T::from_f64(x).unwrap())
774            .collect::<Vec<_>>();
775        let v = v
776            .iter()
777            .map(|&x| M::T::from_f64(x).unwrap())
778            .collect::<Vec<_>>();
779        let y_vec = M::V::from_slice(&y, M::C::default());
780        let v_vec = M::V::from_slice(&v, M::C::default());
781        let mut dydt = M::V::zeros(n, M::C::default());
782        self.problem.eqn.rhs().jac_mul_inplace(
783            &y_vec,
784            M::T::from_f64(t).unwrap(),
785            &v_vec,
786            &mut dydt,
787        );
788        Ok((*dydt.inner()).clone().to_host_array())
789    }
790
791    fn solve(
792        &mut self,
793        method: OdeSolverType,
794        linear_solver: LinearSolverType,
795        params: &[f64],
796        final_time: f64,
797    ) -> SolveResult {
798        self.check(linear_solver)?;
799        self.setup_problem(params)?;
800        let final_time = M::T::from_f64(final_time).unwrap();
801        let soln = match linear_solver {
802            LinearSolverType::Default => {
803                method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
804            }
805            LinearSolverType::Lu => {
806                method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
807            }
808            LinearSolverType::Klu => {
809                method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
810            }
811        };
812        Ok(Box::new(soln?))
813    }
814
815    fn solve_dense(
816        &mut self,
817        method: OdeSolverType,
818        linear_solver: LinearSolverType,
819        params: &[f64],
820        t_eval: &[f64],
821    ) -> SolveResult {
822        self.check(linear_solver)?;
823        self.setup_problem(params)?;
824
825        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
826        let soln =
827            match linear_solver {
828                LinearSolverType::Default => method
829                    .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
830                LinearSolverType::Lu => method
831                    .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
832                LinearSolverType::Klu => method
833                    .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
834            };
835        Ok(Box::new(soln?))
836    }
837
838    fn solve_fwd_sens(
839        &mut self,
840        method: OdeSolverType,
841        linear_solver: LinearSolverType,
842        params: &[f64],
843        t_eval: &[f64],
844    ) -> SolveResult {
845        self.check(linear_solver)?;
846        self.setup_problem(params)?;
847
848        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
849        let soln = match linear_solver {
850            LinearSolverType::Default => {
851                method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
852            }
853            LinearSolverType::Lu => method
854                .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
855            LinearSolverType::Klu => method
856                .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
857        };
858        Ok(Box::new(soln?))
859    }
860
861    fn solve_continuous_adjoint(
862        &mut self,
863        method: OdeSolverType,
864        linear_solver: LinearSolverType,
865        params: &[f64],
866        final_time: f64,
867    ) -> Result<(HostArray, HostArray), DiffsolRtError> {
868        self.check(linear_solver)?;
869        self.setup_problem(params)?;
870
871        let final_time = M::T::from_f64(final_time).unwrap();
872        let integrate_out = self.problem.integrate_out;
873        self.problem.integrate_out = true;
874        let result = match linear_solver {
875            LinearSolverType::Default => method
876                .solve_continuous_adjoint::<M, CG, <M as DefaultSolver>::LS>(
877                    &mut self.problem,
878                    final_time,
879                ),
880            LinearSolverType::Lu => method
881                .solve_continuous_adjoint::<M, CG, <M as LuValidator<M>>::LS>(
882                    &mut self.problem,
883                    final_time,
884                ),
885            LinearSolverType::Klu => method
886                .solve_continuous_adjoint::<M, CG, <M as KluValidator<M>>::LS>(
887                    &mut self.problem,
888                    final_time,
889                ),
890        };
891        self.problem.integrate_out = integrate_out;
892        let (integral, gradient) = result?;
893        Ok((
894            (*integral.inner()).clone().to_host_array(),
895            gradient.to_host_array(),
896        ))
897    }
898
899    fn solve_adjoint_fwd(
900        &mut self,
901        method: OdeSolverType,
902        linear_solver: LinearSolverType,
903        params: &[f64],
904        t_eval: &[f64],
905    ) -> Result<(Box<dyn Solution>, AdjointCheckpointWrapper), DiffsolRtError> {
906        self.check(linear_solver)?;
907        self.setup_problem(params)?;
908
909        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
910        let (soln, checkpoint) = match linear_solver {
911            LinearSolverType::Default => method
912                .solve_adjoint_fwd::<M, CG, <M as DefaultSolver>::LS>(
913                    &mut self.problem,
914                    &t_eval,
915                    params,
916                    linear_solver,
917                ),
918            LinearSolverType::Lu => method.solve_adjoint_fwd::<M, CG, <M as LuValidator<M>>::LS>(
919                &mut self.problem,
920                &t_eval,
921                params,
922                linear_solver,
923            ),
924            LinearSolverType::Klu => method.solve_adjoint_fwd::<M, CG, <M as KluValidator<M>>::LS>(
925                &mut self.problem,
926                &t_eval,
927                params,
928                linear_solver,
929            ),
930        }?;
931        Ok((Box::new(soln), AdjointCheckpointWrapper::new(checkpoint)))
932    }
933
934    fn solve_adjoint_bkwd(
935        &mut self,
936        method: OdeSolverType,
937        linear_solver: LinearSolverType,
938        checkpoint: &AdjointCheckpointWrapper,
939        t_eval: &[f64],
940        dgdu_eval: HostArray,
941    ) -> Result<HostArray, DiffsolRtError> {
942        self.check(linear_solver)?;
943        let checkpoint = checkpoint.guard()?;
944        self.setup_problem(checkpoint.params())?;
945
946        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
947        let dgdu_eval = host_array_to_dense_matrix::<M>(dgdu_eval)?;
948        if dgdu_eval.nrows() != self.problem.eqn.nout() {
949            return Err(DiffsolError::Other(format!(
950                "Expected dgdu_eval to have {} rows, got {}",
951                self.problem.eqn.nout(),
952                dgdu_eval.nrows()
953            ))
954            .into());
955        }
956        if dgdu_eval.ncols() != t_eval.len() {
957            return Err(DiffsolError::Other(format!(
958                "Expected dgdu_eval to have {} columns, got {}",
959                t_eval.len(),
960                dgdu_eval.ncols()
961            ))
962            .into());
963        }
964
965        let gradient = match linear_solver {
966            LinearSolverType::Default => method
967                .solve_adjoint_bkwd::<M, CG, <M as DefaultSolver>::LS>(
968                    &mut self.problem,
969                    checkpoint.as_ref(),
970                    &dgdu_eval,
971                    &t_eval,
972                ),
973            LinearSolverType::Lu => method.solve_adjoint_bkwd::<M, CG, <M as LuValidator<M>>::LS>(
974                &mut self.problem,
975                checkpoint.as_ref(),
976                &dgdu_eval,
977                &t_eval,
978            ),
979            LinearSolverType::Klu => method
980                .solve_adjoint_bkwd::<M, CG, <M as KluValidator<M>>::LS>(
981                    &mut self.problem,
982                    checkpoint.as_ref(),
983                    &dgdu_eval,
984                    &t_eval,
985                ),
986        }?;
987        Ok(gradient.to_host_array())
988    }
989}
990
991fn host_array_to_dense_matrix<M>(
992    array: HostArray,
993) -> Result<<M::V as DefaultDenseMatrix>::M, DiffsolRtError>
994where
995    M: MatrixHost<T: Scalar>,
996    M::V: DefaultDenseMatrix,
997{
998    let view = array.as_array::<M::T>()?;
999    let mut values = Vec::with_capacity(view.nrows() * view.ncols());
1000    for col in 0..view.ncols() {
1001        for row in 0..view.nrows() {
1002            values.push(view[(row, col)]);
1003        }
1004    }
1005    Ok(<M::V as DefaultDenseMatrix>::M::from_vec(
1006        view.nrows(),
1007        view.ncols(),
1008        values,
1009        M::C::default(),
1010    ))
1011}
1012
1013#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
1014mod tests {
1015    use diffsol::MatrixCommon;
1016    use diffsol::{
1017        CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
1018    };
1019
1020    #[cfg(feature = "diffsl-llvm")]
1021    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr};
1022    use crate::{
1023        host_array::FromHostArray,
1024        linear_solver_type::LinearSolverType,
1025        matrix_type::MatrixType,
1026        ode_solver_type::OdeSolverType,
1027        scalar_type::ScalarType,
1028        test_support::{
1029            assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
1030            logistic_state, LOGISTIC_X0,
1031        },
1032    };
1033
1034    use super::{solve_factory_with_jit_backend, GenericSolve, Solve, SolveSerialization};
1035
1036    fn make_generic_solve<T, CG>() -> GenericSolve<diffsol::NalgebraMat<T>, CG>
1037    where
1038        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1039        CG: diffsol::CodegenModule
1040            + CodegenModuleJit
1041            + CodegenModuleCompile
1042            + SolveSerialization<diffsol::NalgebraMat<T>>,
1043        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1044        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1045        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1046    {
1047        let problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1048            .build_from_diffsl::<CG>(logistic_diffsl_code())
1049            .unwrap();
1050        GenericSolve { problem }
1051    }
1052
1053    fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
1054    where
1055        CG: diffsol::CodegenModule
1056            + CodegenModuleJit
1057            + CodegenModuleCompile
1058            + SolveSerialization<diffsol::NalgebraMat<f32>>
1059            + SolveSerialization<diffsol::NalgebraMat<f64>>
1060            + SolveSerialization<diffsol::FaerMat<f32>>
1061            + SolveSerialization<diffsol::FaerMat<f64>>
1062            + SolveSerialization<diffsol::FaerSparseMat<f32>>
1063            + SolveSerialization<diffsol::FaerSparseMat<f64>>,
1064    {
1065        for matrix_type in [
1066            MatrixType::NalgebraDense,
1067            MatrixType::FaerDense,
1068            MatrixType::FaerSparse,
1069        ] {
1070            for scalar_type in [ScalarType::F32, ScalarType::F64] {
1071                assert!(solve_factory_with_jit_backend::<CG>(
1072                    logistic_diffsl_code(),
1073                    matrix_type,
1074                    scalar_type,
1075                )
1076                .is_ok());
1077            }
1078        }
1079    }
1080
1081    fn assert_solve_metadata_and_helpers<T, CG>(tol: f64)
1082    where
1083        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1084        CG: diffsol::CodegenModule
1085            + CodegenModuleJit
1086            + CodegenModuleCompile
1087            + SolveSerialization<diffsol::NalgebraMat<T>>,
1088        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1089        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1090        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1091        GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1092    {
1093        let mut solve = make_generic_solve::<T, CG>();
1094        assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
1095        assert_eq!(solve.nstates(), 1);
1096        assert_eq!(solve.nparams(), 1);
1097        assert_eq!(solve.nout(), 1);
1098        assert!(!solve.has_stop());
1099        assert!(solve.check(LinearSolverType::Default).is_ok());
1100        assert!(solve.check(LinearSolverType::Lu).is_ok());
1101        assert!(solve.check(LinearSolverType::Klu).is_err());
1102
1103        solve.set_atol(1e-5);
1104        solve.set_rtol(1e-4);
1105        assert_close(solve.atol(), 1e-5, tol, "solve atol");
1106        assert_close(solve.rtol(), 1e-4, tol, "solve rtol");
1107        solve.set_t0(0.125);
1108        solve.set_h0(0.25);
1109        solve.set_integrate_out(true);
1110        assert_close(solve.t0(), 0.125, tol, "solve t0");
1111        assert_close(solve.h0(), 0.25, tol, "solve h0");
1112        assert!(solve.integrate_out());
1113
1114        solve.set_sens_rtol(Some(1e-3));
1115        solve.set_sens_atol(Some(1e-4));
1116        solve.set_out_rtol(Some(2e-3));
1117        solve.set_out_atol(Some(2e-4));
1118        solve.set_param_rtol(Some(3e-3));
1119        solve.set_param_atol(Some(3e-4));
1120        assert_close(solve.sens_rtol().unwrap(), 1e-3, tol, "solve sens rtol");
1121        assert_close(solve.sens_atol().unwrap(), 1e-4, tol, "solve sens atol");
1122        assert_close(solve.out_rtol().unwrap(), 2e-3, tol, "solve out rtol");
1123        assert_close(solve.out_atol().unwrap(), 2e-4, tol, "solve out atol");
1124        assert_close(solve.param_rtol().unwrap(), 3e-3, tol, "solve param rtol");
1125        assert_close(solve.param_atol().unwrap(), 3e-4, tol, "solve param atol");
1126
1127        solve.set_sens_rtol(None);
1128        solve.set_sens_atol(None);
1129        solve.set_out_rtol(None);
1130        solve.set_out_atol(None);
1131        solve.set_param_rtol(None);
1132        solve.set_param_atol(None);
1133        assert_eq!(solve.sens_rtol(), None);
1134        assert_eq!(solve.sens_atol(), None);
1135        assert_eq!(solve.out_rtol(), None);
1136        assert_eq!(solve.out_atol(), None);
1137        assert_eq!(solve.param_rtol(), None);
1138        assert_eq!(solve.param_atol(), None);
1139
1140        let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
1141        assert_eq!(y0.len(), 1);
1142        assert_close(y0[0], LOGISTIC_X0, tol, "y0[0]");
1143
1144        let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
1145        assert_eq!(rhs.len(), 1);
1146        assert_close(rhs[0], 0.375, tol, "solve rhs");
1147
1148        let jac_mul =
1149            Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
1150                .unwrap();
1151        assert_eq!(jac_mul.len(), 1);
1152        assert_close(jac_mul[0], 3.0, tol, "solve rhs jac mul");
1153    }
1154
1155    fn assert_solve_runtime_paths<T, CG>(tol: f64)
1156    where
1157        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
1158        CG: diffsol::CodegenModule
1159            + CodegenModuleJit
1160            + CodegenModuleCompile
1161            + SolveSerialization<diffsol::NalgebraMat<T>>,
1162        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1163        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1164        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1165        GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1166    {
1167        let mut solve = make_generic_solve::<T, CG>();
1168        let soln = solve
1169            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
1170            .unwrap();
1171        let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
1172        let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
1173        assert_close(*ts.last().unwrap(), 1.0, tol, "solve final time");
1174        assert_close(
1175            ys[0][ts.len() - 1],
1176            logistic_state(LOGISTIC_X0, 2.0, 1.0),
1177            tol,
1178            "solve final value",
1179        );
1180
1181        let mut solve = make_generic_solve::<T, CG>();
1182        let dense = solve
1183            .solve_dense(
1184                OdeSolverType::Tsit45,
1185                LinearSolverType::Lu,
1186                &[2.0],
1187                &[0.25, 0.5, 1.0],
1188            )
1189            .unwrap();
1190        let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
1191        let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
1192        assert_eq!(ts, vec![0.25, 0.5, 1.0]);
1193        for (i, &t) in ts.iter().enumerate() {
1194            assert_close(
1195                ys[0][i],
1196                logistic_state(LOGISTIC_X0, 2.0, t),
1197                tol,
1198                &format!("solve_dense[{i}]"),
1199            );
1200        }
1201
1202        let mut solve = make_generic_solve::<T, CG>();
1203        let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
1204            Ok(_) => panic!("expected parameter count mismatch"),
1205            Err(err) => err,
1206        };
1207        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1208
1209        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1210            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1211            .unwrap();
1212        let mut hybrid_solve = GenericSolve {
1213            problem: hybrid_problem,
1214        };
1215        let hybrid = hybrid_solve
1216            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
1217            .unwrap();
1218        let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
1219        let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
1220        assert_close(*hybrid_ts.last().unwrap(), 2.0, tol, "solve final time");
1221        assert_close(
1222            hybrid_ys[0][hybrid_ts.len() - 1],
1223            hybrid_logistic_state(2.0, 2.0),
1224            tol,
1225            "solve final value",
1226        );
1227
1228        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1229            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1230            .unwrap();
1231        let mut hybrid_solve = GenericSolve {
1232            problem: hybrid_problem,
1233        };
1234        let hybrid_dense = hybrid_solve
1235            .solve_dense(
1236                OdeSolverType::Tsit45,
1237                LinearSolverType::Lu,
1238                &[2.0],
1239                &[0.5, 1.0, 1.5, 2.0],
1240            )
1241            .unwrap();
1242        let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
1243        let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
1244        assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
1245        for (i, &t) in hybrid_dense_ts.iter().enumerate() {
1246            assert_close(
1247                hybrid_dense_ys[0][i],
1248                hybrid_logistic_state(2.0, t),
1249                tol,
1250                &format!("solve_dense[{i}]"),
1251            );
1252        }
1253    }
1254
1255    #[cfg(feature = "diffsl-llvm")]
1256    fn assert_solve_sensitivity_paths() {
1257        let t_eval = [0.25, 0.5, 1.0];
1258
1259        let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1260        let sens = solve
1261            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1262            .unwrap();
1263        let sens_values = sens.get_sens();
1264        assert_eq!(sens_values.len(), 1);
1265        let sens_matrix =
1266            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1267        for (i, &t) in t_eval.iter().enumerate() {
1268            assert_close(
1269                sens_matrix[0][i],
1270                logistic_state_dr(LOGISTIC_X0, 2.0, t),
1271                5e-4,
1272                &format!("solve_fwd_sens[{i}]"),
1273            );
1274        }
1275
1276        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
1277            .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
1278            .unwrap();
1279        let mut solve = GenericSolve {
1280            problem: hybrid_problem,
1281        };
1282        let hybrid_sens = solve
1283            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1284            .unwrap();
1285        let sens_values = hybrid_sens.get_sens();
1286        let sens_matrix =
1287            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1288        for (i, &t) in t_eval.iter().enumerate() {
1289            assert_close(
1290                sens_matrix[0][i],
1291                hybrid_logistic_state_dr(2.0, t),
1292                5e-4,
1293                &format!("solve_fwd_sens[{i}]"),
1294            );
1295        }
1296    }
1297
1298    #[cfg(feature = "diffsl-cranelift")]
1299    #[test]
1300    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
1301        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
1302    }
1303
1304    #[cfg(feature = "diffsl-cranelift")]
1305    #[test]
1306    fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
1307        assert_solve_metadata_and_helpers::<f64, diffsol::CraneliftJitModule>(1e-12);
1308        assert_solve_runtime_paths::<f64, diffsol::CraneliftJitModule>(5e-4);
1309        assert_solve_metadata_and_helpers::<f32, diffsol::CraneliftJitModule>(5e-3);
1310        assert_solve_runtime_paths::<f32, diffsol::CraneliftJitModule>(5e-3);
1311    }
1312
1313    #[cfg(feature = "diffsl-llvm")]
1314    #[test]
1315    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
1316        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
1317    }
1318
1319    #[cfg(feature = "diffsl-llvm")]
1320    #[test]
1321    fn solve_trait_helpers_and_runtime_paths_for_llvm() {
1322        assert_solve_metadata_and_helpers::<f64, diffsol::LlvmModule>(1e-12);
1323        assert_solve_runtime_paths::<f64, diffsol::LlvmModule>(5e-4);
1324        assert_solve_metadata_and_helpers::<f32, diffsol::LlvmModule>(5e-3);
1325        assert_solve_runtime_paths::<f32, diffsol::LlvmModule>(5e-3);
1326    }
1327
1328    #[cfg(feature = "diffsl-llvm")]
1329    #[test]
1330    fn solve_trait_sensitivity_paths_for_llvm() {
1331        assert_solve_sensitivity_paths();
1332    }
1333
1334    #[cfg(feature = "diffsl-cranelift")]
1335    #[test]
1336    fn setup_problem_validates_parameter_count_for_cranelift() {
1337        let mut solve = make_generic_solve::<f64, diffsol::CraneliftJitModule>();
1338        let err = solve.setup_problem(&[]).unwrap_err();
1339        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1340
1341        solve.setup_problem(&[2.0]).unwrap();
1342        let mut params = solve.problem.context().vector_zeros(1);
1343        solve.problem.eqn.get_params(&mut params);
1344        assert_eq!(params.get_index(0), 2.0);
1345    }
1346
1347    #[cfg(feature = "diffsl-llvm")]
1348    #[test]
1349    fn setup_problem_validates_parameter_count_for_llvm() {
1350        let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1351        let err = solve.setup_problem(&[]).unwrap_err();
1352        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1353
1354        solve.setup_problem(&[2.0]).unwrap();
1355        let mut params = solve.problem.context().vector_zeros(1);
1356        solve.problem.eqn.get_params(&mut params);
1357        assert_eq!(params.get_index(0), 2.0);
1358    }
1359}