cvode_wrap/lib.rs
1//! A wrapper around cvode and cvodes from the sundials tool suite.
2//!
3//! Users should be mostly interested in [`SolverSensi`] and [`SolverNoSensi`].
4//!
5//! # Building sundials
6//!
7//! To build sundials, activate the `sundials-sys/build_libraries` feature.
8//!
9//! # Examples
10//!
11//! ## Oscillator
12//!
13//! An oscillatory system defined by `x'' = -k * x`.
14//!
15//! ### Without sensitivities
16//!
17//! ```rust
18//! use cvode_wrap::*;
19//! let y0 = [0., 1.];
20//! //define the right-hand-side
21//! fn f(_t: Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], k: &Realtype) -> RhsResult {
22//! *ydot = [y[1], -y[0] * k];
23//! RhsResult::Ok
24//! }
25//! //initialize the solver
26//! let mut solver = SolverNoSensi::new(
27//! LinearMultistepMethod::Adams,
28//! f,
29//! 0.,
30//! &y0,
31//! 1e-4,
32//! AbsTolerance::scalar(1e-4),
33//! 1e-2,
34//! )
35//! .unwrap();
36//! //and solve
37//! let ts: Vec<_> = (1..100).collect();
38//! println!("0,{},{}", y0[0], y0[1]);
39//! for &t in &ts {
40//! let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap();
41//! println!("{},{},{}", t, x, xdot);
42//! }
43//! ```
44//!
45//! ### With sensitivities
46//!
47//! The sensitivities are computed with respect to `x(0)`, `x'(0)` and `k`.
48//!
49//! ```rust
50//! use cvode_wrap::*;
51//! let y0 = [0., 1.];
52//! //define the right-hand-side
53//! fn f(_t: Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], k: &Realtype) -> RhsResult {
54//! *ydot = [y[1], -y[0] * k];
55//! RhsResult::Ok
56//! }
57//! //define the sensitivity function for the right hand side
58//! fn fs(
59//! _t: Realtype,
60//! y: &[Realtype; 2],
61//! _ydot: &[Realtype; 2],
62//! ys: [&[Realtype; 2]; N_SENSI],
63//! ysdot: [&mut [Realtype; 2]; N_SENSI],
64//! k: &Realtype,
65//! ) -> RhsResult {
66//! // Mind that when indexing sensitivities, the first index
67//! // is the parameter index, and the second the state variable
68//! // index
69//! *ysdot[0] = [ys[0][1], -ys[0][0] * k];
70//! *ysdot[1] = [ys[1][1], -ys[1][0] * k];
71//! *ysdot[2] = [ys[2][1], -ys[2][0] * k - y[0]];
72//! RhsResult::Ok
73//! }
74//!
75//! const N_SENSI: usize = 3;
76//!
77//! // the sensitivities in order are d/dy0[0], d/dy0[1] and d/dk
78//! let ys0 = [[1., 0.], [0., 1.], [0., 0.]];
79//!
80//! //initialize the solver
81//! let mut solver = SolverSensi::new(
82//! LinearMultistepMethod::Adams,
83//! f,
84//! fs,
85//! 0.,
86//! &y0,
87//! &ys0,
88//! 1e-4,
89//! AbsTolerance::scalar(1e-4),
90//! SensiAbsTolerance::scalar([1e-4; N_SENSI]),
91//! 1e-2,
92//! )
93//! .unwrap();
94//! //and solve
95//! let ts: Vec<_> = (1..100).collect();
96//! println!("0,{},{}", y0[0], y0[1]);
97//! for &t in &ts {
98//! let (_tret, &[x, xdot], [&[dy0_dy00, dy1_dy00], &[dy0_dy01, dy1_dy01], &[dy0_dk, dy1_dk]]) =
99//! solver.step(t as _, StepKind::Normal).unwrap();
100//! println!(
101//! "{},{},{},{},{},{},{},{},{}",
102//! t, x, xdot, dy0_dy00, dy1_dy00, dy0_dy01, dy1_dy01, dy0_dk, dy1_dk
103//! );
104//! }
105//! ```
106use std::{ffi::c_void, os::raw::c_int, ptr::NonNull};
107
108use sundials_sys::realtype;
109
110mod nvector;
111pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
112
113mod cvode;
114mod cvode_sens;
115
116pub use cvode::Solver as SolverNoSensi;
117pub use cvode_sens::Solver as SolverSensi;
118
119/// The floatting-point type sundials was compiled with
120pub type Realtype = realtype;
121
122#[repr(i32)]
123#[derive(Debug)]
124/// An integration method.
125pub enum LinearMultistepMethod {
126 /// Recomended for non-stiff problems.
127 Adams = sundials_sys::CV_ADAMS,
128 /// Recommended for stiff problems.
129 Bdf = sundials_sys::CV_BDF,
130}
131
132/// A return type for the right-hand-side rust function.
133///
134/// Adapted from Sundials cv-ode guide version 5.7 (BSD Licensed), setcion 4.6.1 :
135///
136/// > If a recoverable error occurred, `cvode` will attempt to correct,
137/// > if the error is unrecoverable, the integration is halted.
138/// >
139/// > A recoverable failure error return is typically used to flag a value of
140/// > the dependent variableythat is “illegal” in some way (e.g., negative where
141/// > only a non-negative value is physically meaningful). If such a return is
142/// > made, `cvode` will attempt to recover (possibly repeating the nonlinear solve,
143/// > or reducing the step size) in order to avoid this recoverable error return.
144pub enum RhsResult {
145 /// Indicates that there was no error
146 Ok,
147 /// Indicate that there was a recoverable error and its code
148 RecoverableError(u8),
149 /// Indicatest hat there was a non recoverable error
150 NonRecoverableError(u8),
151}
152
153/// Type of integration step
154#[repr(i32)]
155pub enum StepKind {
156 /// The `NORMAL`option causes the solver to take internal steps
157 /// until it has reached or just passed the user-specified time.
158 /// The solver then interpolates in order to return an approximate
159 /// value of y at the desired time.
160 Normal = sundials_sys::CV_NORMAL,
161 /// The `CV_ONE_STEP` option tells the solver to take just one
162 /// internal step and then return thesolution at the point reached
163 /// by that step.
164 OneStep = sundials_sys::CV_ONE_STEP,
165}
166
167/// The error type for this crate
168#[derive(Debug)]
169pub enum Error {
170 NullPointerError { func_id: &'static str },
171 ErrorCode { func_id: &'static str, flag: c_int },
172}
173
174/// An enum representing the choice between a scalar or vector absolute tolerance
175pub enum AbsTolerance<const SIZE: usize> {
176 Scalar(Realtype),
177 Vector(NVectorSerialHeapAllocated<SIZE>),
178}
179
180impl<const SIZE: usize> AbsTolerance<SIZE> {
181 pub fn scalar(atol: Realtype) -> Self {
182 AbsTolerance::Scalar(atol)
183 }
184
185 pub fn vector(atol: &[Realtype; SIZE]) -> Self {
186 let atol = NVectorSerialHeapAllocated::new_from(atol);
187 AbsTolerance::Vector(atol)
188 }
189}
190
191/// An enum representing the choice between scalars or vectors absolute tolerances
192/// for sensitivities.
193pub enum SensiAbsTolerance<const SIZE: usize, const N_SENSI: usize> {
194 Scalar([Realtype; N_SENSI]),
195 Vector([NVectorSerialHeapAllocated<SIZE>; N_SENSI]),
196}
197
198impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
199 pub fn scalar(atol: [Realtype; N_SENSI]) -> Self {
200 SensiAbsTolerance::Scalar(atol)
201 }
202
203 pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI]) -> Self {
204 SensiAbsTolerance::Vector(
205 array_init::from_iter(
206 atol.iter()
207 .map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
208 )
209 .unwrap(),
210 )
211 }
212}
213
214/// A short-hand for `std::result::Result<T, crate::Error>`
215pub type Result<T> = std::result::Result<T, Error>;
216
217fn check_non_null<T>(ptr: *mut T, func_id: &'static str) -> Result<NonNull<T>> {
218 NonNull::new(ptr).ok_or(Error::NullPointerError { func_id })
219}
220
221fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
222 if flag == sundials_sys::CV_SUCCESS {
223 Ok(())
224 } else {
225 Err(Error::ErrorCode { flag, func_id })
226 }
227}
228
229#[repr(C)]
230struct CvodeMemoryBlock {
231 _private: [u8; 0],
232}
233
234#[repr(transparent)]
235#[derive(Debug, Clone, Copy)]
236struct CvodeMemoryBlockNonNullPtr {
237 ptr: NonNull<CvodeMemoryBlock>,
238}
239
240impl CvodeMemoryBlockNonNullPtr {
241 fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self {
242 Self { ptr }
243 }
244
245 fn as_raw(self) -> *mut c_void {
246 self.ptr.as_ptr() as *mut c_void
247 }
248}
249
250impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
251 fn from(x: NonNull<CvodeMemoryBlock>) -> Self {
252 Self::new(x)
253 }
254}