Skip to main content

diffsol_c/
solve.rs

1// Delegate solver types selected at runtime in Host to concrete solver types
2// in Rust.
3
4use diffsol::{
5    CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DiffSl, MatrixCommon,
6    NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, OdeSolverProblem, Op, Vector,
7    VectorCommon, VectorHost, VectorRef,
8    error::DiffsolError,
9    matrix::{MatrixHost, MatrixRef},
10};
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use diffsol::{CodegenModuleCompile, CodegenModuleJit};
13use num_traits::{FromPrimitive, ToPrimitive}; // for from_f64 and to_f64
14use paste::paste;
15
16use crate::error::DiffsolJsError;
17use crate::host_array::{HostArray, ToHostArray};
18#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
19use crate::jit::JitBackendType;
20#[cfg(feature = "external")]
21use crate::scalar_type::ExternalScalar;
22use crate::scalar_type::{Scalar, ScalarType};
23use crate::{
24    generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
25    generate_trait_ic_option_accessors, generate_trait_ode_option_accessors,
26    option_value_from_store, option_value_to_store,
27};
28
29use crate::{
30    linear_solver_type::LinearSolverType,
31    matrix_type::{MatrixKind, MatrixType},
32    ode_solver_type::OdeSolverType,
33    valid_linear_solver::{KluValidator, LuValidator, validate_linear_solver},
34};
35
36// Each matrix type implements PySolve as bridge between diffsol and Host
37
38use crate::solution::Solution;
39pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolJsError>;
40
41pub(crate) trait Solve {
42    fn matrix_type(&self) -> MatrixType;
43
44    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError>;
45
46    fn rhs_jac_mul(
47        &mut self,
48        params: &[f64],
49        t: f64,
50        y: &[f64],
51        v: &[f64],
52    ) -> Result<HostArray, DiffsolJsError>;
53
54    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError>;
55
56    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError>;
57    fn set_rtol(&mut self, rtol: f64);
58    fn rtol(&self) -> f64;
59    fn set_atol(&mut self, atol: f64);
60    fn atol(&self) -> f64;
61
62    // New API: solution object support
63    fn solve(
64        &mut self,
65        method: OdeSolverType,
66        linear_solver: LinearSolverType,
67        params: &[f64],
68        final_time: f64,
69    ) -> SolveResult;
70
71    fn solve_hybrid(
72        &mut self,
73        method: OdeSolverType,
74        linear_solver: LinearSolverType,
75        params: &[f64],
76        final_time: f64,
77    ) -> SolveResult;
78
79    fn solve_dense(
80        &mut self,
81        method: OdeSolverType,
82        linear_solver: LinearSolverType,
83        params: &[f64],
84        t_eval: &[f64],
85    ) -> SolveResult;
86
87    fn solve_hybrid_dense(
88        &mut self,
89        method: OdeSolverType,
90        linear_solver: LinearSolverType,
91        params: &[f64],
92        t_eval: &[f64],
93    ) -> SolveResult;
94
95    fn solve_fwd_sens(
96        &mut self,
97        method: OdeSolverType,
98        linear_solver: LinearSolverType,
99        params: &[f64],
100        t_eval: &[f64],
101    ) -> SolveResult;
102
103    fn solve_hybrid_fwd_sens(
104        &mut self,
105        method: OdeSolverType,
106        linear_solver: LinearSolverType,
107        params: &[f64],
108        t_eval: &[f64],
109    ) -> SolveResult;
110
111    #[allow(clippy::type_complexity)]
112    #[allow(clippy::too_many_arguments)]
113    fn solve_sum_squares_adj(
114        &mut self,
115        method: OdeSolverType,
116        linear_solver: LinearSolverType,
117        backwards_method: OdeSolverType,
118        backwards_linear_solver: LinearSolverType,
119        params: &[f64],
120        data: HostArray,
121        t_eval: &[f64],
122    ) -> Result<(f64, HostArray), DiffsolJsError>;
123
124    generate_trait_ic_option_accessors! {
125        use_linesearch: bool,
126        max_linesearch_iterations: usize,
127        max_newton_iterations: usize,
128        max_linear_solver_setups: usize,
129        step_reduction_factor: f64,
130        armijo_constant: f64
131    }
132    generate_trait_ode_option_accessors! {
133        max_nonlinear_solver_iterations: usize,
134        max_error_test_failures: usize,
135        min_timestep: f64,
136        update_jacobian_after_steps: usize,
137        update_rhs_jacobian_after_steps: usize,
138        threshold_to_update_jacobian: f64,
139        threshold_to_update_rhs_jacobian: f64
140    }
141}
142// Public factory method for generating an instance based on matrix type
143#[cfg(feature = "external")]
144pub(crate) fn solve_factory_external(
145    rhs_state_deps: Vec<(usize, usize)>,
146    rhs_input_deps: Vec<(usize, usize)>,
147    mass_state_deps: Vec<(usize, usize)>,
148    matrix_type: MatrixType,
149    scalar_type: ScalarType,
150) -> Result<Box<dyn Solve>, DiffsolJsError> {
151    let solve: Box<dyn Solve> = match matrix_type {
152        MatrixType::NalgebraDense => match scalar_type {
153            #[cfg(feature = "diffsl-external-f32")]
154            ScalarType::F32 => Box::new(GenericSolve::<
155                diffsol::NalgebraMat<f32>,
156                diffsl::ExternalModule<f32>,
157            >::from_external(
158                rhs_state_deps, rhs_input_deps, mass_state_deps, false
159            )?),
160            #[cfg(feature = "diffsl-external-f64")]
161            ScalarType::F64 => Box::new(GenericSolve::<
162                diffsol::NalgebraMat<f64>,
163                diffsl::ExternalModule<f64>,
164            >::from_external(
165                rhs_state_deps, rhs_input_deps, mass_state_deps, false
166            )?),
167            _ => {
168                return Err(DiffsolJsError::from(DiffsolError::Other(
169                    "Unsupported scalar type for NalgebraDense".to_string(),
170                )));
171            }
172        },
173        MatrixType::FaerDense => match scalar_type {
174            #[cfg(feature = "diffsl-external-f32")]
175            ScalarType::F32 => Box::new(GenericSolve::<
176                diffsol::FaerMat<f32>,
177                diffsl::ExternalModule<f32>,
178            >::from_external(
179                rhs_state_deps, rhs_input_deps, mass_state_deps, false
180            )?),
181            #[cfg(feature = "diffsl-external-f64")]
182            ScalarType::F64 => Box::new(GenericSolve::<
183                diffsol::FaerMat<f64>,
184                diffsl::ExternalModule<f64>,
185            >::from_external(
186                rhs_state_deps, rhs_input_deps, mass_state_deps, false
187            )?),
188            _ => {
189                return Err(DiffsolJsError::from(DiffsolError::Other(
190                    "Unsupported scalar type for FaerDense".to_string(),
191                )));
192            }
193        },
194        MatrixType::FaerSparse => match scalar_type {
195            #[cfg(feature = "diffsl-external-f32")]
196            ScalarType::F32 => Box::new(GenericSolve::<
197                diffsol::FaerSparseMat<f32>,
198                diffsl::ExternalModule<f32>,
199            >::from_external(
200                rhs_state_deps, rhs_input_deps, mass_state_deps, false
201            )?),
202            #[cfg(feature = "diffsl-external-f64")]
203            ScalarType::F64 => Box::new(GenericSolve::<
204                diffsol::FaerSparseMat<f64>,
205                diffsl::ExternalModule<f64>,
206            >::from_external(
207                rhs_state_deps, rhs_input_deps, mass_state_deps, false
208            )?),
209            _ => {
210                return Err(DiffsolJsError::from(DiffsolError::Other(
211                    "Unsupported scalar type for FaerSparse".to_string(),
212                )));
213            }
214        },
215    };
216    Ok(solve)
217}
218
219#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
220pub(crate) fn solve_factory_jit(
221    code: &str,
222    jit_backend: JitBackendType,
223    matrix_type: MatrixType,
224    scalar_type: ScalarType,
225) -> Result<Box<dyn Solve>, DiffsolJsError> {
226    match jit_backend {
227        #[cfg(feature = "diffsl-cranelift")]
228        JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
229            code,
230            matrix_type,
231            scalar_type,
232        ),
233        #[cfg(feature = "diffsl-llvm")]
234        JitBackendType::Llvm => {
235            solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
236        }
237    }
238}
239
240#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
241fn solve_factory_with_jit_backend<CG>(
242    code: &str,
243    matrix_type: MatrixType,
244    scalar_type: ScalarType,
245) -> Result<Box<dyn Solve>, DiffsolJsError>
246where
247    CG: CodegenModule + CodegenModuleJit + CodegenModuleCompile,
248{
249    let solve: Box<dyn Solve> = match matrix_type {
250        MatrixType::NalgebraDense => match scalar_type {
251            ScalarType::F32 => {
252                let problem =
253                    OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
254                Box::new(GenericSolve { problem })
255            }
256            ScalarType::F64 => {
257                let problem =
258                    OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
259                Box::new(GenericSolve { problem })
260            }
261        },
262        MatrixType::FaerDense => match scalar_type {
263            ScalarType::F32 => {
264                let problem =
265                    OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
266                Box::new(GenericSolve { problem })
267            }
268            ScalarType::F64 => {
269                let problem =
270                    OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
271                Box::new(GenericSolve { problem })
272            }
273        },
274        MatrixType::FaerSparse => match scalar_type {
275            ScalarType::F32 => {
276                let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
277                    .build_from_diffsl::<CG>(code)?;
278                Box::new(GenericSolve { problem })
279            }
280            ScalarType::F64 => {
281                let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
282                    .build_from_diffsl::<CG>(code)?;
283                Box::new(GenericSolve { problem })
284            }
285        },
286    };
287    Ok(solve)
288}
289
290pub(crate) struct GenericSolve<M, CG>
291where
292    M: MatrixHost<T: Scalar>,
293    M::V: Vector + VectorHost,
294    CG: CodegenModule,
295{
296    problem: OdeSolverProblem<DiffSl<M, CG>>,
297}
298
299impl<M, CG> GenericSolve<M, CG>
300where
301    M: MatrixHost<T: Scalar>,
302    M::V: Vector + VectorHost + DefaultDenseMatrix,
303    CG: CodegenModule,
304{
305    pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolJsError> {
306        let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
307        let params = M::V::from_slice(&params, M::C::default());
308
309        // Attempt to set problem from params and config
310        let nparams = self.problem.eqn.nparams();
311        if params.len() == nparams {
312            self.problem.eqn.set_params(&params);
313            Ok(())
314        } else {
315            Err(DiffsolError::Other(format!(
316                "Expecting {} params but got {}",
317                nparams,
318                params.len()
319            ))
320            .into())
321        }
322    }
323}
324
325#[cfg(feature = "external")]
326impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
327where
328    M: MatrixHost<T: ExternalScalar>,
329    M::V: Vector + VectorHost + DefaultDenseMatrix,
330{
331    pub fn from_external(
332        rhs_state_deps: Vec<(usize, usize)>,
333        rhs_input_deps: Vec<(usize, usize)>,
334        mass_state_deps: Vec<(usize, usize)>,
335        include_sensitivities: bool,
336    ) -> Result<Self, DiffsolJsError> {
337        let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
338            M::C::default(),
339            rhs_state_deps,
340            rhs_input_deps,
341            mass_state_deps,
342            include_sensitivities,
343        )?;
344        let default_p = vec![0.0; eqn.nparams()];
345        let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
346        Ok(GenericSolve { problem })
347    }
348}
349
350impl<M, CG> Solve for GenericSolve<M, CG>
351where
352    M: MatrixHost<T: Scalar + ToPrimitive>
353        + DefaultSolver
354        + LuValidator<M>
355        + KluValidator<M>
356        + MatrixKind,
357    CG: CodegenModule,
358    for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
359    for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
360    M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
361    <M::V as DefaultDenseMatrix>::M: Send + Sync,
362    for<'b> &'b M::V: VectorRef<M::V>,
363    for<'b> &'b M: MatrixRef<M>,
364{
365    fn matrix_type(&self) -> MatrixType {
366        MatrixType::from_diffsol::<M>()
367    }
368
369    fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError> {
370        validate_linear_solver::<M>(linear_solver)
371    }
372
373    fn set_atol(&mut self, atol: f64) {
374        self.problem.atol.fill(M::T::from_f64(atol).unwrap());
375    }
376
377    fn atol(&self) -> f64 {
378        self.problem.atol[0].to_f64().unwrap()
379    }
380
381    fn set_rtol(&mut self, rtol: f64) {
382        self.problem.rtol = M::T::from_f64(rtol).unwrap();
383    }
384
385    fn rtol(&self) -> f64 {
386        self.problem.rtol.to_f64().unwrap()
387    }
388
389    generate_ic_option_accessors! {
390        use_linesearch: bool,
391        max_linesearch_iterations: usize,
392        max_newton_iterations: usize,
393        max_linear_solver_setups: usize,
394        step_reduction_factor: f64,
395        armijo_constant: f64
396    }
397
398    generate_ode_option_accessors! {
399        max_nonlinear_solver_iterations: usize,
400        max_error_test_failures: usize,
401        min_timestep: f64,
402        update_jacobian_after_steps: usize,
403        update_rhs_jacobian_after_steps: usize,
404        threshold_to_update_jacobian: f64,
405        threshold_to_update_rhs_jacobian: f64
406    }
407
408    fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError> {
409        self.setup_problem(params)?;
410        let n = self.problem.eqn.nstates();
411        let mut y0 = M::V::zeros(n, M::C::default());
412        let t0 = self.problem.t0;
413        self.problem.eqn.init().call_inplace(t0, &mut y0);
414        Ok((*y0.inner()).clone().to_host_array())
415    }
416
417    fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError> {
418        self.setup_problem(params)?;
419        let n = self.problem.eqn.nstates();
420        let y = y
421            .iter()
422            .map(|&x| M::T::from_f64(x).unwrap())
423            .collect::<Vec<_>>();
424        let y_vec = M::V::from_slice(&y, M::C::default());
425        let mut dydt = M::V::zeros(n, M::C::default());
426        self.problem
427            .eqn
428            .rhs()
429            .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
430        Ok((*dydt.inner()).clone().to_host_array())
431    }
432
433    fn rhs_jac_mul(
434        &mut self,
435
436        params: &[f64],
437        t: f64,
438        y: &[f64],
439        v: &[f64],
440    ) -> Result<HostArray, DiffsolJsError> {
441        self.setup_problem(params)?;
442        let n = self.problem.eqn.nstates();
443        let y = y
444            .iter()
445            .map(|&x| M::T::from_f64(x).unwrap())
446            .collect::<Vec<_>>();
447        let v = v
448            .iter()
449            .map(|&x| M::T::from_f64(x).unwrap())
450            .collect::<Vec<_>>();
451        let y_vec = M::V::from_slice(&y, M::C::default());
452        let v_vec = M::V::from_slice(&v, M::C::default());
453        let mut dydt = M::V::zeros(n, M::C::default());
454        self.problem.eqn.rhs().jac_mul_inplace(
455            &y_vec,
456            M::T::from_f64(t).unwrap(),
457            &v_vec,
458            &mut dydt,
459        );
460        Ok((*dydt.inner()).clone().to_host_array())
461    }
462
463    fn solve(
464        &mut self,
465        method: OdeSolverType,
466        linear_solver: LinearSolverType,
467        params: &[f64],
468        final_time: f64,
469    ) -> SolveResult {
470        self.check(linear_solver)?;
471        self.setup_problem(params)?;
472        let final_time = M::T::from_f64(final_time).unwrap();
473        let soln = match linear_solver {
474            LinearSolverType::Default => {
475                method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
476            }
477            LinearSolverType::Lu => {
478                method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
479            }
480            LinearSolverType::Klu => {
481                method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
482            }
483        };
484        Ok(Box::new(soln?))
485    }
486
487    fn solve_dense(
488        &mut self,
489        method: OdeSolverType,
490        linear_solver: LinearSolverType,
491        params: &[f64],
492        t_eval: &[f64],
493    ) -> SolveResult {
494        self.check(linear_solver)?;
495        self.setup_problem(params)?;
496
497        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
498        let soln =
499            match linear_solver {
500                LinearSolverType::Default => method
501                    .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
502                LinearSolverType::Lu => method
503                    .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
504                LinearSolverType::Klu => method
505                    .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
506            };
507        Ok(Box::new(soln?))
508    }
509
510    fn solve_hybrid(
511        &mut self,
512        method: OdeSolverType,
513        linear_solver: LinearSolverType,
514        params: &[f64],
515        final_time: f64,
516    ) -> SolveResult {
517        self.check(linear_solver)?;
518        self.setup_problem(params)?;
519        let final_time = M::T::from_f64(final_time).unwrap();
520        let soln = match linear_solver {
521            LinearSolverType::Default => method
522                .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time),
523            LinearSolverType::Lu => method
524                .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time),
525            LinearSolverType::Klu => method
526                .solve_hybrid::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time),
527        };
528        Ok(Box::new(soln?))
529    }
530
531    fn solve_fwd_sens(
532        &mut self,
533        method: OdeSolverType,
534        linear_solver: LinearSolverType,
535        params: &[f64],
536        t_eval: &[f64],
537    ) -> SolveResult {
538        self.check(linear_solver)?;
539        self.setup_problem(params)?;
540
541        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
542        let soln = match linear_solver {
543            LinearSolverType::Default => {
544                method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
545            }
546            LinearSolverType::Lu => method
547                .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
548            LinearSolverType::Klu => method
549                .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
550        };
551        Ok(Box::new(soln?))
552    }
553
554    fn solve_hybrid_dense(
555        &mut self,
556        method: OdeSolverType,
557        linear_solver: LinearSolverType,
558        params: &[f64],
559        t_eval: &[f64],
560    ) -> SolveResult {
561        self.check(linear_solver)?;
562        self.setup_problem(params)?;
563
564        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
565        let soln = match linear_solver {
566            LinearSolverType::Default => method
567                .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
568            LinearSolverType::Lu => method
569                .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
570            LinearSolverType::Klu => method
571                .solve_hybrid_dense::<M, CG, <M as KluValidator<M>>::LS>(
572                    &mut self.problem,
573                    &t_eval,
574                ),
575        };
576        Ok(Box::new(soln?))
577    }
578
579    fn solve_hybrid_fwd_sens(
580        &mut self,
581        method: OdeSolverType,
582        linear_solver: LinearSolverType,
583        params: &[f64],
584        t_eval: &[f64],
585    ) -> SolveResult {
586        self.check(linear_solver)?;
587        self.setup_problem(params)?;
588
589        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
590        let soln = match linear_solver {
591            LinearSolverType::Default => method
592                .solve_hybrid_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(
593                    &mut self.problem,
594                    &t_eval,
595                ),
596            LinearSolverType::Lu => method
597                .solve_hybrid_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(
598                    &mut self.problem,
599                    &t_eval,
600                ),
601            LinearSolverType::Klu => method
602                .solve_hybrid_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(
603                    &mut self.problem,
604                    &t_eval,
605                ),
606        };
607        Ok(Box::new(soln?))
608    }
609
610    fn solve_sum_squares_adj(
611        &mut self,
612
613        method: OdeSolverType,
614        linear_solver: LinearSolverType,
615        backwards_method: OdeSolverType,
616        backwards_linear_solver: LinearSolverType,
617        params: &[f64],
618        data: HostArray,
619        t_eval: &[f64],
620    ) -> Result<(f64, HostArray), DiffsolJsError> {
621        self.check(linear_solver)?;
622        self.setup_problem(params)?;
623
624        let data = data.as_array()?;
625        let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
626
627        let previous_integrate_out = self.problem.integrate_out;
628        self.problem.integrate_out = true;
629        let result = match linear_solver {
630            LinearSolverType::Default => method
631                .solve_sum_squares_adj::<M, CG, <M as DefaultSolver>::LS>(
632                    &mut self.problem,
633                    data,
634                    &t_eval,
635                    backwards_method,
636                    backwards_linear_solver,
637                ),
638            LinearSolverType::Lu => method
639                .solve_sum_squares_adj::<M, CG, <M as LuValidator<M>>::LS>(
640                    &mut self.problem,
641                    data,
642                    &t_eval,
643                    backwards_method,
644                    backwards_linear_solver,
645                ),
646            LinearSolverType::Klu => method
647                .solve_sum_squares_adj::<M, CG, <M as KluValidator<M>>::LS>(
648                    &mut self.problem,
649                    data,
650                    &t_eval,
651                    backwards_method,
652                    backwards_linear_solver,
653                ),
654        };
655        self.problem.integrate_out = previous_integrate_out;
656        let (y, y_sens) = result?;
657
658        Ok((
659            y.to_f64().unwrap(),
660            (*y_sens.inner()).clone().to_host_array(),
661        ))
662    }
663}
664
665#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
666mod tests {
667    use diffsol::{
668        CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
669    };
670
671    use crate::{
672        host_array::FromHostArray,
673        linear_solver_type::LinearSolverType,
674        matrix_type::MatrixType,
675        ode_solver_type::OdeSolverType,
676        scalar_type::ScalarType,
677        test_support::{
678            LOGISTIC_X0, assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state,
679            hybrid_logistic_state_dr, logistic_diffsl_code, logistic_integral, logistic_state,
680            logistic_state_dr, matrix_host,
681        },
682    };
683
684    use super::{GenericSolve, Solve, solve_factory_with_jit_backend};
685
686    fn make_generic_solve<CG>() -> GenericSolve<diffsol::NalgebraMat<f64>, CG>
687    where
688        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
689    {
690        let problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
691            .build_from_diffsl::<CG>(logistic_diffsl_code())
692            .unwrap();
693        GenericSolve { problem }
694    }
695
696    fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
697    where
698        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
699    {
700        for matrix_type in [
701            MatrixType::NalgebraDense,
702            MatrixType::FaerDense,
703            MatrixType::FaerSparse,
704        ] {
705            for scalar_type in [ScalarType::F32, ScalarType::F64] {
706                assert!(
707                    solve_factory_with_jit_backend::<CG>(
708                        logistic_diffsl_code(),
709                        matrix_type,
710                        scalar_type,
711                    )
712                    .is_ok()
713                );
714            }
715        }
716    }
717
718    fn assert_solve_metadata_and_helpers<CG>()
719    where
720        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
721    {
722        let mut solve = make_generic_solve::<CG>();
723        assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
724        assert!(solve.check(LinearSolverType::Default).is_ok());
725        assert!(solve.check(LinearSolverType::Lu).is_ok());
726        assert!(solve.check(LinearSolverType::Klu).is_err());
727
728        solve.set_atol(1e-5);
729        solve.set_rtol(1e-4);
730        assert_close(solve.atol(), 1e-5, 1e-12, "solve atol");
731        assert_close(solve.rtol(), 1e-4, 1e-12, "solve rtol");
732
733        let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
734        assert_eq!(y0, vec![LOGISTIC_X0]);
735
736        let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
737        assert_eq!(rhs.len(), 1);
738        assert_close(rhs[0], 0.375, 1e-12, "solve rhs");
739
740        let jac_mul =
741            Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
742                .unwrap();
743        assert_eq!(jac_mul.len(), 1);
744        assert_close(jac_mul[0], 3.0, 1e-12, "solve rhs jac mul");
745    }
746
747    fn assert_solve_runtime_paths<CG>()
748    where
749        CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
750    {
751        let mut solve = make_generic_solve::<CG>();
752        let soln = solve
753            .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
754            .unwrap();
755        let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
756        let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
757        assert_close(*ts.last().unwrap(), 1.0, 5e-4, "solve final time");
758        assert_close(
759            ys[0][ts.len() - 1],
760            logistic_state(LOGISTIC_X0, 2.0, 1.0),
761            5e-4,
762            "solve final value",
763        );
764
765        let mut solve = make_generic_solve::<CG>();
766        let dense = solve
767            .solve_dense(
768                OdeSolverType::Tsit45,
769                LinearSolverType::Lu,
770                &[2.0],
771                &[0.25, 0.5, 1.0],
772            )
773            .unwrap();
774        let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
775        let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
776        assert_eq!(ts, vec![0.25, 0.5, 1.0]);
777        for (i, &t) in ts.iter().enumerate() {
778            assert_close(
779                ys[0][i],
780                logistic_state(LOGISTIC_X0, 2.0, t),
781                5e-4,
782                &format!("solve_dense[{i}]"),
783            );
784        }
785
786        let mut solve = make_generic_solve::<CG>();
787        let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
788            Ok(_) => panic!("expected parameter count mismatch"),
789            Err(err) => err,
790        };
791        assert!(err.to_string().contains("Expecting 1 params but got 0"));
792
793        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
794            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
795            .unwrap();
796        let mut hybrid_solve = GenericSolve {
797            problem: hybrid_problem,
798        };
799        let hybrid = hybrid_solve
800            .solve_hybrid(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
801            .unwrap();
802        let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
803        let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
804        assert_close(
805            *hybrid_ts.last().unwrap(),
806            2.0,
807            5e-4,
808            "solve_hybrid final time",
809        );
810        assert_close(
811            hybrid_ys[0][hybrid_ts.len() - 1],
812            hybrid_logistic_state(2.0, 2.0),
813            5e-4,
814            "solve_hybrid final value",
815        );
816
817        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
818            .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
819            .unwrap();
820        let mut hybrid_solve = GenericSolve {
821            problem: hybrid_problem,
822        };
823        let hybrid_dense = hybrid_solve
824            .solve_hybrid_dense(
825                OdeSolverType::Tsit45,
826                LinearSolverType::Lu,
827                &[2.0],
828                &[0.5, 1.0, 1.5, 2.0],
829            )
830            .unwrap();
831        let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
832        let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
833        assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
834        for (i, &t) in hybrid_dense_ts.iter().enumerate() {
835            assert_close(
836                hybrid_dense_ys[0][i],
837                hybrid_logistic_state(2.0, t),
838                5e-4,
839                &format!("solve_hybrid_dense[{i}]"),
840            );
841        }
842    }
843
844    #[cfg(feature = "diffsl-llvm")]
845    fn assert_solve_sensitivity_paths() {
846        let t_eval = [0.25, 0.5, 1.0];
847
848        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
849        let sens = solve
850            .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
851            .unwrap();
852        let sens_values = sens.get_sens();
853        assert_eq!(sens_values.len(), 1);
854        let sens_matrix =
855            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
856        for (i, &t) in t_eval.iter().enumerate() {
857            assert_close(
858                sens_matrix[0][i],
859                logistic_state_dr(LOGISTIC_X0, 2.0, t),
860                5e-4,
861                &format!("solve_fwd_sens[{i}]"),
862            );
863        }
864
865        let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
866            .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
867            .unwrap();
868        let mut solve = GenericSolve {
869            problem: hybrid_problem,
870        };
871        let hybrid_sens = solve
872            .solve_hybrid_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
873            .unwrap();
874        let sens_values = hybrid_sens.get_sens();
875        let sens_matrix =
876            Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
877        for (i, &t) in t_eval.iter().enumerate() {
878            assert_close(
879                sens_matrix[0][i],
880                hybrid_logistic_state_dr(2.0, t),
881                5e-4,
882                &format!("solve_hybrid_fwd_sens[{i}]"),
883            );
884        }
885
886        let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
887        let adjoint_data: Vec<f64> = adjoint_t_eval
888            .iter()
889            .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
890            .collect();
891        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
892        let (objective, gradient) = solve
893            .solve_sum_squares_adj(
894                OdeSolverType::Bdf,
895                LinearSolverType::Lu,
896                OdeSolverType::TrBdf2,
897                LinearSolverType::Lu,
898                &[2.0],
899                matrix_host(1, adjoint_t_eval.len(), &adjoint_data),
900                &adjoint_t_eval,
901            )
902            .unwrap();
903        assert!(objective.is_finite());
904        let gradient = Vec::<f64>::from_host_array(gradient).unwrap();
905        assert_eq!(gradient.len(), 1);
906        assert!(gradient[0].is_finite());
907    }
908
909    #[cfg(feature = "diffsl-cranelift")]
910    #[test]
911    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
912        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
913    }
914
915    #[cfg(feature = "diffsl-cranelift")]
916    #[test]
917    fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
918        assert_solve_metadata_and_helpers::<diffsol::CraneliftJitModule>();
919        assert_solve_runtime_paths::<diffsol::CraneliftJitModule>();
920    }
921
922    #[cfg(feature = "diffsl-llvm")]
923    #[test]
924    fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
925        assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
926    }
927
928    #[cfg(feature = "diffsl-llvm")]
929    #[test]
930    fn solve_trait_helpers_and_runtime_paths_for_llvm() {
931        assert_solve_metadata_and_helpers::<diffsol::LlvmModule>();
932        assert_solve_runtime_paths::<diffsol::LlvmModule>();
933    }
934
935    #[cfg(feature = "diffsl-llvm")]
936    #[test]
937    fn solve_trait_sensitivity_paths_for_llvm() {
938        assert_solve_sensitivity_paths();
939    }
940
941    #[cfg(feature = "diffsl-cranelift")]
942    #[test]
943    fn setup_problem_validates_parameter_count_for_cranelift() {
944        let mut solve = make_generic_solve::<diffsol::CraneliftJitModule>();
945        let err = solve.setup_problem(&[]).unwrap_err();
946        assert!(err.to_string().contains("Expecting 1 params but got 0"));
947
948        solve.setup_problem(&[2.0]).unwrap();
949        let mut params = solve.problem.context().vector_zeros(1);
950        solve.problem.eqn.get_params(&mut params);
951        assert_eq!(params.get_index(0), 2.0);
952    }
953
954    #[cfg(feature = "diffsl-llvm")]
955    #[test]
956    fn setup_problem_validates_parameter_count_for_llvm() {
957        let mut solve = make_generic_solve::<diffsol::LlvmModule>();
958        let err = solve.setup_problem(&[]).unwrap_err();
959        assert!(err.to_string().contains("Expecting 1 params but got 0"));
960
961        solve.setup_problem(&[2.0]).unwrap();
962        let mut params = solve.problem.context().vector_zeros(1);
963        solve.problem.eqn.get_params(&mut params);
964        assert_eq!(params.get_index(0), 2.0);
965    }
966}