1use crate::{
2 error::DiffsolRtError,
3 scalar_type::{Scalar, ScalarType, ToScalarType},
4};
5use diffsol::{FaerScalar, NalgebraScalar, Vector};
6use ndarray::{ArrayView2, ShapeBuilder};
7use std::any::Any;
8
9pub trait ToHostArray<T> {
10 fn to_host_array(self) -> HostArray;
11}
12
13pub trait FromHostArray<T> {
14 fn from_host_array(array: HostArray) -> Result<T, DiffsolRtError>;
15}
16
17impl<V> ToHostArray<Vec<V>> for Vec<V>
18where
19 V: Vector,
20 V::T: Scalar + 'static,
21{
22 fn to_host_array(self) -> HostArray {
23 let ncols = self.len();
24 let nrows = self.first().map(|column| column.len()).unwrap_or(0);
25 let mut owner = Vec::with_capacity(nrows * ncols);
26 for column in self {
27 assert_eq!(
28 column.len(),
29 nrows,
30 "all vector columns must have the same length"
31 );
32 for row in 0..nrows {
33 owner.push(column.get_index(row));
34 }
35 }
36 let ptr = owner.as_mut_ptr() as *mut u8;
37 HostArray::new_col_major(ptr, nrows, ncols, 1, nrows as isize, V::T::scalar_type())
38 .with_owner(Box::new(owner))
39 }
40}
41
42impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Mat<T> {
43 fn to_host_array(self) -> HostArray {
44 let owner = Box::new(self);
45 let nrows = owner.nrows();
46 let ncols = owner.ncols();
47 let row_stride = owner.row_stride();
48 let col_stride = owner.col_stride();
49 let ptr = owner.as_ptr() as *mut u8;
50 HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
51 .with_owner(owner)
52 }
53}
54
55impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DMatrix<T> {
56 fn to_host_array(self) -> HostArray {
57 let owner = Box::new(self);
58 let nrows = owner.nrows();
59 let ncols = owner.ncols();
60 let (row_stride, col_stride) = owner.strides();
61 let row_stride = row_stride as isize;
62 let col_stride = col_stride as isize;
63 let ptr = owner.as_ptr() as *mut u8;
64 HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
65 .with_owner(owner)
66 }
67}
68
69impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DVector<T> {
70 fn to_host_array(self) -> HostArray {
71 let owner = Box::new(self);
72 let len = owner.len();
73 let ptr = owner.as_ptr() as *mut u8;
74 HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
75 }
76}
77
78impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Col<T> {
79 fn to_host_array(self) -> HostArray {
80 let owner = Box::new(self);
81 let len = owner.nrows();
82 let ptr = owner.as_ptr() as *mut u8;
83 HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
84 }
85}
86
87impl<T: Scalar + 'static> ToHostArray<T> for Vec<T> {
88 fn to_host_array(self) -> HostArray {
89 let owner = Box::new(self);
90 let len = owner.len();
91 let ptr = owner.as_ptr() as *mut u8;
92 HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
93 }
94}
95
96impl<'h, T: Scalar> FromHostArray<ArrayView2<'h, T>> for ArrayView2<'h, T> {
97 fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
98 array.as_array()
99 }
100}
101
102impl FromHostArray<Vec<f32>> for Vec<f32> {
103 fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
104 array.as_slice::<f32>().map(|slice| slice.to_vec())
105 }
106}
107
108impl FromHostArray<Vec<f64>> for Vec<f64> {
109 fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
110 match array.dtype() {
111 ScalarType::F32 => Ok(array
112 .as_slice::<f32>()?
113 .iter()
114 .map(|&value| value as f64)
115 .collect()),
116 ScalarType::F64 => Ok(array.as_slice::<f64>()?.to_vec()),
117 }
118 }
119}
120
121impl FromHostArray<Vec<Vec<f32>>> for Vec<Vec<f32>> {
122 fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
123 array.expect_ndim(2)?;
124 let view = array.as_array::<f32>()?;
125 Ok((0..view.nrows())
126 .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
127 .collect())
128 }
129}
130
131impl FromHostArray<Vec<Vec<f64>>> for Vec<Vec<f64>> {
132 fn from_host_array(array: HostArray) -> Result<Self, DiffsolRtError> {
133 array.expect_ndim(2)?;
134 match array.dtype() {
135 ScalarType::F32 => {
136 let view = array.as_array::<f32>()?;
137 Ok((0..view.nrows())
138 .map(|row| {
139 (0..view.ncols())
140 .map(|col| view[(row, col)] as f64)
141 .collect()
142 })
143 .collect())
144 }
145 ScalarType::F64 => {
146 let view = array.as_array::<f64>()?;
147 Ok((0..view.nrows())
148 .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
149 .collect())
150 }
151 }
152 }
153}
154
155pub struct HostArray {
157 dtype: ScalarType,
158 shape: Vec<usize>,
159 strides: Vec<usize>,
160 ptr: *mut u8,
161 owner: Option<Box<dyn Any>>,
162}
163
164fn scalar_size(dtype: ScalarType) -> usize {
165 match dtype {
166 ScalarType::F32 => std::mem::size_of::<f32>(),
167 ScalarType::F64 => std::mem::size_of::<f64>(),
168 }
169}
170
171impl HostArray {
172 pub fn new(ptr: *mut u8, shape: Vec<usize>, strides: Vec<usize>, dtype: ScalarType) -> Self {
173 Self {
174 ptr,
175 shape,
176 strides,
177 dtype,
178 owner: None,
179 }
180 }
181 pub fn new_vector(ptr: *mut u8, len: usize, dtype: ScalarType) -> Self {
182 let elem_size = scalar_size(dtype);
183 Self {
184 ptr,
185 shape: vec![len],
186 strides: vec![elem_size],
187 dtype,
188 owner: None,
189 }
190 }
191 pub fn alloc_vector(len: usize, dtype: ScalarType) -> Self {
192 match dtype {
193 ScalarType::F32 => {
194 let mut data = vec![0f32; len];
195 let ptr = data.as_mut_ptr() as *mut u8;
196 HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
197 }
198 ScalarType::F64 => {
199 let mut data = vec![0f64; len];
200 let ptr = data.as_mut_ptr() as *mut u8;
201 HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
202 }
203 }
204 }
205 pub fn new_col_major(
206 ptr: *mut u8,
207 rows: usize,
208 cols: usize,
209 row_stride_elems: isize,
210 col_stride_elems: isize,
211 dtype: ScalarType,
212 ) -> Self {
213 let elem_size = scalar_size(dtype);
214 Self {
215 ptr,
216 shape: vec![rows, cols],
217 strides: vec![
218 elem_size * (row_stride_elems as usize),
219 elem_size * (col_stride_elems as usize),
220 ],
221 dtype,
222 owner: None,
223 }
224 }
225 fn with_owner(mut self, owner: Box<dyn Any>) -> Self {
226 self.owner = Some(owner);
227 self
228 }
229 pub(crate) fn data_ptr(&self) -> *const u8 {
230 self.ptr as *const u8
231 }
232 pub(crate) fn ndim(&self) -> usize {
233 self.shape.len()
234 }
235 pub(crate) fn dim(&self, index: usize) -> usize {
236 self.shape.get(index).copied().unwrap_or(0)
237 }
238 pub(crate) fn stride(&self, index: usize) -> usize {
239 self.strides.get(index).copied().unwrap_or(0)
240 }
241 pub(crate) fn dtype(&self) -> ScalarType {
242 self.dtype
243 }
244 fn expect_ndim(&self, expected: usize) -> Result<(), DiffsolRtError> {
245 if self.shape.len() != expected {
246 return Err(DiffsolRtError::from(diffsol::error::DiffsolError::Other(
247 format!("Expected a {expected}D array"),
248 )));
249 }
250 Ok(())
251 }
252 pub fn as_array<'h, T: Scalar>(&self) -> Result<ArrayView2<'h, T>, DiffsolRtError> {
253 self.expect_ndim(2)?;
254 if self.dtype != T::scalar_type() {
255 return Err(DiffsolRtError::from(diffsol::error::DiffsolError::Other(
256 "Data type mismatch".to_string(),
257 )));
258 }
259 let rows = self.shape[0];
260 let cols = self.shape[1];
261 let row_stride_bytes = self.strides[0];
262 let col_stride_bytes = self.strides[1];
263 let row_stride_elems = row_stride_bytes / std::mem::size_of::<T>();
264 let col_stride_elems = col_stride_bytes / std::mem::size_of::<T>();
265 unsafe {
266 Ok(ArrayView2::from_shape_ptr(
267 (rows, cols).strides((row_stride_elems, col_stride_elems)),
268 self.ptr as *const T,
269 ))
270 }
271 }
272 pub fn as_slice<T: Scalar>(&self) -> Result<&[T], DiffsolRtError> {
273 self.expect_ndim(1)?;
274 if self.dtype != T::scalar_type() {
275 return Err(DiffsolRtError::from(diffsol::error::DiffsolError::Other(
276 "Data type mismatch".to_string(),
277 )));
278 }
279 let len = self.shape[0];
280 Ok(unsafe { std::slice::from_raw_parts(self.ptr as *const T, len) })
281 }
282}
283
284impl Drop for HostArray {
285 fn drop(&mut self) {
286 if let Some(owner) = self.owner.take() {
287 drop(owner);
288 }
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::{FromHostArray, HostArray, ToHostArray};
295
296 #[test]
297 fn vector_from_host_array_rejects_non_1d_input() {
298 let array = vec![vec![1.0f64, 2.0], vec![3.0, 4.0]]
299 .into_iter()
300 .flatten()
301 .collect::<Vec<_>>();
302 let host = HostArray::new_col_major(
303 array.as_ptr() as *mut u8,
304 2,
305 2,
306 1,
307 2,
308 super::ScalarType::F64,
309 );
310 let error = Vec::<f64>::from_host_array(host).unwrap_err().to_string();
311 assert!(error.contains("Expected a 1D array"));
312 }
313
314 #[test]
315 fn vector_round_trips_from_1d_host_array() {
316 let host = vec![1.0f64, 2.0, 3.0].to_host_array();
317 let values = Vec::<f64>::from_host_array(host).unwrap();
318 assert_eq!(values, vec![1.0, 2.0, 3.0]);
319 }
320
321 #[test]
322 fn matrix_from_host_array_rejects_non_2d_input() {
323 let host = vec![1.0f64, 2.0, 3.0].to_host_array();
324 let error = Vec::<Vec<f64>>::from_host_array(host)
325 .unwrap_err()
326 .to_string();
327 assert!(error.contains("Expected a 2D array"));
328 }
329}