Skip to main content

diffsol_c/
adjoint_checkpoint.rs

1use std::{
2    any::Any,
3    sync::{Arc, Mutex, MutexGuard},
4};
5
6use diffsol::{CheckpointingPath, DefaultSolver, DiffSl, DiffsolError, OdeEquationsImplicit};
7
8use crate::{
9    error::DiffsolRtError, linear_solver_type::LinearSolverType,
10    ode_solver_tag::OdeSolverMethodTag, ode_solver_type::OdeSolverType, scalar_type::Scalar,
11};
12
13pub(crate) trait AdjointCheckpoint: Any + Send {
14    fn as_any(&self) -> &dyn Any;
15    fn method(&self) -> OdeSolverType;
16    fn linear_solver(&self) -> LinearSolverType;
17    fn params(&self) -> &[f64];
18}
19
20impl dyn AdjointCheckpoint + '_ {
21    pub(crate) fn data<M, CG, Tag>(
22        &self,
23    ) -> Result<&AdjointCheckpointData<M, CG, Tag>, DiffsolError>
24    where
25        M: diffsol::Matrix<T: Scalar> + DefaultSolver + 'static,
26        CG: diffsol::CodegenModule + 'static,
27        DiffSl<M, CG>: OdeEquationsImplicit<V = M::V> + 'static,
28        Tag: OdeSolverMethodTag<M, CG> + 'static,
29    {
30        self.as_any()
31            .downcast_ref::<AdjointCheckpointData<M, CG, Tag>>()
32            .ok_or_else(|| {
33                DiffsolError::Other(
34                    "Adjoint checkpoint is incompatible with this ODE solver".to_string(),
35                )
36            })
37    }
38}
39
40pub(crate) struct AdjointCheckpointData<M, CG, Tag>
41where
42    M: diffsol::Matrix<T: Scalar> + DefaultSolver,
43    CG: diffsol::CodegenModule,
44    DiffSl<M, CG>: OdeEquationsImplicit<V = M::V>,
45    Tag: OdeSolverMethodTag<M, CG>,
46{
47    pub(crate) checkpointing: CheckpointingPath<DiffSl<M, CG>, Tag::State>,
48    params: Vec<f64>,
49    method: OdeSolverType,
50    linear_solver: LinearSolverType,
51}
52
53impl<M, CG, Tag> AdjointCheckpointData<M, CG, Tag>
54where
55    M: diffsol::Matrix<T: Scalar> + DefaultSolver,
56    CG: diffsol::CodegenModule,
57    DiffSl<M, CG>: OdeEquationsImplicit<V = M::V>,
58    Tag: OdeSolverMethodTag<M, CG>,
59{
60    pub(crate) fn new(
61        checkpointing: CheckpointingPath<DiffSl<M, CG>, Tag::State>,
62        params: Vec<f64>,
63        method: OdeSolverType,
64        linear_solver: LinearSolverType,
65    ) -> Self {
66        Self {
67            checkpointing,
68            params,
69            method,
70            linear_solver,
71        }
72    }
73}
74
75impl<M, CG, Tag> AdjointCheckpoint for AdjointCheckpointData<M, CG, Tag>
76where
77    M: diffsol::Matrix<T: Scalar> + DefaultSolver + 'static,
78    CG: diffsol::CodegenModule + 'static,
79    DiffSl<M, CG>: OdeEquationsImplicit<V = M::V> + 'static,
80    Tag: OdeSolverMethodTag<M, CG> + 'static,
81{
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn method(&self) -> OdeSolverType {
87        self.method
88    }
89
90    fn linear_solver(&self) -> LinearSolverType {
91        self.linear_solver
92    }
93
94    fn params(&self) -> &[f64] {
95        &self.params
96    }
97}
98
99/// Wrapper for an adjoint checkpoint that allows it to be shared across threads and mutated safely.    
100///
101/// Mutex is required because Checkpointing uses interior mutability when interpolating checkpoints
102#[derive(Clone)]
103pub struct AdjointCheckpointWrapper(Arc<Mutex<Box<dyn AdjointCheckpoint>>>);
104
105impl AdjointCheckpointWrapper {
106    pub(crate) fn new(checkpoint: Box<dyn AdjointCheckpoint>) -> Self {
107        Self(Arc::new(Mutex::new(checkpoint)))
108    }
109
110    pub(crate) fn guard(
111        &self,
112    ) -> Result<MutexGuard<'_, Box<dyn AdjointCheckpoint>>, DiffsolRtError> {
113        self.0.lock().map_err(|_| {
114            DiffsolError::Other("Adjoint checkpoint mutex poisoned".to_string()).into()
115        })
116    }
117}