1use diffsol::error::DiffsolError;
5use diffsol::{
6 AdjointOdeSolverMethod, Checkpointing, CodegenModule, DefaultSolver, DenseMatrix, MatrixCommon,
7 OdeEquations, OdeSolverState, OdeSolverStopReason, Op, SensitivitiesOdeSolverMethod, Solution,
8 VectorViewMut,
9};
10use diffsol::{
11 DefaultDenseMatrix, DiffSl, LinearSolver, Matrix, OdeSolverMethod, OdeSolverProblem, Vector,
12 VectorHost, VectorRef, matrix::MatrixRef,
13};
14use ndarray::ArrayView2;
15use num_traits::{FromPrimitive, Zero}; use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18
19use crate::scalar_type::Scalar;
20use crate::utils::is_sens_available;
21use crate::{
22 linear_solver_type::LinearSolverType,
23 valid_linear_solver::{KluValidator, LuValidator},
24};
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
33#[serde(rename_all = "snake_case")]
34pub enum OdeSolverType {
35 Bdf,
36 Esdirk34,
37 TrBdf2,
38 Tsit45,
39}
40
41impl OdeSolverType {
42 pub(crate) fn solve<M, CG, LS>(
43 &self,
44 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
45 final_time: M::T,
46 ) -> Result<Solution<M::V>, DiffsolError>
47 where
48 M: Matrix<T: Scalar>,
49 CG: CodegenModule,
50 M::V: VectorHost + DefaultDenseMatrix,
51 LS: LinearSolver<M>,
52 for<'b> &'b M::V: VectorRef<M::V>,
53 for<'b> &'b M: MatrixRef<M>,
54 {
55 match self {
56 OdeSolverType::Bdf => {
57 let solver = problem.bdf::<LS>()?;
58 let mut soln = Solution::new(final_time);
59 solver.solve_soln(&mut soln)?;
60 Ok(soln)
61 }
62 OdeSolverType::Esdirk34 => {
63 let solver = problem.esdirk34::<LS>()?;
64 let mut soln = Solution::new(final_time);
65 solver.solve_soln(&mut soln)?;
66 Ok(soln)
67 }
68 OdeSolverType::TrBdf2 => {
69 let solver = problem.tr_bdf2::<LS>()?;
70 let mut soln = Solution::new(final_time);
71 solver.solve_soln(&mut soln)?;
72 Ok(soln)
73 }
74 OdeSolverType::Tsit45 => {
75 let solver = problem.tsit45()?;
76 let mut soln = Solution::new(final_time);
77 solver.solve_soln(&mut soln)?;
78 Ok(soln)
79 }
80 }
81 }
82
83 pub(crate) fn solve_dense<M, CG, LS>(
84 &self,
85 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
86 t_eval: &[M::T],
87 ) -> Result<Solution<M::V>, DiffsolError>
88 where
89 M: Matrix<T: Scalar>,
90 CG: CodegenModule,
91 M::V: VectorHost + DefaultDenseMatrix,
92 LS: LinearSolver<M>,
93 for<'b> &'b M::V: VectorRef<M::V>,
94 for<'b> &'b M: MatrixRef<M>,
95 {
96 match self {
97 OdeSolverType::Bdf => {
98 let solver = problem.bdf::<LS>()?;
99 let mut soln = Solution::new_dense(t_eval.to_vec())?;
100 solver.solve_soln(&mut soln)?;
101 Ok(soln)
102 }
103 OdeSolverType::Esdirk34 => {
104 let solver = problem.esdirk34::<LS>()?;
105 let mut soln = Solution::new_dense(t_eval.to_vec())?;
106 solver.solve_soln(&mut soln)?;
107 Ok(soln)
108 }
109 OdeSolverType::TrBdf2 => {
110 let solver = problem.tr_bdf2::<LS>()?;
111 let mut soln = Solution::new_dense(t_eval.to_vec())?;
112 solver.solve_soln(&mut soln)?;
113 Ok(soln)
114 }
115 OdeSolverType::Tsit45 => {
116 let solver = problem.tsit45()?;
117 let mut soln = Solution::new_dense(t_eval.to_vec())?;
118 solver.solve_soln(&mut soln)?;
119 Ok(soln)
120 }
121 }
122 }
123
124 pub(crate) fn solve_hybrid<M, CG, LS>(
125 &self,
126 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
127 final_time: M::T,
128 ) -> Result<Solution<M::V>, DiffsolError>
129 where
130 M: Matrix<T: Scalar>,
131 CG: CodegenModule,
132 M::V: VectorHost + DefaultDenseMatrix,
133 LS: LinearSolver<M>,
134 for<'b> &'b M::V: VectorRef<M::V>,
135 for<'b> &'b M: MatrixRef<M>,
136 {
137 match self {
138 OdeSolverType::Bdf => {
139 let mut soln = Solution::new(final_time);
140 let mut solver = problem.bdf::<LS>()?;
141 while !soln.is_complete() {
142 solver = solver.solve_soln(&mut soln)?;
143 let root_idx = match soln.stop_reason {
144 Some(OdeSolverStopReason::RootFound(_, root_idx))
145 if !soln.is_complete() =>
146 {
147 root_idx
148 }
149 _ => continue,
150 };
151 let state = solver.into_state();
152 problem.eqn.set_model_index(root_idx);
153 let mut restarted_solver = problem.bdf_solver::<LS>(state)?;
154 restarted_solver.reset()?;
155 solver = restarted_solver;
156 }
157 Ok(soln)
158 }
159 OdeSolverType::Esdirk34 => {
160 let mut soln = Solution::new(final_time);
161 let mut solver = problem.esdirk34::<LS>()?;
162 while !soln.is_complete() {
163 solver = solver.solve_soln(&mut soln)?;
164 let root_idx = match soln.stop_reason {
165 Some(OdeSolverStopReason::RootFound(_, root_idx))
166 if !soln.is_complete() =>
167 {
168 root_idx
169 }
170 _ => continue,
171 };
172 let state = solver.into_state();
173 problem.eqn.set_model_index(root_idx);
174 let mut restarted_solver = problem.esdirk34_solver::<LS>(state)?;
175 restarted_solver.reset()?;
176 solver = restarted_solver;
177 }
178 Ok(soln)
179 }
180 OdeSolverType::TrBdf2 => {
181 let mut soln = Solution::new(final_time);
182 let mut solver = problem.tr_bdf2::<LS>()?;
183 while !soln.is_complete() {
184 solver = solver.solve_soln(&mut soln)?;
185 let root_idx = match soln.stop_reason {
186 Some(OdeSolverStopReason::RootFound(_, root_idx))
187 if !soln.is_complete() =>
188 {
189 root_idx
190 }
191 _ => continue,
192 };
193 let state = solver.into_state();
194 problem.eqn.set_model_index(root_idx);
195 let mut restarted_solver = problem.tr_bdf2_solver::<LS>(state)?;
196 restarted_solver.reset()?;
197 solver = restarted_solver;
198 }
199 Ok(soln)
200 }
201 OdeSolverType::Tsit45 => {
202 let mut soln = Solution::new(final_time);
203 let mut solver = problem.tsit45()?;
204 while !soln.is_complete() {
205 solver = solver.solve_soln(&mut soln)?;
206 let root_idx = match soln.stop_reason {
207 Some(OdeSolverStopReason::RootFound(_, root_idx))
208 if !soln.is_complete() =>
209 {
210 root_idx
211 }
212 _ => continue,
213 };
214 let state = solver.into_state();
215 problem.eqn.set_model_index(root_idx);
216 let mut restarted_solver = problem.tsit45_solver(state)?;
217 restarted_solver.reset()?;
218 solver = restarted_solver;
219 }
220 Ok(soln)
221 }
222 }
223 }
224
225 pub(crate) fn solve_hybrid_dense<M, CG, LS>(
226 &self,
227 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
228 t_eval: &[M::T],
229 ) -> Result<Solution<M::V>, DiffsolError>
230 where
231 M: Matrix<T: Scalar>,
232 CG: CodegenModule,
233 M::V: VectorHost + DefaultDenseMatrix,
234 LS: LinearSolver<M>,
235 for<'b> &'b M::V: VectorRef<M::V>,
236 for<'b> &'b M: MatrixRef<M>,
237 {
238 match self {
239 OdeSolverType::Bdf => {
240 let mut soln = Solution::new_dense(t_eval.to_vec())?;
241 let mut solver = problem.bdf::<LS>()?;
242 while !soln.is_complete() {
243 solver = solver.solve_soln(&mut soln)?;
244 let root_idx = match soln.stop_reason {
245 Some(OdeSolverStopReason::RootFound(_, root_idx))
246 if !soln.is_complete() =>
247 {
248 root_idx
249 }
250 _ => continue,
251 };
252 let state = solver.into_state();
253 problem.eqn.set_model_index(root_idx);
254 let mut restarted_solver = problem.bdf_solver::<LS>(state)?;
255 restarted_solver.reset()?;
256 solver = restarted_solver;
257 }
258 Ok(soln)
259 }
260 OdeSolverType::Esdirk34 => {
261 let mut soln = Solution::new_dense(t_eval.to_vec())?;
262 let mut solver = problem.esdirk34::<LS>()?;
263 while !soln.is_complete() {
264 solver = solver.solve_soln(&mut soln)?;
265 let root_idx = match soln.stop_reason {
266 Some(OdeSolverStopReason::RootFound(_, root_idx))
267 if !soln.is_complete() =>
268 {
269 root_idx
270 }
271 _ => continue,
272 };
273 let state = solver.into_state();
274 problem.eqn.set_model_index(root_idx);
275 let mut restarted_solver = problem.esdirk34_solver::<LS>(state)?;
276 restarted_solver.reset()?;
277 solver = restarted_solver;
278 }
279 Ok(soln)
280 }
281 OdeSolverType::TrBdf2 => {
282 let mut soln = Solution::new_dense(t_eval.to_vec())?;
283 let mut solver = problem.tr_bdf2::<LS>()?;
284 while !soln.is_complete() {
285 solver = solver.solve_soln(&mut soln)?;
286 let root_idx = match soln.stop_reason {
287 Some(OdeSolverStopReason::RootFound(_, root_idx))
288 if !soln.is_complete() =>
289 {
290 root_idx
291 }
292 _ => continue,
293 };
294 let state = solver.into_state();
295 problem.eqn.set_model_index(root_idx);
296 let mut restarted_solver = problem.tr_bdf2_solver::<LS>(state)?;
297 restarted_solver.reset()?;
298 solver = restarted_solver;
299 }
300 Ok(soln)
301 }
302 OdeSolverType::Tsit45 => {
303 let mut soln = Solution::new_dense(t_eval.to_vec())?;
304 let mut solver = problem.tsit45()?;
305 while !soln.is_complete() {
306 solver = solver.solve_soln(&mut soln)?;
307 let root_idx = match soln.stop_reason {
308 Some(OdeSolverStopReason::RootFound(_, root_idx))
309 if !soln.is_complete() =>
310 {
311 root_idx
312 }
313 _ => continue,
314 };
315 let state = solver.into_state();
316 problem.eqn.set_model_index(root_idx);
317 let mut restarted_solver = problem.tsit45_solver(state)?;
318 restarted_solver.reset()?;
319 solver = restarted_solver;
320 }
321 Ok(soln)
322 }
323 }
324 }
325
326 fn check_sens_available() -> Result<(), DiffsolError> {
327 if !is_sens_available() {
328 return Err(DiffsolError::Other(
329 "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
330 ));
331 }
332 Ok(())
333 }
334
335 #[allow(clippy::type_complexity)]
336 pub(crate) fn solve_fwd_sens<M, CG, LS>(
337 &self,
338 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
339 t_eval: &[M::T],
340 ) -> Result<Solution<M::V>, DiffsolError>
341 where
342 M: Matrix<T: Scalar> + DefaultSolver,
343 CG: CodegenModule,
344 M::V: VectorHost + DefaultDenseMatrix,
345 LS: LinearSolver<M>,
346 for<'b> &'b M::V: VectorRef<M::V>,
347 for<'b> &'b M: MatrixRef<M>,
348 {
349 Self::check_sens_available()?;
350 match self {
351 OdeSolverType::Bdf => {
352 let solver = problem.bdf_sens::<LS>()?;
353 let mut soln = Solution::new_dense(t_eval.to_vec())?;
354 solver.solve_soln_sensitivities(&mut soln)?;
355 Ok(soln)
356 }
357 OdeSolverType::Esdirk34 => {
358 let solver = problem.esdirk34_sens::<LS>()?;
359 let mut soln = Solution::new_dense(t_eval.to_vec())?;
360 solver.solve_soln_sensitivities(&mut soln)?;
361 Ok(soln)
362 }
363 OdeSolverType::TrBdf2 => {
364 let solver = problem.tr_bdf2_sens::<LS>()?;
365 let mut soln = Solution::new_dense(t_eval.to_vec())?;
366 solver.solve_soln_sensitivities(&mut soln)?;
367 Ok(soln)
368 }
369 OdeSolverType::Tsit45 => {
370 let solver = problem.tsit45_sens()?;
371 let mut soln = Solution::new_dense(t_eval.to_vec())?;
372 solver.solve_soln_sensitivities(&mut soln)?;
373 Ok(soln)
374 }
375 }
376 }
377
378 #[allow(clippy::type_complexity)]
379 pub(crate) fn solve_hybrid_fwd_sens<M, CG, LS>(
380 &self,
381 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
382 t_eval: &[M::T],
383 ) -> Result<Solution<M::V>, DiffsolError>
384 where
385 M: Matrix<T: Scalar> + DefaultSolver,
386 CG: CodegenModule,
387 M::V: VectorHost + DefaultDenseMatrix,
388 LS: LinearSolver<M>,
389 for<'b> &'b M::V: VectorRef<M::V>,
390 for<'b> &'b M: MatrixRef<M>,
391 {
392 Self::check_sens_available()?;
393 match self {
394 OdeSolverType::Bdf => {
395 let mut soln = Solution::new_dense(t_eval.to_vec())?;
396 let mut solver = problem.bdf_sens::<LS>()?;
397 while !soln.is_complete() {
398 solver = solver.solve_soln_sensitivities(&mut soln)?;
399 let root_idx = match soln.stop_reason {
400 Some(OdeSolverStopReason::RootFound(_, root_idx))
401 if !soln.is_complete() =>
402 {
403 root_idx
404 }
405 _ => continue,
406 };
407 let state = solver.into_state();
408 problem.eqn.set_model_index(root_idx);
409 let mut restarted_solver = problem.bdf_solver_sens::<LS>(state)?;
410 restarted_solver.reset_with_sens()?;
411 solver = restarted_solver;
412 }
413 Ok(soln)
414 }
415 OdeSolverType::Esdirk34 => {
416 let mut soln = Solution::new_dense(t_eval.to_vec())?;
417 let mut solver = problem.esdirk34_sens::<LS>()?;
418 while !soln.is_complete() {
419 solver = solver.solve_soln_sensitivities(&mut soln)?;
420 let root_idx = match soln.stop_reason {
421 Some(OdeSolverStopReason::RootFound(_, root_idx))
422 if !soln.is_complete() =>
423 {
424 root_idx
425 }
426 _ => continue,
427 };
428 let state = solver.into_state();
429 problem.eqn.set_model_index(root_idx);
430 let mut restarted_solver = problem.esdirk34_solver_sens::<LS>(state)?;
431 restarted_solver.reset_with_sens()?;
432 solver = restarted_solver;
433 }
434 Ok(soln)
435 }
436 OdeSolverType::TrBdf2 => {
437 let mut soln = Solution::new_dense(t_eval.to_vec())?;
438 let mut solver = problem.tr_bdf2_sens::<LS>()?;
439 while !soln.is_complete() {
440 solver = solver.solve_soln_sensitivities(&mut soln)?;
441 let root_idx = match soln.stop_reason {
442 Some(OdeSolverStopReason::RootFound(_, root_idx))
443 if !soln.is_complete() =>
444 {
445 root_idx
446 }
447 _ => continue,
448 };
449 let state = solver.into_state();
450 problem.eqn.set_model_index(root_idx);
451 let mut restarted_solver = problem.tr_bdf2_solver_sens::<LS>(state)?;
452 restarted_solver.reset_with_sens()?;
453 solver = restarted_solver;
454 }
455 Ok(soln)
456 }
457 OdeSolverType::Tsit45 => {
458 let mut soln = Solution::new_dense(t_eval.to_vec())?;
459 let mut solver = problem.tsit45_sens()?;
460 while !soln.is_complete() {
461 solver = solver.solve_soln_sensitivities(&mut soln)?;
462 let root_idx = match soln.stop_reason {
463 Some(OdeSolverStopReason::RootFound(_, root_idx))
464 if !soln.is_complete() =>
465 {
466 root_idx
467 }
468 _ => continue,
469 };
470 let state = solver.into_state();
471 problem.eqn.set_model_index(root_idx);
472 let mut restarted_solver = problem.tsit45_solver_sens(state)?;
473 restarted_solver.reset_with_sens()?;
474 solver = restarted_solver;
475 }
476 Ok(soln)
477 }
478 }
479 }
480
481 pub(crate) fn solve_sum_squares_adj<'a, M, CG, LS>(
482 &self,
483 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
484 data: ArrayView2<'a, M::T>,
485 t_eval: &[M::T],
486 backwards_method: OdeSolverType,
487 backwards_linear_solver: LinearSolverType,
488 ) -> Result<(M::T, M::V), DiffsolError>
489 where
490 M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
491 CG: CodegenModule,
492 M::V: VectorHost + DefaultDenseMatrix,
493 LS: LinearSolver<M>,
494 for<'b> &'b M::V: VectorRef<M::V>,
495 for<'b> &'b M: MatrixRef<M>,
496 {
497 Self::check_sens_available()?;
498 match self {
499 OdeSolverType::Bdf => self._solve_sum_squares_adj(
500 problem.bdf::<LS>()?,
501 data,
502 t_eval,
503 backwards_method,
504 backwards_linear_solver,
505 ),
506 OdeSolverType::Esdirk34 => self._solve_sum_squares_adj(
507 problem.esdirk34::<LS>()?,
508 data,
509 t_eval,
510 backwards_method,
511 backwards_linear_solver,
512 ),
513 OdeSolverType::TrBdf2 => self._solve_sum_squares_adj(
514 problem.tr_bdf2::<LS>()?,
515 data,
516 t_eval,
517 backwards_method,
518 backwards_linear_solver,
519 ),
520 OdeSolverType::Tsit45 => self._solve_sum_squares_adj(
521 problem.tsit45()?,
522 data,
523 t_eval,
524 backwards_method,
525 backwards_linear_solver,
526 ),
527 }
528 }
529
530 pub(crate) fn _solve_sum_squares_adj<'data, 'solver, M, CG, S>(
531 &self,
532 mut solver: S,
533 data: ArrayView2<'data, M::T>,
534 t_eval: &[M::T],
535 backwards_method: OdeSolverType,
536 backwards_linear_solver: LinearSolverType,
537 ) -> Result<(M::T, M::V), DiffsolError>
538 where
539 M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
540 CG: CodegenModule,
541 M::V: VectorHost + DefaultDenseMatrix,
542 S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
543 for<'b> &'b M::V: VectorRef<M::V>,
544 for<'b> &'b M: MatrixRef<M>,
545 {
546 let (chk, ys) = solver.solve_dense_with_checkpointing(t_eval, None)?;
547 let eqn = solver.problem().eqn();
548 let ctx = eqn.context();
549 let mut g_m = <M::V as DefaultDenseMatrix>::M::zeros(eqn.nout(), t_eval.len(), ctx.clone());
550 let mut y = M::T::zero();
551 for j in 0..g_m.ncols() {
552 let ys_col = ys.column(j);
553 let mut tmp = M::V::from_slice(data.column(j).as_slice().unwrap(), ctx.clone());
555 tmp.axpy_v(
557 M::T::from_f64(2.0).unwrap(),
558 &ys_col,
559 M::T::from_f64(-2.0).unwrap(),
560 );
561 g_m.column_mut(j).copy_from(&tmp);
562
563 let norm = tmp.norm(2);
565 y += M::T::from_f64(1.0 / 4.0).unwrap() * norm * norm;
566 }
567 let mut y_sens = match backwards_linear_solver {
568 LinearSolverType::Default => backwards_method
569 .solve_adjoint_backwards::<M, CG, <M as DefaultSolver>::LS, S>(
570 solver.problem(),
571 chk,
572 &g_m,
573 t_eval,
574 Some(1),
575 )?,
576 LinearSolverType::Lu => backwards_method
577 .solve_adjoint_backwards::<M, CG, <M as LuValidator<M>>::LS, S>(
578 solver.problem(),
579 chk,
580 &g_m,
581 t_eval,
582 Some(1),
583 )?,
584 LinearSolverType::Klu => backwards_method
585 .solve_adjoint_backwards::<M, CG, <M as KluValidator<M>>::LS, S>(
586 solver.problem(),
587 chk,
588 &g_m,
589 t_eval,
590 Some(1),
591 )?,
592 };
593 Ok((y, y_sens.pop().unwrap()))
594 }
595
596 pub(crate) fn solve_adjoint_backwards<'solver, M, CG, LS, S>(
597 &self,
598 problem: &'solver OdeSolverProblem<DiffSl<M, CG>>,
599 checkpointing: Checkpointing<'solver, DiffSl<M, CG>, S>,
600 g_m: &<M::V as DefaultDenseMatrix>::M,
601 t_eval: &[M::T],
602 nout_override: Option<usize>,
603 ) -> Result<Vec<M::V>, DiffsolError>
604 where
605 M: Matrix<T: Scalar> + DefaultSolver,
606 CG: CodegenModule,
607 M::V: VectorHost + DefaultDenseMatrix,
608 S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
609 LS: LinearSolver<M>,
610 for<'b> &'b M::V: VectorRef<M::V>,
611 for<'b> &'b M: MatrixRef<M>,
612 {
613 match self {
614 OdeSolverType::Bdf => problem
615 .bdf_solver_adjoint::<LS, _>(checkpointing, nout_override)?
616 .solve_adjoint_backwards_pass(t_eval, &[g_m])
617 .map(|res| res.into_common().sg),
618 OdeSolverType::Esdirk34 => problem
619 .esdirk34_solver_adjoint::<LS, _>(checkpointing, nout_override)?
620 .solve_adjoint_backwards_pass(t_eval, &[g_m])
621 .map(|res| res.into_common().sg),
622 OdeSolverType::TrBdf2 => problem
623 .tr_bdf2_solver_adjoint::<LS, _>(checkpointing, nout_override)?
624 .solve_adjoint_backwards_pass(t_eval, &[g_m])
625 .map(|res| res.into_common().sg),
626 OdeSolverType::Tsit45 => Err(DiffsolError::Other(
627 "Tsit45 solver does not support adjoint sensitivity analysis.".to_string(),
628 )),
629 }
630 }
631}
632
633#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
634mod tests {
635 use diffsol::Vector;
636 use diffsol::{
637 CodegenModuleCompile, CodegenModuleJit, DefaultSolver, DenseMatrix, OdeBuilder,
638 OdeSolverProblem,
639 };
640 use ndarray::Array2;
641
642 use crate::linear_solver_type::LinearSolverType;
643 use crate::test_support::{
644 LOGISTIC_X0, assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state,
645 hybrid_logistic_state_dr, logistic_diffsl_code, logistic_integral, logistic_state,
646 logistic_state_dr,
647 };
648 use crate::valid_linear_solver::LuValidator;
649
650 use super::OdeSolverType;
651
652 type M = diffsol::NalgebraMat<f64>;
653
654 fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
655 where
656 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
657 {
658 OdeBuilder::<M>::new()
659 .p([2.0])
660 .rtol(1e-6)
661 .atol([1e-6])
662 .build_from_diffsl::<CG>(code)
663 .unwrap()
664 }
665
666 fn assert_dense_solution_matches_expected(
667 soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
668 t_eval: &[f64],
669 expected: impl Fn(f64) -> f64,
670 ) {
671 assert_eq!(soln.ts, t_eval);
672 for (i, &t) in t_eval.iter().enumerate() {
673 assert_close(
674 soln.ys.get_index(0, i),
675 expected(t),
676 5e-4,
677 &format!("solution[{i}]"),
678 );
679 }
680 }
681
682 fn test_all_solver_variants<CG>()
683 where
684 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
685 {
686 let t_eval = [0.25, 0.5, 1.0];
687 for method in [
688 OdeSolverType::Bdf,
689 OdeSolverType::Esdirk34,
690 OdeSolverType::TrBdf2,
691 OdeSolverType::Tsit45,
692 ] {
693 let mut problem = build_problem::<CG>(logistic_diffsl_code());
694 let soln = method
695 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
696 .unwrap();
697 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
698 assert_close(
699 soln.ys.get_index(0, soln.ts.len() - 1),
700 logistic_state(LOGISTIC_X0, 2.0, 1.0),
701 5e-4,
702 "solve final value",
703 );
704
705 let mut problem = build_problem::<CG>(logistic_diffsl_code());
706 let soln = method
707 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
708 .unwrap();
709 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
710 logistic_state(LOGISTIC_X0, 2.0, t)
711 });
712 }
713 }
714
715 fn test_all_hybrid_solver_variants<CG>()
716 where
717 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
718 {
719 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
720 for method in [
721 OdeSolverType::Bdf,
722 OdeSolverType::Esdirk34,
723 OdeSolverType::TrBdf2,
724 OdeSolverType::Tsit45,
725 ] {
726 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
727 let soln = method
728 .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
729 .unwrap();
730 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
731 assert_close(
732 soln.ys.get_index(0, soln.ts.len() - 1),
733 hybrid_logistic_state(2.0, 2.0),
734 5e-4,
735 "hybrid final value",
736 );
737
738 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
739 let soln = method
740 .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
741 .unwrap();
742 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
743 hybrid_logistic_state(2.0, t)
744 });
745 }
746 }
747
748 fn test_all_solver_variants_with_lu<CG>()
749 where
750 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
751 {
752 let t_eval = [0.25, 0.5, 1.0];
753 for method in [
754 OdeSolverType::Bdf,
755 OdeSolverType::Esdirk34,
756 OdeSolverType::TrBdf2,
757 OdeSolverType::Tsit45,
758 ] {
759 let mut problem = build_problem::<CG>(logistic_diffsl_code());
760 let soln = method
761 .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
762 .unwrap();
763 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
764
765 let mut problem = build_problem::<CG>(logistic_diffsl_code());
766 let soln = method
767 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
768 .unwrap();
769 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
770 logistic_state(LOGISTIC_X0, 2.0, t)
771 });
772 }
773 }
774
775 fn test_all_hybrid_solver_variants_with_lu<CG>()
776 where
777 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
778 {
779 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
780 for method in [
781 OdeSolverType::Bdf,
782 OdeSolverType::Esdirk34,
783 OdeSolverType::TrBdf2,
784 OdeSolverType::Tsit45,
785 ] {
786 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
787 let soln = method
788 .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
789 .unwrap();
790 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
791
792 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
793 let soln = method
794 .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
795 .unwrap();
796 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
797 hybrid_logistic_state(2.0, t)
798 });
799 }
800 }
801
802 fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
803 where
804 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
805 {
806 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
807
808 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
809 let soln = method
810 .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
811 .unwrap();
812 assert_close(
813 *soln.ts.last().unwrap(),
814 2.0,
815 5e-4,
816 "direct hybrid restart final time",
817 );
818 assert_close(
819 soln.ys.get_index(0, soln.ts.len() - 1),
820 hybrid_logistic_state(2.0, 2.0),
821 5e-4,
822 "direct hybrid restart final value",
823 );
824
825 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
826 let soln = method
827 .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
828 .unwrap();
829 assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
830 }
831
832 #[cfg(feature = "diffsl-llvm")]
833 fn test_all_sensitivity_solver_variants() {
834 let t_eval = [0.25, 0.5, 1.0];
835 for method in [
836 OdeSolverType::Bdf,
837 OdeSolverType::Esdirk34,
838 OdeSolverType::TrBdf2,
839 OdeSolverType::Tsit45,
840 ] {
841 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
842 let soln = method
843 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
844 &mut problem,
845 &t_eval,
846 )
847 .unwrap();
848 for (i, &t) in t_eval.iter().enumerate() {
849 assert_close(
850 soln.y_sens[0].get_index(0, i),
851 logistic_state_dr(LOGISTIC_X0, 2.0, t),
852 5e-4,
853 &format!("fwd_sens[{i}]"),
854 );
855 }
856
857 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
858 let soln = method
859 .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
860 &mut problem,
861 &t_eval,
862 )
863 .unwrap();
864 for (i, &t) in t_eval.iter().enumerate() {
865 assert_close(
866 soln.y_sens[0].get_index(0, i),
867 hybrid_logistic_state_dr(2.0, t),
868 5e-4,
869 &format!("hybrid_fwd_sens[{i}]"),
870 );
871 }
872 }
873 }
874
875 #[cfg(feature = "diffsl-llvm")]
876 fn test_lu_sensitivity_and_adjoint_solver_variants() {
877 let t_eval = [0.25, 0.5, 1.0];
878 for method in [
879 OdeSolverType::Bdf,
880 OdeSolverType::Esdirk34,
881 OdeSolverType::TrBdf2,
882 OdeSolverType::Tsit45,
883 ] {
884 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
885 let soln = method
886 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
887 &mut problem,
888 &t_eval,
889 )
890 .unwrap();
891 for (i, &t) in t_eval.iter().enumerate() {
892 assert_close(
893 soln.y_sens[0].get_index(0, i),
894 logistic_state_dr(LOGISTIC_X0, 2.0, t),
895 5e-4,
896 &format!("lu fwd_sens[{i}]"),
897 );
898 }
899 }
900
901 let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
902 let data = Array2::from_shape_vec(
903 (1, adjoint_t_eval.len()),
904 adjoint_t_eval
905 .iter()
906 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
907 .collect(),
908 )
909 .unwrap();
910
911 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
912 let (objective, gradient) = OdeSolverType::Bdf
913 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
914 &mut problem,
915 data.view(),
916 &adjoint_t_eval,
917 OdeSolverType::TrBdf2,
918 LinearSolverType::Lu,
919 )
920 .unwrap();
921 assert!(objective.is_finite());
922 assert_eq!(gradient.len(), 1);
923 assert!(gradient.get_index(0).is_finite());
924 }
925
926 #[cfg(feature = "diffsl-llvm")]
927 fn test_direct_hybrid_sensitivity_restart_paths() {
928 let t_eval = [0.25, 0.5, 1.0];
929 for method in [
930 OdeSolverType::Esdirk34,
931 OdeSolverType::TrBdf2,
932 OdeSolverType::Tsit45,
933 ] {
934 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
935 let soln = method
936 .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
937 &mut problem,
938 &t_eval,
939 )
940 .unwrap();
941 for (i, &t) in t_eval.iter().enumerate() {
942 assert_close(
943 soln.y_sens[0].get_index(0, i),
944 hybrid_logistic_state_dr(2.0, t),
945 5e-4,
946 &format!("direct hybrid fwd sens[{i}]"),
947 );
948 }
949 }
950 }
951
952 #[cfg(feature = "diffsl-llvm")]
953 fn test_adjoint_backwards_methods_and_klu_branch() {
954 let t_eval = [0.0, 0.25, 0.5, 1.0];
955 let data = Array2::from_shape_vec(
956 (1, t_eval.len()),
957 t_eval
958 .iter()
959 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
960 .collect(),
961 )
962 .unwrap();
963
964 for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
965 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
966 let (objective, gradient) = OdeSolverType::Bdf
967 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
968 &mut problem,
969 data.view(),
970 &t_eval,
971 backwards_method,
972 LinearSolverType::Klu,
973 )
974 .unwrap();
975 assert!(objective.is_finite());
976 assert_eq!(gradient.len(), 1);
977 assert!(gradient.get_index(0).is_finite());
978 }
979
980 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
981 let err = OdeSolverType::Bdf
982 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
983 &mut problem,
984 data.view(),
985 &t_eval,
986 OdeSolverType::Tsit45,
987 LinearSolverType::Default,
988 )
989 .unwrap_err();
990 assert!(
991 err.to_string()
992 .contains("Tsit45 solver does not support adjoint sensitivity analysis")
993 );
994 }
995
996 #[cfg(feature = "diffsl-llvm")]
997 fn test_all_adjoint_solver_variants() {
998 let t_eval = [0.0, 0.25, 0.5, 1.0];
999 let data = Array2::from_shape_vec(
1000 (1, t_eval.len()),
1001 t_eval
1002 .iter()
1003 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
1004 .collect(),
1005 )
1006 .unwrap();
1007
1008 for method in [
1009 OdeSolverType::Bdf,
1010 OdeSolverType::Esdirk34,
1011 OdeSolverType::TrBdf2,
1012 OdeSolverType::Tsit45,
1013 ] {
1014 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1015 let (objective, gradient) = method
1016 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1017 &mut problem,
1018 data.view(),
1019 &t_eval,
1020 OdeSolverType::Bdf,
1021 crate::linear_solver_type::LinearSolverType::Default,
1022 )
1023 .unwrap();
1024 assert!(objective.is_finite());
1025 assert_eq!(gradient.len(), 1);
1026 assert!(gradient.get_index(0).is_finite());
1027 }
1028 }
1029
1030 #[cfg(feature = "diffsl-cranelift")]
1031 #[test]
1032 fn runtime_dispatch_solves_all_variants_for_cranelift() {
1033 test_all_solver_variants::<diffsol::CraneliftJitModule>();
1034 test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1035 }
1036
1037 #[cfg(feature = "diffsl-cranelift")]
1038 #[test]
1039 fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1040 test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1041 test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1042 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1043 OdeSolverType::Esdirk34,
1044 );
1045 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1046 OdeSolverType::TrBdf2,
1047 );
1048 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1049 OdeSolverType::Tsit45,
1050 );
1051 }
1052
1053 #[cfg(feature = "diffsl-llvm")]
1054 #[test]
1055 fn runtime_dispatch_solves_all_variants_for_llvm() {
1056 test_all_solver_variants::<diffsol::LlvmModule>();
1057 test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1058 }
1059
1060 #[cfg(feature = "diffsl-llvm")]
1061 #[test]
1062 fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1063 test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1064 test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1065 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1066 OdeSolverType::Esdirk34,
1067 );
1068 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1069 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1070 }
1071
1072 #[cfg(feature = "diffsl-llvm")]
1073 #[test]
1074 fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1075 test_all_sensitivity_solver_variants();
1076 test_lu_sensitivity_and_adjoint_solver_variants();
1077 test_direct_hybrid_sensitivity_restart_paths();
1078 }
1079
1080 #[cfg(feature = "diffsl-llvm")]
1081 #[test]
1082 fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1083 test_all_adjoint_solver_variants();
1084 test_adjoint_backwards_methods_and_klu_branch();
1085 }
1086}