cvode_wrap/
nvector.rs

1use std::{
2    convert::TryInto,
3    ops::{Deref, DerefMut},
4    ptr::NonNull,
5};
6
7use sundials_sys::realtype;
8
9/// A sundials `N_Vector_Serial`.
10#[repr(transparent)]
11#[derive(Debug)]
12pub struct NVectorSerial<const SIZE: usize> {
13    inner: sundials_sys::_generic_N_Vector,
14}
15
16impl<const SIZE: usize> NVectorSerial<SIZE> {
17    pub(crate) unsafe fn as_raw(&self) -> sundials_sys::N_Vector {
18        std::mem::transmute(&self.inner)
19    }
20
21    /// Returns a reference to the inner slice of the vector.
22    pub fn as_slice(&self) -> &[realtype; SIZE] {
23        unsafe { &*(sundials_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *const [f64; SIZE]) }
24    }
25
26    /// Returns a mutable reference to the inner slice of the vector.
27    pub fn as_slice_mut(&mut self) -> &mut [realtype; SIZE] {
28        unsafe {
29            &mut *(sundials_sys::N_VGetArrayPointer_Serial(self.as_raw()) as *mut [f64; SIZE])
30        }
31    }
32}
33
34#[repr(transparent)]
35#[derive(Debug)]
36/// An owning pointer to a sundials [`NVectorSerial`] on the heap.
37pub struct NVectorSerialHeapAllocated<const SIZE: usize> {
38    inner: NonNull<NVectorSerial<SIZE>>,
39}
40
41impl<const SIZE: usize> Deref for NVectorSerialHeapAllocated<SIZE> {
42    type Target = NVectorSerial<SIZE>;
43
44    fn deref(&self) -> &Self::Target {
45        unsafe { self.inner.as_ref() }
46    }
47}
48
49impl<const SIZE: usize> DerefMut for NVectorSerialHeapAllocated<SIZE> {
50    fn deref_mut(&mut self) -> &mut Self::Target {
51        unsafe { self.inner.as_mut() }
52    }
53}
54
55impl<const SIZE: usize> NVectorSerialHeapAllocated<SIZE> {
56    unsafe fn new_inner_uninitialized() -> NonNull<NVectorSerial<SIZE>> {
57        let raw_c = sundials_sys::N_VNew_Serial(SIZE.try_into().unwrap());
58        NonNull::new(raw_c as *mut NVectorSerial<SIZE>).unwrap()
59    }
60
61    /// Creates a new vector, filled with 0.
62    pub fn new() -> Self {
63        let inner = unsafe {
64            let x = Self::new_inner_uninitialized();
65            let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
66            for off in 0..SIZE {
67                *ptr.add(off) = 0.;
68            }
69            x
70        };
71        Self { inner }
72    }
73
74    /// Creates a new vector, filled with data from `data`.
75    pub fn new_from(data: &[realtype; SIZE]) -> Self {
76        let inner = unsafe {
77            let x = Self::new_inner_uninitialized();
78            let ptr = sundials_sys::N_VGetArrayPointer_Serial(x.as_ref().as_raw());
79            std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, SIZE);
80            x
81        };
82        Self { inner }
83    }
84}
85
86impl<const SIZE: usize> Drop for NVectorSerialHeapAllocated<SIZE> {
87    fn drop(&mut self) {
88        unsafe { sundials_sys::N_VDestroy(self.as_raw()) }
89    }
90}