Skip to main content

diffsol_c/
host_array.rs

1use crate::{
2    error::DiffsolRtError,
3    scalar_type::{Scalar, ScalarType, ToScalarType},
4};
5use diffsol::{FaerScalar, NalgebraScalar, Vector};
6use ndarray::{ArrayView2, ShapeBuilder};
7use std::any::Any;
8
9pub trait ToHostArray<T> {
10    fn to_host_array(self) -> HostArray;
11}
12
13pub trait FromHostArray<T> {
14    fn from_host_array(array: HostArray) -> Result<T, DiffsolRtError>;
15}
16
17impl<V> ToHostArray<Vec<V>> for Vec<V>
18where
19    V: Vector,
20    V::T: Scalar + 'static,
21{
22    fn to_host_array(self) -> HostArray {
23        let ncols = self.len();
24        let nrows = self.first().map(|column| column.len()).unwrap_or(0);
25        let mut owner = Vec::with_capacity(nrows * ncols);
26        for column in self {
27            assert_eq!(
28                column.len(),
29                nrows,
30                "all vector columns must have the same length"
31            );
32            for row in 0..nrows {
33                owner.push(column.get_index(row));
34            }
35        }
36        let ptr = owner.as_mut_ptr() as *mut u8;
37        HostArray::new_col_major(ptr, nrows, ncols, 1, nrows as isize, V::T::scalar_type())
38            .with_owner(Box::new(owner))
39    }
40}
41
42impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Mat<T> {
43    fn to_host_array(self) -> HostArray {
44        let owner = Box::new(self);
45        let nrows = owner.nrows();
46        let ncols = owner.ncols();
47        let row_stride = owner.row_stride();
48        let col_stride = owner.col_stride();
49        let ptr = owner.as_ptr() as *mut u8;
50        HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
51            .with_owner(owner)
52    }
53}
54
55impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DMatrix<T> {
56    fn to_host_array(self) -> HostArray {
57        let owner = Box::new(self);
58        let nrows = owner.nrows();
59        let ncols = owner.ncols();
60        let (row_stride, col_stride) = owner.strides();
61        let row_stride = row_stride as isize;
62        let col_stride = col_stride as isize;
63        let ptr = owner.as_ptr() as *mut u8;
64        HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
65            .with_owner(owner)
66    }
67}
68
69impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DVector<T> {
70    fn to_host_array(self) -> HostArray {
71        let owner = Box::new(self);
72        let len = owner.len();
73        let ptr = owner.as_ptr() as *mut u8;
74        HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
75    }
76}
77
78impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Col<T> {
79    fn to_host_array(self) -> HostArray {
80        let owner = Box::new(self);
81        let len = owner.nrows();
82        let ptr = owner.as_ptr() as *mut u8;
83        HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
84    }
85}
86
87impl<T: Scalar + 'static> ToHostArray<T> for Vec<T> {
88    fn to_host_array(self) -> HostArray {
89        let owner = Box::new(self);
90        let len = owner.len();
91        let ptr = owner.as_ptr() as *mut u8;
92        HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
93    }
94}
95
96impl<'h, T: Scalar> FromHostArray<ArrayView2<'h, T>> for ArrayView2<'h, T> {
97    fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
98        array.as_array()
99    }
100}
101
102impl FromHostArray<Vec<f32>> for Vec<f32> {
103    fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
104        array.as_slice::<f32>().map(|slice| slice.to_vec())
105    }
106}
107
108impl FromHostArray<Vec<f64>> for Vec<f64> {
109    fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
110        match array.dtype() {
111            ScalarType::F32 => Ok(array
112                .as_slice::<f32>()?
113                .iter()
114                .map(|&value| value as f64)
115                .collect()),
116            ScalarType::F64 => Ok(array.as_slice::<f64>()?.to_vec()),
117        }
118    }
119}
120
121impl FromHostArray<Vec<Vec<f32>>> for Vec<Vec<f32>> {
122    fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
123        array.expect_ndim(2)?;
124        let view = array.as_array::<f32>()?;
125        Ok((0..view.nrows())
126            .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
127            .collect())
128    }
129}
130
131impl FromHostArray<Vec<Vec<f64>>> for Vec<Vec<f64>> {
132    fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
133        array.expect_ndim(2)?;
134        match array.dtype() {
135            ScalarType::F32 => {
136                let view = array.as_array::<f32>()?;
137                Ok((0..view.nrows())
138                    .map(|row| {
139                        (0..view.ncols())
140                            .map(|col| view[(row, col)] as f64)
141                            .collect()
142                    })
143                    .collect())
144            }
145            ScalarType::F64 => {
146                let view = array.as_array::<f64>()?;
147                Ok((0..view.nrows())
148                    .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
149                    .collect())
150            }
151        }
152    }
153}
154
155/// a read-only array that is allocated in rust and can be safely accessed in the host language (e.g. Python) without copying
156pub struct HostArray {
157    dtype: ScalarType,
158    shape: Vec<usize>,
159    strides: Vec<usize>,
160    ptr: *mut u8,
161    owner: Option<Box<dyn Any>>,
162}
163
164fn scalar_size(dtype: ScalarType) -> usize {
165    match dtype {
166        ScalarType::F32 => std::mem::size_of::<f32>(),
167        ScalarType::F64 => std::mem::size_of::<f64>(),
168    }
169}
170
171impl HostArray {
172    pub fn new(ptr: *mut u8, shape: Vec<usize>, strides: Vec<usize>, dtype: ScalarType) -> Self {
173        Self {
174            ptr,
175            shape,
176            strides,
177            dtype,
178            owner: None,
179        }
180    }
181    pub fn new_vector(ptr: *mut u8, len: usize, dtype: ScalarType) -> Self {
182        let elem_size = scalar_size(dtype);
183        Self {
184            ptr,
185            shape: vec![len],
186            strides: vec![elem_size],
187            dtype,
188            owner: None,
189        }
190    }
191    pub fn alloc_vector(len: usize, dtype: ScalarType) -> Self {
192        match dtype {
193            ScalarType::F32 => {
194                let mut data = vec![0f32; len];
195                let ptr = data.as_mut_ptr() as *mut u8;
196                HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
197            }
198            ScalarType::F64 => {
199                let mut data = vec![0f64; len];
200                let ptr = data.as_mut_ptr() as *mut u8;
201                HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
202            }
203        }
204    }
205    pub fn new_col_major(
206        ptr: *mut u8,
207        rows: usize,
208        cols: usize,
209        row_stride_elems: isize,
210        col_stride_elems: isize,
211        dtype: ScalarType,
212    ) -> Self {
213        let elem_size = scalar_size(dtype);
214        Self {
215            ptr,
216            shape: vec![rows, cols],
217            strides: vec![
218                elem_size * (row_stride_elems as usize),
219                elem_size * (col_stride_elems as usize),
220            ],
221            dtype,
222            owner: None,
223        }
224    }
225    fn with_owner(mut self, owner: Box<dyn Any>) -> Self {
226        self.owner = Some(owner);
227        self
228    }
229    pub(crate) fn data_ptr(&self) -> *const u8 {
230        self.ptr as *const u8
231    }
232    pub(crate) fn ndim(&self) -> usize {
233        self.shape.len()
234    }
235    pub(crate) fn dim(&self, index: usize) -> usize {
236        self.shape.get(index).copied().unwrap_or(0)
237    }
238    pub(crate) fn stride(&self, index: usize) -> usize {
239        self.strides.get(index).copied().unwrap_or(0)
240    }
241    pub(crate) fn dtype(&self) -> ScalarType {
242        self.dtype
243    }
244    fn expect_ndim(&self, expected: usize) -> Result<(), DiffsolRtError> {
245        if self.shape.len() != expected {
246            return Err(DiffsolRtError::from(diffsol::error::DiffsolError::Other(
247                format!("Expected a {expected}D array"),
248            )));
249        }
250        Ok(())
251    }
252    pub fn as_array<'h, T: Scalar>(&self) -> Result<ArrayView2<'h, T>, DiffsolRtError> {
253        self.expect_ndim(2)?;
254        if self.dtype != T::scalar_type() {
255            return Err(DiffsolRtError::from(diffsol::error::DiffsolError::Other(
256                "Data type mismatch".to_string(),
257            )));
258        }
259        let rows = self.shape[0];
260        let cols = self.shape[1];
261        let row_stride_bytes = self.strides[0];
262        let col_stride_bytes = self.strides[1];
263        let row_stride_elems = row_stride_bytes / std::mem::size_of::<T>();
264        let col_stride_elems = col_stride_bytes / std::mem::size_of::<T>();
265        unsafe {
266            Ok(ArrayView2::from_shape_ptr(
267                (rows, cols).strides((row_stride_elems, col_stride_elems)),
268                self.ptr as *const T,
269            ))
270        }
271    }
272    pub fn as_slice<T: Scalar>(&self) -> Result<&[T], DiffsolRtError> {
273        self.expect_ndim(1)?;
274        if self.dtype != T::scalar_type() {
275            return Err(DiffsolRtError::from(diffsol::error::DiffsolError::Other(
276                "Data type mismatch".to_string(),
277            )));
278        }
279        let len = self.shape[0];
280        Ok(unsafe { std::slice::from_raw_parts(self.ptr as *const T, len) })
281    }
282}
283
284impl Drop for HostArray {
285    fn drop(&mut self) {
286        if let Some(owner) = self.owner.take() {
287            drop(owner);
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::{FromHostArray, HostArray, ToHostArray};
295
296    #[test]
297    fn vector_from_host_array_rejects_non_1d_input() {
298        let array = vec![vec![1.0f64, 2.0], vec![3.0, 4.0]]
299            .into_iter()
300            .flatten()
301            .collect::<Vec<_>>();
302        let host = HostArray::new_col_major(
303            array.as_ptr() as *mut u8,
304            2,
305            2,
306            1,
307            2,
308            super::ScalarType::F64,
309        );
310        let error = Vec::<f64>::from_host_array(host).unwrap_err().to_string();
311        assert!(error.contains("Expected a 1D array"));
312    }
313
314    #[test]
315    fn vector_round_trips_from_1d_host_array() {
316        let host = vec![1.0f64, 2.0, 3.0].to_host_array();
317        let values = Vec::<f64>::from_host_array(host).unwrap();
318        assert_eq!(values, vec![1.0, 2.0, 3.0]);
319    }
320
321    #[test]
322    fn matrix_from_host_array_rejects_non_2d_input() {
323        let host = vec![1.0f64, 2.0, 3.0].to_host_array();
324        let error = Vec::<Vec<f64>>::from_host_array(host)
325            .unwrap_err()
326            .to_string();
327        assert!(error.contains("Expected a 2D array"));
328    }
329}