1use diffsol::error::{DiffsolError, OdeSolverError};
5use diffsol::ode_equations::OdeEquationsImplicitSensWithReset;
6use diffsol::{
7 matrix::MatrixRef, DefaultDenseMatrix, DiffSl, LinearSolver, Matrix, OdeSolverMethod,
8 OdeSolverProblem, OdeSolverState, Vector, VectorHost, VectorRef,
9};
10use diffsol::{
11 ode_solver_error, AdjointOdeSolverMethod, Checkpointing, CodegenModule, DefaultSolver,
12 DenseMatrix, MatrixCommon, OdeEquations, OdeSolverStopReason, Op, SensitivitiesOdeSolverMethod,
13 Solution, VectorViewMut,
14};
15use ndarray::ArrayView2;
16use num_traits::{FromPrimitive, Zero}; use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19
20use crate::scalar_type::Scalar;
21use crate::utils::is_sens_available;
22use crate::{
23 linear_solver_type::LinearSolverType,
24 valid_linear_solver::{KluValidator, LuValidator},
25};
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
34#[serde(rename_all = "snake_case")]
35pub enum OdeSolverType {
36 Bdf,
37 Esdirk34,
38 TrBdf2,
39 Tsit45,
40}
41
42fn apply_state_reset<Eqn, S>(
43 problem: &OdeSolverProblem<Eqn>,
44 state: &mut S,
45) -> Result<(), DiffsolError>
46where
47 Eqn: OdeEquations,
48 S: OdeSolverState<Eqn::V>,
49{
50 let eqn = &problem.eqn;
51 if let Some(reset_fn) = eqn.reset() {
52 state.state_mut_op(eqn, &reset_fn)?;
53 }
54 Ok(())
55}
56
57fn apply_state_reset_with_sens<Eqn, S>(
58 problem: &OdeSolverProblem<Eqn>,
59 state: &mut S,
60 root_idx: usize,
61) -> Result<(), DiffsolError>
62where
63 Eqn: OdeEquationsImplicitSensWithReset,
64 S: OdeSolverState<Eqn::V>,
65{
66 let eqn = &problem.eqn;
67 match (eqn.reset(), eqn.root()) {
68 (None, _) => Ok(()),
69 (Some(_), None) => Err(ode_solver_error!(ResetRequiresRootOperator)),
70 (Some(reset_fn), Some(root_fn)) => {
71 state.state_mut_op_with_sens_and_reset(eqn, &reset_fn, &root_fn, root_idx)?;
72 Ok(())
73 }
74 }
75}
76
77impl OdeSolverType {
78 pub(crate) fn solve<M, CG, LS>(
79 &self,
80 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
81 final_time: M::T,
82 ) -> Result<Solution<M::V>, DiffsolError>
83 where
84 M: Matrix<T: Scalar>,
85 CG: CodegenModule,
86 M::V: VectorHost + DefaultDenseMatrix,
87 LS: LinearSolver<M>,
88 for<'b> &'b M::V: VectorRef<M::V>,
89 for<'b> &'b M: MatrixRef<M>,
90 {
91 match self {
92 OdeSolverType::Bdf => {
93 let solver = problem.bdf::<LS>()?;
94 let mut soln = Solution::new(final_time);
95 solver.solve_soln(&mut soln)?;
96 Ok(soln)
97 }
98 OdeSolverType::Esdirk34 => {
99 let solver = problem.esdirk34::<LS>()?;
100 let mut soln = Solution::new(final_time);
101 solver.solve_soln(&mut soln)?;
102 Ok(soln)
103 }
104 OdeSolverType::TrBdf2 => {
105 let solver = problem.tr_bdf2::<LS>()?;
106 let mut soln = Solution::new(final_time);
107 solver.solve_soln(&mut soln)?;
108 Ok(soln)
109 }
110 OdeSolverType::Tsit45 => {
111 let solver = problem.tsit45()?;
112 let mut soln = Solution::new(final_time);
113 solver.solve_soln(&mut soln)?;
114 Ok(soln)
115 }
116 }
117 }
118
119 pub(crate) fn solve_dense<M, CG, LS>(
120 &self,
121 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
122 t_eval: &[M::T],
123 ) -> Result<Solution<M::V>, DiffsolError>
124 where
125 M: Matrix<T: Scalar>,
126 CG: CodegenModule,
127 M::V: VectorHost + DefaultDenseMatrix,
128 LS: LinearSolver<M>,
129 for<'b> &'b M::V: VectorRef<M::V>,
130 for<'b> &'b M: MatrixRef<M>,
131 {
132 match self {
133 OdeSolverType::Bdf => {
134 let solver = problem.bdf::<LS>()?;
135 let mut soln = Solution::new_dense(t_eval.to_vec())?;
136 solver.solve_soln(&mut soln)?;
137 Ok(soln)
138 }
139 OdeSolverType::Esdirk34 => {
140 let solver = problem.esdirk34::<LS>()?;
141 let mut soln = Solution::new_dense(t_eval.to_vec())?;
142 solver.solve_soln(&mut soln)?;
143 Ok(soln)
144 }
145 OdeSolverType::TrBdf2 => {
146 let solver = problem.tr_bdf2::<LS>()?;
147 let mut soln = Solution::new_dense(t_eval.to_vec())?;
148 solver.solve_soln(&mut soln)?;
149 Ok(soln)
150 }
151 OdeSolverType::Tsit45 => {
152 let solver = problem.tsit45()?;
153 let mut soln = Solution::new_dense(t_eval.to_vec())?;
154 solver.solve_soln(&mut soln)?;
155 Ok(soln)
156 }
157 }
158 }
159
160 pub(crate) fn solve_hybrid<M, CG, LS>(
161 &self,
162 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
163 final_time: M::T,
164 ) -> Result<Solution<M::V>, DiffsolError>
165 where
166 M: Matrix<T: Scalar>,
167 CG: CodegenModule,
168 M::V: VectorHost + DefaultDenseMatrix,
169 LS: LinearSolver<M>,
170 for<'b> &'b M::V: VectorRef<M::V>,
171 for<'b> &'b M: MatrixRef<M>,
172 {
173 match self {
174 OdeSolverType::Bdf => {
175 let mut soln = Solution::new(final_time);
176 let mut solver = problem.bdf::<LS>()?;
177 while !soln.is_complete() {
178 solver = solver.solve_soln(&mut soln)?;
179 let root_idx = match soln.stop_reason {
180 Some(OdeSolverStopReason::RootFound(_, root_idx))
181 if !soln.is_complete() =>
182 {
183 root_idx
184 }
185 _ => continue,
186 };
187 let mut state = solver.into_state();
188 problem.eqn.set_model_index(root_idx);
189 apply_state_reset(problem, &mut state)?;
190 solver = problem.bdf_solver::<LS>(state)?;
191 }
192 Ok(soln)
193 }
194 OdeSolverType::Esdirk34 => {
195 let mut soln = Solution::new(final_time);
196 let mut solver = problem.esdirk34::<LS>()?;
197 while !soln.is_complete() {
198 solver = solver.solve_soln(&mut soln)?;
199 let root_idx = match soln.stop_reason {
200 Some(OdeSolverStopReason::RootFound(_, root_idx))
201 if !soln.is_complete() =>
202 {
203 root_idx
204 }
205 _ => continue,
206 };
207 let mut state = solver.into_state();
208 problem.eqn.set_model_index(root_idx);
209 apply_state_reset(problem, &mut state)?;
210 solver = problem.esdirk34_solver::<LS>(state)?;
211 }
212 Ok(soln)
213 }
214 OdeSolverType::TrBdf2 => {
215 let mut soln = Solution::new(final_time);
216 let mut solver = problem.tr_bdf2::<LS>()?;
217 while !soln.is_complete() {
218 solver = solver.solve_soln(&mut soln)?;
219 let root_idx = match soln.stop_reason {
220 Some(OdeSolverStopReason::RootFound(_, root_idx))
221 if !soln.is_complete() =>
222 {
223 root_idx
224 }
225 _ => continue,
226 };
227 let mut state = solver.into_state();
228 problem.eqn.set_model_index(root_idx);
229 apply_state_reset(problem, &mut state)?;
230 solver = problem.tr_bdf2_solver::<LS>(state)?;
231 }
232 Ok(soln)
233 }
234 OdeSolverType::Tsit45 => {
235 let mut soln = Solution::new(final_time);
236 let mut solver = problem.tsit45()?;
237 while !soln.is_complete() {
238 solver = solver.solve_soln(&mut soln)?;
239 let root_idx = match soln.stop_reason {
240 Some(OdeSolverStopReason::RootFound(_, root_idx))
241 if !soln.is_complete() =>
242 {
243 root_idx
244 }
245 _ => continue,
246 };
247 let mut state = solver.into_state();
248 problem.eqn.set_model_index(root_idx);
249 apply_state_reset(problem, &mut state)?;
250 solver = problem.tsit45_solver(state)?;
251 }
252 Ok(soln)
253 }
254 }
255 }
256
257 pub(crate) fn solve_hybrid_dense<M, CG, LS>(
258 &self,
259 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
260 t_eval: &[M::T],
261 ) -> Result<Solution<M::V>, DiffsolError>
262 where
263 M: Matrix<T: Scalar>,
264 CG: CodegenModule,
265 M::V: VectorHost + DefaultDenseMatrix,
266 LS: LinearSolver<M>,
267 for<'b> &'b M::V: VectorRef<M::V>,
268 for<'b> &'b M: MatrixRef<M>,
269 {
270 match self {
271 OdeSolverType::Bdf => {
272 let mut soln = Solution::new_dense(t_eval.to_vec())?;
273 let mut solver = problem.bdf::<LS>()?;
274 while !soln.is_complete() {
275 solver = solver.solve_soln(&mut soln)?;
276 let root_idx = match soln.stop_reason {
277 Some(OdeSolverStopReason::RootFound(_, root_idx))
278 if !soln.is_complete() =>
279 {
280 root_idx
281 }
282 _ => continue,
283 };
284 let mut state = solver.into_state();
285 problem.eqn.set_model_index(root_idx);
286 apply_state_reset(problem, &mut state)?;
287 solver = problem.bdf_solver::<LS>(state)?;
288 }
289 Ok(soln)
290 }
291 OdeSolverType::Esdirk34 => {
292 let mut soln = Solution::new_dense(t_eval.to_vec())?;
293 let mut solver = problem.esdirk34::<LS>()?;
294 while !soln.is_complete() {
295 solver = solver.solve_soln(&mut soln)?;
296 let root_idx = match soln.stop_reason {
297 Some(OdeSolverStopReason::RootFound(_, root_idx))
298 if !soln.is_complete() =>
299 {
300 root_idx
301 }
302 _ => continue,
303 };
304 let mut state = solver.into_state();
305 problem.eqn.set_model_index(root_idx);
306 apply_state_reset(problem, &mut state)?;
307 solver = problem.esdirk34_solver::<LS>(state)?;
308 }
309 Ok(soln)
310 }
311 OdeSolverType::TrBdf2 => {
312 let mut soln = Solution::new_dense(t_eval.to_vec())?;
313 let mut solver = problem.tr_bdf2::<LS>()?;
314 while !soln.is_complete() {
315 solver = solver.solve_soln(&mut soln)?;
316 let root_idx = match soln.stop_reason {
317 Some(OdeSolverStopReason::RootFound(_, root_idx))
318 if !soln.is_complete() =>
319 {
320 root_idx
321 }
322 _ => continue,
323 };
324 let mut state = solver.into_state();
325 problem.eqn.set_model_index(root_idx);
326 apply_state_reset(problem, &mut state)?;
327 solver = problem.tr_bdf2_solver::<LS>(state)?;
328 }
329 Ok(soln)
330 }
331 OdeSolverType::Tsit45 => {
332 let mut soln = Solution::new_dense(t_eval.to_vec())?;
333 let mut solver = problem.tsit45()?;
334 while !soln.is_complete() {
335 solver = solver.solve_soln(&mut soln)?;
336 let root_idx = match soln.stop_reason {
337 Some(OdeSolverStopReason::RootFound(_, root_idx))
338 if !soln.is_complete() =>
339 {
340 root_idx
341 }
342 _ => continue,
343 };
344 let mut state = solver.into_state();
345 problem.eqn.set_model_index(root_idx);
346 apply_state_reset(problem, &mut state)?;
347 solver = problem.tsit45_solver(state)?;
348 }
349 Ok(soln)
350 }
351 }
352 }
353
354 fn check_sens_available() -> Result<(), DiffsolError> {
355 if !is_sens_available() {
356 return Err(DiffsolError::Other(
357 "Sensitivity analysis is not supported on Windows, please use a linux or macOS system.".to_string(),
358 ));
359 }
360 Ok(())
361 }
362
363 #[allow(clippy::type_complexity)]
364 pub(crate) fn solve_fwd_sens<M, CG, LS>(
365 &self,
366 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
367 t_eval: &[M::T],
368 ) -> Result<Solution<M::V>, DiffsolError>
369 where
370 M: Matrix<T: Scalar> + DefaultSolver,
371 CG: CodegenModule,
372 M::V: VectorHost + DefaultDenseMatrix,
373 LS: LinearSolver<M>,
374 for<'b> &'b M::V: VectorRef<M::V>,
375 for<'b> &'b M: MatrixRef<M>,
376 {
377 Self::check_sens_available()?;
378 match self {
379 OdeSolverType::Bdf => {
380 let solver = problem.bdf_sens::<LS>()?;
381 let mut soln = Solution::new_dense(t_eval.to_vec())?;
382 solver.solve_soln_sensitivities(&mut soln)?;
383 Ok(soln)
384 }
385 OdeSolverType::Esdirk34 => {
386 let solver = problem.esdirk34_sens::<LS>()?;
387 let mut soln = Solution::new_dense(t_eval.to_vec())?;
388 solver.solve_soln_sensitivities(&mut soln)?;
389 Ok(soln)
390 }
391 OdeSolverType::TrBdf2 => {
392 let solver = problem.tr_bdf2_sens::<LS>()?;
393 let mut soln = Solution::new_dense(t_eval.to_vec())?;
394 solver.solve_soln_sensitivities(&mut soln)?;
395 Ok(soln)
396 }
397 OdeSolverType::Tsit45 => {
398 let solver = problem.tsit45_sens()?;
399 let mut soln = Solution::new_dense(t_eval.to_vec())?;
400 solver.solve_soln_sensitivities(&mut soln)?;
401 Ok(soln)
402 }
403 }
404 }
405
406 #[allow(clippy::type_complexity)]
407 pub(crate) fn solve_hybrid_fwd_sens<M, CG, LS>(
408 &self,
409 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
410 t_eval: &[M::T],
411 ) -> Result<Solution<M::V>, DiffsolError>
412 where
413 M: Matrix<T: Scalar> + DefaultSolver,
414 CG: CodegenModule,
415 M::V: VectorHost + DefaultDenseMatrix,
416 LS: LinearSolver<M>,
417 for<'b> &'b M::V: VectorRef<M::V>,
418 for<'b> &'b M: MatrixRef<M>,
419 {
420 Self::check_sens_available()?;
421 match self {
422 OdeSolverType::Bdf => {
423 let mut soln = Solution::new_dense(t_eval.to_vec())?;
424 let mut solver = problem.bdf_sens::<LS>()?;
425 while !soln.is_complete() {
426 solver = solver.solve_soln_sensitivities(&mut soln)?;
427 let root_idx = match soln.stop_reason {
428 Some(OdeSolverStopReason::RootFound(_, root_idx))
429 if !soln.is_complete() =>
430 {
431 root_idx
432 }
433 _ => continue,
434 };
435 let mut state = solver.into_state();
436 problem.eqn.set_model_index(root_idx);
437 apply_state_reset_with_sens(problem, &mut state, root_idx)?;
438 solver = problem.bdf_solver_sens::<LS>(state)?;
439 }
440 Ok(soln)
441 }
442 OdeSolverType::Esdirk34 => {
443 let mut soln = Solution::new_dense(t_eval.to_vec())?;
444 let mut solver = problem.esdirk34_sens::<LS>()?;
445 while !soln.is_complete() {
446 solver = solver.solve_soln_sensitivities(&mut soln)?;
447 let root_idx = match soln.stop_reason {
448 Some(OdeSolverStopReason::RootFound(_, root_idx))
449 if !soln.is_complete() =>
450 {
451 root_idx
452 }
453 _ => continue,
454 };
455 let mut state = solver.into_state();
456 problem.eqn.set_model_index(root_idx);
457 apply_state_reset_with_sens(problem, &mut state, root_idx)?;
458 solver = problem.esdirk34_solver_sens::<LS>(state)?;
459 }
460 Ok(soln)
461 }
462 OdeSolverType::TrBdf2 => {
463 let mut soln = Solution::new_dense(t_eval.to_vec())?;
464 let mut solver = problem.tr_bdf2_sens::<LS>()?;
465 while !soln.is_complete() {
466 solver = solver.solve_soln_sensitivities(&mut soln)?;
467 let root_idx = match soln.stop_reason {
468 Some(OdeSolverStopReason::RootFound(_, root_idx))
469 if !soln.is_complete() =>
470 {
471 root_idx
472 }
473 _ => continue,
474 };
475 let mut state = solver.into_state();
476 problem.eqn.set_model_index(root_idx);
477 apply_state_reset_with_sens(problem, &mut state, root_idx)?;
478 solver = problem.tr_bdf2_solver_sens::<LS>(state)?;
479 }
480 Ok(soln)
481 }
482 OdeSolverType::Tsit45 => {
483 let mut soln = Solution::new_dense(t_eval.to_vec())?;
484 let mut solver = problem.tsit45_sens()?;
485 while !soln.is_complete() {
486 solver = solver.solve_soln_sensitivities(&mut soln)?;
487 let root_idx = match soln.stop_reason {
488 Some(OdeSolverStopReason::RootFound(_, root_idx))
489 if !soln.is_complete() =>
490 {
491 root_idx
492 }
493 _ => continue,
494 };
495 let mut state = solver.into_state();
496 problem.eqn.set_model_index(root_idx);
497 apply_state_reset_with_sens(problem, &mut state, root_idx)?;
498 solver = problem.tsit45_solver_sens(state)?;
499 }
500 Ok(soln)
501 }
502 }
503 }
504
505 pub(crate) fn solve_sum_squares_adj<'a, M, CG, LS>(
506 &self,
507 problem: &mut OdeSolverProblem<DiffSl<M, CG>>,
508 data: ArrayView2<'a, M::T>,
509 t_eval: &[M::T],
510 backwards_method: OdeSolverType,
511 backwards_linear_solver: LinearSolverType,
512 ) -> Result<(M::T, M::V), DiffsolError>
513 where
514 M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
515 CG: CodegenModule,
516 M::V: VectorHost + DefaultDenseMatrix,
517 LS: LinearSolver<M>,
518 for<'b> &'b M::V: VectorRef<M::V>,
519 for<'b> &'b M: MatrixRef<M>,
520 {
521 Self::check_sens_available()?;
522 match self {
523 OdeSolverType::Bdf => self._solve_sum_squares_adj(
524 problem.bdf::<LS>()?,
525 data,
526 t_eval,
527 backwards_method,
528 backwards_linear_solver,
529 ),
530 OdeSolverType::Esdirk34 => self._solve_sum_squares_adj(
531 problem.esdirk34::<LS>()?,
532 data,
533 t_eval,
534 backwards_method,
535 backwards_linear_solver,
536 ),
537 OdeSolverType::TrBdf2 => self._solve_sum_squares_adj(
538 problem.tr_bdf2::<LS>()?,
539 data,
540 t_eval,
541 backwards_method,
542 backwards_linear_solver,
543 ),
544 OdeSolverType::Tsit45 => self._solve_sum_squares_adj(
545 problem.tsit45()?,
546 data,
547 t_eval,
548 backwards_method,
549 backwards_linear_solver,
550 ),
551 }
552 }
553
554 pub(crate) fn _solve_sum_squares_adj<'data, 'solver, M, CG, S>(
555 &self,
556 mut solver: S,
557 data: ArrayView2<'data, M::T>,
558 t_eval: &[M::T],
559 backwards_method: OdeSolverType,
560 backwards_linear_solver: LinearSolverType,
561 ) -> Result<(M::T, M::V), DiffsolError>
562 where
563 M: Matrix<T: Scalar> + DefaultSolver + LuValidator<M> + KluValidator<M>,
564 CG: CodegenModule,
565 M::V: VectorHost + DefaultDenseMatrix,
566 S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
567 for<'b> &'b M::V: VectorRef<M::V>,
568 for<'b> &'b M: MatrixRef<M>,
569 {
570 let (chk, ys, stop_reason) = solver.solve_dense_with_checkpointing(t_eval, None)?;
571 let eqn = solver.problem().eqn();
572 let ctx = eqn.context();
573 let mut g_m = <M::V as DefaultDenseMatrix>::M::zeros(eqn.nout(), t_eval.len(), ctx.clone());
574 let mut y = M::T::zero();
575 for j in 0..g_m.ncols() {
576 let ys_col = ys.column(j);
577 let mut tmp = M::V::from_slice(data.column(j).as_slice().unwrap(), ctx.clone());
579 tmp.axpy_v(
581 M::T::from_f64(2.0).unwrap(),
582 &ys_col,
583 M::T::from_f64(-2.0).unwrap(),
584 );
585 g_m.column_mut(j).copy_from(&tmp);
586
587 let norm = tmp.norm(2);
589 y += M::T::from_f64(1.0 / 4.0).unwrap() * norm * norm;
590 }
591 let mut y_sens = match backwards_linear_solver {
592 LinearSolverType::Default => backwards_method
593 .solve_adjoint_backwards::<M, CG, <M as DefaultSolver>::LS, S>(
594 solver.problem(),
595 chk,
596 stop_reason,
597 &g_m,
598 t_eval,
599 Some(1),
600 )?,
601 LinearSolverType::Lu => backwards_method
602 .solve_adjoint_backwards::<M, CG, <M as LuValidator<M>>::LS, S>(
603 solver.problem(),
604 chk,
605 stop_reason,
606 &g_m,
607 t_eval,
608 Some(1),
609 )?,
610 LinearSolverType::Klu => backwards_method
611 .solve_adjoint_backwards::<M, CG, <M as KluValidator<M>>::LS, S>(
612 solver.problem(),
613 chk,
614 stop_reason,
615 &g_m,
616 t_eval,
617 Some(1),
618 )?,
619 };
620 Ok((y, y_sens.pop().unwrap()))
621 }
622
623 pub(crate) fn solve_adjoint_backwards<'solver, M, CG, LS, S>(
624 &self,
625 problem: &'solver OdeSolverProblem<DiffSl<M, CG>>,
626 checkpointing: Checkpointing<'solver, DiffSl<M, CG>, S>,
627 _stop_reason: OdeSolverStopReason<M::T>,
628 g_m: &<M::V as DefaultDenseMatrix>::M,
629 t_eval: &[M::T],
630 nout_override: Option<usize>,
631 ) -> Result<Vec<M::V>, DiffsolError>
632 where
633 M: Matrix<T: Scalar> + DefaultSolver,
634 CG: CodegenModule,
635 M::V: VectorHost + DefaultDenseMatrix,
636 S: OdeSolverMethod<'solver, DiffSl<M, CG>>,
637 LS: LinearSolver<M>,
638 for<'b> &'b M::V: VectorRef<M::V>,
639 for<'b> &'b M: MatrixRef<M>,
640 {
641 match self {
642 OdeSolverType::Bdf => problem
643 .bdf_solver_adjoint::<LS, _>(checkpointing, nout_override)?
644 .solve_adjoint_backwards_pass(None, t_eval, &[g_m])
645 .map(|res| res.into_common().sg),
646 OdeSolverType::Esdirk34 => problem
647 .esdirk34_solver_adjoint::<LS, _>(checkpointing, nout_override)?
648 .solve_adjoint_backwards_pass(None, t_eval, &[g_m])
649 .map(|res| res.into_common().sg),
650 OdeSolverType::TrBdf2 => problem
651 .tr_bdf2_solver_adjoint::<LS, _>(checkpointing, nout_override)?
652 .solve_adjoint_backwards_pass(None, t_eval, &[g_m])
653 .map(|res| res.into_common().sg),
654 OdeSolverType::Tsit45 => Err(DiffsolError::Other(
655 "Tsit45 solver does not support adjoint sensitivity analysis.".to_string(),
656 )),
657 }
658 }
659}
660
661#[cfg(all(test, any(feature = "diffsl-cranelift", feature = "diffsl-llvm")))]
662mod tests {
663 use diffsol::{
664 CodegenModuleCompile, CodegenModuleJit, DefaultSolver, DenseMatrix, OdeBuilder,
665 OdeSolverProblem, Vector,
666 };
667
668 #[cfg(feature = "diffsl-llvm")]
669 use crate::linear_solver_type::LinearSolverType;
670 use crate::test_support::{
671 assert_close, hybrid_logistic_diffsl_code, hybrid_logistic_state, logistic_diffsl_code,
672 logistic_state, LOGISTIC_X0,
673 };
674 #[cfg(feature = "diffsl-llvm")]
675 use crate::test_support::{hybrid_logistic_state_dr, logistic_integral, logistic_state_dr};
676 use crate::valid_linear_solver::LuValidator;
677 #[cfg(feature = "diffsl-llvm")]
678 use ndarray::Array2;
679
680 use super::OdeSolverType;
681
682 type M = diffsol::NalgebraMat<f64>;
683
684 fn build_problem<CG>(code: &str) -> OdeSolverProblem<diffsol::DiffSl<M, CG>>
685 where
686 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
687 {
688 OdeBuilder::<M>::new()
689 .p([2.0])
690 .rtol(1e-6)
691 .atol([1e-6])
692 .build_from_diffsl::<CG>(code)
693 .unwrap()
694 }
695
696 fn assert_dense_solution_matches_expected(
697 soln: &diffsol::Solution<diffsol::NalgebraVec<f64>>,
698 t_eval: &[f64],
699 expected: impl Fn(f64) -> f64,
700 ) {
701 assert_eq!(soln.ts, t_eval);
702 for (i, &t) in t_eval.iter().enumerate() {
703 assert_close(
704 soln.ys.get_index(0, i),
705 expected(t),
706 5e-4,
707 &format!("solution[{i}]"),
708 );
709 }
710 }
711
712 fn test_all_solver_variants<CG>()
713 where
714 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
715 {
716 let t_eval = [0.25, 0.5, 1.0];
717 for method in [
718 OdeSolverType::Bdf,
719 OdeSolverType::Esdirk34,
720 OdeSolverType::TrBdf2,
721 OdeSolverType::Tsit45,
722 ] {
723 let mut problem = build_problem::<CG>(logistic_diffsl_code());
724 let soln = method
725 .solve::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 1.0)
726 .unwrap();
727 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "solve final time");
728 assert_close(
729 soln.ys.get_index(0, soln.ts.len() - 1),
730 logistic_state(LOGISTIC_X0, 2.0, 1.0),
731 5e-4,
732 "solve final value",
733 );
734
735 let mut problem = build_problem::<CG>(logistic_diffsl_code());
736 let soln = method
737 .solve_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
738 .unwrap();
739 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
740 logistic_state(LOGISTIC_X0, 2.0, t)
741 });
742 }
743 }
744
745 fn test_all_hybrid_solver_variants<CG>()
746 where
747 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
748 {
749 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
750 for method in [
751 OdeSolverType::Bdf,
752 OdeSolverType::Esdirk34,
753 OdeSolverType::TrBdf2,
754 OdeSolverType::Tsit45,
755 ] {
756 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
757 let soln = method
758 .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
759 .unwrap();
760 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "hybrid final time");
761 assert_close(
762 soln.ys.get_index(0, soln.ts.len() - 1),
763 hybrid_logistic_state(2.0, 2.0),
764 5e-4,
765 "hybrid final value",
766 );
767
768 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
769 let soln = method
770 .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
771 .unwrap();
772 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
773 hybrid_logistic_state(2.0, t)
774 });
775 }
776 }
777
778 fn test_all_solver_variants_with_lu<CG>()
779 where
780 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
781 {
782 let t_eval = [0.25, 0.5, 1.0];
783 for method in [
784 OdeSolverType::Bdf,
785 OdeSolverType::Esdirk34,
786 OdeSolverType::TrBdf2,
787 OdeSolverType::Tsit45,
788 ] {
789 let mut problem = build_problem::<CG>(logistic_diffsl_code());
790 let soln = method
791 .solve::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 1.0)
792 .unwrap();
793 assert_close(*soln.ts.last().unwrap(), 1.0, 5e-4, "lu solve final time");
794
795 let mut problem = build_problem::<CG>(logistic_diffsl_code());
796 let soln = method
797 .solve_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
798 .unwrap();
799 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
800 logistic_state(LOGISTIC_X0, 2.0, t)
801 });
802 }
803 }
804
805 fn test_all_hybrid_solver_variants_with_lu<CG>()
806 where
807 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
808 {
809 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
810 for method in [
811 OdeSolverType::Bdf,
812 OdeSolverType::Esdirk34,
813 OdeSolverType::TrBdf2,
814 OdeSolverType::Tsit45,
815 ] {
816 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
817 let soln = method
818 .solve_hybrid::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, 2.0)
819 .unwrap();
820 assert_close(*soln.ts.last().unwrap(), 2.0, 5e-4, "lu hybrid final time");
821
822 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
823 let soln = method
824 .solve_hybrid_dense::<M, CG, <M as LuValidator<M>>::LS>(&mut problem, &t_eval)
825 .unwrap();
826 assert_dense_solution_matches_expected(&soln, &t_eval, |t| {
827 hybrid_logistic_state(2.0, t)
828 });
829 }
830 }
831
832 fn assert_direct_hybrid_restart_path_for_method<CG>(method: OdeSolverType)
833 where
834 CG: diffsol::CodegenModule + CodegenModuleJit + CodegenModuleCompile,
835 {
836 let t_eval = [0.5, 1.0, 1.25, 1.5, 2.0];
837
838 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
839 let soln = method
840 .solve_hybrid::<M, CG, <M as DefaultSolver>::LS>(&mut problem, 2.0)
841 .unwrap();
842 assert_close(
843 *soln.ts.last().unwrap(),
844 2.0,
845 5e-4,
846 "direct hybrid restart final time",
847 );
848 assert_close(
849 soln.ys.get_index(0, soln.ts.len() - 1),
850 hybrid_logistic_state(2.0, 2.0),
851 5e-4,
852 "direct hybrid restart final value",
853 );
854
855 let mut problem = build_problem::<CG>(hybrid_logistic_diffsl_code());
856 let soln = method
857 .solve_hybrid_dense::<M, CG, <M as DefaultSolver>::LS>(&mut problem, &t_eval)
858 .unwrap();
859 assert_dense_solution_matches_expected(&soln, &t_eval, |t| hybrid_logistic_state(2.0, t));
860 }
861
862 #[cfg(feature = "diffsl-llvm")]
863 fn test_all_sensitivity_solver_variants() {
864 let t_eval = [0.25, 0.5, 1.0];
865 for method in [
866 OdeSolverType::Bdf,
867 OdeSolverType::Esdirk34,
868 OdeSolverType::TrBdf2,
869 OdeSolverType::Tsit45,
870 ] {
871 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
872 let soln = method
873 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
874 &mut problem,
875 &t_eval,
876 )
877 .unwrap();
878 for (i, &t) in t_eval.iter().enumerate() {
879 assert_close(
880 soln.y_sens[0].get_index(0, i),
881 logistic_state_dr(LOGISTIC_X0, 2.0, t),
882 5e-4,
883 &format!("fwd_sens[{i}]"),
884 );
885 }
886
887 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
888 let soln = method
889 .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
890 &mut problem,
891 &t_eval,
892 )
893 .unwrap();
894 for (i, &t) in t_eval.iter().enumerate() {
895 assert_close(
896 soln.y_sens[0].get_index(0, i),
897 hybrid_logistic_state_dr(2.0, t),
898 5e-4,
899 &format!("hybrid_fwd_sens[{i}]"),
900 );
901 }
902 }
903 }
904
905 #[cfg(feature = "diffsl-llvm")]
906 fn test_lu_sensitivity_and_adjoint_solver_variants() {
907 let t_eval = [0.25, 0.5, 1.0];
908 for method in [
909 OdeSolverType::Bdf,
910 OdeSolverType::Esdirk34,
911 OdeSolverType::TrBdf2,
912 OdeSolverType::Tsit45,
913 ] {
914 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
915 let soln = method
916 .solve_fwd_sens::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
917 &mut problem,
918 &t_eval,
919 )
920 .unwrap();
921 for (i, &t) in t_eval.iter().enumerate() {
922 assert_close(
923 soln.y_sens[0].get_index(0, i),
924 logistic_state_dr(LOGISTIC_X0, 2.0, t),
925 5e-4,
926 &format!("lu fwd_sens[{i}]"),
927 );
928 }
929 }
930
931 let adjoint_t_eval = [0.0, 0.25, 0.5, 1.0];
932 let data = Array2::from_shape_vec(
933 (1, adjoint_t_eval.len()),
934 adjoint_t_eval
935 .iter()
936 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
937 .collect(),
938 )
939 .unwrap();
940
941 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
942 let (objective, gradient) = OdeSolverType::Bdf
943 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as LuValidator<M>>::LS>(
944 &mut problem,
945 data.view(),
946 &adjoint_t_eval,
947 OdeSolverType::TrBdf2,
948 LinearSolverType::Lu,
949 )
950 .unwrap();
951 assert!(objective.is_finite());
952 assert_eq!(gradient.len(), 1);
953 assert!(gradient.get_index(0).is_finite());
954 }
955
956 #[cfg(feature = "diffsl-llvm")]
957 fn test_direct_hybrid_sensitivity_restart_paths() {
958 let t_eval = [0.5, 1.0, 2.5, 3.0, 4.5];
959 for method in [
960 OdeSolverType::Esdirk34,
961 OdeSolverType::TrBdf2,
962 OdeSolverType::Tsit45,
963 ] {
964 let mut problem = build_problem::<diffsol::LlvmModule>(hybrid_logistic_diffsl_code());
965 let soln = method
966 .solve_hybrid_fwd_sens::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
967 &mut problem,
968 &t_eval,
969 )
970 .unwrap();
971 for (i, &t) in t_eval.iter().enumerate() {
972 assert_close(
973 soln.ys.get_index(0, i),
974 hybrid_logistic_state(2.0, t),
975 5e-4,
976 &format!("direct hybrid value[{i}]"),
977 );
978 assert_close(
979 soln.y_sens[0].get_index(0, i),
980 hybrid_logistic_state_dr(2.0, t),
981 5e-4,
982 &format!("direct hybrid fwd sens[{i}]"),
983 );
984 }
985 }
986 }
987
988 #[cfg(feature = "diffsl-llvm")]
989 fn test_adjoint_backwards_methods_and_klu_branch() {
990 let t_eval = [0.0, 0.25, 0.5, 1.0];
991 let data = Array2::from_shape_vec(
992 (1, t_eval.len()),
993 t_eval
994 .iter()
995 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
996 .collect(),
997 )
998 .unwrap();
999
1000 for backwards_method in [OdeSolverType::Esdirk34, OdeSolverType::TrBdf2] {
1001 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1002 let (objective, gradient) = OdeSolverType::Bdf
1003 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1004 &mut problem,
1005 data.view(),
1006 &t_eval,
1007 backwards_method,
1008 LinearSolverType::Klu,
1009 )
1010 .unwrap();
1011 assert!(objective.is_finite());
1012 assert_eq!(gradient.len(), 1);
1013 assert!(gradient.get_index(0).is_finite());
1014 }
1015
1016 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1017 let err = OdeSolverType::Bdf
1018 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1019 &mut problem,
1020 data.view(),
1021 &t_eval,
1022 OdeSolverType::Tsit45,
1023 LinearSolverType::Default,
1024 )
1025 .unwrap_err();
1026 assert!(err
1027 .to_string()
1028 .contains("Tsit45 solver does not support adjoint sensitivity analysis"));
1029 }
1030
1031 #[cfg(feature = "diffsl-llvm")]
1032 fn test_all_adjoint_solver_variants() {
1033 let t_eval = [0.0, 0.25, 0.5, 1.0];
1034 let data = Array2::from_shape_vec(
1035 (1, t_eval.len()),
1036 t_eval
1037 .iter()
1038 .map(|&t| logistic_integral(LOGISTIC_X0, 2.0, t))
1039 .collect(),
1040 )
1041 .unwrap();
1042
1043 for method in [
1044 OdeSolverType::Bdf,
1045 OdeSolverType::Esdirk34,
1046 OdeSolverType::TrBdf2,
1047 OdeSolverType::Tsit45,
1048 ] {
1049 let mut problem = build_problem::<diffsol::LlvmModule>(logistic_diffsl_code());
1050 let (objective, gradient) = method
1051 .solve_sum_squares_adj::<M, diffsol::LlvmModule, <M as DefaultSolver>::LS>(
1052 &mut problem,
1053 data.view(),
1054 &t_eval,
1055 OdeSolverType::Bdf,
1056 crate::linear_solver_type::LinearSolverType::Default,
1057 )
1058 .unwrap();
1059 assert!(objective.is_finite());
1060 assert_eq!(gradient.len(), 1);
1061 assert!(gradient.get_index(0).is_finite());
1062 }
1063 }
1064
1065 #[cfg(feature = "diffsl-cranelift")]
1066 #[test]
1067 fn runtime_dispatch_solves_all_variants_for_cranelift() {
1068 test_all_solver_variants::<diffsol::CraneliftJitModule>();
1069 test_all_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1070 }
1071
1072 #[cfg(feature = "diffsl-cranelift")]
1073 #[test]
1074 fn runtime_dispatch_solves_all_hybrid_variants_for_cranelift() {
1075 test_all_hybrid_solver_variants::<diffsol::CraneliftJitModule>();
1076 test_all_hybrid_solver_variants_with_lu::<diffsol::CraneliftJitModule>();
1077 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1078 OdeSolverType::Esdirk34,
1079 );
1080 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1081 OdeSolverType::TrBdf2,
1082 );
1083 assert_direct_hybrid_restart_path_for_method::<diffsol::CraneliftJitModule>(
1084 OdeSolverType::Tsit45,
1085 );
1086 }
1087
1088 #[cfg(feature = "diffsl-llvm")]
1089 #[test]
1090 fn runtime_dispatch_solves_all_variants_for_llvm() {
1091 test_all_solver_variants::<diffsol::LlvmModule>();
1092 test_all_solver_variants_with_lu::<diffsol::LlvmModule>();
1093 }
1094
1095 #[cfg(feature = "diffsl-llvm")]
1096 #[test]
1097 fn runtime_dispatch_solves_all_hybrid_variants_for_llvm() {
1098 test_all_hybrid_solver_variants::<diffsol::LlvmModule>();
1099 test_all_hybrid_solver_variants_with_lu::<diffsol::LlvmModule>();
1100 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(
1101 OdeSolverType::Esdirk34,
1102 );
1103 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::TrBdf2);
1104 assert_direct_hybrid_restart_path_for_method::<diffsol::LlvmModule>(OdeSolverType::Tsit45);
1105 }
1106
1107 #[cfg(feature = "diffsl-llvm")]
1108 #[test]
1109 fn runtime_dispatch_solves_all_forward_sensitivity_variants_for_llvm() {
1110 test_all_sensitivity_solver_variants();
1111 test_lu_sensitivity_and_adjoint_solver_variants();
1112 test_direct_hybrid_sensitivity_restart_paths();
1113 }
1114
1115 #[cfg(feature = "diffsl-llvm")]
1116 #[test]
1117 fn runtime_dispatch_solves_all_adjoint_variants_for_llvm() {
1118 test_all_adjoint_solver_variants();
1119 test_adjoint_backwards_methods_and_klu_branch();
1120 }
1121}