1use std::{convert::TryInto, os::raw::c_int, pin::Pin};
4
5use sundials_sys::{SUNLinearSolver, SUNMatrix, CV_STAGGERED};
6
7use crate::{
8 check_flag_is_succes, check_non_null, AbsTolerance, CvodeMemoryBlock,
9 CvodeMemoryBlockNonNullPtr, LinearMultistepMethod, NVectorSerial, NVectorSerialHeapAllocated,
10 Realtype, Result, RhsResult, SensiAbsTolerance, StepKind,
11};
12
13struct WrappingUserData<UserData, F, FS> {
14 actual_user_data: UserData,
15 f: F,
16 fs: FS,
17}
18
19pub struct Solver<UserData, F, FS, const N: usize, const N_SENSI: usize> {
34 mem: CvodeMemoryBlockNonNullPtr,
35 y0: NVectorSerialHeapAllocated<N>,
36 y_s0: Box<[NVectorSerialHeapAllocated<N>; N_SENSI]>,
37 sunmatrix: SUNMatrix,
38 linsolver: SUNLinearSolver,
39 atol: AbsTolerance<N>,
40 atol_sens: SensiAbsTolerance<N, N_SENSI>,
41 user_data: Pin<Box<WrappingUserData<UserData, F, FS>>>,
42 sensi_out_buffer: [NVectorSerialHeapAllocated<N>; N_SENSI],
43}
44
45extern "C" fn wrap_f<UserData, F, FS, const N: usize>(
46 t: Realtype,
47 y: *const NVectorSerial<N>,
48 ydot: *mut NVectorSerial<N>,
49 data: *const WrappingUserData<UserData, F, FS>,
50) -> c_int
51where
52 F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
53{
54 let y = unsafe { &*y }.as_slice();
55 let ydot = unsafe { &mut *ydot }.as_slice_mut();
56 let WrappingUserData {
57 actual_user_data: data,
58 f,
59 ..
60 } = unsafe { &*data };
61 let res = f(t, y, ydot, data);
62 match res {
63 RhsResult::Ok => 0,
64 RhsResult::RecoverableError(e) => e as c_int,
65 RhsResult::NonRecoverableError(e) => -(e as c_int),
66 }
67}
68
69extern "C" fn wrap_f_sens<UserData, F, FS, const N: usize, const N_SENSI: usize>(
70 _n_s: c_int,
71 t: Realtype,
72 y: *const NVectorSerial<N>,
73 ydot: *const NVectorSerial<N>,
74 y_s: *const [*const NVectorSerial<N>; N_SENSI],
75 y_sdot: *mut [*mut NVectorSerial<N>; N_SENSI],
76 data: *const WrappingUserData<UserData, F, FS>,
77 _tmp1: *const NVectorSerial<N>,
78 _tmp2: *const NVectorSerial<N>,
79) -> c_int
80where
81 FS: Fn(
82 Realtype,
83 &[Realtype; N],
84 &[Realtype; N],
85 [&[Realtype; N]; N_SENSI],
86 [&mut [Realtype; N]; N_SENSI],
87 &UserData,
88 ) -> RhsResult,
89{
90 let y = unsafe { &*y }.as_slice();
91 let ydot = unsafe { &*ydot }.as_slice();
92 let y_s = unsafe { &*y_s };
93 let y_s: [&[Realtype; N]; N_SENSI] =
94 array_init::from_iter(y_s.iter().map(|&v| unsafe { &*v }.as_slice())).unwrap();
95 let y_sdot = unsafe { &mut *y_sdot };
96 let y_sdot: [&mut [Realtype; N]; N_SENSI] = array_init::from_iter(
97 y_sdot
98 .iter_mut()
99 .map(|&mut v| unsafe { &mut *v }.as_slice_mut()),
100 )
101 .unwrap();
102 let WrappingUserData {
103 actual_user_data: data,
104 fs,
105 ..
106 } = unsafe { &*data };
107 let res = fs(t, y, ydot, y_s, y_sdot, data);
108 match res {
109 RhsResult::Ok => 0,
110 RhsResult::RecoverableError(e) => e as c_int,
111 RhsResult::NonRecoverableError(e) => -(e as c_int),
112 }
113}
114
115impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Solver<UserData, F, FS, N, N_SENSI>
116where
117 F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
118 FS: Fn(
119 Realtype,
120 &[Realtype; N],
121 &[Realtype; N],
122 [&[Realtype; N]; N_SENSI],
123 [&mut [Realtype; N]; N_SENSI],
124 &UserData,
125 ) -> RhsResult,
126{
127 #[allow(clippy::clippy::too_many_arguments)]
129 pub fn new(
130 method: LinearMultistepMethod,
131 f: F,
132 f_sens: FS,
133 t0: Realtype,
134 y0: &[Realtype; N],
135 y_s0: &[[Realtype; N]; N_SENSI],
136 rtol: Realtype,
137 atol: AbsTolerance<N>,
138 atol_sens: SensiAbsTolerance<N, N_SENSI>,
139 user_data: UserData,
140 ) -> Result<Self> {
141 assert_eq!(y0.len(), N);
142 let mem: CvodeMemoryBlockNonNullPtr = {
143 let mem_maybenull = unsafe { sundials_sys::CVodeCreate(method as c_int) };
144 check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
145 };
146 let y0 = NVectorSerialHeapAllocated::new_from(y0);
147 let y_s0 = Box::new(
148 array_init::from_iter(
149 y_s0.iter()
150 .map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
151 )
152 .unwrap(),
153 );
154 let matrix = {
155 let matrix = unsafe {
156 sundials_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
157 };
158 check_non_null(matrix, "SUNDenseMatrix")?
159 };
160 let linsolver = {
161 let linsolver = unsafe { sundials_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
162 check_non_null(linsolver, "SUNDenseLinearSolver")?
163 };
164 let user_data = Box::pin(WrappingUserData {
165 actual_user_data: user_data,
166 f,
167 fs: f_sens,
168 });
169 let res = Solver {
170 mem,
171 y0,
172 y_s0,
173 sunmatrix: matrix.as_ptr(),
174 linsolver: linsolver.as_ptr(),
175 atol,
176 atol_sens,
177 user_data,
178 sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new()),
179 };
180 {
181 let flag = unsafe {
182 sundials_sys::CVodeSetUserData(
183 mem.as_raw(),
184 res.user_data.as_ref().get_ref() as *const _ as _,
185 )
186 };
187 check_flag_is_succes(flag, "CVodeSetUserData")?;
188 }
189 {
190 let fn_ptr = wrap_f::<UserData, F, FS, N> as extern "C" fn(_, _, _, _) -> _;
191 let flag = unsafe {
192 sundials_sys::CVodeInit(
193 mem.as_raw(),
194 Some(std::mem::transmute(fn_ptr)),
195 t0,
196 res.y0.as_raw(),
197 )
198 };
199 check_flag_is_succes(flag, "CVodeInit")?;
200 }
201 {
202 let fn_ptr = wrap_f_sens::<UserData, F, FS, N, N_SENSI>
203 as extern "C" fn(_, _, _, _, _, _, _, _, _) -> _;
204 let flag = unsafe {
205 sundials_sys::CVodeSensInit(
206 mem.as_raw(),
207 N_SENSI as c_int,
208 CV_STAGGERED as _,
209 Some(std::mem::transmute(fn_ptr)),
210 res.y_s0.as_ptr() as _,
211 )
212 };
213 check_flag_is_succes(flag, "CVodeSensInit")?;
214 }
215 match &res.atol {
216 &AbsTolerance::Scalar(atol) => {
217 let flag = unsafe { sundials_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
218 check_flag_is_succes(flag, "CVodeSStolerances")?;
219 }
220 AbsTolerance::Vector(atol) => {
221 let flag =
222 unsafe { sundials_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
223 check_flag_is_succes(flag, "CVodeSVtolerances")?;
224 }
225 }
226 match &res.atol_sens {
227 SensiAbsTolerance::Scalar(atol) => {
228 let flag = unsafe {
229 sundials_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
230 };
231 check_flag_is_succes(flag, "CVodeSensSStolerances")?;
232 }
233 SensiAbsTolerance::Vector(atol) => {
234 let flag = unsafe {
235 sundials_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
236 };
237 check_flag_is_succes(flag, "CVodeSensSVtolerances")?;
238 }
239 }
240 {
241 let flag = unsafe {
242 sundials_sys::CVodeSetLinearSolver(
243 mem.as_raw(),
244 linsolver.as_ptr(),
245 matrix.as_ptr(),
246 )
247 };
248 check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
249 }
250 Ok(res)
251 }
252
253 #[allow(clippy::clippy::type_complexity)]
260 pub fn step(
261 &mut self,
262 tout: Realtype,
263 step_kind: StepKind,
264 ) -> Result<(Realtype, &[Realtype; N], [&[Realtype; N]; N_SENSI])> {
265 let mut tret = 0.;
266 let flag = unsafe {
267 sundials_sys::CVode(
268 self.mem.as_raw(),
269 tout,
270 self.y0.as_raw(),
271 &mut tret,
272 step_kind as c_int,
273 )
274 };
275 check_flag_is_succes(flag, "CVode")?;
276 let flag = unsafe {
277 sundials_sys::CVodeGetSens(
278 self.mem.as_raw(),
279 &mut tret,
280 self.sensi_out_buffer.as_mut_ptr() as _,
281 )
282 };
283 check_flag_is_succes(flag, "CVodeGetSens")?;
284 let sensi_ptr_array =
285 array_init::from_iter(self.sensi_out_buffer.iter().map(|v| v.as_slice())).unwrap();
286 Ok((tret, self.y0.as_slice(), sensi_ptr_array))
287 }
288}
289
290impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Drop
291 for Solver<UserData, F, FS, N, N_SENSI>
292{
293 fn drop(&mut self) {
294 unsafe { sundials_sys::CVodeFree(&mut self.mem.as_raw()) }
295 unsafe { sundials_sys::SUNLinSolFree(self.linsolver) };
296 unsafe { sundials_sys::SUNMatDestroy(self.sunmatrix) };
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use crate::RhsResult;
303
304 use super::*;
305
306 fn f(
307 _t: super::Realtype,
308 y: &[Realtype; 2],
309 ydot: &mut [Realtype; 2],
310 _data: &(),
311 ) -> RhsResult {
312 *ydot = [y[1], -y[0]];
313 RhsResult::Ok
314 }
315
316 fn fs<const N_SENSI: usize>(
317 _t: super::Realtype,
318 _y: &[Realtype; 2],
319 _ydot: &[Realtype; 2],
320 _ys: [&[Realtype; 2]; N_SENSI],
321 ysdot: [&mut [Realtype; 2]; N_SENSI],
322 _data: &(),
323 ) -> RhsResult {
324 for ysdot_i in std::array::IntoIter::new(ysdot) {
325 *ysdot_i = [0., 0.];
326 }
327 RhsResult::Ok
328 }
329
330 #[test]
331 fn create() {
332 let y0 = [0., 1.];
333 let y_s0 = [[0.; 2]; 4];
334 let _solver = Solver::new(
335 LinearMultistepMethod::Adams,
336 f,
337 fs,
338 0.,
339 &y0,
340 &y_s0,
341 1e-4,
342 AbsTolerance::scalar(1e-4),
343 SensiAbsTolerance::scalar([1e-4; 4]),
344 (),
345 )
346 .unwrap();
347 }
348}