cvode_wrap/
cvode.rs

1//! Wrapper around cvode, without sensitivities
2
3use std::{convert::TryInto, os::raw::c_int, pin::Pin};
4
5use sundials_sys::{SUNLinearSolver, SUNMatrix};
6
7use crate::{
8    check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
9    CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
10    Realtype, Result, RhsResult, StepKind,
11};
12
13struct WrappingUserData<UserData, F> {
14    actual_user_data: UserData,
15    f: F,
16}
17
18/// The ODE solver without sensitivities.
19///
20/// # Type Arguments
21///
22/// - `F` is the type of the right-hand side function
23///
24///  - `UserData` is the type of the supplementary arguments for the
25/// right-hand-side. If unused, should be `()`.
26///
27/// - `N` is the "problem size", that is the dimension of the state space.
28pub struct Solver<UserData, F, const N: usize> {
29    mem: CvodeMemoryBlockNonNullPtr,
30    y0: NVectorSerialHeapAllocated<N>,
31    sunmatrix: SUNMatrix,
32    linsolver: SUNLinearSolver,
33    atol: AbsTolerance<N>,
34    user_data: Pin<Box<WrappingUserData<UserData, F>>>,
35}
36
37extern "C" fn wrap_f<UserData, F, const N: usize>(
38    t: Realtype,
39    y: *const NVectorSerial<N>,
40    ydot: *mut NVectorSerial<N>,
41    data: *const WrappingUserData<UserData, F>,
42) -> c_int
43where
44    F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
45{
46    let y = unsafe { &*y }.as_slice();
47    let ydot = unsafe { &mut *ydot }.as_slice_mut();
48    let WrappingUserData {
49        actual_user_data: data,
50        f,
51    } = unsafe { &*data };
52    let res = f(t, y, ydot, data);
53    match res {
54        RhsResult::Ok => 0,
55        RhsResult::RecoverableError(e) => e as c_int,
56        RhsResult::NonRecoverableError(e) => -(e as c_int),
57    }
58}
59
60impl<UserData, F, const N: usize> Solver<UserData, F, N>
61where
62    F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
63{
64    /// Create a new solver.
65    pub fn new(
66        method: LinearMultistepMethod,
67        f: F,
68        t0: Realtype,
69        y0: &[Realtype; N],
70        rtol: Realtype,
71        atol: AbsTolerance<N>,
72        user_data: UserData,
73    ) -> Result<Self> {
74        assert_eq!(y0.len(), N);
75        let mem: CvodeMemoryBlockNonNullPtr = {
76            let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
77            check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
78        };
79        let y0 = NVectorSerialHeapAllocated::new_from(y0);
80        let matrix = {
81            let matrix = unsafe {
82                sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
83            };
84            check_non_null(matrix, "SUNDenseMatrix")?
85        };
86        let linsolver = {
87            let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
88            check_non_null(linsolver, "SUNDenseLinearSolver")?
89        };
90        let user_data = Box::pin(WrappingUserData {
91            actual_user_data: user_data,
92            f,
93        });
94        let res = Solver {
95            mem,
96            y0,
97            sunmatrix: matrix.as_ptr(),
98            linsolver: linsolver.as_ptr(),
99            atol,
100            user_data,
101        };
102        {
103            let fn_ptr = wrap_f::<UserData, F, N> as extern "C" fn(_, _, _, _) -> _;
104            let flag = unsafe {
105                sundials_sys::CVodeInit(
106                    mem.as_raw(),
107                    Some(std::mem::transmute(fn_ptr)),
108                    t0,
109                    res.y0.as_raw(),
110                )
111            };
112            check_flag_is_succes(flag, "CVodeInit")?;
113        }
114        match &res.atol {
115            &AbsTolerance::Scalar(atol) => {
116                let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
117                check_flag_is_succes(flag, "CVodeSStolerances")?;
118            }
119            AbsTolerance::Vector(atol) => {
120                let flag =
121                    unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
122                check_flag_is_succes(flag, "CVodeSVtolerances")?;
123            }
124        }
125        {
126            let flag = unsafe {
127                sundials_sys::CVodeSetLinearSolver(
128                    mem.as_raw(),
129                    linsolver.as_ptr(),
130                    matrix.as_ptr(),
131                )
132            };
133            check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
134        }
135        {
136            let flag = unsafe {
137                sundials_sys::CVodeSetUserData(
138                    mem.as_raw(),
139                    std::mem::transmute(res.user_data.as_ref().get_ref()),
140                )
141            };
142            check_flag_is_succes(flag, "CVodeSetUserData")?;
143        }
144        Ok(res)
145    }
146
147    /// Takes a step according to `step_kind` (see [`StepKind`]).
148    ///
149    /// Returns a tuple `(t_out,&y(t_out))` where `t_out` is the time
150    /// reached by the solver as dictated by `step_kind`, and `y(t_out)` is an
151    /// array of the state variables at that time.
152    pub fn step(
153        &mut self,
154        tout: Realtype,
155        step_kind: StepKind,
156    ) -> Result<(Realtype, &[Realtype; N])> {
157        let mut tret = 0.;
158        let flag = unsafe {
159            sundials_sys::CVode(
160                self.mem.as_raw(),
161                tout,
162                self.y0.as_raw(),
163                &mut tret,
164                step_kind as c_int,
165            )
166        };
167        check_flag_is_succes(flag, "CVode")?;
168        Ok((tret, self.y0.as_slice()))
169    }
170}
171
172impl<UserData, F, const N: usize> Drop for Solver<UserData, F, N> {
173    fn drop(&mut self) {
174        unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) }
175        unsafe { sundials_sys::SUNLinSolFree(self.linsolver) };
176        unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) };
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use crate::RhsResult;
183
184    use super::*;
185
186    fn f(
187        _t: super::Realtype,
188        y: &[Realtype; 2],
189        ydot: &mut [Realtype; 2],
190        _data: &(),
191    ) -> RhsResult {
192        *ydot = [y[1], -y[0]];
193        RhsResult::Ok
194    }
195
196    #[test]
197    fn create() {
198        let y0 = [0., 1.];
199        let _solver = Solver::new(
200            LinearMultistepMethod::Adams,
201            f,
202            0.,
203            &y0,
204            1e-4,
205            AbsTolerance::Scalar(1e-4),
206            (),
207        )
208        .unwrap();
209    }
210}