cvode_wrap/
cvode_sens.rs

1//! Wrapper around cvodeS, with sensitivities
2
3use std::{convert::TryInto, os::raw::c_int, pin::Pin};
4
5use sundials_sys::{SUNLinearSolver, SUNMatrix, CV_STAGGERED};
6
7use crate::{
8    check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
9    CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
10    Realtype, Result, RhsResult, SensiAbsTolerance, StepKind,
11};
12
13struct WrappingUserData<UserData, F, FS> {
14    actual_user_data: UserData,
15    f: F,
16    fs: FS,
17}
18
19/// The ODE solver with sensitivities.
20///
21/// # Type Arguments
22///
23/// - `F` is the type of the right-hand side function
24///
25/// - `FS` is the type of the sensitivities right-hand side function
26///
27///  - `UserData` is the type of the supplementary arguments for the
28/// right-hand-side. If unused, should be `()`.
29///
30/// - `N` is the "problem size", that is the dimension of the state space.
31///
32/// - `N_SENSI` is the number of sensitivities computed
33pub struct Solver<UserData, F, FS, const N: usize, const N_SENSI: usize> {
34    mem: CvodeMemoryBlockNonNullPtr,
35    y0: NVectorSerialHeapAllocated<N>,
36    y_s0: Box<[NVectorSerialHeapAllocated<N>; N_SENSI]>,
37    sunmatrix: SUNMatrix,
38    linsolver: SUNLinearSolver,
39    atol: AbsTolerance<N>,
40    atol_sens: SensiAbsTolerance<N, N_SENSI>,
41    user_data: Pin<Box<WrappingUserData<UserData, F, FS>>>,
42    sensi_out_buffer: [NVectorSerialHeapAllocated<N>; N_SENSI],
43}
44
45extern "C" fn wrap_f<UserData, F, FS, const N: usize>(
46    t: Realtype,
47    y: *const NVectorSerial<N>,
48    ydot: *mut NVectorSerial<N>,
49    data: *const WrappingUserData<UserData, F, FS>,
50) -> c_int
51where
52    F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
53{
54    let y = unsafe { &*y }.as_slice();
55    let ydot = unsafe { &mut *ydot }.as_slice_mut();
56    let WrappingUserData {
57        actual_user_data: data,
58        f,
59        ..
60    } = unsafe { &*data };
61    let res = f(t, y, ydot, data);
62    match res {
63        RhsResult::Ok => 0,
64        RhsResult::RecoverableError(e) => e as c_int,
65        RhsResult::NonRecoverableError(e) => -(e as c_int),
66    }
67}
68
69extern "C" fn wrap_f_sens<UserData, F, FS, const N: usize, const N_SENSI: usize>(
70    _n_s: c_int,
71    t: Realtype,
72    y: *const NVectorSerial<N>,
73    ydot: *const NVectorSerial<N>,
74    y_s: *const [*const NVectorSerial<N>; N_SENSI],
75    y_sdot: *mut [*mut NVectorSerial<N>; N_SENSI],
76    data: *const WrappingUserData<UserData, F, FS>,
77    _tmp1: *const NVectorSerial<N>,
78    _tmp2: *const NVectorSerial<N>,
79) -> c_int
80where
81    FS: Fn(
82        Realtype,
83        &[Realtype; N],
84        &[Realtype; N],
85        [&[Realtype; N]; N_SENSI],
86        [&mut [Realtype; N]; N_SENSI],
87        &UserData,
88    ) -> RhsResult,
89{
90    let y = unsafe { &*y }.as_slice();
91    let ydot = unsafe { &*ydot }.as_slice();
92    let y_s = unsafe { &*y_s };
93    let y_s: [&[Realtype; N]; N_SENSI] =
94        array_init::from_iter(y_s.iter().map(|&v| unsafe { &*v }.as_slice())).unwrap();
95    let y_sdot = unsafe { &mut *y_sdot };
96    let y_sdot: [&mut [Realtype; N]; N_SENSI] = array_init::from_iter(
97        y_sdot
98            .iter_mut()
99            .map(|&mut v| unsafe { &mut *v }.as_slice_mut()),
100    )
101    .unwrap();
102    let WrappingUserData {
103        actual_user_data: data,
104        fs,
105        ..
106    } = unsafe { &*data };
107    let res = fs(t, y, ydot, y_s, y_sdot, data);
108    match res {
109        RhsResult::Ok => 0,
110        RhsResult::RecoverableError(e) => e as c_int,
111        RhsResult::NonRecoverableError(e) => -(e as c_int),
112    }
113}
114
115impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Solver<UserData, F, FS, N, N_SENSI>
116where
117    F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
118    FS: Fn(
119        Realtype,
120        &[Realtype; N],
121        &[Realtype; N],
122        [&[Realtype; N]; N_SENSI],
123        [&mut [Realtype; N]; N_SENSI],
124        &UserData,
125    ) -> RhsResult,
126{
127    /// Creates a new solver.
128    #[allow(clippy::clippy::too_many_arguments)]
129    pub fn new(
130        method: LinearMultistepMethod,
131        f: F,
132        f_sens: FS,
133        t0: Realtype,
134        y0: &[Realtype; N],
135        y_s0: &[[Realtype; N]; N_SENSI],
136        rtol: Realtype,
137        atol: AbsTolerance<N>,
138        atol_sens: SensiAbsTolerance<N, N_SENSI>,
139        user_data: UserData,
140    ) -> Result<Self> {
141        assert_eq!(y0.len(), N);
142        let mem: CvodeMemoryBlockNonNullPtr = {
143            let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
144            check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
145        };
146        let y0 = NVectorSerialHeapAllocated::new_from(y0);
147        let y_s0 = Box::new(
148            array_init::from_iter(
149                y_s0.iter()
150                    .map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
151            )
152            .unwrap(),
153        );
154        let matrix = {
155            let matrix = unsafe {
156                sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
157            };
158            check_non_null(matrix, "SUNDenseMatrix")?
159        };
160        let linsolver = {
161            let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
162            check_non_null(linsolver, "SUNDenseLinearSolver")?
163        };
164        let user_data = Box::pin(WrappingUserData {
165            actual_user_data: user_data,
166            f,
167            fs: f_sens,
168        });
169        let res = Solver {
170            mem,
171            y0,
172            y_s0,
173            sunmatrix: matrix.as_ptr(),
174            linsolver: linsolver.as_ptr(),
175            atol,
176            atol_sens,
177            user_data,
178            sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new()),
179        };
180        {
181            let flag = unsafe {
182                sundials_sys::CVodeSetUserData(
183                    mem.as_raw(),
184                    res.user_data.as_ref().get_ref() as *const _ as _,
185                )
186            };
187            check_flag_is_succes(flag, "CVodeSetUserData")?;
188        }
189        {
190            let fn_ptr = wrap_f::<UserData, F, FS, N> as extern "C" fn(_, _, _, _) -> _;
191            let flag = unsafe {
192                sundials_sys::CVodeInit(
193                    mem.as_raw(),
194                    Some(std::mem::transmute(fn_ptr)),
195                    t0,
196                    res.y0.as_raw(),
197                )
198            };
199            check_flag_is_succes(flag, "CVodeInit")?;
200        }
201        {
202            let fn_ptr = wrap_f_sens::<UserData, F, FS, N, N_SENSI>
203                as extern "C" fn(_, _, _, _, _, _, _, _, _) -> _;
204            let flag = unsafe {
205                sundials_sys::CVodeSensInit(
206                    mem.as_raw(),
207                    N_SENSI as c_int,
208                    CV_STAGGERED as _,
209                    Some(std::mem::transmute(fn_ptr)),
210                    res.y_s0.as_ptr() as _,
211                )
212            };
213            check_flag_is_succes(flag, "CVodeSensInit")?;
214        }
215        match &res.atol {
216            &AbsTolerance::Scalar(atol) => {
217                let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
218                check_flag_is_succes(flag, "CVodeSStolerances")?;
219            }
220            AbsTolerance::Vector(atol) => {
221                let flag =
222                    unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
223                check_flag_is_succes(flag, "CVodeSVtolerances")?;
224            }
225        }
226        match &res.atol_sens {
227            SensiAbsTolerance::Scalar(atol) => {
228                let flag = unsafe {
229                    sundials_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
230                };
231                check_flag_is_succes(flag, "CVodeSensSStolerances")?;
232            }
233            SensiAbsTolerance::Vector(atol) => {
234                let flag = unsafe {
235                    sundials_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
236                };
237                check_flag_is_succes(flag, "CVodeSensSVtolerances")?;
238            }
239        }
240        {
241            let flag = unsafe {
242                sundials_sys::CVodeSetLinearSolver(
243                    mem.as_raw(),
244                    linsolver.as_ptr(),
245                    matrix.as_ptr(),
246                )
247            };
248            check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
249        }
250        Ok(res)
251    }
252
253    /// Takes a step according to `step_kind` (see [`StepKind`]).
254    ///
255    /// Returns a tuple `(t_out,&y(t_out),[&dy_dp(tout)])` where `t_out` is the time
256    /// reached by the solver as dictated by `step_kind`, `y(t_out)` is an
257    /// array of the state variables at that time, and the i-th `dy_dp(tout)` is an array
258    /// of the sensitivities of all variables with respect to parameter i.
259    #[allow(clippy::clippy::type_complexity)]
260    pub fn step(
261        &mut self,
262        tout: Realtype,
263        step_kind: StepKind,
264    ) -> Result<(Realtype, &[Realtype; N], [&[Realtype; N]; N_SENSI])> {
265        let mut tret = 0.;
266        let flag = unsafe {
267            sundials_sys::CVode(
268                self.mem.as_raw(),
269                tout,
270                self.y0.as_raw(),
271                &mut tret,
272                step_kind as c_int,
273            )
274        };
275        check_flag_is_succes(flag, "CVode")?;
276        let flag = unsafe {
277            sundials_sys::CVodeGetSens(
278                self.mem.as_raw(),
279                &mut tret,
280                self.sensi_out_buffer.as_mut_ptr() as _,
281            )
282        };
283        check_flag_is_succes(flag, "CVodeGetSens")?;
284        let sensi_ptr_array =
285            array_init::from_iter(self.sensi_out_buffer.iter().map(|v| v.as_slice())).unwrap();
286        Ok((tret, self.y0.as_slice(), sensi_ptr_array))
287    }
288}
289
290impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Drop
291    for Solver<UserData, F, FS, N, N_SENSI>
292{
293    fn drop(&mut self) {
294        unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) }
295        unsafe { sundials_sys::SUNLinSolFree(self.linsolver) };
296        unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) };
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use crate::RhsResult;
303
304    use super::*;
305
306    fn f(
307        _t: super::Realtype,
308        y: &[Realtype; 2],
309        ydot: &mut [Realtype; 2],
310        _data: &(),
311    ) -> RhsResult {
312        *ydot = [y[1], -y[0]];
313        RhsResult::Ok
314    }
315
316    fn fs<const N_SENSI: usize>(
317        _t: super::Realtype,
318        _y: &[Realtype; 2],
319        _ydot: &[Realtype; 2],
320        _ys: [&[Realtype; 2]; N_SENSI],
321        ysdot: [&mut [Realtype; 2]; N_SENSI],
322        _data: &(),
323    ) -> RhsResult {
324        for ysdot_i in std::array::IntoIter::new(ysdot) {
325            *ysdot_i = [0., 0.];
326        }
327        RhsResult::Ok
328    }
329
330    #[test]
331    fn create() {
332        let y0 = [0., 1.];
333        let y_s0 = [[0.; 2]; 4];
334        let _solver = Solver::new(
335            LinearMultistepMethod::Adams,
336            f,
337            fs,
338            0.,
339            &y0,
340            &y_s0,
341            1e-4,
342            AbsTolerance::scalar(1e-4),
343            SensiAbsTolerance::scalar([1e-4; 4]),
344            (),
345        )
346        .unwrap();
347    }
348}