1use diffsol::{
5 error::DiffsolError,
6 matrix::{MatrixHost, MatrixRef},
7 CodegenModule, ConstantOp, DefaultDenseMatrix, DefaultSolver, DiffSl, MatrixCommon,
8 NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, OdeSolverProblem, Op, Vector,
9 VectorCommon, VectorHost, VectorRef,
10};
11#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
12use diffsol::{CodegenModuleCompile, CodegenModuleJit};
13use num_traits::{FromPrimitive, ToPrimitive}; use paste::paste;
15
16use crate::error::DiffsolJsError;
17use crate::host_array::{HostArray, ToHostArray};
18#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
19use crate::jit::JitBackendType;
20#[cfg(feature = "external")]
21use crate::scalar_type::ExternalScalar;
22use crate::scalar_type::{Scalar, ScalarType};
23use crate::{
24 generate_ic_option_accessors, generate_ode_option_accessors, generate_option_accessors,
25 generate_trait_ic_option_accessors, generate_trait_ode_option_accessors,
26 option_value_from_store, option_value_to_store,
27};
28
29use crate::{
30 linear_solver_type::LinearSolverType,
31 matrix_type::{MatrixKind, MatrixType},
32 ode_solver_type::OdeSolverType,
33 valid_linear_solver::{validate_linear_solver, KluValidator, LuValidator},
34};
35
36use crate::solution::Solution;
39pub(crate) type SolveResult = Result<Box<dyn Solution>, DiffsolJsError>;
40
41pub(crate) trait Solve {
42 fn matrix_type(&self) -> MatrixType;
43 fn nstates(&self) -> usize;
44 fn nparams(&self) -> usize;
45 fn nout(&self) -> usize;
46 fn has_stop(&self) -> bool;
47
48 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError>;
49
50 fn rhs_jac_mul(
51 &mut self,
52 params: &[f64],
53 t: f64,
54 y: &[f64],
55 v: &[f64],
56 ) -> Result<HostArray, DiffsolJsError>;
57
58 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError>;
59
60 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError>;
61 fn set_rtol(&mut self, rtol: f64);
62 fn rtol(&self) -> f64;
63 fn set_atol(&mut self, atol: f64);
64 fn atol(&self) -> f64;
65
66 fn solve(
68 &mut self,
69 method: OdeSolverType,
70 linear_solver: LinearSolverType,
71 params: &[f64],
72 final_time: f64,
73 ) -> SolveResult;
74
75 fn solve_hybrid(
76 &mut self,
77 method: OdeSolverType,
78 linear_solver: LinearSolverType,
79 params: &[f64],
80 final_time: f64,
81 ) -> SolveResult;
82
83 fn solve_dense(
84 &mut self,
85 method: OdeSolverType,
86 linear_solver: LinearSolverType,
87 params: &[f64],
88 t_eval: &[f64],
89 ) -> SolveResult;
90
91 fn solve_hybrid_dense(
92 &mut self,
93 method: OdeSolverType,
94 linear_solver: LinearSolverType,
95 params: &[f64],
96 t_eval: &[f64],
97 ) -> SolveResult;
98
99 fn solve_fwd_sens(
100 &mut self,
101 method: OdeSolverType,
102 linear_solver: LinearSolverType,
103 params: &[f64],
104 t_eval: &[f64],
105 ) -> SolveResult;
106
107 fn solve_hybrid_fwd_sens(
108 &mut self,
109 method: OdeSolverType,
110 linear_solver: LinearSolverType,
111 params: &[f64],
112 t_eval: &[f64],
113 ) -> SolveResult;
114
115 #[allow(clippy::type_complexity)]
116 #[allow(clippy::too_many_arguments)]
117 fn solve_sum_squares_adj(
118 &mut self,
119 method: OdeSolverType,
120 linear_solver: LinearSolverType,
121 backwards_method: OdeSolverType,
122 backwards_linear_solver: LinearSolverType,
123 params: &[f64],
124 data: HostArray,
125 t_eval: &[f64],
126 ) -> Result<(f64, HostArray), DiffsolJsError>;
127
128 generate_trait_ic_option_accessors! {
129 use_linesearch: bool,
130 max_linesearch_iterations: usize,
131 max_newton_iterations: usize,
132 max_linear_solver_setups: usize,
133 step_reduction_factor: f64,
134 armijo_constant: f64
135 }
136 generate_trait_ode_option_accessors! {
137 max_nonlinear_solver_iterations: usize,
138 max_error_test_failures: usize,
139 min_timestep: f64,
140 update_jacobian_after_steps: usize,
141 update_rhs_jacobian_after_steps: usize,
142 threshold_to_update_jacobian: f64,
143 threshold_to_update_rhs_jacobian: f64
144 }
145}
146#[cfg(feature = "external")]
148pub(crate) fn solve_factory_external(
149 rhs_state_deps: Vec<(usize, usize)>,
150 rhs_input_deps: Vec<(usize, usize)>,
151 mass_state_deps: Vec<(usize, usize)>,
152 matrix_type: MatrixType,
153 scalar_type: ScalarType,
154) -> Result<Box<dyn Solve>, DiffsolJsError> {
155 let solve: Box<dyn Solve> = match matrix_type {
156 MatrixType::NalgebraDense => match scalar_type {
157 #[cfg(feature = "diffsl-external-f32")]
158 ScalarType::F32 => Box::new(GenericSolve::<
159 diffsol::NalgebraMat<f32>,
160 diffsl::ExternalModule<f32>,
161 >::from_external(
162 rhs_state_deps, rhs_input_deps, mass_state_deps, false
163 )?),
164 #[cfg(feature = "diffsl-external-f64")]
165 ScalarType::F64 => Box::new(GenericSolve::<
166 diffsol::NalgebraMat<f64>,
167 diffsl::ExternalModule<f64>,
168 >::from_external(
169 rhs_state_deps, rhs_input_deps, mass_state_deps, false
170 )?),
171 _ => {
172 return Err(DiffsolJsError::from(DiffsolError::Other(
173 "Unsupported scalar type for NalgebraDense".to_string(),
174 )));
175 }
176 },
177 MatrixType::FaerDense => match scalar_type {
178 #[cfg(feature = "diffsl-external-f32")]
179 ScalarType::F32 => Box::new(GenericSolve::<
180 diffsol::FaerMat<f32>,
181 diffsl::ExternalModule<f32>,
182 >::from_external(
183 rhs_state_deps, rhs_input_deps, mass_state_deps, false
184 )?),
185 #[cfg(feature = "diffsl-external-f64")]
186 ScalarType::F64 => Box::new(GenericSolve::<
187 diffsol::FaerMat<f64>,
188 diffsl::ExternalModule<f64>,
189 >::from_external(
190 rhs_state_deps, rhs_input_deps, mass_state_deps, false
191 )?),
192 _ => {
193 return Err(DiffsolJsError::from(DiffsolError::Other(
194 "Unsupported scalar type for FaerDense".to_string(),
195 )));
196 }
197 },
198 MatrixType::FaerSparse => match scalar_type {
199 #[cfg(feature = "diffsl-external-f32")]
200 ScalarType::F32 => Box::new(GenericSolve::<
201 diffsol::FaerSparseMat<f32>,
202 diffsl::ExternalModule<f32>,
203 >::from_external(
204 rhs_state_deps, rhs_input_deps, mass_state_deps, false
205 )?),
206 #[cfg(feature = "diffsl-external-f64")]
207 ScalarType::F64 => Box::new(GenericSolve::<
208 diffsol::FaerSparseMat<f64>,
209 diffsl::ExternalModule<f64>,
210 >::from_external(
211 rhs_state_deps, rhs_input_deps, mass_state_deps, false
212 )?),
213 _ => {
214 return Err(DiffsolJsError::from(DiffsolError::Other(
215 "Unsupported scalar type for FaerSparse".to_string(),
216 )));
217 }
218 },
219 };
220 Ok(solve)
221}
222
223#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
224pub(crate) fn solve_factory_jit(
225 code: &str,
226 jit_backend: JitBackendType,
227 matrix_type: MatrixType,
228 scalar_type: ScalarType,
229) -> Result<Box<dyn Solve>, DiffsolJsError> {
230 match jit_backend {
231 #[cfg(feature = "diffsl-cranelift")]
232 JitBackendType::Cranelift => solve_factory_with_jit_backend::<diffsol::CraneliftJitModule>(
233 code,
234 matrix_type,
235 scalar_type,
236 ),
237 #[cfg(feature = "diffsl-llvm")]
238 JitBackendType::Llvm => {
239 solve_factory_with_jit_backend::<diffsol::LlvmModule>(code, matrix_type, scalar_type)
240 }
241 }
242}
243
244#[cfg(any(feature = "diffsl-cranelift", feature = "diffsl-llvm"))]
245fn solve_factory_with_jit_backend<CG>(
246 code: &str,
247 matrix_type: MatrixType,
248 scalar_type: ScalarType,
249) -> Result<Box<dyn Solve>, DiffsolJsError>
250where
251 CG: CodegenModule + CodegenModuleJit + CodegenModuleCompile,
252{
253 let solve: Box<dyn Solve> = match matrix_type {
254 MatrixType::NalgebraDense => match scalar_type {
255 ScalarType::F32 => {
256 let problem =
257 OdeBuilder::<diffsol::NalgebraMat<f32>>::new().build_from_diffsl::<CG>(code)?;
258 Box::new(GenericSolve { problem })
259 }
260 ScalarType::F64 => {
261 let problem =
262 OdeBuilder::<diffsol::NalgebraMat<f64>>::new().build_from_diffsl::<CG>(code)?;
263 Box::new(GenericSolve { problem })
264 }
265 },
266 MatrixType::FaerDense => match scalar_type {
267 ScalarType::F32 => {
268 let problem =
269 OdeBuilder::<diffsol::FaerMat<f32>>::new().build_from_diffsl::<CG>(code)?;
270 Box::new(GenericSolve { problem })
271 }
272 ScalarType::F64 => {
273 let problem =
274 OdeBuilder::<diffsol::FaerMat<f64>>::new().build_from_diffsl::<CG>(code)?;
275 Box::new(GenericSolve { problem })
276 }
277 },
278 MatrixType::FaerSparse => match scalar_type {
279 ScalarType::F32 => {
280 let problem = OdeBuilder::<diffsol::FaerSparseMat<f32>>::new()
281 .build_from_diffsl::<CG>(code)?;
282 Box::new(GenericSolve { problem })
283 }
284 ScalarType::F64 => {
285 let problem = OdeBuilder::<diffsol::FaerSparseMat<f64>>::new()
286 .build_from_diffsl::<CG>(code)?;
287 Box::new(GenericSolve { problem })
288 }
289 },
290 };
291 Ok(solve)
292}
293
294pub(crate) struct GenericSolve<M, CG>
295where
296 M: MatrixHost<T: Scalar>,
297 M::V: Vector + VectorHost,
298 CG: CodegenModule,
299{
300 problem: OdeSolverProblem<DiffSl<M, CG>>,
301}
302
303impl<M, CG> GenericSolve<M, CG>
304where
305 M: MatrixHost<T: Scalar>,
306 M::V: Vector + VectorHost + DefaultDenseMatrix,
307 CG: CodegenModule,
308{
309 pub(crate) fn setup_problem(&mut self, params: &[f64]) -> Result<(), DiffsolJsError> {
310 let params: Vec<M::T> = params.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
311 let params = M::V::from_slice(¶ms, M::C::default());
312
313 let nparams = self.problem.eqn.nparams();
315 if params.len() == nparams {
316 self.problem.eqn.set_params(¶ms);
317 Ok(())
318 } else {
319 Err(DiffsolError::Other(format!(
320 "Expecting {} params but got {}",
321 nparams,
322 params.len()
323 ))
324 .into())
325 }
326 }
327}
328
329#[cfg(feature = "external")]
330impl<M> GenericSolve<M, diffsl::ExternalModule<M::T>>
331where
332 M: MatrixHost<T: ExternalScalar>,
333 M::V: Vector + VectorHost + DefaultDenseMatrix,
334{
335 pub fn from_external(
336 rhs_state_deps: Vec<(usize, usize)>,
337 rhs_input_deps: Vec<(usize, usize)>,
338 mass_state_deps: Vec<(usize, usize)>,
339 include_sensitivities: bool,
340 ) -> Result<Self, DiffsolJsError> {
341 let eqn = DiffSl::<M, diffsl::ExternalModule<M::T>>::from_external(
342 M::C::default(),
343 rhs_state_deps,
344 rhs_input_deps,
345 mass_state_deps,
346 include_sensitivities,
347 )?;
348 let default_p = vec![0.0; eqn.nparams()];
349 let problem = OdeBuilder::<M>::new().p(default_p).build_from_eqn(eqn)?;
350 Ok(GenericSolve { problem })
351 }
352}
353
354impl<M, CG> Solve for GenericSolve<M, CG>
355where
356 M: MatrixHost<T: Scalar + ToPrimitive>
357 + DefaultSolver
358 + LuValidator<M>
359 + KluValidator<M>
360 + MatrixKind,
361 CG: CodegenModule,
362 for<'b> <<M::V as DefaultDenseMatrix>::M as MatrixCommon>::Inner: ToHostArray<M::T> + Clone,
363 for<'b> <M::V as VectorCommon>::Inner: ToHostArray<M::T> + Clone,
364 M::V: VectorHost + DefaultDenseMatrix + Send + Sync + 'static,
365 <M::V as DefaultDenseMatrix>::M: Send + Sync,
366 for<'b> &'b M::V: VectorRef<M::V>,
367 for<'b> &'b M: MatrixRef<M>,
368{
369 fn matrix_type(&self) -> MatrixType {
370 MatrixType::from_diffsol::<M>()
371 }
372
373 fn nstates(&self) -> usize {
374 self.problem.eqn.nstates()
375 }
376
377 fn nparams(&self) -> usize {
378 self.problem.eqn.nparams()
379 }
380
381 fn nout(&self) -> usize {
382 self.problem.eqn.nout()
383 }
384
385 fn has_stop(&self) -> bool {
386 self.problem.eqn.root().is_some()
387 }
388
389 fn check(&self, linear_solver: LinearSolverType) -> Result<(), DiffsolJsError> {
390 validate_linear_solver::<M>(linear_solver)
391 }
392
393 fn set_atol(&mut self, atol: f64) {
394 self.problem.atol.fill(M::T::from_f64(atol).unwrap());
395 }
396
397 fn atol(&self) -> f64 {
398 self.problem.atol[0].to_f64().unwrap()
399 }
400
401 fn set_rtol(&mut self, rtol: f64) {
402 self.problem.rtol = M::T::from_f64(rtol).unwrap();
403 }
404
405 fn rtol(&self) -> f64 {
406 self.problem.rtol.to_f64().unwrap()
407 }
408
409 generate_ic_option_accessors! {
410 use_linesearch: bool,
411 max_linesearch_iterations: usize,
412 max_newton_iterations: usize,
413 max_linear_solver_setups: usize,
414 step_reduction_factor: f64,
415 armijo_constant: f64
416 }
417
418 generate_ode_option_accessors! {
419 max_nonlinear_solver_iterations: usize,
420 max_error_test_failures: usize,
421 min_timestep: f64,
422 update_jacobian_after_steps: usize,
423 update_rhs_jacobian_after_steps: usize,
424 threshold_to_update_jacobian: f64,
425 threshold_to_update_rhs_jacobian: f64
426 }
427
428 fn y0(&mut self, params: &[f64]) -> Result<HostArray, DiffsolJsError> {
429 self.setup_problem(params)?;
430 let n = self.problem.eqn.nstates();
431 let mut y0 = M::V::zeros(n, M::C::default());
432 let t0 = self.problem.t0;
433 self.problem.eqn.init().call_inplace(t0, &mut y0);
434 Ok((*y0.inner()).clone().to_host_array())
435 }
436
437 fn rhs(&mut self, params: &[f64], t: f64, y: &[f64]) -> Result<HostArray, DiffsolJsError> {
438 self.setup_problem(params)?;
439 let n = self.problem.eqn.nstates();
440 let y = y
441 .iter()
442 .map(|&x| M::T::from_f64(x).unwrap())
443 .collect::<Vec<_>>();
444 let y_vec = M::V::from_slice(&y, M::C::default());
445 let mut dydt = M::V::zeros(n, M::C::default());
446 self.problem
447 .eqn
448 .rhs()
449 .call_inplace(&y_vec, M::T::from_f64(t).unwrap(), &mut dydt);
450 Ok((*dydt.inner()).clone().to_host_array())
451 }
452
453 fn rhs_jac_mul(
454 &mut self,
455
456 params: &[f64],
457 t: f64,
458 y: &[f64],
459 v: &[f64],
460 ) -> Result<HostArray, DiffsolJsError> {
461 self.setup_problem(params)?;
462 let n = self.problem.eqn.nstates();
463 let y = y
464 .iter()
465 .map(|&x| M::T::from_f64(x).unwrap())
466 .collect::<Vec<_>>();
467 let v = v
468 .iter()
469 .map(|&x| M::T::from_f64(x).unwrap())
470 .collect::<Vec<_>>();
471 let y_vec = M::V::from_slice(&y, M::C::default());
472 let v_vec = M::V::from_slice(&v, M::C::default());
473 let mut dydt = M::V::zeros(n, M::C::default());
474 self.problem.eqn.rhs().jac_mul_inplace(
475 &y_vec,
476 M::T::from_f64(t).unwrap(),
477 &v_vec,
478 &mut dydt,
479 );
480 Ok((*dydt.inner()).clone().to_host_array())
481 }
482
483 fn solve(
484 &mut self,
485 method: OdeSolverType,
486 linear_solver: LinearSolverType,
487 params: &[f64],
488 final_time: f64,
489 ) -> SolveResult {
490 self.check(linear_solver)?;
491 self.setup_problem(params)?;
492 let final_time = M::T::from_f64(final_time).unwrap();
493 let soln = match linear_solver {
494 LinearSolverType::Default => {
495 method.solve::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time)
496 }
497 LinearSolverType::Lu => {
498 method.solve::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time)
499 }
500 LinearSolverType::Klu => {
501 method.solve::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time)
502 }
503 };
504 Ok(Box::new(soln?))
505 }
506
507 fn solve_dense(
508 &mut self,
509 method: OdeSolverType,
510 linear_solver: LinearSolverType,
511 params: &[f64],
512 t_eval: &[f64],
513 ) -> SolveResult {
514 self.check(linear_solver)?;
515 self.setup_problem(params)?;
516
517 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
518 let soln =
519 match linear_solver {
520 LinearSolverType::Default => method
521 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
522 LinearSolverType::Lu => method
523 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
524 LinearSolverType::Klu => method
525 .solve_dense::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
526 };
527 Ok(Box::new(soln?))
528 }
529
530 fn solve_hybrid(
531 &mut self,
532 method: OdeSolverType,
533 linear_solver: LinearSolverType,
534 params: &[f64],
535 final_time: f64,
536 ) -> SolveResult {
537 self.check(linear_solver)?;
538 self.setup_problem(params)?;
539 let final_time = M::T::from_f64(final_time).unwrap();
540 let soln = match linear_solver {
541 LinearSolverType::Default => method
542 .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, final_time),
543 LinearSolverType::Lu => method
544 .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, final_time),
545 LinearSolverType::Klu => method
546 .solve_hybrid::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, final_time),
547 };
548 Ok(Box::new(soln?))
549 }
550
551 fn solve_fwd_sens(
552 &mut self,
553 method: OdeSolverType,
554 linear_solver: LinearSolverType,
555 params: &[f64],
556 t_eval: &[f64],
557 ) -> SolveResult {
558 self.check(linear_solver)?;
559 self.setup_problem(params)?;
560
561 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
562 let soln = match linear_solver {
563 LinearSolverType::Default => {
564 method.solve_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval)
565 }
566 LinearSolverType::Lu => method
567 .solve_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
568 LinearSolverType::Klu => method
569 .solve_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(&mut self.problem, &t_eval),
570 };
571 Ok(Box::new(soln?))
572 }
573
574 fn solve_hybrid_dense(
575 &mut self,
576 method: OdeSolverType,
577 linear_solver: LinearSolverType,
578 params: &[f64],
579 t_eval: &[f64],
580 ) -> SolveResult {
581 self.check(linear_solver)?;
582 self.setup_problem(params)?;
583
584 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
585 let soln = match linear_solver {
586 LinearSolverType::Default => method
587 .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut self.problem, &t_eval),
588 LinearSolverType::Lu => method
589 .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut self.problem, &t_eval),
590 LinearSolverType::Klu => method
591 .solve_hybrid_dense::<M, CG, <M as KluValidator<M>>::LS>(
592 &mut self.problem,
593 &t_eval,
594 ),
595 };
596 Ok(Box::new(soln?))
597 }
598
599 fn solve_hybrid_fwd_sens(
600 &mut self,
601 method: OdeSolverType,
602 linear_solver: LinearSolverType,
603 params: &[f64],
604 t_eval: &[f64],
605 ) -> SolveResult {
606 self.check(linear_solver)?;
607 self.setup_problem(params)?;
608
609 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
610 let soln = match linear_solver {
611 LinearSolverType::Default => method
612 .solve_hybrid_fwd_sens::<M, CG, <M as DefaultSolver>::LS>(
613 &mut self.problem,
614 &t_eval,
615 ),
616 LinearSolverType::Lu => method
617 .solve_hybrid_fwd_sens::<M, CG, <M as LuValidator<M>>::LS>(
618 &mut self.problem,
619 &t_eval,
620 ),
621 LinearSolverType::Klu => method
622 .solve_hybrid_fwd_sens::<M, CG, <M as KluValidator<M>>::LS>(
623 &mut self.problem,
624 &t_eval,
625 ),
626 };
627 Ok(Box::new(soln?))
628 }
629
630 fn solve_sum_squares_adj(
631 &mut self,
632
633 method: OdeSolverType,
634 linear_solver: LinearSolverType,
635 backwards_method: OdeSolverType,
636 backwards_linear_solver: LinearSolverType,
637 params: &[f64],
638 data: HostArray,
639 t_eval: &[f64],
640 ) -> Result<(f64, HostArray), DiffsolJsError> {
641 self.check(linear_solver)?;
642 self.setup_problem(params)?;
643
644 let data = data.as_array()?;
645 let t_eval: Vec<M::T> = t_eval.iter().map(|&x| M::T::from_f64(x).unwrap()).collect();
646
647 let previous_integrate_out = self.problem.integrate_out;
648 self.problem.integrate_out = true;
649 let result = match linear_solver {
650 LinearSolverType::Default => method
651 .solve_sum_squares_adj::<M, CG, <M as DefaultSolver>::LS>(
652 &mut self.problem,
653 data,
654 &t_eval,
655 backwards_method,
656 backwards_linear_solver,
657 ),
658 LinearSolverType::Lu => method
659 .solve_sum_squares_adj::<M, CG, <M as LuValidator<M>>::LS>(
660 &mut self.problem,
661 data,
662 &t_eval,
663 backwards_method,
664 backwards_linear_solver,
665 ),
666 LinearSolverType::Klu => method
667 .solve_sum_squares_adj::<M, CG, <M as KluValidator<M>>::LS>(
668 &mut self.problem,
669 data,
670 &t_eval,
671 backwards_method,
672 backwards_linear_solver,
673 ),
674 };
675 self.problem.integrate_out = previous_integrate_out;
676 let (y, y_sens) = result?;
677
678 Ok((
679 y.to_f64().unwrap(),
680 (*y_sens.inner()).clone().to_host_array(),
681 ))
682 }
683}
684
685#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
686mod tests {
687 use diffsol::{
688 CodegenModuleCompile, CodegenModuleJit, Context, OdeBuilder, OdeEquations, Vector,
689 };
690
691 #[cfg(feature = "diffsl-llvm")]
692 use crate::test_support::{
693 hybrid_logistic_state_dr, logistic_integral, logistic_state_dr, matrix_host,
694 };
695 use crate::{
696 host_array::FromHostArray,
697 linear_solver_type::LinearSolverType,
698 matrix_type::MatrixType,
699 ode_solver_type::OdeSolverType,
700 scalar_type::ScalarType,
701 test_support::{
702 assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
703 logistic_state, LOGISTIC_X0,
704 },
705 };
706
707 use super::{solve_factory_with_jit_backend, GenericSolve, Solve};
708
709 fn make_generic_solve<CG>() -> GenericSolve<diffsol::NalgebraMat<f64>, CG>
710 where
711 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
712 {
713 let problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
714 .build_from_diffsl::<CG>(logistic_diffsl_code())
715 .unwrap();
716 GenericSolve { problem }
717 }
718
719 fn assert_factory_supports_all_matrix_and_scalar_types<CG>()
720 where
721 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
722 {
723 for matrix_type in [
724 MatrixType::NalgebraDense,
725 MatrixType::FaerDense,
726 MatrixType::FaerSparse,
727 ] {
728 for scalar_type in [ScalarType::F32, ScalarType::F64] {
729 assert!(solve_factory_with_jit_backend::<CG>(
730 logistic_diffsl_code(),
731 matrix_type,
732 scalar_type,
733 )
734 .is_ok());
735 }
736 }
737 }
738
739 fn assert_solve_metadata_and_helpers<CG>()
740 where
741 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
742 {
743 let mut solve = make_generic_solve::<CG>();
744 assert_eq!(solve.matrix_type(), MatrixType::NalgebraDense);
745 assert_eq!(solve.nstates(), 1);
746 assert_eq!(solve.nparams(), 1);
747 assert_eq!(solve.nout(), 1);
748 assert!(!solve.has_stop());
749 assert!(solve.check(LinearSolverType::Default).is_ok());
750 assert!(solve.check(LinearSolverType::Lu).is_ok());
751 assert!(solve.check(LinearSolverType::Klu).is_err());
752
753 solve.set_atol(1e-5);
754 solve.set_rtol(1e-4);
755 assert_close(solve.atol(), 1e-5, 1e-12, "solve atol");
756 assert_close(solve.rtol(), 1e-4, 1e-12, "solve rtol");
757
758 let y0 = Vec::<f64>::from_host_array(solve.y0(&[2.0]).unwrap()).unwrap();
759 assert_eq!(y0, vec![LOGISTIC_X0]);
760
761 let rhs = Vec::<f64>::from_host_array(solve.rhs(&[2.0], 0.0, &[0.25]).unwrap()).unwrap();
762 assert_eq!(rhs.len(), 1);
763 assert_close(rhs[0], 0.375, 1e-12, "solve rhs");
764
765 let jac_mul =
766 Vec::<f64>::from_host_array(solve.rhs_jac_mul(&[2.0], 0.0, &[0.25], &[3.0]).unwrap())
767 .unwrap();
768 assert_eq!(jac_mul.len(), 1);
769 assert_close(jac_mul[0], 3.0, 1e-12, "solve rhs jac mul");
770 }
771
772 fn assert_solve_runtime_paths<CG>()
773 where
774 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
775 {
776 let mut solve = make_generic_solve::<CG>();
777 let soln = solve
778 .solve(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 1.0)
779 .unwrap();
780 let ts = Vec::<f64>::from_host_array(soln.get_ts()).unwrap();
781 let ys = Vec::<Vec<f64>>::from_host_array(soln.get_ys()).unwrap();
782 assert_close(*ts.last().unwrap(), 1.0, 5e-4, "solve final time");
783 assert_close(
784 ys[0][ts.len() - 1],
785 logistic_state(LOGISTIC_X0, 2.0, 1.0),
786 5e-4,
787 "solve final value",
788 );
789
790 let mut solve = make_generic_solve::<CG>();
791 let dense = solve
792 .solve_dense(
793 OdeSolverType::Tsit45,
794 LinearSolverType::Lu,
795 &[2.0],
796 &[0.25, 0.5, 1.0],
797 )
798 .unwrap();
799 let ts = Vec::<f64>::from_host_array(dense.get_ts()).unwrap();
800 let ys = Vec::<Vec<f64>>::from_host_array(dense.get_ys()).unwrap();
801 assert_eq!(ts, vec![0.25, 0.5, 1.0]);
802 for (i, &t) in ts.iter().enumerate() {
803 assert_close(
804 ys[0][i],
805 logistic_state(LOGISTIC_X0, 2.0, t),
806 5e-4,
807 &format!("solve_dense[{i}]"),
808 );
809 }
810
811 let mut solve = make_generic_solve::<CG>();
812 let err = match solve.solve(OdeSolverType::Bdf, LinearSolverType::Default, &[], 1.0) {
813 Ok(_) => panic!("expected parameter count mismatch"),
814 Err(err) => err,
815 };
816 assert!(err.to_string().contains("Expecting 1 params but got 0"));
817
818 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
819 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
820 .unwrap();
821 let mut hybrid_solve = GenericSolve {
822 problem: hybrid_problem,
823 };
824 let hybrid = hybrid_solve
825 .solve_hybrid(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], 2.0)
826 .unwrap();
827 let hybrid_ts = Vec::<f64>::from_host_array(hybrid.get_ts()).unwrap();
828 let hybrid_ys = Vec::<Vec<f64>>::from_host_array(hybrid.get_ys()).unwrap();
829 assert_close(
830 *hybrid_ts.last().unwrap(),
831 2.0,
832 5e-4,
833 "solve_hybrid final time",
834 );
835 assert_close(
836 hybrid_ys[0][hybrid_ts.len() - 1],
837 hybrid_logistic_state(2.0, 2.0),
838 5e-4,
839 "solve_hybrid final value",
840 );
841
842 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
843 .build_from_diffsl::<CG>(hybrid_logistic_diffsl_code())
844 .unwrap();
845 let mut hybrid_solve = GenericSolve {
846 problem: hybrid_problem,
847 };
848 let hybrid_dense = hybrid_solve
849 .solve_hybrid_dense(
850 OdeSolverType::Tsit45,
851 LinearSolverType::Lu,
852 &[2.0],
853 &[0.5, 1.0, 1.5, 2.0],
854 )
855 .unwrap();
856 let hybrid_dense_ts = Vec::<f64>::from_host_array(hybrid_dense.get_ts()).unwrap();
857 let hybrid_dense_ys = Vec::<Vec<f64>>::from_host_array(hybrid_dense.get_ys()).unwrap();
858 assert_eq!(hybrid_dense_ts, vec![0.5, 1.0, 1.5, 2.0]);
859 for (i, &t) in hybrid_dense_ts.iter().enumerate() {
860 assert_close(
861 hybrid_dense_ys[0][i],
862 hybrid_logistic_state(2.0, t),
863 5e-4,
864 &format!("solve_hybrid_dense[{i}]"),
865 );
866 }
867 }
868
869 #[cfg(feature = "diffsl-llvm")]
870 fn assert_solve_sensitivity_paths() {
871 let t_eval = [0.25, 0.5, 1.0];
872
873 let mut solve = make_generic_solve::<diffsol::LlvmModule>();
874 let sens = solve
875 .solve_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
876 .unwrap();
877 let sens_values = sens.get_sens();
878 assert_eq!(sens_values.len(), 1);
879 let sens_matrix =
880 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
881 for (i, &t) in t_eval.iter().enumerate() {
882 assert_close(
883 sens_matrix[0][i],
884 logistic_state_dr(LOGISTIC_X0, 2.0, t),
885 5e-4,
886 &format!("solve_fwd_sens[{i}]"),
887 );
888 }
889
890 let hybrid_problem = OdeBuilder::<diffsol::NalgebraMat<f64>>::new()
891 .build_from_diffsl::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code())
892 .unwrap();
893 let mut solve = GenericSolve {
894 problem: hybrid_problem,
895 };
896 let hybrid_sens = solve
897 .solve_hybrid_fwd_sens(OdeSolverType::Bdf, LinearSolverType::Lu, &[2.0], &t_eval)
898 .unwrap();
899 let sens_values = hybrid_sens.get_sens();
900 let sens_matrix =
901 Vec::<Vec<f64>>::from_host_array(sens_values.into_iter().next().unwrap()).unwrap();
902 for (i, &t) in t_eval.iter().enumerate() {
903 assert_close(
904 sens_matrix[0][i],
905 hybrid_logistic_state_dr(2.0, t),
906 5e-4,
907 &format!("solve_hybrid_fwd_sens[{i}]"),
908 );
909 }
910
911 let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
912 let adjoint_data: Vec<f64> = adjoint_t_eval
913 .iter()
914 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
915 .collect();
916 let mut solve = make_generic_solve::<diffsol::LlvmModule>();
917 let (objective, gradient) = solve
918 .solve_sum_squares_adj(
919 OdeSolverType::Bdf,
920 LinearSolverType::Lu,
921 OdeSolverType::TrBdf2,
922 LinearSolverType::Lu,
923 &[2.0],
924 matrix_host(1, adjoint_t_eval.len(), &adjoint_data),
925 &adjoint_t_eval,
926 )
927 .unwrap();
928 assert!(objective.is_finite());
929 let gradient = Vec::<f64>::from_host_array(gradient).unwrap();
930 assert_eq!(gradient.len(), 1);
931 assert!(gradient[0].is_finite());
932 }
933
934 #[cfg(feature = "diffsl-cranelift")]
935 #[test]
936 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_cranelift() {
937 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::CraneliftJitModule>();
938 }
939
940 #[cfg(feature = "diffsl-cranelift")]
941 #[test]
942 fn solve_trait_helpers_and_runtime_paths_for_cranelift() {
943 assert_solve_metadata_and_helpers::<diffsol::CraneliftJitModule>();
944 assert_solve_runtime_paths::<diffsol::CraneliftJitModule>();
945 }
946
947 #[cfg(feature = "diffsl-llvm")]
948 #[test]
949 fn solve_factory_supports_all_jit_matrix_and_scalar_types_for_llvm() {
950 assert_factory_supports_all_matrix_and_scalar_types::<diffsol::LlvmModule>();
951 }
952
953 #[cfg(feature = "diffsl-llvm")]
954 #[test]
955 fn solve_trait_helpers_and_runtime_paths_for_llvm() {
956 assert_solve_metadata_and_helpers::<diffsol::LlvmModule>();
957 assert_solve_runtime_paths::<diffsol::LlvmModule>();
958 }
959
960 #[cfg(feature = "diffsl-llvm")]
961 #[test]
962 fn solve_trait_sensitivity_paths_for_llvm() {
963 assert_solve_sensitivity_paths();
964 }
965
966 #[cfg(feature = "diffsl-cranelift")]
967 #[test]
968 fn setup_problem_validates_parameter_count_for_cranelift() {
969 let mut solve = make_generic_solve::<diffsol::CraneliftJitModule>();
970 let err = solve.setup_problem(&[]).unwrap_err();
971 assert!(err.to_string().contains("Expecting 1 params but got 0"));
972
973 solve.setup_problem(&[2.0]).unwrap();
974 let mut params = solve.problem.context().vector_zeros(1);
975 solve.problem.eqn.get_params(&mut params);
976 assert_eq!(params.get_index(0), 2.0);
977 }
978
979 #[cfg(feature = "diffsl-llvm")]
980 #[test]
981 fn setup_problem_validates_parameter_count_for_llvm() {
982 let mut solve = make_generic_solve::<diffsol::LlvmModule>();
983 let err = solve.setup_problem(&[]).unwrap_err();
984 assert!(err.to_string().contains("Expecting 1 params but got 0"));
985
986 solve.setup_problem(&[2.0]).unwrap();
987 let mut params = solve.problem.context().vector_zeros(1);
988 solve.problem.eqn.get_params(&mut params);
989 assert_eq!(params.get_index(0), 2.0);
990 }
991}