Skip to main content

diffsol_c/
solve.rs

1// Delegate solver types selected at runtime in Host to concrete solver types
2// in Rust.
3
4#[cfg(feature = "diffsl-external-dynamic")]
5use std::path::PathBuf;
6
7use crate::error::DiffsolRtError;
8use crate::host_array::HostArray;
9use crate::host_array::ToHostArray;
10#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
11use crate::jit::JitBackendType;
12#[cfg(feature = "external")]
13use crate::scalar_type::ExternalScalar;
14use crate::scalar_type::Scalar;
15use crate::scalar_type::ScalarType;
16use crate::{
17    generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
18    option_value_from_store, option_value_to_store,
19};
20use crate::{generate_trait_ic_option_accessors, generate_trait_ode_option_accessors};
21use diffsol::ObjectModule;
22use diffsol::OdeBuilder;
23use diffsol::{
24    error::DiffsolError,
25    matrix::{MatrixHost, MatrixRef},
26    CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DiffSl, MatrixCommon,
27    NonLinearOp, NonLinearOpJacobian, OdeEquations, OdeSolverProblem, Op, Vector, VectorCommon,
28    VectorHost, VectorRef,
29};
30#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
31use diffsol::{CodegenModuleCompile, CodegenModuleJit};
32use num_traits::{FromPrimitive, ToPrimitive}; // for from_f64 and to_f64
33use paste::paste;
34
35use crate::solve_serialization::{unsupported_serialization_error, SolveSerialization};
36use crate::{
37    linear_solver_type::LinearSolverType, matrix_type::MatrixType, ode_solver_type::OdeSolverType,
38};
39use crate::{
40    matrix_type::MatrixKind,
41    valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
42};
43
44// Each matrix type implements PySolve as bridge between diffsol and Host
45
46use crate::solution::Solution;
47pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolRtError>;
48
49pub(crate) trait Solve {
50    fn matrix_type(&self) -> MatrixType;
51    fn nstates(&self) -> usize;
52    fn nparams(&self) -> usize;
53    fn nout(&self) -> usize;
54    fn has_stop(&self) -> bool;
55
56    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError>;
57
58    fn rhs_jac_mul(
59        &mut self,
60        params: &[f64],
61        t: f64,
62        y: &[f64],
63        v: &[f64],
64    ) -> Result<HostArray, DiffsolRtError>;
65
66    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError>;
67
68    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError>;
69    fn set_rtol(&mut self, rtol: f64);
70    fn rtol(&self) -> f64;
71    fn set_atol(&mut self, atol: f64);
72    fn atol(&self) -> f64;
73    fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
74        unsupported_serialization_error(
75            "ODE serialization is only supported for JIT-backed solvers",
76        )
77    }
78
79    // New API: solution object support
80    fn solve(
81        &mut self,
82        method: OdeSolverType,
83        linear_solver: LinearSolverType,
84        params: &[f64],
85        final_time: f64,
86    ) -> SolveResult;
87
88    fn solve_hybrid(
89        &mut self,
90        method: OdeSolverType,
91        linear_solver: LinearSolverType,
92        params: &[f64],
93        final_time: f64,
94    ) -> SolveResult;
95
96    fn solve_dense(
97        &mut self,
98        method: OdeSolverType,
99        linear_solver: LinearSolverType,
100        params: &[f64],
101        t_eval: &[f64],
102    ) -> SolveResult;
103
104    fn solve_hybrid_dense(
105        &mut self,
106        method: OdeSolverType,
107        linear_solver: LinearSolverType,
108        params: &[f64],
109        t_eval: &[f64],
110    ) -> SolveResult;
111
112    fn solve_fwd_sens(
113        &mut self,
114        method: OdeSolverType,
115        linear_solver: LinearSolverType,
116        params: &[f64],
117        t_eval: &[f64],
118    ) -> SolveResult;
119
120    fn solve_hybrid_fwd_sens(
121        &mut self,
122        method: OdeSolverType,
123        linear_solver: LinearSolverType,
124        params: &[f64],
125        t_eval: &[f64],
126    ) -> SolveResult;
127
128    #[allow(clippy::type_complexity)]
129    #[allow(clippy::too_many_arguments)]
130    fn solve_sum_squares_adj(
131        &mut self,
132        method: OdeSolverType,
133        linear_solver: LinearSolverType,
134        backwards_method: OdeSolverType,
135        backwards_linear_solver: LinearSolverType,
136        params: &[f64],
137        data: HostArray,
138        t_eval: &[f64],
139    ) -> Result<(f64, HostArray), DiffsolRtError>;
140
141    generate_trait_ic_option_accessors! {
142        use_linesearch: bool,
143        max_linesearch_iterations: usize,
144        max_newton_iterations: usize,
145        max_linear_solver_setups: usize,
146        step_reduction_factor: f64,
147        armijo_constant: f64
148    }
149    generate_trait_ode_option_accessors! {
150        max_nonlinear_solver_iterations: usize,
151        max_error_test_failures: usize,
152        min_timestep: f64,
153        update_jacobian_after_steps: usize,
154        update_rhs_jacobian_after_steps: usize,
155        threshold_to_update_jacobian: f64,
156        threshold_to_update_rhs_jacobian: f64
157    }
158}
159// Public factory method for generating an instance based on matrix type
160#[cfg(feature = "external")]
161pub(crate) fn solve_factory_external(
162    rhs_state_deps: Vec<(usize, usize)>,
163    rhs_input_deps: Vec<(usize, usize)>,
164    mass_state_deps: Vec<(usize, usize)>,
165    matrix_type: MatrixType,
166    scalar_type: ScalarType,
167) -> Result<Box<dyn Solve>, DiffsolRtError> {
168    let solve: Box<dyn Solve> = match matrix_type {
169        MatrixType::NalgebraDense => match scalar_type {
170            #[cfg(feature = "diffsl-external-f32")]
171            ScalarType::F32 => Box::new(GenericSolve::<
172                diffsol::NalgebraMat<f32>,
173                diffsl::ExternalModule<f32>,
174            >::from_external(
175                rhs_state_deps, rhs_input_deps, mass_state_deps, false
176            )?),
177            #[cfg(feature = "diffsl-external-f64")]
178            ScalarType::F64 => Box::new(GenericSolve::<
179                diffsol::NalgebraMat<f64>,
180                diffsl::ExternalModule<f64>,
181            >::from_external(
182                rhs_state_deps, rhs_input_deps, mass_state_deps, false
183            )?),
184            _ => {
185                return Err(DiffsolRtError::from(DiffsolError::Other(
186                    "Unsupported scalar type for NalgebraDense".to_string(),
187                )));
188            }
189        },
190        MatrixType::FaerDense => match scalar_type {
191            #[cfg(feature = "diffsl-external-f32")]
192            ScalarType::F32 => Box::new(GenericSolve::<
193                diffsol::FaerMat<f32>,
194                diffsl::ExternalModule<f32>,
195            >::from_external(
196                rhs_state_deps, rhs_input_deps, mass_state_deps, false
197            )?),
198            #[cfg(feature = "diffsl-external-f64")]
199            ScalarType::F64 => Box::new(GenericSolve::<
200                diffsol::FaerMat<f64>,
201                diffsl::ExternalModule<f64>,
202            >::from_external(
203                rhs_state_deps, rhs_input_deps, mass_state_deps, false
204            )?),
205            _ => {
206                return Err(DiffsolRtError::from(DiffsolError::Other(
207                    "Unsupported scalar type for FaerDense".to_string(),
208                )));
209            }
210        },
211        MatrixType::FaerSparse => match scalar_type {
212            #[cfg(feature = "diffsl-external-f32")]
213            ScalarType::F32 => Box::new(GenericSolve::<
214                diffsol::FaerSparseMat<f32>,
215                diffsl::ExternalModule<f32>,
216            >::from_external(
217                rhs_state_deps, rhs_input_deps, mass_state_deps, false
218            )?),
219            #[cfg(feature = "diffsl-external-f64")]
220            ScalarType::F64 => Box::new(GenericSolve::<
221                diffsol::FaerSparseMat<f64>,
222                diffsl::ExternalModule<f64>,
223            >::from_external(
224                rhs_state_deps, rhs_input_deps, mass_state_deps, false
225            )?),
226            _ => {
227                return Err(DiffsolRtError::from(DiffsolError::Other(
228                    "Unsupported scalar type for FaerSparse".to_string(),
229                )));
230            }
231        },
232    };
233    Ok(solve)
234}
235
236#[cfg(feature = "diffsl-external-dynamic")]
237pub(crate) fn solve_factory_external_dynamic(
238    path: PathBuf,
239    rhs_state_deps: Vec<(usize, usize)>,
240    rhs_input_deps: Vec<(usize, usize)>,
241    mass_state_deps: Vec<(usize, usize)>,
242    matrix_type: MatrixType,
243    scalar_type: ScalarType,
244) -> Result<Box<dyn Solve>, DiffsolRtError> {
245    let solve: Box<dyn Solve> = match matrix_type {
246        MatrixType::NalgebraDense => match scalar_type {
247            ScalarType::F32 => Box::new(GenericSolve::<
248                diffsol::NalgebraMat<f32>,
249                diffsl::ExternalDynModule<f32>,
250            >::from_external_dynamic(
251                path,
252                rhs_state_deps,
253                rhs_input_deps,
254                mass_state_deps,
255                false,
256            )?),
257            ScalarType::F64 => Box::new(GenericSolve::<
258                diffsol::NalgebraMat<f64>,
259                diffsl::ExternalDynModule<f64>,
260            >::from_external_dynamic(
261                path,
262                rhs_state_deps,
263                rhs_input_deps,
264                mass_state_deps,
265                false,
266            )?),
267        },
268        MatrixType::FaerDense => match scalar_type {
269            ScalarType::F32 => Box::new(GenericSolve::<
270                diffsol::FaerMat<f32>,
271                diffsl::ExternalDynModule<f32>,
272            >::from_external_dynamic(
273                path,
274                rhs_state_deps,
275                rhs_input_deps,
276                mass_state_deps,
277                false,
278            )?),
279            ScalarType::F64 => Box::new(GenericSolve::<
280                diffsol::FaerMat<f64>,
281                diffsl::ExternalDynModule<f64>,
282            >::from_external_dynamic(
283                path,
284                rhs_state_deps,
285                rhs_input_deps,
286                mass_state_deps,
287                false,
288            )?),
289        },
290        MatrixType::FaerSparse => match scalar_type {
291            ScalarType::F32 => Box::new(GenericSolve::<
292                diffsol::FaerSparseMat<f32>,
293                diffsl::ExternalDynModule<f32>,
294            >::from_external_dynamic(
295                path,
296                rhs_state_deps,
297                rhs_input_deps,
298                mass_state_deps,
299                false,
300            )?),
301            ScalarType::F64 => Box::new(GenericSolve::<
302                diffsol::FaerSparseMat<f64>,
303                diffsl::ExternalDynModule<f64>,
304            >::from_external_dynamic(
305                path,
306                rhs_state_deps,
307                rhs_input_deps,
308                mass_state_deps,
309                false,
310            )?),
311        },
312    };
313    Ok(solve)
314}
315
316#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
317pub(crate) fn solve_factory_jit(
318    code: &str,
319    jit_backend: JitBackendType,
320    matrix_type: MatrixType,
321    scalar_type: ScalarType,
322) -> Result<Box<dyn Solve>, DiffsolRtError> {
323    match jit_backend {
324        #[cfg(feature = "diffsl-cranelift")]
325        JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
326            code,
327            matrix_type,
328            scalar_type,
329        ),
330        #[cfg(feature = "diffsl-llvm")]
331        JitBackendType::Llvm => {
332            solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
333        }
334    }
335}
336
337#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
338fn solve_factory_with_jit_backend<CG>(
339    code: &str,
340    matrix_type: MatrixType,
341    scalar_type: ScalarType,
342) -> Result<Box<dyn Solve>, DiffsolRtError>
343where
344    CG: CodegenModule
345        + CodegenModuleJit
346        + CodegenModuleCompile
347        + SolveSerialization<diffsol::NalgebraMat<f32>>
348        + SolveSerialization<diffsol::NalgebraMat<f64>>
349        + SolveSerialization<diffsol::FaerMat<f32>>
350        + SolveSerialization<diffsol::FaerMat<f64>>
351        + SolveSerialization<diffsol::FaerSparseMat<f32>>
352        + SolveSerialization<diffsol::FaerSparseMat<f64>>,
353{
354    let solve: Box<dyn Solve> = match matrix_type {
355        MatrixType::NalgebraDense => match scalar_type {
356            ScalarType::F32 => {
357                let problem =
358                    OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
359                Box::new(GenericSolve { problem })
360            }
361            ScalarType::F64 => {
362                let problem =
363                    OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
364                Box::new(GenericSolve { problem })
365            }
366        },
367        MatrixType::FaerDense => match scalar_type {
368            ScalarType::F32 => {
369                let problem =
370                    OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
371                Box::new(GenericSolve { problem })
372            }
373            ScalarType::F64 => {
374                let problem =
375                    OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
376                Box::new(GenericSolve { problem })
377            }
378        },
379        MatrixType::FaerSparse => match scalar_type {
380            ScalarType::F32 => {
381                let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
382                    .build_from_diffsl::<CG>(code)?;
383                Box::new(GenericSolve { problem })
384            }
385            ScalarType::F64 => {
386                let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
387                    .build_from_diffsl::<CG>(code)?;
388                Box::new(GenericSolve { problem })
389            }
390        },
391    };
392    Ok(solve)
393}
394
395pub(crate) fn solve_factory_from_serialized_diffsl(
396    serialized_diffsl: &[u8],
397    matrix_type: MatrixType,
398    scalar_type: ScalarType,
399) -> Result<Box<dyn Solve>, DiffsolRtError> {
400    let solve: Box<dyn Solve> = match matrix_type {
401        MatrixType::NalgebraDense => match scalar_type {
402            ScalarType::F32 => Box::new(
403                GenericSolve::<diffsol::NalgebraMat<f32>, ObjectModule>::from_serialized_diffsl(
404                    serialized_diffsl,
405                )?,
406            ),
407            ScalarType::F64 => Box::new(
408                GenericSolve::<diffsol::NalgebraMat<f64>, ObjectModule>::from_serialized_diffsl(
409                    serialized_diffsl,
410                )?,
411            ),
412        },
413        MatrixType::FaerDense => match scalar_type {
414            ScalarType::F32 => Box::new(
415                GenericSolve::<diffsol::FaerMat<f32>, ObjectModule>::from_serialized_diffsl(
416                    serialized_diffsl,
417                )?,
418            ),
419            ScalarType::F64 => Box::new(
420                GenericSolve::<diffsol::FaerMat<f64>, ObjectModule>::from_serialized_diffsl(
421                    serialized_diffsl,
422                )?,
423            ),
424        },
425        MatrixType::FaerSparse => match scalar_type {
426            ScalarType::F32 => Box::new(
427                GenericSolve::<diffsol::FaerSparseMat<f32>, ObjectModule>::from_serialized_diffsl(
428                    serialized_diffsl,
429                )?,
430            ),
431            ScalarType::F64 => Box::new(
432                GenericSolve::<diffsol::FaerSparseMat<f64>, ObjectModule>::from_serialized_diffsl(
433                    serialized_diffsl,
434                )?,
435            ),
436        },
437    };
438    Ok(solve)
439}
440
441pub(crate) struct GenericSolve<M, CG>
442where
443    M: MatrixHost<T: Scalar>,
444    M::V: Vector + VectorHost,
445    CG: CodegenModule,
446{
447    problem: OdeSolverProblem<DiffSl<M, CG>>,
448}
449
450impl<M, CG> GenericSolve<M, CG>
451where
452    M: MatrixHost<T: Scalar>,
453    M::V: Vector + VectorHost + DefaultDenseMatrix,
454    CG: CodegenModule,
455{
456    fn from_eqn(eqn: DiffSl<M, CG>) -> Result<Self, DiffsolRtError> {
457        let default_p = vec![0.0; eqn.nparams()];
458        let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
459        Ok(GenericSolve { problem })
460    }
461
462    pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolRtError> {
463        let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
464        let params = M::V::from_slice(&params, M::C::default());
465
466        // Attempt to set problem from params and config
467        let nparams = self.problem.eqn.nparams();
468        if params.len() == nparams {
469            self.problem.eqn.set_params(&params);
470            Ok(())
471        } else {
472            Err(DiffsolError::Other(format!(
473                "Expecting {} params but got {}",
474                nparams,
475                params.len()
476            ))
477            .into())
478        }
479    }
480
481    pub(crate) fn serialize_eqn(&self) -> Result<Vec<u8>, DiffsolRtError>
482    where
483        DiffSl<M, CG>: serde::Serialize,
484    {
485        serde_json::to_vec(&self.problem.eqn)
486            .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))
487    }
488}
489
490#[cfg(feature = "external")]
491impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
492where
493    M: MatrixHost<T: ExternalScalar>,
494    M::V: Vector + VectorHost + DefaultDenseMatrix,
495{
496    pub fn from_external(
497        rhs_state_deps: Vec<(usize, usize)>,
498        rhs_input_deps: Vec<(usize, usize)>,
499        mass_state_deps: Vec<(usize, usize)>,
500        include_sensitivities: bool,
501    ) -> Result<Self, DiffsolRtError> {
502        let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
503            M::C::default(),
504            rhs_state_deps,
505            rhs_input_deps,
506            mass_state_deps,
507            include_sensitivities,
508        )?;
509        Self::from_eqn(eqn)
510    }
511}
512
513#[cfg(feature = "diffsl-external-dynamic")]
514impl<M> GenericSolve<M, diffsl::ExternalDynModule<M::T>>
515where
516    M: MatrixHost<T: Scalar>,
517    M::V: Vector + VectorHost + DefaultDenseMatrix,
518{
519    pub fn from_external_dynamic(
520        path: impl Into<PathBuf>,
521        rhs_state_deps: Vec<(usize, usize)>,
522        rhs_input_deps: Vec<(usize, usize)>,
523        mass_state_deps: Vec<(usize, usize)>,
524        include_sensitivities: bool,
525    ) -> Result<Self, DiffsolRtError> {
526        let eqn = DiffSl::<M, diffsl::ExternalDynModule<M::T>>::from_external_dynamic(
527            path,
528            M::C::default(),
529            rhs_state_deps,
530            rhs_input_deps,
531            mass_state_deps,
532            include_sensitivities,
533        )?;
534        Self::from_eqn(eqn)
535    }
536}
537
538impl<M> GenericSolve<M, ObjectModule>
539where
540    M: MatrixHost<T: Scalar>,
541    M::V: Vector + VectorHost + DefaultDenseMatrix,
542{
543    pub fn from_serialized_diffsl(serialized_diffsl: &[u8]) -> Result<Self, DiffsolRtError> {
544        let eqn = serde_json::from_slice::<DiffSl<M, ObjectModule>>(serialized_diffsl)
545            .map_err(|e| DiffsolRtError::from(DiffsolError::Other(e.to_string())))?;
546        Self::from_eqn(eqn)
547    }
548}
549
550impl<M, CG> Solve for GenericSolve<M, CG>
551where
552    M: MatrixHost<T: Scalar + ToPrimitive>
553        + DefaultSolver
554        + LuValidator<M>
555        + KluValidator<M>
556        + MatrixKind,
557    CG: CodegenModule,
558    for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
559    for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
560    M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
561    <M::V as DefaultDenseMatrix>::M: Send + Sync,
562    for<'b> &'b M::V: VectorRef<M::V>,
563    for<'b> &'b M: MatrixRef<M>,
564    CG: SolveSerialization<M>,
565{
566    fn matrix_type(&self) -> MatrixType {
567        MatrixType::from_diffsol::<M>()
568    }
569
570    fn nstates(&self) -> usize {
571        self.problem.eqn.nstates()
572    }
573
574    fn nparams(&self) -> usize {
575        self.problem.eqn.nparams()
576    }
577
578    fn nout(&self) -> usize {
579        self.problem.eqn.nout()
580    }
581
582    fn has_stop(&self) -> bool {
583        self.problem.eqn.root().is_some()
584    }
585
586    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError> {
587        validate_linear_solver::<M>(linear_solver)
588    }
589
590    fn set_atol(&mut self, atol: f64) {
591        self.problem.atol.fill(M::T::from_f64(atol).unwrap());
592    }
593
594    fn atol(&self) -> f64 {
595        self.problem.atol[0].to_f64().unwrap()
596    }
597
598    fn serialized_diffsl(&self) -> Result<Vec<u8>, DiffsolRtError> {
599        CG::serialized_diffsl(self)
600    }
601
602    fn set_rtol(&mut self, rtol: f64) {
603        self.problem.rtol = M::T::from_f64(rtol).unwrap();
604    }
605
606    fn rtol(&self) -> f64 {
607        self.problem.rtol.to_f64().unwrap()
608    }
609
610    generate_ic_option_accessors! {
611        use_linesearch: bool,
612        max_linesearch_iterations: usize,
613        max_newton_iterations: usize,
614        max_linear_solver_setups: usize,
615        step_reduction_factor: f64,
616        armijo_constant: f64
617    }
618
619    generate_ode_option_accessors! {
620        max_nonlinear_solver_iterations: usize,
621        max_error_test_failures: usize,
622        min_timestep: f64,
623        update_jacobian_after_steps: usize,
624        update_rhs_jacobian_after_steps: usize,
625        threshold_to_update_jacobian: f64,
626        threshold_to_update_rhs_jacobian: f64
627    }
628
629    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError> {
630        self.setup_problem(params)?;
631        let n = self.problem.eqn.nstates();
632        let mut y0 = M::V::zeros(n, M::C::default());
633        let t0 = self.problem.t0;
634        self.problem.eqn.init().call_inplace(t0, &mut y0);
635        Ok((*y0.inner()).clone().to_host_array())
636    }
637
638    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError> {
639        self.setup_problem(params)?;
640        let n = self.problem.eqn.nstates();
641        let y = y
642            .iter()
643            .map(|&x| M::T::from_f64(x).unwrap())
644            .collect::<Vec<_>>();
645        let y_vec = M::V::from_slice(&y, M::C::default());
646        let mut dydt = M::V::zeros(n, M::C::default());
647        self.problem
648            .eqn
649            .rhs()
650            .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
651        Ok((*dydt.inner()).clone().to_host_array())
652    }
653
654    fn rhs_jac_mul(
655        &mut self,
656
657        params: &[f64],
658        t: f64,
659        y: &[f64],
660        v: &[f64],
661    ) -> Result<HostArray, DiffsolRtError> {
662        self.setup_problem(params)?;
663        let n = self.problem.eqn.nstates();
664        let y = y
665            .iter()
666            .map(|&x| M::T::from_f64(x).unwrap())
667            .collect::<Vec<_>>();
668        let v = v
669            .iter()
670            .map(|&x| M::T::from_f64(x).unwrap())
671            .collect::<Vec<_>>();
672        let y_vec = M::V::from_slice(&y, M::C::default());
673        let v_vec = M::V::from_slice(&v, M::C::default());
674        let mut dydt = M::V::zeros(n, M::C::default());
675        self.problem.eqn.rhs().jac_mul_inplace(
676            &y_vec,
677            M::T::from_f64(t).unwrap(),
678            &v_vec,
679            &mut dydt,
680        );
681        Ok((*dydt.inner()).clone().to_host_array())
682    }
683
684    fn solve(
685        &mut self,
686        method: OdeSolverType,
687        linear_solver: LinearSolverType,
688        params: &[f64],
689        final_time: f64,
690    ) -> SolveResult {
691        self.check(linear_solver)?;
692        self.setup_problem(params)?;
693        let final_time = M::T::from_f64(final_time).unwrap();
694        let soln = match linear_solver {
695            LinearSolverType::Default => {
696                method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
697            }
698            LinearSolverType::Lu => {
699                method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
700            }
701            LinearSolverType::Klu => {
702                method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
703            }
704        };
705        Ok(Box::new(soln?))
706    }
707
708    fn solve_dense(
709        &mut self,
710        method: OdeSolverType,
711        linear_solver: LinearSolverType,
712        params: &[f64],
713        t_eval: &[f64],
714    ) -> SolveResult {
715        self.check(linear_solver)?;
716        self.setup_problem(params)?;
717
718        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
719        let soln =
720            match linear_solver {
721                LinearSolverType::Default => method
722                    .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
723                LinearSolverType::Lu => method
724                    .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
725                LinearSolverType::Klu => method
726                    .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
727            };
728        Ok(Box::new(soln?))
729    }
730
731    fn solve_hybrid(
732        &mut self,
733        method: OdeSolverType,
734        linear_solver: LinearSolverType,
735        params: &[f64],
736        final_time: f64,
737    ) -> SolveResult {
738        self.check(linear_solver)?;
739        self.setup_problem(params)?;
740        let final_time = M::T::from_f64(final_time).unwrap();
741        let soln = match linear_solver {
742            LinearSolverType::Default => method
743                .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time),
744            LinearSolverType::Lu => method
745                .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time),
746            LinearSolverType::Klu => method
747                .solve_hybrid::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time),
748        };
749        Ok(Box::new(soln?))
750    }
751
752    fn solve_fwd_sens(
753        &mut self,
754        method: OdeSolverType,
755        linear_solver: LinearSolverType,
756        params: &[f64],
757        t_eval: &[f64],
758    ) -> SolveResult {
759        self.check(linear_solver)?;
760        self.setup_problem(params)?;
761
762        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
763        let soln = match linear_solver {
764            LinearSolverType::Default => {
765                method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
766            }
767            LinearSolverType::Lu => method
768                .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
769            LinearSolverType::Klu => method
770                .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
771        };
772        Ok(Box::new(soln?))
773    }
774
775    fn solve_hybrid_dense(
776        &mut self,
777        method: OdeSolverType,
778        linear_solver: LinearSolverType,
779        params: &[f64],
780        t_eval: &[f64],
781    ) -> SolveResult {
782        self.check(linear_solver)?;
783        self.setup_problem(params)?;
784
785        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
786        let soln = match linear_solver {
787            LinearSolverType::Default => method
788                .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
789            LinearSolverType::Lu => method
790                .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
791            LinearSolverType::Klu => method
792                .solve_hybrid_dense::<M, CG, <M as KluValidator<M>>::LS>(
793                    &mut self.problem,
794                    &t_eval,
795                ),
796        };
797        Ok(Box::new(soln?))
798    }
799
800    fn solve_hybrid_fwd_sens(
801        &mut self,
802        method: OdeSolverType,
803        linear_solver: LinearSolverType,
804        params: &[f64],
805        t_eval: &[f64],
806    ) -> SolveResult {
807        self.check(linear_solver)?;
808        self.setup_problem(params)?;
809
810        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
811        let soln = match linear_solver {
812            LinearSolverType::Default => method
813                .solve_hybrid_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(
814                    &mut self.problem,
815                    &t_eval,
816                ),
817            LinearSolverType::Lu => method
818                .solve_hybrid_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(
819                    &mut self.problem,
820                    &t_eval,
821                ),
822            LinearSolverType::Klu => method
823                .solve_hybrid_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(
824                    &mut self.problem,
825                    &t_eval,
826                ),
827        };
828        Ok(Box::new(soln?))
829    }
830
831    fn solve_sum_squares_adj(
832        &mut self,
833
834        method: OdeSolverType,
835        linear_solver: LinearSolverType,
836        backwards_method: OdeSolverType,
837        backwards_linear_solver: LinearSolverType,
838        params: &[f64],
839        data: HostArray,
840        t_eval: &[f64],
841    ) -> Result<(f64, HostArray), DiffsolRtError> {
842        self.check(linear_solver)?;
843        self.setup_problem(params)?;
844
845        let data = data.as_array()?;
846        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
847
848        let result = match linear_solver {
849            LinearSolverType::Default => method
850                .solve_sum_squares_adj::<M, CG, <M as DefaultSolver>::LS>(
851                    &mut self.problem,
852                    data,
853                    &t_eval,
854                    backwards_method,
855                    backwards_linear_solver,
856                ),
857            LinearSolverType::Lu => method
858                .solve_sum_squares_adj::<M, CG, <M as LuValidator<M>>::LS>(
859                    &mut self.problem,
860                    data,
861                    &t_eval,
862                    backwards_method,
863                    backwards_linear_solver,
864                ),
865            LinearSolverType::Klu => method
866                .solve_sum_squares_adj::<M, CG, <M as KluValidator<M>>::LS>(
867                    &mut self.problem,
868                    data,
869                    &t_eval,
870                    backwards_method,
871                    backwards_linear_solver,
872                ),
873        };
874        let (y, y_sens) = result?;
875
876        Ok((
877            y.to_f64().unwrap(),
878            (*y_sens.inner()).clone().to_host_array(),
879        ))
880    }
881}
882
883#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
884mod tests {
885    use diffsol::MatrixCommon;
886    use diffsol::{
887        CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
888    };
889
890    #[cfg(feature = "diffsl-llvm")]
891    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr, matrix_host};
892    use crate::{
893        host_array::FromHostArray,
894        linear_solver_type::LinearSolverType,
895        matrix_type::MatrixType,
896        ode_solver_type::OdeSolverType,
897        scalar_type::ScalarType,
898        test_support::{
899            assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
900            logistic_state, LOGISTIC_X0,
901        },
902    };
903
904    use super::{solve_factory_with_jit_backend, GenericSolve, Solve, SolveSerialization};
905
906    fn make_generic_solve<T, CG>() -> GenericSolve<diffsol::NalgebraMat<T>, CG>
907    where
908        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
909        CG: diffsol::CodegenModule
910            + CodegenModuleJit
911            + CodegenModuleCompile
912            + SolveSerialization<diffsol::NalgebraMat<T>>,
913        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
914        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
915        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
916    {
917        let problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
918            .build_from_diffsl::<CG>(logistic_diffsl_code())
919            .unwrap();
920        GenericSolve { problem }
921    }
922
923    fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
924    where
925        CG: diffsol::CodegenModule
926            + CodegenModuleJit
927            + CodegenModuleCompile
928            + SolveSerialization<diffsol::NalgebraMat<f32>>
929            + SolveSerialization<diffsol::NalgebraMat<f64>>
930            + SolveSerialization<diffsol::FaerMat<f32>>
931            + SolveSerialization<diffsol::FaerMat<f64>>
932            + SolveSerialization<diffsol::FaerSparseMat<f32>>
933            + SolveSerialization<diffsol::FaerSparseMat<f64>>,
934    {
935        for matrix_type in [
936            MatrixType::NalgebraDense,
937            MatrixType::FaerDense,
938            MatrixType::FaerSparse,
939        ] {
940            for scalar_type in [ScalarType::F32, ScalarType::F64] {
941                assert!(solve_factory_with_jit_backend::<CG>(
942                    logistic_diffsl_code(),
943                    matrix_type,
944                    scalar_type,
945                )
946                .is_ok());
947            }
948        }
949    }
950
951    fn assert_solve_metadata_and_helpers<T, CG>(tol: f64)
952    where
953        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
954        CG: diffsol::CodegenModule
955            + CodegenModuleJit
956            + CodegenModuleCompile
957            + SolveSerialization<diffsol::NalgebraMat<T>>,
958        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
959        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
960        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
961        GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
962    {
963        let mut solve = make_generic_solve::<T, CG>();
964        assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
965        assert_eq!(solve.nstates(), 1);
966        assert_eq!(solve.nparams(), 1);
967        assert_eq!(solve.nout(), 1);
968        assert!(!solve.has_stop());
969        assert!(solve.check(LinearSolverType::Default).is_ok());
970        assert!(solve.check(LinearSolverType::Lu).is_ok());
971        assert!(solve.check(LinearSolverType::Klu).is_err());
972
973        solve.set_atol(1e-5);
974        solve.set_rtol(1e-4);
975        assert_close(solve.atol(), 1e-5, tol, "solve atol");
976        assert_close(solve.rtol(), 1e-4, tol, "solve rtol");
977
978        let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
979        assert_eq!(y0.len(), 1);
980        assert_close(y0[0], LOGISTIC_X0, tol, "y0[0]");
981
982        let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
983        assert_eq!(rhs.len(), 1);
984        assert_close(rhs[0], 0.375, tol, "solve rhs");
985
986        let jac_mul =
987            Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
988                .unwrap();
989        assert_eq!(jac_mul.len(), 1);
990        assert_close(jac_mul[0], 3.0, tol, "solve rhs jac mul");
991    }
992
993    fn assert_solve_runtime_paths<T, CG>(tol: f64)
994    where
995        T: crate::scalar_type::Scalar + diffsol::NalgebraScalar,
996        CG: diffsol::CodegenModule
997            + CodegenModuleJit
998            + CodegenModuleCompile
999            + SolveSerialization<diffsol::NalgebraMat<T>>,
1000        diffsol::NalgebraMat<T>: diffsol::matrix::MatrixHost,
1001        <diffsol::NalgebraMat<T> as MatrixCommon>::T: crate::scalar_type::Scalar,
1002        <diffsol::NalgebraMat<T> as MatrixCommon>::V: diffsol::DefaultDenseMatrix,
1003        GenericSolve<diffsol::NalgebraMat<T>, CG>: Solve,
1004    {
1005        let mut solve = make_generic_solve::<T, CG>();
1006        let soln = solve
1007            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
1008            .unwrap();
1009        let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
1010        let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
1011        assert_close(*ts.last().unwrap(), 1.0, tol, "solve final time");
1012        assert_close(
1013            ys[0][ts.len() - 1],
1014            logistic_state(LOGISTIC_X0, 2.0, 1.0),
1015            tol,
1016            "solve final value",
1017        );
1018
1019        let mut solve = make_generic_solve::<T, CG>();
1020        let dense = solve
1021            .solve_dense(
1022                OdeSolverType::Tsit45,
1023                LinearSolverType::Lu,
1024                &[2.0],
1025                &[0.25, 0.5, 1.0],
1026            )
1027            .unwrap();
1028        let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
1029        let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
1030        assert_eq!(ts, vec![0.25, 0.5, 1.0]);
1031        for (i, &t) in ts.iter().enumerate() {
1032            assert_close(
1033                ys[0][i],
1034                logistic_state(LOGISTIC_X0, 2.0, t),
1035                tol,
1036                &format!("solve_dense[{i}]"),
1037            );
1038        }
1039
1040        let mut solve = make_generic_solve::<T, CG>();
1041        let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
1042            Ok(_) => panic!("expected parameter count mismatch"),
1043            Err(err) => err,
1044        };
1045        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1046
1047        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1048            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1049            .unwrap();
1050        let mut hybrid_solve = GenericSolve {
1051            problem: hybrid_problem,
1052        };
1053        let hybrid = hybrid_solve
1054            .solve_hybrid(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
1055            .unwrap();
1056        let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
1057        let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
1058        assert_close(
1059            *hybrid_ts.last().unwrap(),
1060            2.0,
1061            tol,
1062            "solve_hybrid final time",
1063        );
1064        assert_close(
1065            hybrid_ys[0][hybrid_ts.len() - 1],
1066            hybrid_logistic_state(2.0, 2.0),
1067            tol,
1068            "solve_hybrid final value",
1069        );
1070
1071        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<T>>::new()
1072            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
1073            .unwrap();
1074        let mut hybrid_solve = GenericSolve {
1075            problem: hybrid_problem,
1076        };
1077        let hybrid_dense = hybrid_solve
1078            .solve_hybrid_dense(
1079                OdeSolverType::Tsit45,
1080                LinearSolverType::Lu,
1081                &[2.0],
1082                &[0.5, 1.0, 1.5, 2.0],
1083            )
1084            .unwrap();
1085        let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
1086        let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
1087        assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
1088        for (i, &t) in hybrid_dense_ts.iter().enumerate() {
1089            assert_close(
1090                hybrid_dense_ys[0][i],
1091                hybrid_logistic_state(2.0, t),
1092                tol,
1093                &format!("solve_hybrid_dense[{i}]"),
1094            );
1095        }
1096    }
1097
1098    #[cfg(feature = "diffsl-llvm")]
1099    fn assert_solve_sensitivity_paths() {
1100        let t_eval = [0.25, 0.5, 1.0];
1101
1102        let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1103        let sens = solve
1104            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1105            .unwrap();
1106        let sens_values = sens.get_sens();
1107        assert_eq!(sens_values.len(), 1);
1108        let sens_matrix =
1109            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1110        for (i, &t) in t_eval.iter().enumerate() {
1111            assert_close(
1112                sens_matrix[0][i],
1113                logistic_state_dr(LOGISTIC_X0, 2.0, t),
1114                5e-4,
1115                &format!("solve_fwd_sens[{i}]"),
1116            );
1117        }
1118
1119        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
1120            .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
1121            .unwrap();
1122        let mut solve = GenericSolve {
1123            problem: hybrid_problem,
1124        };
1125        let hybrid_sens = solve
1126            .solve_hybrid_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
1127            .unwrap();
1128        let sens_values = hybrid_sens.get_sens();
1129        let sens_matrix =
1130            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
1131        for (i, &t) in t_eval.iter().enumerate() {
1132            assert_close(
1133                sens_matrix[0][i],
1134                hybrid_logistic_state_dr(2.0, t),
1135                5e-4,
1136                &format!("solve_hybrid_fwd_sens[{i}]"),
1137            );
1138        }
1139
1140        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
1141        let adjoint_data: Vec<f64> = adjoint_t_eval
1142            .iter()
1143            .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
1144            .collect();
1145        let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1146        let (objective, gradient) = solve
1147            .solve_sum_squares_adj(
1148                OdeSolverType::Bdf,
1149                LinearSolverType::Lu,
1150                OdeSolverType::TrBdf2,
1151                LinearSolverType::Lu,
1152                &[2.0],
1153                matrix_host(1, adjoint_t_eval.len(), &adjoint_data),
1154                &adjoint_t_eval,
1155            )
1156            .unwrap();
1157        assert!(objective.is_finite());
1158        let gradient = Vec::<f64>::from_host_array(gradient).unwrap();
1159        assert_eq!(gradient.len(), 1);
1160        assert!(gradient[0].is_finite());
1161    }
1162
1163    #[cfg(feature = "diffsl-cranelift")]
1164    #[test]
1165    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
1166        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
1167    }
1168
1169    #[cfg(feature = "diffsl-cranelift")]
1170    #[test]
1171    fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
1172        assert_solve_metadata_and_helpers::<f64, diffsol::CraneliftJitModule>(1e-12);
1173        assert_solve_runtime_paths::<f64, diffsol::CraneliftJitModule>(5e-4);
1174        assert_solve_metadata_and_helpers::<f32, diffsol::CraneliftJitModule>(5e-3);
1175        assert_solve_runtime_paths::<f32, diffsol::CraneliftJitModule>(5e-3);
1176    }
1177
1178    #[cfg(feature = "diffsl-llvm")]
1179    #[test]
1180    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
1181        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
1182    }
1183
1184    #[cfg(feature = "diffsl-llvm")]
1185    #[test]
1186    fn solve_trait_helpers_and_runtime_paths_for_llvm() {
1187        assert_solve_metadata_and_helpers::<f64, diffsol::LlvmModule>(1e-12);
1188        assert_solve_runtime_paths::<f64, diffsol::LlvmModule>(5e-4);
1189        assert_solve_metadata_and_helpers::<f32, diffsol::LlvmModule>(5e-3);
1190        assert_solve_runtime_paths::<f32, diffsol::LlvmModule>(5e-3);
1191    }
1192
1193    #[cfg(feature = "diffsl-llvm")]
1194    #[test]
1195    fn solve_trait_sensitivity_paths_for_llvm() {
1196        assert_solve_sensitivity_paths();
1197    }
1198
1199    #[cfg(feature = "diffsl-cranelift")]
1200    #[test]
1201    fn setup_problem_validates_parameter_count_for_cranelift() {
1202        let mut solve = make_generic_solve::<f64, diffsol::CraneliftJitModule>();
1203        let err = solve.setup_problem(&[]).unwrap_err();
1204        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1205
1206        solve.setup_problem(&[2.0]).unwrap();
1207        let mut params = solve.problem.context().vector_zeros(1);
1208        solve.problem.eqn.get_params(&mut params);
1209        assert_eq!(params.get_index(0), 2.0);
1210    }
1211
1212    #[cfg(feature = "diffsl-llvm")]
1213    #[test]
1214    fn setup_problem_validates_parameter_count_for_llvm() {
1215        let mut solve = make_generic_solve::<f64, diffsol::LlvmModule>();
1216        let err = solve.setup_problem(&[]).unwrap_err();
1217        assert!(err.to_string().contains("Expecting 1 params but got 0"));
1218
1219        solve.setup_problem(&[2.0]).unwrap();
1220        let mut params = solve.problem.context().vector_zeros(1);
1221        solve.problem.eqn.get_params(&mut params);
1222        assert_eq!(params.get_index(0), 2.0);
1223    }
1224}