cvode_wrap/
lib.rs

1//! A wrapper around cvode and cvodes from the sundials tool suite.
2//!
3//! Users should be mostly interested in [`SolverSensi`] and [`SolverNoSensi`].
4//!
5//! # Building sundials
6//!
7//! To build sundials, activate the `sundials-sys/build_libraries` feature.
8//!
9//! # Examples
10//!
11//! ## Oscillator
12//!
13//! An oscillatory system defined by `x'' = -k * x`.
14//!
15//! ### Without sensitivities
16//!
17//! ```rust
18//! use cvode_wrap::*;
19//! let y0 = [0., 1.];
20//! //define the right-hand-side
21//! fn f(_t: Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], k: &Realtype) -> RhsResult {
22//!     *ydot = [y[1], -y[0] * k];
23//!     RhsResult::Ok
24//! }
25//! //initialize the solver
26//! let mut solver = SolverNoSensi::new(
27//!     LinearMultistepMethod::Adams,
28//!     f,
29//!     0.,
30//!     &y0,
31//!     1e-4,
32//!     AbsTolerance::scalar(1e-4),
33//!     1e-2,
34//! )
35//! .unwrap();
36//! //and solve
37//! let ts: Vec<_> = (1..100).collect();
38//! println!("0,{},{}", y0[0], y0[1]);
39//! for &t in &ts {
40//!     let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap();
41//!     println!("{},{},{}", t, x, xdot);
42//! }
43//! ```
44//!
45//! ### With sensitivities
46//!
47//! The sensitivities are computed with respect to `x(0)`, `x'(0)` and `k`.
48//!
49//! ```rust
50//! use cvode_wrap::*;
51//! let y0 = [0., 1.];
52//! //define the right-hand-side
53//! fn f(_t: Realtype, y: &[Realtype; 2], ydot: &mut [Realtype; 2], k: &Realtype) -> RhsResult {
54//!     *ydot = [y[1], -y[0] * k];
55//!     RhsResult::Ok
56//! }
57//! //define the sensitivity function for the right hand side
58//! fn fs(
59//!     _t: Realtype,
60//!     y: &[Realtype; 2],
61//!     _ydot: &[Realtype; 2],
62//!     ys: [&[Realtype; 2]; N_SENSI],
63//!     ysdot: [&mut [Realtype; 2]; N_SENSI],
64//!     k: &Realtype,
65//! ) -> RhsResult {
66//!     // Mind that when indexing sensitivities, the first index
67//!     // is the parameter index, and the second the state variable
68//!     // index
69//!     *ysdot[0] = [ys[0][1], -ys[0][0] * k];
70//!     *ysdot[1] = [ys[1][1], -ys[1][0] * k];
71//!     *ysdot[2] = [ys[2][1], -ys[2][0] * k - y[0]];
72//!     RhsResult::Ok
73//! }
74//!
75//! const N_SENSI: usize = 3;
76//!
77//! // the sensitivities in order are d/dy0[0], d/dy0[1] and d/dk
78//! let ys0 = [[1., 0.], [0., 1.], [0., 0.]];
79//!
80//! //initialize the solver
81//! let mut solver = SolverSensi::new(
82//!     LinearMultistepMethod::Adams,
83//!     f,
84//!     fs,
85//!     0.,
86//!     &y0,
87//!     &ys0,
88//!     1e-4,
89//!     AbsTolerance::scalar(1e-4),
90//!     SensiAbsTolerance::scalar([1e-4; N_SENSI]),
91//!     1e-2,
92//! )
93//! .unwrap();
94//! //and solve
95//! let ts: Vec<_> = (1..100).collect();
96//! println!("0,{},{}", y0[0], y0[1]);
97//! for &t in &ts {
98//!     let (_tret, &[x, xdot], [&[dy0_dy00, dy1_dy00], &[dy0_dy01, dy1_dy01], &[dy0_dk, dy1_dk]]) =
99//!         solver.step(t as _, StepKind::Normal).unwrap();
100//!     println!(
101//!         "{},{},{},{},{},{},{},{},{}",
102//!         t, x, xdot, dy0_dy00, dy1_dy00, dy0_dy01, dy1_dy01, dy0_dk, dy1_dk
103//!     );
104//! }
105//! ```
106use std::{ffi::c_void, os::raw::c_int, ptr::NonNull};
107
108use sundials_sys::realtype;
109
110mod nvector;
111pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
112
113mod cvode;
114mod cvode_sens;
115
116pub use cvode::Solver as SolverNoSensi;
117pub use cvode_sens::Solver as SolverSensi;
118
119/// The floatting-point type sundials was compiled with
120pub type Realtype = realtype;
121
122#[repr(i32)]
123#[derive(Debug)]
124/// An integration method.
125pub enum LinearMultistepMethod {
126    /// Recomended for non-stiff problems.
127    Adams = sundials_sys::CV_ADAMS,
128    /// Recommended for stiff problems.
129    Bdf = sundials_sys::CV_BDF,
130}
131
132/// A return type for the right-hand-side rust function.
133///
134/// Adapted from Sundials cv-ode guide version 5.7 (BSD Licensed), setcion 4.6.1 :
135///
136/// > If a recoverable error occurred, `cvode` will attempt to correct,
137/// > if the error is unrecoverable, the integration is halted.
138/// >
139/// > A recoverable failure error return is typically used to flag a value of
140/// > the dependent variableythat is “illegal” in some way (e.g., negative where
141/// > only a non-negative value is physically meaningful).  If such a return is
142/// > made, `cvode` will attempt to recover (possibly repeating the nonlinear solve,
143/// > or reducing the step size) in order to avoid this recoverable error return.
144pub enum RhsResult {
145    /// Indicates that there was no error
146    Ok,
147    /// Indicate that there was a recoverable error and its code
148    RecoverableError(u8),
149    /// Indicatest hat there was a non recoverable error
150    NonRecoverableError(u8),
151}
152
153/// Type of integration step
154#[repr(i32)]
155pub enum StepKind {
156    /// The `NORMAL`option causes the solver to take internal steps
157    /// until it has reached or just passed the user-specified time.
158    /// The solver then interpolates in order to return an approximate
159    /// value of y at the desired time.
160    Normal = sundials_sys::CV_NORMAL,
161    /// The `CV_ONE_STEP` option tells the solver to take just one
162    /// internal step and then return thesolution at the point reached
163    /// by that step.
164    OneStep = sundials_sys::CV_ONE_STEP,
165}
166
167/// The error type for this crate
168#[derive(Debug)]
169pub enum Error {
170    NullPointerError { func_id: &'static str },
171    ErrorCode { func_id: &'static str, flag: c_int },
172}
173
174/// An enum representing the choice between a scalar or vector absolute tolerance
175pub enum AbsTolerance<const SIZE: usize> {
176    Scalar(Realtype),
177    Vector(NVectorSerialHeapAllocated<SIZE>),
178}
179
180impl<const SIZE: usize> AbsTolerance<SIZE> {
181    pub fn scalar(atol: Realtype) -> Self {
182        AbsTolerance::Scalar(atol)
183    }
184
185    pub fn vector(atol: &[Realtype; SIZE]) -> Self {
186        let atol = NVectorSerialHeapAllocated::new_from(atol);
187        AbsTolerance::Vector(atol)
188    }
189}
190
191/// An enum representing the choice between scalars or vectors absolute tolerances
192/// for sensitivities.
193pub enum SensiAbsTolerance<const SIZE: usize, const N_SENSI: usize> {
194    Scalar([Realtype; N_SENSI]),
195    Vector([NVectorSerialHeapAllocated<SIZE>; N_SENSI]),
196}
197
198impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
199    pub fn scalar(atol: [Realtype; N_SENSI]) -> Self {
200        SensiAbsTolerance::Scalar(atol)
201    }
202
203    pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI]) -> Self {
204        SensiAbsTolerance::Vector(
205            array_init::from_iter(
206                atol.iter()
207                    .map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
208            )
209            .unwrap(),
210        )
211    }
212}
213
214/// A short-hand for `std::result::Result<T, crate::Error>`
215pub type Result<T> = std::result::Result<T, Error>;
216
217fn check_non_null<T>(ptr: *mut T, func_id: &'static str) -> Result<NonNull<T>> {
218    NonNull::new(ptr).ok_or(Error::NullPointerError { func_id })
219}
220
221fn check_flag_is_succes(flag: c_int, func_id: &'static str) -> Result<()> {
222    if flag == sundials_sys::CV_SUCCESS {
223        Ok(())
224    } else {
225        Err(Error::ErrorCode { flag, func_id })
226    }
227}
228
229#[repr(C)]
230struct CvodeMemoryBlock {
231    _private: [u8; 0],
232}
233
234#[repr(transparent)]
235#[derive(Debug, Clone, Copy)]
236struct CvodeMemoryBlockNonNullPtr {
237    ptr: NonNull<CvodeMemoryBlock>,
238}
239
240impl CvodeMemoryBlockNonNullPtr {
241    fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self {
242        Self { ptr }
243    }
244
245    fn as_raw(self) -> *mut c_void {
246        self.ptr.as_ptr() as *mut c_void
247    }
248}
249
250impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
251    fn from(x: NonNull<CvodeMemoryBlock>) -> Self {
252        Self::new(x)
253    }
254}