diffsol_c/
adjoint_checkpoint.rs1use std::{
2 any::Any,
3 sync::{Arc, Mutex, MutexGuard},
4};
5
6use diffsol::{CheckpointingPath, DefaultSolver, DiffSl, DiffsolError, OdeEquationsImplicit};
7
8use crate::{
9 error::DiffsolRtError, linear_solver_type::LinearSolverType,
10 ode_solver_tag::OdeSolverMethodTag, ode_solver_type::OdeSolverType, scalar_type::Scalar,
11};
12
13pub(crate) trait AdjointCheckpoint: Any + Send {
14 fn as_any(&self) -> &dyn Any;
15 fn method(&self) -> OdeSolverType;
16 fn linear_solver(&self) -> LinearSolverType;
17 fn params(&self) -> &[f64];
18}
19
20impl dyn AdjointCheckpoint + '_ {
21 pub(crate) fn data<M, CG, Tag>(
22 &self,
23 ) -> Result<&AdjointCheckpointData<M, CG, Tag>, DiffsolError>
24 where
25 M: diffsol::Matrix<T: Scalar> + DefaultSolver + 'static,
26 CG: diffsol::CodegenModule + 'static,
27 DiffSl<M, CG>: OdeEquationsImplicit<V = M::V> + 'static,
28 Tag: OdeSolverMethodTag<M, CG> + 'static,
29 {
30 self.as_any()
31 .downcast_ref::<AdjointCheckpointData<M, CG, Tag>>()
32 .ok_or_else(|| {
33 DiffsolError::Other(
34 "Adjoint checkpoint is incompatible with this ODE solver".to_string(),
35 )
36 })
37 }
38}
39
40pub(crate) struct AdjointCheckpointData<M, CG, Tag>
41where
42 M: diffsol::Matrix<T: Scalar> + DefaultSolver,
43 CG: diffsol::CodegenModule,
44 DiffSl<M, CG>: OdeEquationsImplicit<V = M::V>,
45 Tag: OdeSolverMethodTag<M, CG>,
46{
47 pub(crate) checkpointing: CheckpointingPath<DiffSl<M, CG>, Tag::State>,
48 params: Vec<f64>,
49 method: OdeSolverType,
50 linear_solver: LinearSolverType,
51}
52
53impl<M, CG, Tag> AdjointCheckpointData<M, CG, Tag>
54where
55 M: diffsol::Matrix<T: Scalar> + DefaultSolver,
56 CG: diffsol::CodegenModule,
57 DiffSl<M, CG>: OdeEquationsImplicit<V = M::V>,
58 Tag: OdeSolverMethodTag<M, CG>,
59{
60 pub(crate) fn new(
61 checkpointing: CheckpointingPath<DiffSl<M, CG>, Tag::State>,
62 params: Vec<f64>,
63 method: OdeSolverType,
64 linear_solver: LinearSolverType,
65 ) -> Self {
66 Self {
67 checkpointing,
68 params,
69 method,
70 linear_solver,
71 }
72 }
73}
74
75impl<M, CG, Tag> AdjointCheckpoint for AdjointCheckpointData<M, CG, Tag>
76where
77 M: diffsol::Matrix<T: Scalar> + DefaultSolver + 'static,
78 CG: diffsol::CodegenModule + 'static,
79 DiffSl<M, CG>: OdeEquationsImplicit<V = M::V> + 'static,
80 Tag: OdeSolverMethodTag<M, CG> + 'static,
81{
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn method(&self) -> OdeSolverType {
87 self.method
88 }
89
90 fn linear_solver(&self) -> LinearSolverType {
91 self.linear_solver
92 }
93
94 fn params(&self) -> &[f64] {
95 &self.params
96 }
97}
98
99#[derive(Clone)]
103pub struct AdjointCheckpointWrapper(Arc<Mutex<Box<dyn AdjointCheckpoint>>>);
104
105impl AdjointCheckpointWrapper {
106 pub(crate) fn new(checkpoint: Box<dyn AdjointCheckpoint>) -> Self {
107 Self(Arc::new(Mutex::new(checkpoint)))
108 }
109
110 pub(crate) fn guard(
111 &self,
112 ) -> Result<MutexGuard<'_, Box<dyn AdjointCheckpoint>>, DiffsolRtError> {
113 self.0.lock().map_err(|_| {
114 DiffsolError::Other("Adjoint checkpoint mutex poisoned".to_string()).into()
115 })
116 }
117}