Skip to main content

diffsol_c/
solve.rs

1// Delegate solver types selected at runtime in Host to concrete solver types
2// in Rust.
3
4use diffsol::{
5    error::DiffsolError,
6    matrix::{MatrixHost, MatrixRef},
7    CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DiffSl, MatrixCommon,
8    NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, OdeSolverProblem, Op, Vector,
9    VectorCommon, VectorHost, VectorRef,
10};
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use diffsol::{CodegenModuleCompile, CodegenModuleJit};
13use num_traits::{FromPrimitive, ToPrimitive}; // for from_f64 and to_f64
14use paste::paste;
15
16use crate::error::DiffsolRtError;
17use crate::host_array::{HostArray, ToHostArray};
18#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
19use crate::jit::JitBackendType;
20#[cfg(feature = "external")]
21use crate::scalar_type::ExternalScalar;
22use crate::scalar_type::{Scalar, ScalarType};
23use crate::{
24    generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
25    generate_trait_ic_option_accessors, generate_trait_ode_option_accessors,
26    option_value_from_store, option_value_to_store,
27};
28
29use crate::{
30    linear_solver_type::LinearSolverType,
31    matrix_type::{MatrixKind, MatrixType},
32    ode_solver_type::OdeSolverType,
33    valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
34};
35
36// Each matrix type implements PySolve as bridge between diffsol and Host
37
38use crate::solution::Solution;
39pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolRtError>;
40
41pub(crate) trait Solve {
42    fn matrix_type(&self) -> MatrixType;
43    fn nstates(&self) -> usize;
44    fn nparams(&self) -> usize;
45    fn nout(&self) -> usize;
46    fn has_stop(&self) -> bool;
47
48    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError>;
49
50    fn rhs_jac_mul(
51        &mut self,
52        params: &[f64],
53        t: f64,
54        y: &[f64],
55        v: &[f64],
56    ) -> Result<HostArray, DiffsolRtError>;
57
58    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError>;
59
60    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError>;
61    fn set_rtol(&mut self, rtol: f64);
62    fn rtol(&self) -> f64;
63    fn set_atol(&mut self, atol: f64);
64    fn atol(&self) -> f64;
65
66    // New API: solution object support
67    fn solve(
68        &mut self,
69        method: OdeSolverType,
70        linear_solver: LinearSolverType,
71        params: &[f64],
72        final_time: f64,
73    ) -> SolveResult;
74
75    fn solve_hybrid(
76        &mut self,
77        method: OdeSolverType,
78        linear_solver: LinearSolverType,
79        params: &[f64],
80        final_time: f64,
81    ) -> SolveResult;
82
83    fn solve_dense(
84        &mut self,
85        method: OdeSolverType,
86        linear_solver: LinearSolverType,
87        params: &[f64],
88        t_eval: &[f64],
89    ) -> SolveResult;
90
91    fn solve_hybrid_dense(
92        &mut self,
93        method: OdeSolverType,
94        linear_solver: LinearSolverType,
95        params: &[f64],
96        t_eval: &[f64],
97    ) -> SolveResult;
98
99    fn solve_fwd_sens(
100        &mut self,
101        method: OdeSolverType,
102        linear_solver: LinearSolverType,
103        params: &[f64],
104        t_eval: &[f64],
105    ) -> SolveResult;
106
107    fn solve_hybrid_fwd_sens(
108        &mut self,
109        method: OdeSolverType,
110        linear_solver: LinearSolverType,
111        params: &[f64],
112        t_eval: &[f64],
113    ) -> SolveResult;
114
115    #[allow(clippy::type_complexity)]
116    #[allow(clippy::too_many_arguments)]
117    fn solve_sum_squares_adj(
118        &mut self,
119        method: OdeSolverType,
120        linear_solver: LinearSolverType,
121        backwards_method: OdeSolverType,
122        backwards_linear_solver: LinearSolverType,
123        params: &[f64],
124        data: HostArray,
125        t_eval: &[f64],
126    ) -> Result<(f64, HostArray), DiffsolRtError>;
127
128    generate_trait_ic_option_accessors! {
129        use_linesearch: bool,
130        max_linesearch_iterations: usize,
131        max_newton_iterations: usize,
132        max_linear_solver_setups: usize,
133        step_reduction_factor: f64,
134        armijo_constant: f64
135    }
136    generate_trait_ode_option_accessors! {
137        max_nonlinear_solver_iterations: usize,
138        max_error_test_failures: usize,
139        min_timestep: f64,
140        update_jacobian_after_steps: usize,
141        update_rhs_jacobian_after_steps: usize,
142        threshold_to_update_jacobian: f64,
143        threshold_to_update_rhs_jacobian: f64
144    }
145}
146// Public factory method for generating an instance based on matrix type
147#[cfg(feature = "external")]
148pub(crate) fn solve_factory_external(
149    rhs_state_deps: Vec<(usize, usize)>,
150    rhs_input_deps: Vec<(usize, usize)>,
151    mass_state_deps: Vec<(usize, usize)>,
152    matrix_type: MatrixType,
153    scalar_type: ScalarType,
154) -> Result<Box<dyn Solve>, DiffsolRtError> {
155    let solve: Box<dyn Solve> = match matrix_type {
156        MatrixType::NalgebraDense => match scalar_type {
157            #[cfg(feature = "diffsl-external-f32")]
158            ScalarType::F32 => Box::new(GenericSolve::<
159                diffsol::NalgebraMat<f32>,
160                diffsl::ExternalModule<f32>,
161            >::from_external(
162                rhs_state_deps, rhs_input_deps, mass_state_deps, false
163            )?),
164            #[cfg(feature = "diffsl-external-f64")]
165            ScalarType::F64 => Box::new(GenericSolve::<
166                diffsol::NalgebraMat<f64>,
167                diffsl::ExternalModule<f64>,
168            >::from_external(
169                rhs_state_deps, rhs_input_deps, mass_state_deps, false
170            )?),
171            _ => {
172                return Err(DiffsolRtError::from(DiffsolError::Other(
173                    "Unsupported scalar type for NalgebraDense".to_string(),
174                )));
175            }
176        },
177        MatrixType::FaerDense => match scalar_type {
178            #[cfg(feature = "diffsl-external-f32")]
179            ScalarType::F32 => Box::new(GenericSolve::<
180                diffsol::FaerMat<f32>,
181                diffsl::ExternalModule<f32>,
182            >::from_external(
183                rhs_state_deps, rhs_input_deps, mass_state_deps, false
184            )?),
185            #[cfg(feature = "diffsl-external-f64")]
186            ScalarType::F64 => Box::new(GenericSolve::<
187                diffsol::FaerMat<f64>,
188                diffsl::ExternalModule<f64>,
189            >::from_external(
190                rhs_state_deps, rhs_input_deps, mass_state_deps, false
191            )?),
192            _ => {
193                return Err(DiffsolRtError::from(DiffsolError::Other(
194                    "Unsupported scalar type for FaerDense".to_string(),
195                )));
196            }
197        },
198        MatrixType::FaerSparse => match scalar_type {
199            #[cfg(feature = "diffsl-external-f32")]
200            ScalarType::F32 => Box::new(GenericSolve::<
201                diffsol::FaerSparseMat<f32>,
202                diffsl::ExternalModule<f32>,
203            >::from_external(
204                rhs_state_deps, rhs_input_deps, mass_state_deps, false
205            )?),
206            #[cfg(feature = "diffsl-external-f64")]
207            ScalarType::F64 => Box::new(GenericSolve::<
208                diffsol::FaerSparseMat<f64>,
209                diffsl::ExternalModule<f64>,
210            >::from_external(
211                rhs_state_deps, rhs_input_deps, mass_state_deps, false
212            )?),
213            _ => {
214                return Err(DiffsolRtError::from(DiffsolError::Other(
215                    "Unsupported scalar type for FaerSparse".to_string(),
216                )));
217            }
218        },
219    };
220    Ok(solve)
221}
222
223#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
224pub(crate) fn solve_factory_jit(
225    code: &str,
226    jit_backend: JitBackendType,
227    matrix_type: MatrixType,
228    scalar_type: ScalarType,
229) -> Result<Box<dyn Solve>, DiffsolRtError> {
230    match jit_backend {
231        #[cfg(feature = "diffsl-cranelift")]
232        JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
233            code,
234            matrix_type,
235            scalar_type,
236        ),
237        #[cfg(feature = "diffsl-llvm")]
238        JitBackendType::Llvm => {
239            solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
240        }
241    }
242}
243
244#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
245fn solve_factory_with_jit_backend<CG>(
246    code: &str,
247    matrix_type: MatrixType,
248    scalar_type: ScalarType,
249) -> Result<Box<dyn Solve>, DiffsolRtError>
250where
251    CG: CodegenModule + CodegenModuleJit + CodegenModuleCompile,
252{
253    let solve: Box<dyn Solve> = match matrix_type {
254        MatrixType::NalgebraDense => match scalar_type {
255            ScalarType::F32 => {
256                let problem =
257                    OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
258                Box::new(GenericSolve { problem })
259            }
260            ScalarType::F64 => {
261                let problem =
262                    OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
263                Box::new(GenericSolve { problem })
264            }
265        },
266        MatrixType::FaerDense => match scalar_type {
267            ScalarType::F32 => {
268                let problem =
269                    OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
270                Box::new(GenericSolve { problem })
271            }
272            ScalarType::F64 => {
273                let problem =
274                    OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
275                Box::new(GenericSolve { problem })
276            }
277        },
278        MatrixType::FaerSparse => match scalar_type {
279            ScalarType::F32 => {
280                let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
281                    .build_from_diffsl::<CG>(code)?;
282                Box::new(GenericSolve { problem })
283            }
284            ScalarType::F64 => {
285                let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
286                    .build_from_diffsl::<CG>(code)?;
287                Box::new(GenericSolve { problem })
288            }
289        },
290    };
291    Ok(solve)
292}
293
294pub(crate) struct GenericSolve<M, CG>
295where
296    M: MatrixHost<T: Scalar>,
297    M::V: Vector + VectorHost,
298    CG: CodegenModule,
299{
300    problem: OdeSolverProblem<DiffSl<M, CG>>,
301}
302
303impl<M, CG> GenericSolve<M, CG>
304where
305    M: MatrixHost<T: Scalar>,
306    M::V: Vector + VectorHost + DefaultDenseMatrix,
307    CG: CodegenModule,
308{
309    pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolRtError> {
310        let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
311        let params = M::V::from_slice(&params, M::C::default());
312
313        // Attempt to set problem from params and config
314        let nparams = self.problem.eqn.nparams();
315        if params.len() == nparams {
316            self.problem.eqn.set_params(&params);
317            Ok(())
318        } else {
319            Err(DiffsolError::Other(format!(
320                "Expecting {} params but got {}",
321                nparams,
322                params.len()
323            ))
324            .into())
325        }
326    }
327}
328
329#[cfg(feature = "external")]
330impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
331where
332    M: MatrixHost<T: ExternalScalar>,
333    M::V: Vector + VectorHost + DefaultDenseMatrix,
334{
335    pub fn from_external(
336        rhs_state_deps: Vec<(usize, usize)>,
337        rhs_input_deps: Vec<(usize, usize)>,
338        mass_state_deps: Vec<(usize, usize)>,
339        include_sensitivities: bool,
340    ) -> Result<Self, DiffsolRtError> {
341        let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
342            M::C::default(),
343            rhs_state_deps,
344            rhs_input_deps,
345            mass_state_deps,
346            include_sensitivities,
347        )?;
348        let default_p = vec![0.0; eqn.nparams()];
349        let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
350        Ok(GenericSolve { problem })
351    }
352}
353
354impl<M, CG> Solve for GenericSolve<M, CG>
355where
356    M: MatrixHost<T: Scalar + ToPrimitive>
357        + DefaultSolver
358        + LuValidator<M>
359        + KluValidator<M>
360        + MatrixKind,
361    CG: CodegenModule,
362    for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
363    for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
364    M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
365    <M::V as DefaultDenseMatrix>::M: Send + Sync,
366    for<'b> &'b M::V: VectorRef<M::V>,
367    for<'b> &'b M: MatrixRef<M>,
368{
369    fn matrix_type(&self) -> MatrixType {
370        MatrixType::from_diffsol::<M>()
371    }
372
373    fn nstates(&self) -> usize {
374        self.problem.eqn.nstates()
375    }
376
377    fn nparams(&self) -> usize {
378        self.problem.eqn.nparams()
379    }
380
381    fn nout(&self) -> usize {
382        self.problem.eqn.nout()
383    }
384
385    fn has_stop(&self) -> bool {
386        self.problem.eqn.root().is_some()
387    }
388
389    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolRtError> {
390        validate_linear_solver::<M>(linear_solver)
391    }
392
393    fn set_atol(&mut self, atol: f64) {
394        self.problem.atol.fill(M::T::from_f64(atol).unwrap());
395    }
396
397    fn atol(&self) -> f64 {
398        self.problem.atol[0].to_f64().unwrap()
399    }
400
401    fn set_rtol(&mut self, rtol: f64) {
402        self.problem.rtol = M::T::from_f64(rtol).unwrap();
403    }
404
405    fn rtol(&self) -> f64 {
406        self.problem.rtol.to_f64().unwrap()
407    }
408
409    generate_ic_option_accessors! {
410        use_linesearch: bool,
411        max_linesearch_iterations: usize,
412        max_newton_iterations: usize,
413        max_linear_solver_setups: usize,
414        step_reduction_factor: f64,
415        armijo_constant: f64
416    }
417
418    generate_ode_option_accessors! {
419        max_nonlinear_solver_iterations: usize,
420        max_error_test_failures: usize,
421        min_timestep: f64,
422        update_jacobian_after_steps: usize,
423        update_rhs_jacobian_after_steps: usize,
424        threshold_to_update_jacobian: f64,
425        threshold_to_update_rhs_jacobian: f64
426    }
427
428    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolRtError> {
429        self.setup_problem(params)?;
430        let n = self.problem.eqn.nstates();
431        let mut y0 = M::V::zeros(n, M::C::default());
432        let t0 = self.problem.t0;
433        self.problem.eqn.init().call_inplace(t0, &mut y0);
434        Ok((*y0.inner()).clone().to_host_array())
435    }
436
437    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolRtError> {
438        self.setup_problem(params)?;
439        let n = self.problem.eqn.nstates();
440        let y = y
441            .iter()
442            .map(|&x| M::T::from_f64(x).unwrap())
443            .collect::<Vec<_>>();
444        let y_vec = M::V::from_slice(&y, M::C::default());
445        let mut dydt = M::V::zeros(n, M::C::default());
446        self.problem
447            .eqn
448            .rhs()
449            .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
450        Ok((*dydt.inner()).clone().to_host_array())
451    }
452
453    fn rhs_jac_mul(
454        &mut self,
455
456        params: &[f64],
457        t: f64,
458        y: &[f64],
459        v: &[f64],
460    ) -> Result<HostArray, DiffsolRtError> {
461        self.setup_problem(params)?;
462        let n = self.problem.eqn.nstates();
463        let y = y
464            .iter()
465            .map(|&x| M::T::from_f64(x).unwrap())
466            .collect::<Vec<_>>();
467        let v = v
468            .iter()
469            .map(|&x| M::T::from_f64(x).unwrap())
470            .collect::<Vec<_>>();
471        let y_vec = M::V::from_slice(&y, M::C::default());
472        let v_vec = M::V::from_slice(&v, M::C::default());
473        let mut dydt = M::V::zeros(n, M::C::default());
474        self.problem.eqn.rhs().jac_mul_inplace(
475            &y_vec,
476            M::T::from_f64(t).unwrap(),
477            &v_vec,
478            &mut dydt,
479        );
480        Ok((*dydt.inner()).clone().to_host_array())
481    }
482
483    fn solve(
484        &mut self,
485        method: OdeSolverType,
486        linear_solver: LinearSolverType,
487        params: &[f64],
488        final_time: f64,
489    ) -> SolveResult {
490        self.check(linear_solver)?;
491        self.setup_problem(params)?;
492        let final_time = M::T::from_f64(final_time).unwrap();
493        let soln = match linear_solver {
494            LinearSolverType::Default => {
495                method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
496            }
497            LinearSolverType::Lu => {
498                method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
499            }
500            LinearSolverType::Klu => {
501                method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
502            }
503        };
504        Ok(Box::new(soln?))
505    }
506
507    fn solve_dense(
508        &mut self,
509        method: OdeSolverType,
510        linear_solver: LinearSolverType,
511        params: &[f64],
512        t_eval: &[f64],
513    ) -> SolveResult {
514        self.check(linear_solver)?;
515        self.setup_problem(params)?;
516
517        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
518        let soln =
519            match linear_solver {
520                LinearSolverType::Default => method
521                    .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
522                LinearSolverType::Lu => method
523                    .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
524                LinearSolverType::Klu => method
525                    .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
526            };
527        Ok(Box::new(soln?))
528    }
529
530    fn solve_hybrid(
531        &mut self,
532        method: OdeSolverType,
533        linear_solver: LinearSolverType,
534        params: &[f64],
535        final_time: f64,
536    ) -> SolveResult {
537        self.check(linear_solver)?;
538        self.setup_problem(params)?;
539        let final_time = M::T::from_f64(final_time).unwrap();
540        let soln = match linear_solver {
541            LinearSolverType::Default => method
542                .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time),
543            LinearSolverType::Lu => method
544                .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time),
545            LinearSolverType::Klu => method
546                .solve_hybrid::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time),
547        };
548        Ok(Box::new(soln?))
549    }
550
551    fn solve_fwd_sens(
552        &mut self,
553        method: OdeSolverType,
554        linear_solver: LinearSolverType,
555        params: &[f64],
556        t_eval: &[f64],
557    ) -> SolveResult {
558        self.check(linear_solver)?;
559        self.setup_problem(params)?;
560
561        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
562        let soln = match linear_solver {
563            LinearSolverType::Default => {
564                method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
565            }
566            LinearSolverType::Lu => method
567                .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
568            LinearSolverType::Klu => method
569                .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
570        };
571        Ok(Box::new(soln?))
572    }
573
574    fn solve_hybrid_dense(
575        &mut self,
576        method: OdeSolverType,
577        linear_solver: LinearSolverType,
578        params: &[f64],
579        t_eval: &[f64],
580    ) -> SolveResult {
581        self.check(linear_solver)?;
582        self.setup_problem(params)?;
583
584        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
585        let soln = match linear_solver {
586            LinearSolverType::Default => method
587                .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
588            LinearSolverType::Lu => method
589                .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
590            LinearSolverType::Klu => method
591                .solve_hybrid_dense::<M, CG, <M as KluValidator<M>>::LS>(
592                    &mut self.problem,
593                    &t_eval,
594                ),
595        };
596        Ok(Box::new(soln?))
597    }
598
599    fn solve_hybrid_fwd_sens(
600        &mut self,
601        method: OdeSolverType,
602        linear_solver: LinearSolverType,
603        params: &[f64],
604        t_eval: &[f64],
605    ) -> SolveResult {
606        self.check(linear_solver)?;
607        self.setup_problem(params)?;
608
609        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
610        let soln = match linear_solver {
611            LinearSolverType::Default => method
612                .solve_hybrid_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(
613                    &mut self.problem,
614                    &t_eval,
615                ),
616            LinearSolverType::Lu => method
617                .solve_hybrid_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(
618                    &mut self.problem,
619                    &t_eval,
620                ),
621            LinearSolverType::Klu => method
622                .solve_hybrid_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(
623                    &mut self.problem,
624                    &t_eval,
625                ),
626        };
627        Ok(Box::new(soln?))
628    }
629
630    fn solve_sum_squares_adj(
631        &mut self,
632
633        method: OdeSolverType,
634        linear_solver: LinearSolverType,
635        backwards_method: OdeSolverType,
636        backwards_linear_solver: LinearSolverType,
637        params: &[f64],
638        data: HostArray,
639        t_eval: &[f64],
640    ) -> Result<(f64, HostArray), DiffsolRtError> {
641        self.check(linear_solver)?;
642        self.setup_problem(params)?;
643
644        let data = data.as_array()?;
645        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
646
647        let result = match linear_solver {
648            LinearSolverType::Default => method
649                .solve_sum_squares_adj::<M, CG, <M as DefaultSolver>::LS>(
650                    &mut self.problem,
651                    data,
652                    &t_eval,
653                    backwards_method,
654                    backwards_linear_solver,
655                ),
656            LinearSolverType::Lu => method
657                .solve_sum_squares_adj::<M, CG, <M as LuValidator<M>>::LS>(
658                    &mut self.problem,
659                    data,
660                    &t_eval,
661                    backwards_method,
662                    backwards_linear_solver,
663                ),
664            LinearSolverType::Klu => method
665                .solve_sum_squares_adj::<M, CG, <M as KluValidator<M>>::LS>(
666                    &mut self.problem,
667                    data,
668                    &t_eval,
669                    backwards_method,
670                    backwards_linear_solver,
671                ),
672        };
673        let (y, y_sens) = result?;
674
675        Ok((
676            y.to_f64().unwrap(),
677            (*y_sens.inner()).clone().to_host_array(),
678        ))
679    }
680}
681
682#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
683mod tests {
684    use diffsol::{
685        CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
686    };
687
688    #[cfg(feature = "diffsl-llvm")]
689    use crate::test_support::{hybrid_logistic_state_dr, logistic_state_dr, matrix_host};
690    use crate::{
691        host_array::FromHostArray,
692        linear_solver_type::LinearSolverType,
693        matrix_type::MatrixType,
694        ode_solver_type::OdeSolverType,
695        scalar_type::ScalarType,
696        test_support::{
697            assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
698            logistic_state, LOGISTIC_X0,
699        },
700    };
701
702    use super::{solve_factory_with_jit_backend, GenericSolve, Solve};
703
704    fn make_generic_solve<CG>() -> GenericSolve<diffsol::NalgebraMat<f64>, CG>
705    where
706        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
707    {
708        let problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
709            .build_from_diffsl::<CG>(logistic_diffsl_code())
710            .unwrap();
711        GenericSolve { problem }
712    }
713
714    fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
715    where
716        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
717    {
718        for matrix_type in [
719            MatrixType::NalgebraDense,
720            MatrixType::FaerDense,
721            MatrixType::FaerSparse,
722        ] {
723            for scalar_type in [ScalarType::F32, ScalarType::F64] {
724                assert!(solve_factory_with_jit_backend::<CG>(
725                    logistic_diffsl_code(),
726                    matrix_type,
727                    scalar_type,
728                )
729                .is_ok());
730            }
731        }
732    }
733
734    fn assert_solve_metadata_and_helpers<CG>()
735    where
736        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
737    {
738        let mut solve = make_generic_solve::<CG>();
739        assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
740        assert_eq!(solve.nstates(), 1);
741        assert_eq!(solve.nparams(), 1);
742        assert_eq!(solve.nout(), 1);
743        assert!(!solve.has_stop());
744        assert!(solve.check(LinearSolverType::Default).is_ok());
745        assert!(solve.check(LinearSolverType::Lu).is_ok());
746        assert!(solve.check(LinearSolverType::Klu).is_err());
747
748        solve.set_atol(1e-5);
749        solve.set_rtol(1e-4);
750        assert_close(solve.atol(), 1e-5, 1e-12, "solve atol");
751        assert_close(solve.rtol(), 1e-4, 1e-12, "solve rtol");
752
753        let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
754        assert_eq!(y0, vec![LOGISTIC_X0]);
755
756        let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
757        assert_eq!(rhs.len(), 1);
758        assert_close(rhs[0], 0.375, 1e-12, "solve rhs");
759
760        let jac_mul =
761            Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
762                .unwrap();
763        assert_eq!(jac_mul.len(), 1);
764        assert_close(jac_mul[0], 3.0, 1e-12, "solve rhs jac mul");
765    }
766
767    fn assert_solve_runtime_paths<CG>()
768    where
769        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
770    {
771        let mut solve = make_generic_solve::<CG>();
772        let soln = solve
773            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
774            .unwrap();
775        let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
776        let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
777        assert_close(*ts.last().unwrap(), 1.0, 5e-4, "solve final time");
778        assert_close(
779            ys[0][ts.len() - 1],
780            logistic_state(LOGISTIC_X0, 2.0, 1.0),
781            5e-4,
782            "solve final value",
783        );
784
785        let mut solve = make_generic_solve::<CG>();
786        let dense = solve
787            .solve_dense(
788                OdeSolverType::Tsit45,
789                LinearSolverType::Lu,
790                &[2.0],
791                &[0.25, 0.5, 1.0],
792            )
793            .unwrap();
794        let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
795        let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
796        assert_eq!(ts, vec![0.25, 0.5, 1.0]);
797        for (i, &t) in ts.iter().enumerate() {
798            assert_close(
799                ys[0][i],
800                logistic_state(LOGISTIC_X0, 2.0, t),
801                5e-4,
802                &format!("solve_dense[{i}]"),
803            );
804        }
805
806        let mut solve = make_generic_solve::<CG>();
807        let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
808            Ok(_) => panic!("expected parameter count mismatch"),
809            Err(err) => err,
810        };
811        assert!(err.to_string().contains("Expecting 1 params but got 0"));
812
813        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
814            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
815            .unwrap();
816        let mut hybrid_solve = GenericSolve {
817            problem: hybrid_problem,
818        };
819        let hybrid = hybrid_solve
820            .solve_hybrid(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
821            .unwrap();
822        let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
823        let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
824        assert_close(
825            *hybrid_ts.last().unwrap(),
826            2.0,
827            5e-4,
828            "solve_hybrid final time",
829        );
830        assert_close(
831            hybrid_ys[0][hybrid_ts.len() - 1],
832            hybrid_logistic_state(2.0, 2.0),
833            5e-4,
834            "solve_hybrid final value",
835        );
836
837        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
838            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
839            .unwrap();
840        let mut hybrid_solve = GenericSolve {
841            problem: hybrid_problem,
842        };
843        let hybrid_dense = hybrid_solve
844            .solve_hybrid_dense(
845                OdeSolverType::Tsit45,
846                LinearSolverType::Lu,
847                &[2.0],
848                &[0.5, 1.0, 1.5, 2.0],
849            )
850            .unwrap();
851        let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
852        let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
853        assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
854        for (i, &t) in hybrid_dense_ts.iter().enumerate() {
855            assert_close(
856                hybrid_dense_ys[0][i],
857                hybrid_logistic_state(2.0, t),
858                5e-4,
859                &format!("solve_hybrid_dense[{i}]"),
860            );
861        }
862    }
863
864    #[cfg(feature = "diffsl-llvm")]
865    fn assert_solve_sensitivity_paths() {
866        let t_eval = [0.25, 0.5, 1.0];
867
868        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
869        let sens = solve
870            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
871            .unwrap();
872        let sens_values = sens.get_sens();
873        assert_eq!(sens_values.len(), 1);
874        let sens_matrix =
875            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
876        for (i, &t) in t_eval.iter().enumerate() {
877            assert_close(
878                sens_matrix[0][i],
879                logistic_state_dr(LOGISTIC_X0, 2.0, t),
880                5e-4,
881                &format!("solve_fwd_sens[{i}]"),
882            );
883        }
884
885        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
886            .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
887            .unwrap();
888        let mut solve = GenericSolve {
889            problem: hybrid_problem,
890        };
891        let hybrid_sens = solve
892            .solve_hybrid_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
893            .unwrap();
894        let sens_values = hybrid_sens.get_sens();
895        let sens_matrix =
896            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
897        for (i, &t) in t_eval.iter().enumerate() {
898            assert_close(
899                sens_matrix[0][i],
900                hybrid_logistic_state_dr(2.0, t),
901                5e-4,
902                &format!("solve_hybrid_fwd_sens[{i}]"),
903            );
904        }
905
906        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
907        let adjoint_data: Vec<f64> = adjoint_t_eval
908            .iter()
909            .map(|&t| logistic_state(LOGISTIC_X0, 2.0, t))
910            .collect();
911        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
912        let (objective, gradient) = solve
913            .solve_sum_squares_adj(
914                OdeSolverType::Bdf,
915                LinearSolverType::Lu,
916                OdeSolverType::TrBdf2,
917                LinearSolverType::Lu,
918                &[2.0],
919                matrix_host(1, adjoint_t_eval.len(), &adjoint_data),
920                &adjoint_t_eval,
921            )
922            .unwrap();
923        assert!(objective.is_finite());
924        let gradient = Vec::<f64>::from_host_array(gradient).unwrap();
925        assert_eq!(gradient.len(), 1);
926        assert!(gradient[0].is_finite());
927    }
928
929    #[cfg(feature = "diffsl-cranelift")]
930    #[test]
931    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
932        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
933    }
934
935    #[cfg(feature = "diffsl-cranelift")]
936    #[test]
937    fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
938        assert_solve_metadata_and_helpers::<diffsol::CraneliftJitModule>();
939        assert_solve_runtime_paths::<diffsol::CraneliftJitModule>();
940    }
941
942    #[cfg(feature = "diffsl-llvm")]
943    #[test]
944    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
945        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
946    }
947
948    #[cfg(feature = "diffsl-llvm")]
949    #[test]
950    fn solve_trait_helpers_and_runtime_paths_for_llvm() {
951        assert_solve_metadata_and_helpers::<diffsol::LlvmModule>();
952        assert_solve_runtime_paths::<diffsol::LlvmModule>();
953    }
954
955    #[cfg(feature = "diffsl-llvm")]
956    #[test]
957    fn solve_trait_sensitivity_paths_for_llvm() {
958        assert_solve_sensitivity_paths();
959    }
960
961    #[cfg(feature = "diffsl-cranelift")]
962    #[test]
963    fn setup_problem_validates_parameter_count_for_cranelift() {
964        let mut solve = make_generic_solve::<diffsol::CraneliftJitModule>();
965        let err = solve.setup_problem(&[]).unwrap_err();
966        assert!(err.to_string().contains("Expecting 1 params but got 0"));
967
968        solve.setup_problem(&[2.0]).unwrap();
969        let mut params = solve.problem.context().vector_zeros(1);
970        solve.problem.eqn.get_params(&mut params);
971        assert_eq!(params.get_index(0), 2.0);
972    }
973
974    #[cfg(feature = "diffsl-llvm")]
975    #[test]
976    fn setup_problem_validates_parameter_count_for_llvm() {
977        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
978        let err = solve.setup_problem(&[]).unwrap_err();
979        assert!(err.to_string().contains("Expecting 1 params but got 0"));
980
981        solve.setup_problem(&[2.0]).unwrap();
982        let mut params = solve.problem.context().vector_zeros(1);
983        solve.problem.eqn.get_params(&mut params);
984        assert_eq!(params.get_index(0), 2.0);
985    }
986}