1use crate::{
2 error::DiffsolJsError,
3 scalar_type::{Scalar, ScalarType},
4};
5use diffsol::{FaerScalar, NalgebraScalar};
6use ndarray::{ArrayView2, ShapeBuilder};
7use std::any::Any;
8
9pub trait ToHostArray<T> {
10 fn to_host_array(self) -> HostArray;
11}
12
13pub trait FromHostArray<T> {
14 fn from_host_array(array: HostArray) -> Result<T, DiffsolJsError>;
15}
16
17impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Mat<T> {
18 fn to_host_array(self) -> HostArray {
19 let owner = Box::new(self);
20 let nrows = owner.nrows();
21 let ncols = owner.ncols();
22 let row_stride = owner.row_stride();
23 let col_stride = owner.col_stride();
24 let ptr = owner.as_ptr() as *mut u8;
25 HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
26 .with_owner(owner)
27 }
28}
29
30impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DMatrix<T> {
31 fn to_host_array(self) -> HostArray {
32 let owner = Box::new(self);
33 let nrows = owner.nrows();
34 let ncols = owner.ncols();
35 let (row_stride, col_stride) = owner.strides();
36 let row_stride = row_stride as isize;
37 let col_stride = col_stride as isize;
38 let ptr = owner.as_ptr() as *mut u8;
39 HostArray::new_col_major(ptr, nrows, ncols, row_stride, col_stride, T::scalar_type())
40 .with_owner(owner)
41 }
42}
43
44impl<T: Scalar + NalgebraScalar + 'static> ToHostArray<T> for nalgebra::DVector<T> {
45 fn to_host_array(self) -> HostArray {
46 let owner = Box::new(self);
47 let len = owner.len();
48 let ptr = owner.as_ptr() as *mut u8;
49 HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
50 }
51}
52
53impl<T: Scalar + FaerScalar + 'static> ToHostArray<T> for faer::Col<T> {
54 fn to_host_array(self) -> HostArray {
55 let owner = Box::new(self);
56 let len = owner.nrows();
57 let ptr = owner.as_ptr() as *mut u8;
58 HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
59 }
60}
61
62impl<T: Scalar + 'static> ToHostArray<T> for Vec<T> {
63 fn to_host_array(self) -> HostArray {
64 let owner = Box::new(self);
65 let len = owner.len();
66 let ptr = owner.as_ptr() as *mut u8;
67 HostArray::new_vector(ptr, len, T::scalar_type()).with_owner(owner)
68 }
69}
70
71impl<'h, T: Scalar> FromHostArray<ArrayView2<'h, T>> for ArrayView2<'h, T> {
72 fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
73 array.as_array()
74 }
75}
76
77impl FromHostArray<Vec<f32>> for Vec<f32> {
78 fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
79 array.as_slice::<f32>().map(|slice| slice.to_vec())
80 }
81}
82
83impl FromHostArray<Vec<f64>> for Vec<f64> {
84 fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
85 match array.dtype() {
86 ScalarType::F32 => Ok(array
87 .as_slice::<f32>()?
88 .iter()
89 .map(|&value| value as f64)
90 .collect()),
91 ScalarType::F64 => Ok(array.as_slice::<f64>()?.to_vec()),
92 }
93 }
94}
95
96impl FromHostArray<Vec<Vec<f32>>> for Vec<Vec<f32>> {
97 fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
98 array.expect_ndim(2)?;
99 let view = array.as_array::<f32>()?;
100 Ok((0..view.nrows())
101 .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
102 .collect())
103 }
104}
105
106impl FromHostArray<Vec<Vec<f64>>> for Vec<Vec<f64>> {
107 fn from_host_array(array: HostArray) -> Result<Self, DiffsolJsError> {
108 array.expect_ndim(2)?;
109 match array.dtype() {
110 ScalarType::F32 => {
111 let view = array.as_array::<f32>()?;
112 Ok((0..view.nrows())
113 .map(|row| {
114 (0..view.ncols())
115 .map(|col| view[(row, col)] as f64)
116 .collect()
117 })
118 .collect())
119 }
120 ScalarType::F64 => {
121 let view = array.as_array::<f64>()?;
122 Ok((0..view.nrows())
123 .map(|row| (0..view.ncols()).map(|col| view[(row, col)]).collect())
124 .collect())
125 }
126 }
127 }
128}
129
130pub struct HostArray {
132 dtype: ScalarType,
133 shape: Vec<usize>,
134 strides: Vec<usize>,
135 ptr: *mut u8,
136 owner: Option<Box<dyn Any>>,
137}
138
139fn scalar_size(dtype: ScalarType) -> usize {
140 match dtype {
141 ScalarType::F32 => std::mem::size_of::<f32>(),
142 ScalarType::F64 => std::mem::size_of::<f64>(),
143 }
144}
145
146impl HostArray {
147 pub fn new(ptr: *mut u8, shape: Vec<usize>, strides: Vec<usize>, dtype: ScalarType) -> Self {
148 Self {
149 ptr,
150 shape,
151 strides,
152 dtype,
153 owner: None,
154 }
155 }
156 pub fn new_vector(ptr: *mut u8, len: usize, dtype: ScalarType) -> Self {
157 let elem_size = scalar_size(dtype);
158 Self {
159 ptr,
160 shape: vec![len],
161 strides: vec![elem_size],
162 dtype,
163 owner: None,
164 }
165 }
166 pub fn alloc_vector(len: usize, dtype: ScalarType) -> Self {
167 match dtype {
168 ScalarType::F32 => {
169 let mut data = vec![0f32; len];
170 let ptr = data.as_mut_ptr() as *mut u8;
171 HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
172 }
173 ScalarType::F64 => {
174 let mut data = vec![0f64; len];
175 let ptr = data.as_mut_ptr() as *mut u8;
176 HostArray::new_vector(ptr, len, dtype).with_owner(Box::new(data))
177 }
178 }
179 }
180 pub fn new_col_major(
181 ptr: *mut u8,
182 rows: usize,
183 cols: usize,
184 row_stride_elems: isize,
185 col_stride_elems: isize,
186 dtype: ScalarType,
187 ) -> Self {
188 let elem_size = scalar_size(dtype);
189 Self {
190 ptr,
191 shape: vec![rows, cols],
192 strides: vec![
193 elem_size * (row_stride_elems as usize),
194 elem_size * (col_stride_elems as usize),
195 ],
196 dtype,
197 owner: None,
198 }
199 }
200 fn with_owner(mut self, owner: Box<dyn Any>) -> Self {
201 self.owner = Some(owner);
202 self
203 }
204 pub(crate) fn data_ptr(&self) -> *const u8 {
205 self.ptr as *const u8
206 }
207 pub(crate) fn ndim(&self) -> usize {
208 self.shape.len()
209 }
210 pub(crate) fn dim(&self, index: usize) -> usize {
211 self.shape.get(index).copied().unwrap_or(0)
212 }
213 pub(crate) fn stride(&self, index: usize) -> usize {
214 self.strides.get(index).copied().unwrap_or(0)
215 }
216 pub(crate) fn dtype(&self) -> ScalarType {
217 self.dtype
218 }
219 fn expect_ndim(&self, expected: usize) -> Result<(), DiffsolJsError> {
220 if self.shape.len() != expected {
221 return Err(DiffsolJsError::from(diffsol::error::DiffsolError::Other(
222 format!("Expected a {expected}D array"),
223 )));
224 }
225 Ok(())
226 }
227 pub fn as_array<'h, T: Scalar>(&self) -> Result<ArrayView2<'h, T>, DiffsolJsError> {
228 self.expect_ndim(2)?;
229 if self.dtype != T::scalar_type() {
230 return Err(DiffsolJsError::from(diffsol::error::DiffsolError::Other(
231 "Data type mismatch".to_string(),
232 )));
233 }
234 let rows = self.shape[0];
235 let cols = self.shape[1];
236 let row_stride_bytes = self.strides[0];
237 let col_stride_bytes = self.strides[1];
238 let row_stride_elems = row_stride_bytes / std::mem::size_of::<T>();
239 let col_stride_elems = col_stride_bytes / std::mem::size_of::<T>();
240 unsafe {
241 Ok(ArrayView2::from_shape_ptr(
242 (rows, cols).strides((row_stride_elems, col_stride_elems)),
243 self.ptr as *const T,
244 ))
245 }
246 }
247 pub fn as_slice<T: Scalar>(&self) -> Result<&[T], DiffsolJsError> {
248 self.expect_ndim(1)?;
249 if self.dtype != T::scalar_type() {
250 return Err(DiffsolJsError::from(diffsol::error::DiffsolError::Other(
251 "Data type mismatch".to_string(),
252 )));
253 }
254 let len = self.shape[0];
255 Ok(unsafe { std::slice::from_raw_parts(self.ptr as *const T, len) })
256 }
257}
258
259impl Drop for HostArray {
260 fn drop(&mut self) {
261 if let Some(owner) = self.owner.take() {
262 drop(owner);
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::{FromHostArray, HostArray, ToHostArray};
270
271 #[test]
272 fn vector_from_host_array_rejects_non_1d_input() {
273 let array = vec![vec![1.0f64, 2.0], vec![3.0, 4.0]]
274 .into_iter()
275 .flatten()
276 .collect::<Vec<_>>();
277 let host = HostArray::new_col_major(
278 array.as_ptr() as *mut u8,
279 2,
280 2,
281 1,
282 2,
283 super::ScalarType::F64,
284 );
285 let error = Vec::<f64>::from_host_array(host).unwrap_err().to_string();
286 assert!(error.contains("Expected a 1D array"));
287 }
288
289 #[test]
290 fn vector_round_trips_from_1d_host_array() {
291 let host = vec![1.0f64, 2.0, 3.0].to_host_array();
292 let values = Vec::<f64>::from_host_array(host).unwrap();
293 assert_eq!(values, vec![1.0, 2.0, 3.0]);
294 }
295
296 #[test]
297 fn matrix_from_host_array_rejects_non_2d_input() {
298 let host = vec![1.0f64, 2.0, 3.0].to_host_array();
299 let error = Vec::<Vec<f64>>::from_host_array(host)
300 .unwrap_err()
301 .to_string();
302 assert!(error.contains("Expected a 2D array"));
303 }
304}