Skip to main content

diffsol_c/
host_array.rs

1use crate::{
2    error::DiffsolJsError,
3    scalar_type::{Scalar, ScalarType},
4};
5use diffsol::{FaerScalar, NalgebraScalar};
6use ndarray::{ArrayView2, ShapeBuilder};
7use std::any::Any;
8
9pub trait ToHostArray<T> {
10    fn to_host_array(self) -> HostArray;
11}
12
13pub trait FromHostArray<T> {
14    fn from_host_array(array: HostArray) -> Result<T, DiffsolJsError>;
15}
16
17impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Mat<T> {
18    fn to_host_array(self) -> HostArray {
19        let owner = Box::new(self);
20        let nrows = owner.nrows();
21        let ncols = owner.ncols();
22        let row_stride = owner.row_stride();
23        let col_stride = owner.col_stride();
24        let ptr = owner.as_ptr() as *mut u8;
25        HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
26            .with_owner(owner)
27    }
28}
29
30impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DMatrix<T> {
31    fn to_host_array(self) -> HostArray {
32        let owner = Box::new(self);
33        let nrows = owner.nrows();
34        let ncols = owner.ncols();
35        let (row_stride, col_stride) = owner.strides();
36        let row_stride = row_stride as isize;
37        let col_stride = col_stride as isize;
38        let ptr = owner.as_ptr() as *mut u8;
39        HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
40            .with_owner(owner)
41    }
42}
43
44impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DVector<T> {
45    fn to_host_array(self) -> HostArray {
46        let owner = Box::new(self);
47        let len = owner.len();
48        let ptr = owner.as_ptr() as *mut u8;
49        HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
50    }
51}
52
53impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Col<T> {
54    fn to_host_array(self) -> HostArray {
55        let owner = Box::new(self);
56        let len = owner.nrows();
57        let ptr = owner.as_ptr() as *mut u8;
58        HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
59    }
60}
61
62impl<T: Scalar + 'static> ToHostArray<T> for Vec<T> {
63    fn to_host_array(self) -> HostArray {
64        let owner = Box::new(self);
65        let len = owner.len();
66        let ptr = owner.as_ptr() as *mut u8;
67        HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
68    }
69}
70
71impl<'h, T: Scalar> FromHostArray<ArrayView2<'h, T>> for ArrayView2<'h, T> {
72    fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
73        array.as_array()
74    }
75}
76
77impl FromHostArray<Vec<f32>> for Vec<f32> {
78    fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
79        array.as_slice::<f32>().map(|slice| slice.to_vec())
80    }
81}
82
83impl FromHostArray<Vec<f64>> for Vec<f64> {
84    fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
85        match array.dtype() {
86            ScalarType::F32 => Ok(array
87                .as_slice::<f32>()?
88                .iter()
89                .map(|&value| value as f64)
90                .collect()),
91            ScalarType::F64 => Ok(array.as_slice::<f64>()?.to_vec()),
92        }
93    }
94}
95
96impl FromHostArray<Vec<Vec<f32>>> for Vec<Vec<f32>> {
97    fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
98        array.expect_ndim(2)?;
99        let view = array.as_array::<f32>()?;
100        Ok((0..view.nrows())
101            .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
102            .collect())
103    }
104}
105
106impl FromHostArray<Vec<Vec<f64>>> for Vec<Vec<f64>> {
107    fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
108        array.expect_ndim(2)?;
109        match array.dtype() {
110            ScalarType::F32 => {
111                let view = array.as_array::<f32>()?;
112                Ok((0..view.nrows())
113                    .map(|row| {
114                        (0..view.ncols())
115                            .map(|col| view[(row, col)] as f64)
116                            .collect()
117                    })
118                    .collect())
119            }
120            ScalarType::F64 => {
121                let view = array.as_array::<f64>()?;
122                Ok((0..view.nrows())
123                    .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
124                    .collect())
125            }
126        }
127    }
128}
129
130/// a read-only array that is allocated in rust and can be safely accessed in the host language (e.g. Python) without copying
131pub struct HostArray {
132    dtype: ScalarType,
133    shape: Vec<usize>,
134    strides: Vec<usize>,
135    ptr: *mut u8,
136    owner: Option<Box<dyn Any>>,
137}
138
139fn scalar_size(dtype: ScalarType) -> usize {
140    match dtype {
141        ScalarType::F32 => std::mem::size_of::<f32>(),
142        ScalarType::F64 => std::mem::size_of::<f64>(),
143    }
144}
145
146impl HostArray {
147    pub fn new(ptr: *mut u8, shape: Vec<usize>, strides: Vec<usize>, dtype: ScalarType) -> Self {
148        Self {
149            ptr,
150            shape,
151            strides,
152            dtype,
153            owner: None,
154        }
155    }
156    pub fn new_vector(ptr: *mut u8, len: usize, dtype: ScalarType) -> Self {
157        let elem_size = scalar_size(dtype);
158        Self {
159            ptr,
160            shape: vec![len],
161            strides: vec![elem_size],
162            dtype,
163            owner: None,
164        }
165    }
166    pub fn alloc_vector(len: usize, dtype: ScalarType) -> Self {
167        match dtype {
168            ScalarType::F32 => {
169                let mut data = vec![0f32; len];
170                let ptr = data.as_mut_ptr() as *mut u8;
171                HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
172            }
173            ScalarType::F64 => {
174                let mut data = vec![0f64; len];
175                let ptr = data.as_mut_ptr() as *mut u8;
176                HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
177            }
178        }
179    }
180    pub fn new_col_major(
181        ptr: *mut u8,
182        rows: usize,
183        cols: usize,
184        row_stride_elems: isize,
185        col_stride_elems: isize,
186        dtype: ScalarType,
187    ) -> Self {
188        let elem_size = scalar_size(dtype);
189        Self {
190            ptr,
191            shape: vec![rows, cols],
192            strides: vec![
193                elem_size * (row_stride_elems as usize),
194                elem_size * (col_stride_elems as usize),
195            ],
196            dtype,
197            owner: None,
198        }
199    }
200    fn with_owner(mut self, owner: Box<dyn Any>) -> Self {
201        self.owner = Some(owner);
202        self
203    }
204    pub(crate) fn data_ptr(&self) -> *const u8 {
205        self.ptr as *const u8
206    }
207    pub(crate) fn ndim(&self) -> usize {
208        self.shape.len()
209    }
210    pub(crate) fn dim(&self, index: usize) -> usize {
211        self.shape.get(index).copied().unwrap_or(0)
212    }
213    pub(crate) fn stride(&self, index: usize) -> usize {
214        self.strides.get(index).copied().unwrap_or(0)
215    }
216    pub(crate) fn dtype(&self) -> ScalarType {
217        self.dtype
218    }
219    fn expect_ndim(&self, expected: usize) -> Result<(), DiffsolJsError> {
220        if self.shape.len() != expected {
221            return Err(DiffsolJsError::from(diffsol::error::DiffsolError::Other(
222                format!("Expected a {expected}D array"),
223            )));
224        }
225        Ok(())
226    }
227    pub fn as_array<'h, T: Scalar>(&self) -> Result<ArrayView2<'h, T>, DiffsolJsError> {
228        self.expect_ndim(2)?;
229        if self.dtype != T::scalar_type() {
230            return Err(DiffsolJsError::from(diffsol::error::DiffsolError::Other(
231                "Data type mismatch".to_string(),
232            )));
233        }
234        let rows = self.shape[0];
235        let cols = self.shape[1];
236        let row_stride_bytes = self.strides[0];
237        let col_stride_bytes = self.strides[1];
238        let row_stride_elems = row_stride_bytes / std::mem::size_of::<T>();
239        let col_stride_elems = col_stride_bytes / std::mem::size_of::<T>();
240        unsafe {
241            Ok(ArrayView2::from_shape_ptr(
242                (rows, cols).strides((row_stride_elems, col_stride_elems)),
243                self.ptr as *const T,
244            ))
245        }
246    }
247    pub fn as_slice<T: Scalar>(&self) -> Result<&[T], DiffsolJsError> {
248        self.expect_ndim(1)?;
249        if self.dtype != T::scalar_type() {
250            return Err(DiffsolJsError::from(diffsol::error::DiffsolError::Other(
251                "Data type mismatch".to_string(),
252            )));
253        }
254        let len = self.shape[0];
255        Ok(unsafe { std::slice::from_raw_parts(self.ptr as *const T, len) })
256    }
257}
258
259impl Drop for HostArray {
260    fn drop(&mut self) {
261        if let Some(owner) = self.owner.take() {
262            drop(owner);
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::{FromHostArray, HostArray, ToHostArray};
270
271    #[test]
272    fn vector_from_host_array_rejects_non_1d_input() {
273        let array = vec![vec![1.0f64, 2.0], vec![3.0, 4.0]]
274            .into_iter()
275            .flatten()
276            .collect::<Vec<_>>();
277        let host = HostArray::new_col_major(
278            array.as_ptr() as *mut u8,
279            2,
280            2,
281            1,
282            2,
283            super::ScalarType::F64,
284        );
285        let error = Vec::<f64>::from_host_array(host).unwrap_err().to_string();
286        assert!(error.contains("Expected a 1D array"));
287    }
288
289    #[test]
290    fn vector_round_trips_from_1d_host_array() {
291        let host = vec![1.0f64, 2.0, 3.0].to_host_array();
292        let values = Vec::<f64>::from_host_array(host).unwrap();
293        assert_eq!(values, vec![1.0, 2.0, 3.0]);
294    }
295
296    #[test]
297    fn matrix_from_host_array_rejects_non_2d_input() {
298        let host = vec![1.0f64, 2.0, 3.0].to_host_array();
299        let error = Vec::<Vec<f64>>::from_host_array(host)
300            .unwrap_err()
301            .to_string();
302        assert!(error.contains("Expected a 2D array"));
303    }
304}