Skip to main content

diffsol_c/
solve.rs

1// Delegate solver types selected at runtime in Host to concrete solver types
2// in Rust.
3
4use diffsol::{
5    error::DiffsolError,
6    matrix::{MatrixHost, MatrixRef},
7    CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DiffSl, MatrixCommon,
8    NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, OdeSolverProblem, Op, Vector,
9    VectorCommon, VectorHost, VectorRef,
10};
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use diffsol::{CodegenModuleCompile, CodegenModuleJit};
13use num_traits::{FromPrimitive, ToPrimitive}; // for from_f64 and to_f64
14use paste::paste;
15
16use crate::error::DiffsolJsError;
17use crate::host_array::{HostArray, ToHostArray};
18#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
19use crate::jit::JitBackendType;
20#[cfg(feature = "external")]
21use crate::scalar_type::ExternalScalar;
22use crate::scalar_type::{Scalar, ScalarType};
23use crate::{
24    generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
25    generate_trait_ic_option_accessors, generate_trait_ode_option_accessors,
26    option_value_from_store, option_value_to_store,
27};
28
29use crate::{
30    linear_solver_type::LinearSolverType,
31    matrix_type::{MatrixKind, MatrixType},
32    ode_solver_type::OdeSolverType,
33    valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
34};
35
36// Each matrix type implements PySolve as bridge between diffsol and Host
37
38use crate::solution::Solution;
39pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolJsError>;
40
41pub(crate) trait Solve {
42    fn matrix_type(&self) -> MatrixType;
43    fn nstates(&self) -> usize;
44    fn nparams(&self) -> usize;
45    fn nout(&self) -> usize;
46    fn has_stop(&self) -> bool;
47
48    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError>;
49
50    fn rhs_jac_mul(
51        &mut self,
52        params: &[f64],
53        t: f64,
54        y: &[f64],
55        v: &[f64],
56    ) -> Result<HostArray, DiffsolJsError>;
57
58    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError>;
59
60    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError>;
61    fn set_rtol(&mut self, rtol: f64);
62    fn rtol(&self) -> f64;
63    fn set_atol(&mut self, atol: f64);
64    fn atol(&self) -> f64;
65
66    // New API: solution object support
67    fn solve(
68        &mut self,
69        method: OdeSolverType,
70        linear_solver: LinearSolverType,
71        params: &[f64],
72        final_time: f64,
73    ) -> SolveResult;
74
75    fn solve_hybrid(
76        &mut self,
77        method: OdeSolverType,
78        linear_solver: LinearSolverType,
79        params: &[f64],
80        final_time: f64,
81    ) -> SolveResult;
82
83    fn solve_dense(
84        &mut self,
85        method: OdeSolverType,
86        linear_solver: LinearSolverType,
87        params: &[f64],
88        t_eval: &[f64],
89    ) -> SolveResult;
90
91    fn solve_hybrid_dense(
92        &mut self,
93        method: OdeSolverType,
94        linear_solver: LinearSolverType,
95        params: &[f64],
96        t_eval: &[f64],
97    ) -> SolveResult;
98
99    fn solve_fwd_sens(
100        &mut self,
101        method: OdeSolverType,
102        linear_solver: LinearSolverType,
103        params: &[f64],
104        t_eval: &[f64],
105    ) -> SolveResult;
106
107    fn solve_hybrid_fwd_sens(
108        &mut self,
109        method: OdeSolverType,
110        linear_solver: LinearSolverType,
111        params: &[f64],
112        t_eval: &[f64],
113    ) -> SolveResult;
114
115    #[allow(clippy::type_complexity)]
116    #[allow(clippy::too_many_arguments)]
117    fn solve_sum_squares_adj(
118        &mut self,
119        method: OdeSolverType,
120        linear_solver: LinearSolverType,
121        backwards_method: OdeSolverType,
122        backwards_linear_solver: LinearSolverType,
123        params: &[f64],
124        data: HostArray,
125        t_eval: &[f64],
126    ) -> Result<(f64, HostArray), DiffsolJsError>;
127
128    generate_trait_ic_option_accessors! {
129        use_linesearch: bool,
130        max_linesearch_iterations: usize,
131        max_newton_iterations: usize,
132        max_linear_solver_setups: usize,
133        step_reduction_factor: f64,
134        armijo_constant: f64
135    }
136    generate_trait_ode_option_accessors! {
137        max_nonlinear_solver_iterations: usize,
138        max_error_test_failures: usize,
139        min_timestep: f64,
140        update_jacobian_after_steps: usize,
141        update_rhs_jacobian_after_steps: usize,
142        threshold_to_update_jacobian: f64,
143        threshold_to_update_rhs_jacobian: f64
144    }
145}
146// Public factory method for generating an instance based on matrix type
147#[cfg(feature = "external")]
148pub(crate) fn solve_factory_external(
149    rhs_state_deps: Vec<(usize, usize)>,
150    rhs_input_deps: Vec<(usize, usize)>,
151    mass_state_deps: Vec<(usize, usize)>,
152    matrix_type: MatrixType,
153    scalar_type: ScalarType,
154) -> Result<Box<dyn Solve>, DiffsolJsError> {
155    let solve: Box<dyn Solve> = match matrix_type {
156        MatrixType::NalgebraDense => match scalar_type {
157            #[cfg(feature = "diffsl-external-f32")]
158            ScalarType::F32 => Box::new(GenericSolve::<
159                diffsol::NalgebraMat<f32>,
160                diffsl::ExternalModule<f32>,
161            >::from_external(
162                rhs_state_deps, rhs_input_deps, mass_state_deps, false
163            )?),
164            #[cfg(feature = "diffsl-external-f64")]
165            ScalarType::F64 => Box::new(GenericSolve::<
166                diffsol::NalgebraMat<f64>,
167                diffsl::ExternalModule<f64>,
168            >::from_external(
169                rhs_state_deps, rhs_input_deps, mass_state_deps, false
170            )?),
171            _ => {
172                return Err(DiffsolJsError::from(DiffsolError::Other(
173                    "Unsupported scalar type for NalgebraDense".to_string(),
174                )));
175            }
176        },
177        MatrixType::FaerDense => match scalar_type {
178            #[cfg(feature = "diffsl-external-f32")]
179            ScalarType::F32 => Box::new(GenericSolve::<
180                diffsol::FaerMat<f32>,
181                diffsl::ExternalModule<f32>,
182            >::from_external(
183                rhs_state_deps, rhs_input_deps, mass_state_deps, false
184            )?),
185            #[cfg(feature = "diffsl-external-f64")]
186            ScalarType::F64 => Box::new(GenericSolve::<
187                diffsol::FaerMat<f64>,
188                diffsl::ExternalModule<f64>,
189            >::from_external(
190                rhs_state_deps, rhs_input_deps, mass_state_deps, false
191            )?),
192            _ => {
193                return Err(DiffsolJsError::from(DiffsolError::Other(
194                    "Unsupported scalar type for FaerDense".to_string(),
195                )));
196            }
197        },
198        MatrixType::FaerSparse => match scalar_type {
199            #[cfg(feature = "diffsl-external-f32")]
200            ScalarType::F32 => Box::new(GenericSolve::<
201                diffsol::FaerSparseMat<f32>,
202                diffsl::ExternalModule<f32>,
203            >::from_external(
204                rhs_state_deps, rhs_input_deps, mass_state_deps, false
205            )?),
206            #[cfg(feature = "diffsl-external-f64")]
207            ScalarType::F64 => Box::new(GenericSolve::<
208                diffsol::FaerSparseMat<f64>,
209                diffsl::ExternalModule<f64>,
210            >::from_external(
211                rhs_state_deps, rhs_input_deps, mass_state_deps, false
212            )?),
213            _ => {
214                return Err(DiffsolJsError::from(DiffsolError::Other(
215                    "Unsupported scalar type for FaerSparse".to_string(),
216                )));
217            }
218        },
219    };
220    Ok(solve)
221}
222
223#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
224pub(crate) fn solve_factory_jit(
225    code: &str,
226    jit_backend: JitBackendType,
227    matrix_type: MatrixType,
228    scalar_type: ScalarType,
229) -> Result<Box<dyn Solve>, DiffsolJsError> {
230    match jit_backend {
231        #[cfg(feature = "diffsl-cranelift")]
232        JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
233            code,
234            matrix_type,
235            scalar_type,
236        ),
237        #[cfg(feature = "diffsl-llvm")]
238        JitBackendType::Llvm => {
239            solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
240        }
241    }
242}
243
244#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
245fn solve_factory_with_jit_backend<CG>(
246    code: &str,
247    matrix_type: MatrixType,
248    scalar_type: ScalarType,
249) -> Result<Box<dyn Solve>, DiffsolJsError>
250where
251    CG: CodegenModule + CodegenModuleJit + CodegenModuleCompile,
252{
253    let solve: Box<dyn Solve> = match matrix_type {
254        MatrixType::NalgebraDense => match scalar_type {
255            ScalarType::F32 => {
256                let problem =
257                    OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
258                Box::new(GenericSolve { problem })
259            }
260            ScalarType::F64 => {
261                let problem =
262                    OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
263                Box::new(GenericSolve { problem })
264            }
265        },
266        MatrixType::FaerDense => match scalar_type {
267            ScalarType::F32 => {
268                let problem =
269                    OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
270                Box::new(GenericSolve { problem })
271            }
272            ScalarType::F64 => {
273                let problem =
274                    OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
275                Box::new(GenericSolve { problem })
276            }
277        },
278        MatrixType::FaerSparse => match scalar_type {
279            ScalarType::F32 => {
280                let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
281                    .build_from_diffsl::<CG>(code)?;
282                Box::new(GenericSolve { problem })
283            }
284            ScalarType::F64 => {
285                let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
286                    .build_from_diffsl::<CG>(code)?;
287                Box::new(GenericSolve { problem })
288            }
289        },
290    };
291    Ok(solve)
292}
293
294pub(crate) struct GenericSolve<M, CG>
295where
296    M: MatrixHost<T: Scalar>,
297    M::V: Vector + VectorHost,
298    CG: CodegenModule,
299{
300    problem: OdeSolverProblem<DiffSl<M, CG>>,
301}
302
303impl<M, CG> GenericSolve<M, CG>
304where
305    M: MatrixHost<T: Scalar>,
306    M::V: Vector + VectorHost + DefaultDenseMatrix,
307    CG: CodegenModule,
308{
309    pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolJsError> {
310        let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
311        let params = M::V::from_slice(&params, M::C::default());
312
313        // Attempt to set problem from params and config
314        let nparams = self.problem.eqn.nparams();
315        if params.len() == nparams {
316            self.problem.eqn.set_params(&params);
317            Ok(())
318        } else {
319            Err(DiffsolError::Other(format!(
320                "Expecting {} params but got {}",
321                nparams,
322                params.len()
323            ))
324            .into())
325        }
326    }
327}
328
329#[cfg(feature = "external")]
330impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
331where
332    M: MatrixHost<T: ExternalScalar>,
333    M::V: Vector + VectorHost + DefaultDenseMatrix,
334{
335    pub fn from_external(
336        rhs_state_deps: Vec<(usize, usize)>,
337        rhs_input_deps: Vec<(usize, usize)>,
338        mass_state_deps: Vec<(usize, usize)>,
339        include_sensitivities: bool,
340    ) -> Result<Self, DiffsolJsError> {
341        let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
342            M::C::default(),
343            rhs_state_deps,
344            rhs_input_deps,
345            mass_state_deps,
346            include_sensitivities,
347        )?;
348        let default_p = vec![0.0; eqn.nparams()];
349        let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
350        Ok(GenericSolve { problem })
351    }
352}
353
354impl<M, CG> Solve for GenericSolve<M, CG>
355where
356    M: MatrixHost<T: Scalar + ToPrimitive>
357        + DefaultSolver
358        + LuValidator<M>
359        + KluValidator<M>
360        + MatrixKind,
361    CG: CodegenModule,
362    for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
363    for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
364    M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
365    <M::V as DefaultDenseMatrix>::M: Send + Sync,
366    for<'b> &'b M::V: VectorRef<M::V>,
367    for<'b> &'b M: MatrixRef<M>,
368{
369    fn matrix_type(&self) -> MatrixType {
370        MatrixType::from_diffsol::<M>()
371    }
372
373    fn nstates(&self) -> usize {
374        self.problem.eqn.nstates()
375    }
376
377    fn nparams(&self) -> usize {
378        self.problem.eqn.nparams()
379    }
380
381    fn nout(&self) -> usize {
382        self.problem.eqn.nout()
383    }
384
385    fn has_stop(&self) -> bool {
386        self.problem.eqn.root().is_some()
387    }
388
389    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError> {
390        validate_linear_solver::<M>(linear_solver)
391    }
392
393    fn set_atol(&mut self, atol: f64) {
394        self.problem.atol.fill(M::T::from_f64(atol).unwrap());
395    }
396
397    fn atol(&self) -> f64 {
398        self.problem.atol[0].to_f64().unwrap()
399    }
400
401    fn set_rtol(&mut self, rtol: f64) {
402        self.problem.rtol = M::T::from_f64(rtol).unwrap();
403    }
404
405    fn rtol(&self) -> f64 {
406        self.problem.rtol.to_f64().unwrap()
407    }
408
409    generate_ic_option_accessors! {
410        use_linesearch: bool,
411        max_linesearch_iterations: usize,
412        max_newton_iterations: usize,
413        max_linear_solver_setups: usize,
414        step_reduction_factor: f64,
415        armijo_constant: f64
416    }
417
418    generate_ode_option_accessors! {
419        max_nonlinear_solver_iterations: usize,
420        max_error_test_failures: usize,
421        min_timestep: f64,
422        update_jacobian_after_steps: usize,
423        update_rhs_jacobian_after_steps: usize,
424        threshold_to_update_jacobian: f64,
425        threshold_to_update_rhs_jacobian: f64
426    }
427
428    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError> {
429        self.setup_problem(params)?;
430        let n = self.problem.eqn.nstates();
431        let mut y0 = M::V::zeros(n, M::C::default());
432        let t0 = self.problem.t0;
433        self.problem.eqn.init().call_inplace(t0, &mut y0);
434        Ok((*y0.inner()).clone().to_host_array())
435    }
436
437    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError> {
438        self.setup_problem(params)?;
439        let n = self.problem.eqn.nstates();
440        let y = y
441            .iter()
442            .map(|&x| M::T::from_f64(x).unwrap())
443            .collect::<Vec<_>>();
444        let y_vec = M::V::from_slice(&y, M::C::default());
445        let mut dydt = M::V::zeros(n, M::C::default());
446        self.problem
447            .eqn
448            .rhs()
449            .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
450        Ok((*dydt.inner()).clone().to_host_array())
451    }
452
453    fn rhs_jac_mul(
454        &mut self,
455
456        params: &[f64],
457        t: f64,
458        y: &[f64],
459        v: &[f64],
460    ) -> Result<HostArray, DiffsolJsError> {
461        self.setup_problem(params)?;
462        let n = self.problem.eqn.nstates();
463        let y = y
464            .iter()
465            .map(|&x| M::T::from_f64(x).unwrap())
466            .collect::<Vec<_>>();
467        let v = v
468            .iter()
469            .map(|&x| M::T::from_f64(x).unwrap())
470            .collect::<Vec<_>>();
471        let y_vec = M::V::from_slice(&y, M::C::default());
472        let v_vec = M::V::from_slice(&v, M::C::default());
473        let mut dydt = M::V::zeros(n, M::C::default());
474        self.problem.eqn.rhs().jac_mul_inplace(
475            &y_vec,
476            M::T::from_f64(t).unwrap(),
477            &v_vec,
478            &mut dydt,
479        );
480        Ok((*dydt.inner()).clone().to_host_array())
481    }
482
483    fn solve(
484        &mut self,
485        method: OdeSolverType,
486        linear_solver: LinearSolverType,
487        params: &[f64],
488        final_time: f64,
489    ) -> SolveResult {
490        self.check(linear_solver)?;
491        self.setup_problem(params)?;
492        let final_time = M::T::from_f64(final_time).unwrap();
493        let soln = match linear_solver {
494            LinearSolverType::Default => {
495                method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
496            }
497            LinearSolverType::Lu => {
498                method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
499            }
500            LinearSolverType::Klu => {
501                method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
502            }
503        };
504        Ok(Box::new(soln?))
505    }
506
507    fn solve_dense(
508        &mut self,
509        method: OdeSolverType,
510        linear_solver: LinearSolverType,
511        params: &[f64],
512        t_eval: &[f64],
513    ) -> SolveResult {
514        self.check(linear_solver)?;
515        self.setup_problem(params)?;
516
517        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
518        let soln =
519            match linear_solver {
520                LinearSolverType::Default => method
521                    .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
522                LinearSolverType::Lu => method
523                    .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
524                LinearSolverType::Klu => method
525                    .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
526            };
527        Ok(Box::new(soln?))
528    }
529
530    fn solve_hybrid(
531        &mut self,
532        method: OdeSolverType,
533        linear_solver: LinearSolverType,
534        params: &[f64],
535        final_time: f64,
536    ) -> SolveResult {
537        self.check(linear_solver)?;
538        self.setup_problem(params)?;
539        let final_time = M::T::from_f64(final_time).unwrap();
540        let soln = match linear_solver {
541            LinearSolverType::Default => method
542                .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time),
543            LinearSolverType::Lu => method
544                .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time),
545            LinearSolverType::Klu => method
546                .solve_hybrid::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time),
547        };
548        Ok(Box::new(soln?))
549    }
550
551    fn solve_fwd_sens(
552        &mut self,
553        method: OdeSolverType,
554        linear_solver: LinearSolverType,
555        params: &[f64],
556        t_eval: &[f64],
557    ) -> SolveResult {
558        self.check(linear_solver)?;
559        self.setup_problem(params)?;
560
561        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
562        let soln = match linear_solver {
563            LinearSolverType::Default => {
564                method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
565            }
566            LinearSolverType::Lu => method
567                .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
568            LinearSolverType::Klu => method
569                .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
570        };
571        Ok(Box::new(soln?))
572    }
573
574    fn solve_hybrid_dense(
575        &mut self,
576        method: OdeSolverType,
577        linear_solver: LinearSolverType,
578        params: &[f64],
579        t_eval: &[f64],
580    ) -> SolveResult {
581        self.check(linear_solver)?;
582        self.setup_problem(params)?;
583
584        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
585        let soln = match linear_solver {
586            LinearSolverType::Default => method
587                .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
588            LinearSolverType::Lu => method
589                .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
590            LinearSolverType::Klu => method
591                .solve_hybrid_dense::<M, CG, <M as KluValidator<M>>::LS>(
592                    &mut self.problem,
593                    &t_eval,
594                ),
595        };
596        Ok(Box::new(soln?))
597    }
598
599    fn solve_hybrid_fwd_sens(
600        &mut self,
601        method: OdeSolverType,
602        linear_solver: LinearSolverType,
603        params: &[f64],
604        t_eval: &[f64],
605    ) -> SolveResult {
606        self.check(linear_solver)?;
607        self.setup_problem(params)?;
608
609        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
610        let soln = match linear_solver {
611            LinearSolverType::Default => method
612                .solve_hybrid_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(
613                    &mut self.problem,
614                    &t_eval,
615                ),
616            LinearSolverType::Lu => method
617                .solve_hybrid_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(
618                    &mut self.problem,
619                    &t_eval,
620                ),
621            LinearSolverType::Klu => method
622                .solve_hybrid_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(
623                    &mut self.problem,
624                    &t_eval,
625                ),
626        };
627        Ok(Box::new(soln?))
628    }
629
630    fn solve_sum_squares_adj(
631        &mut self,
632
633        method: OdeSolverType,
634        linear_solver: LinearSolverType,
635        backwards_method: OdeSolverType,
636        backwards_linear_solver: LinearSolverType,
637        params: &[f64],
638        data: HostArray,
639        t_eval: &[f64],
640    ) -> Result<(f64, HostArray), DiffsolJsError> {
641        self.check(linear_solver)?;
642        self.setup_problem(params)?;
643
644        let data = data.as_array()?;
645        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
646
647        let previous_integrate_out = self.problem.integrate_out;
648        self.problem.integrate_out = true;
649        let result = match linear_solver {
650            LinearSolverType::Default => method
651                .solve_sum_squares_adj::<M, CG, <M as DefaultSolver>::LS>(
652                    &mut self.problem,
653                    data,
654                    &t_eval,
655                    backwards_method,
656                    backwards_linear_solver,
657                ),
658            LinearSolverType::Lu => method
659                .solve_sum_squares_adj::<M, CG, <M as LuValidator<M>>::LS>(
660                    &mut self.problem,
661                    data,
662                    &t_eval,
663                    backwards_method,
664                    backwards_linear_solver,
665                ),
666            LinearSolverType::Klu => method
667                .solve_sum_squares_adj::<M, CG, <M as KluValidator<M>>::LS>(
668                    &mut self.problem,
669                    data,
670                    &t_eval,
671                    backwards_method,
672                    backwards_linear_solver,
673                ),
674        };
675        self.problem.integrate_out = previous_integrate_out;
676        let (y, y_sens) = result?;
677
678        Ok((
679            y.to_f64().unwrap(),
680            (*y_sens.inner()).clone().to_host_array(),
681        ))
682    }
683}
684
685#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
686mod tests {
687    use diffsol::{
688        CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
689    };
690
691    #[cfg(feature = "diffsl-llvm")]
692    use crate::test_support::{
693        hybrid_logistic_state_dr, logistic_integral, logistic_state_dr, matrix_host,
694    };
695    use crate::{
696        host_array::FromHostArray,
697        linear_solver_type::LinearSolverType,
698        matrix_type::MatrixType,
699        ode_solver_type::OdeSolverType,
700        scalar_type::ScalarType,
701        test_support::{
702            assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
703            logistic_state, LOGISTIC_X0,
704        },
705    };
706
707    use super::{solve_factory_with_jit_backend, GenericSolve, Solve};
708
709    fn make_generic_solve<CG>() -> GenericSolve<diffsol::NalgebraMat<f64>, CG>
710    where
711        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
712    {
713        let problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
714            .build_from_diffsl::<CG>(logistic_diffsl_code())
715            .unwrap();
716        GenericSolve { problem }
717    }
718
719    fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
720    where
721        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
722    {
723        for matrix_type in [
724            MatrixType::NalgebraDense,
725            MatrixType::FaerDense,
726            MatrixType::FaerSparse,
727        ] {
728            for scalar_type in [ScalarType::F32, ScalarType::F64] {
729                assert!(solve_factory_with_jit_backend::<CG>(
730                    logistic_diffsl_code(),
731                    matrix_type,
732                    scalar_type,
733                )
734                .is_ok());
735            }
736        }
737    }
738
739    fn assert_solve_metadata_and_helpers<CG>()
740    where
741        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
742    {
743        let mut solve = make_generic_solve::<CG>();
744        assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
745        assert_eq!(solve.nstates(), 1);
746        assert_eq!(solve.nparams(), 1);
747        assert_eq!(solve.nout(), 1);
748        assert!(!solve.has_stop());
749        assert!(solve.check(LinearSolverType::Default).is_ok());
750        assert!(solve.check(LinearSolverType::Lu).is_ok());
751        assert!(solve.check(LinearSolverType::Klu).is_err());
752
753        solve.set_atol(1e-5);
754        solve.set_rtol(1e-4);
755        assert_close(solve.atol(), 1e-5, 1e-12, "solve atol");
756        assert_close(solve.rtol(), 1e-4, 1e-12, "solve rtol");
757
758        let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
759        assert_eq!(y0, vec![LOGISTIC_X0]);
760
761        let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
762        assert_eq!(rhs.len(), 1);
763        assert_close(rhs[0], 0.375, 1e-12, "solve rhs");
764
765        let jac_mul =
766            Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
767                .unwrap();
768        assert_eq!(jac_mul.len(), 1);
769        assert_close(jac_mul[0], 3.0, 1e-12, "solve rhs jac mul");
770    }
771
772    fn assert_solve_runtime_paths<CG>()
773    where
774        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
775    {
776        let mut solve = make_generic_solve::<CG>();
777        let soln = solve
778            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
779            .unwrap();
780        let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
781        let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
782        assert_close(*ts.last().unwrap(), 1.0, 5e-4, "solve final time");
783        assert_close(
784            ys[0][ts.len() - 1],
785            logistic_state(LOGISTIC_X0, 2.0, 1.0),
786            5e-4,
787            "solve final value",
788        );
789
790        let mut solve = make_generic_solve::<CG>();
791        let dense = solve
792            .solve_dense(
793                OdeSolverType::Tsit45,
794                LinearSolverType::Lu,
795                &[2.0],
796                &[0.25, 0.5, 1.0],
797            )
798            .unwrap();
799        let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
800        let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
801        assert_eq!(ts, vec![0.25, 0.5, 1.0]);
802        for (i, &t) in ts.iter().enumerate() {
803            assert_close(
804                ys[0][i],
805                logistic_state(LOGISTIC_X0, 2.0, t),
806                5e-4,
807                &format!("solve_dense[{i}]"),
808            );
809        }
810
811        let mut solve = make_generic_solve::<CG>();
812        let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
813            Ok(_) => panic!("expected parameter count mismatch"),
814            Err(err) => err,
815        };
816        assert!(err.to_string().contains("Expecting 1 params but got 0"));
817
818        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
819            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
820            .unwrap();
821        let mut hybrid_solve = GenericSolve {
822            problem: hybrid_problem,
823        };
824        let hybrid = hybrid_solve
825            .solve_hybrid(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
826            .unwrap();
827        let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
828        let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
829        assert_close(
830            *hybrid_ts.last().unwrap(),
831            2.0,
832            5e-4,
833            "solve_hybrid final time",
834        );
835        assert_close(
836            hybrid_ys[0][hybrid_ts.len() - 1],
837            hybrid_logistic_state(2.0, 2.0),
838            5e-4,
839            "solve_hybrid final value",
840        );
841
842        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
843            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
844            .unwrap();
845        let mut hybrid_solve = GenericSolve {
846            problem: hybrid_problem,
847        };
848        let hybrid_dense = hybrid_solve
849            .solve_hybrid_dense(
850                OdeSolverType::Tsit45,
851                LinearSolverType::Lu,
852                &[2.0],
853                &[0.5, 1.0, 1.5, 2.0],
854            )
855            .unwrap();
856        let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
857        let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
858        assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
859        for (i, &t) in hybrid_dense_ts.iter().enumerate() {
860            assert_close(
861                hybrid_dense_ys[0][i],
862                hybrid_logistic_state(2.0, t),
863                5e-4,
864                &format!("solve_hybrid_dense[{i}]"),
865            );
866        }
867    }
868
869    #[cfg(feature = "diffsl-llvm")]
870    fn assert_solve_sensitivity_paths() {
871        let t_eval = [0.25, 0.5, 1.0];
872
873        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
874        let sens = solve
875            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
876            .unwrap();
877        let sens_values = sens.get_sens();
878        assert_eq!(sens_values.len(), 1);
879        let sens_matrix =
880            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
881        for (i, &t) in t_eval.iter().enumerate() {
882            assert_close(
883                sens_matrix[0][i],
884                logistic_state_dr(LOGISTIC_X0, 2.0, t),
885                5e-4,
886                &format!("solve_fwd_sens[{i}]"),
887            );
888        }
889
890        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
891            .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
892            .unwrap();
893        let mut solve = GenericSolve {
894            problem: hybrid_problem,
895        };
896        let hybrid_sens = solve
897            .solve_hybrid_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
898            .unwrap();
899        let sens_values = hybrid_sens.get_sens();
900        let sens_matrix =
901            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
902        for (i, &t) in t_eval.iter().enumerate() {
903            assert_close(
904                sens_matrix[0][i],
905                hybrid_logistic_state_dr(2.0, t),
906                5e-4,
907                &format!("solve_hybrid_fwd_sens[{i}]"),
908            );
909        }
910
911        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
912        let adjoint_data: Vec<f64> = adjoint_t_eval
913            .iter()
914            .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
915            .collect();
916        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
917        let (objective, gradient) = solve
918            .solve_sum_squares_adj(
919                OdeSolverType::Bdf,
920                LinearSolverType::Lu,
921                OdeSolverType::TrBdf2,
922                LinearSolverType::Lu,
923                &[2.0],
924                matrix_host(1, adjoint_t_eval.len(), &adjoint_data),
925                &adjoint_t_eval,
926            )
927            .unwrap();
928        assert!(objective.is_finite());
929        let gradient = Vec::<f64>::from_host_array(gradient).unwrap();
930        assert_eq!(gradient.len(), 1);
931        assert!(gradient[0].is_finite());
932    }
933
934    #[cfg(feature = "diffsl-cranelift")]
935    #[test]
936    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
937        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
938    }
939
940    #[cfg(feature = "diffsl-cranelift")]
941    #[test]
942    fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
943        assert_solve_metadata_and_helpers::<diffsol::CraneliftJitModule>();
944        assert_solve_runtime_paths::<diffsol::CraneliftJitModule>();
945    }
946
947    #[cfg(feature = "diffsl-llvm")]
948    #[test]
949    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
950        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
951    }
952
953    #[cfg(feature = "diffsl-llvm")]
954    #[test]
955    fn solve_trait_helpers_and_runtime_paths_for_llvm() {
956        assert_solve_metadata_and_helpers::<diffsol::LlvmModule>();
957        assert_solve_runtime_paths::<diffsol::LlvmModule>();
958    }
959
960    #[cfg(feature = "diffsl-llvm")]
961    #[test]
962    fn solve_trait_sensitivity_paths_for_llvm() {
963        assert_solve_sensitivity_paths();
964    }
965
966    #[cfg(feature = "diffsl-cranelift")]
967    #[test]
968    fn setup_problem_validates_parameter_count_for_cranelift() {
969        let mut solve = make_generic_solve::<diffsol::CraneliftJitModule>();
970        let err = solve.setup_problem(&[]).unwrap_err();
971        assert!(err.to_string().contains("Expecting 1 params but got 0"));
972
973        solve.setup_problem(&[2.0]).unwrap();
974        let mut params = solve.problem.context().vector_zeros(1);
975        solve.problem.eqn.get_params(&mut params);
976        assert_eq!(params.get_index(0), 2.0);
977    }
978
979    #[cfg(feature = "diffsl-llvm")]
980    #[test]
981    fn setup_problem_validates_parameter_count_for_llvm() {
982        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
983        let err = solve.setup_problem(&[]).unwrap_err();
984        assert!(err.to_string().contains("Expecting 1 params but got 0"));
985
986        solve.setup_problem(&[2.0]).unwrap();
987        let mut params = solve.problem.context().vector_zeros(1);
988        solve.problem.eqn.get_params(&mut params);
989        assert_eq!(params.get_index(0), 2.0);
990    }
991}