1use std::{convert::TryInto, os::raw::c_int, pin::Pin};
4
5use sundials_sys::{SUNLinearSolver, SUNMatrix};
6
7use crate::{
8 check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
9 CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
10 Realtype, Result, RhsResult, StepKind,
11};
12
13struct WrappingUserData<UserData, F> {
14 actual_user_data: UserData,
15 f: F,
16}
17
18pub struct Solver<UserData, F, const N: usize> {
29 mem: CvodeMemoryBlockNonNullPtr,
30 y0: NVectorSerialHeapAllocated<N>,
31 sunmatrix: SUNMatrix,
32 linsolver: SUNLinearSolver,
33 atol: AbsTolerance<N>,
34 user_data: Pin<Box<WrappingUserData<UserData, F>>>,
35}
36
37extern "C" fn wrap_f<UserData, F, const N: usize>(
38 t: Realtype,
39 y: *const NVectorSerial<N>,
40 ydot: *mut NVectorSerial<N>,
41 data: *const WrappingUserData<UserData, F>,
42) -> c_int
43where
44 F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
45{
46 let y = unsafe { &*y }.as_slice();
47 let ydot = unsafe { &mut *ydot }.as_slice_mut();
48 let WrappingUserData {
49 actual_user_data: data,
50 f,
51 } = unsafe { &*data };
52 let res = f(t, y, ydot, data);
53 match res {
54 RhsResult::Ok => 0,
55 RhsResult::RecoverableError(e) => e as c_int,
56 RhsResult::NonRecoverableError(e) => -(e as c_int),
57 }
58}
59
60impl<UserData, F, const N: usize> Solver<UserData, F, N>
61where
62 F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
63{
64 pub fn new(
66 method: LinearMultistepMethod,
67 f: F,
68 t0: Realtype,
69 y0: &[Realtype; N],
70 rtol: Realtype,
71 atol: AbsTolerance<N>,
72 user_data: UserData,
73 ) -> Result<Self> {
74 assert_eq!(y0.len(), N);
75 let mem: CvodeMemoryBlockNonNullPtr = {
76 let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
77 check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
78 };
79 let y0 = NVectorSerialHeapAllocated::new_from(y0);
80 let matrix = {
81 let matrix = unsafe {
82 sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
83 };
84 check_non_null(matrix, "SUNDenseMatrix")?
85 };
86 let linsolver = {
87 let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
88 check_non_null(linsolver, "SUNDenseLinearSolver")?
89 };
90 let user_data = Box::pin(WrappingUserData {
91 actual_user_data: user_data,
92 f,
93 });
94 let res = Solver {
95 mem,
96 y0,
97 sunmatrix: matrix.as_ptr(),
98 linsolver: linsolver.as_ptr(),
99 atol,
100 user_data,
101 };
102 {
103 let fn_ptr = wrap_f::<UserData, F, N> as extern "C" fn(_, _, _, _) -> _;
104 let flag = unsafe {
105 sundials_sys::CVodeInit(
106 mem.as_raw(),
107 Some(std::mem::transmute(fn_ptr)),
108 t0,
109 res.y0.as_raw(),
110 )
111 };
112 check_flag_is_succes(flag, "CVodeInit")?;
113 }
114 match &res.atol {
115 &AbsTolerance::Scalar(atol) => {
116 let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
117 check_flag_is_succes(flag, "CVodeSStolerances")?;
118 }
119 AbsTolerance::Vector(atol) => {
120 let flag =
121 unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
122 check_flag_is_succes(flag, "CVodeSVtolerances")?;
123 }
124 }
125 {
126 let flag = unsafe {
127 sundials_sys::CVodeSetLinearSolver(
128 mem.as_raw(),
129 linsolver.as_ptr(),
130 matrix.as_ptr(),
131 )
132 };
133 check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
134 }
135 {
136 let flag = unsafe {
137 sundials_sys::CVodeSetUserData(
138 mem.as_raw(),
139 std::mem::transmute(res.user_data.as_ref().get_ref()),
140 )
141 };
142 check_flag_is_succes(flag, "CVodeSetUserData")?;
143 }
144 Ok(res)
145 }
146
147 pub fn step(
153 &mut self,
154 tout: Realtype,
155 step_kind: StepKind,
156 ) -> Result<(Realtype, &[Realtype; N])> {
157 let mut tret = 0.;
158 let flag = unsafe {
159 sundials_sys::CVode(
160 self.mem.as_raw(),
161 tout,
162 self.y0.as_raw(),
163 &mut tret,
164 step_kind as c_int,
165 )
166 };
167 check_flag_is_succes(flag, "CVode")?;
168 Ok((tret, self.y0.as_slice()))
169 }
170}
171
172impl<UserData, F, const N: usize> Drop for Solver<UserData, F, N> {
173 fn drop(&mut self) {
174 unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) }
175 unsafe { sundials_sys::SUNLinSolFree(self.linsolver) };
176 unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) };
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use crate::RhsResult;
183
184 use super::*;
185
186 fn f(
187 _t: super::Realtype,
188 y: &[Realtype; 2],
189 ydot: &mut [Realtype; 2],
190 _data: &(),
191 ) -> RhsResult {
192 *ydot = [y[1], -y[0]];
193 RhsResult::Ok
194 }
195
196 #[test]
197 fn create() {
198 let y0 = [0., 1.];
199 let _solver = Solver::new(
200 LinearMultistepMethod::Adams,
201 f,
202 0.,
203 &y0,
204 1e-4,
205 AbsTolerance::Scalar(1e-4),
206 (),
207 )
208 .unwrap();
209 }
210}