1use nalgebra::base::{SliceStorage, SliceStorageMut};
2use nalgebra::{Dynamic, Matrix};
3use numpy::npyffi;
4use numpy::npyffi::objects::PyArrayObject;
5use pyo3::{types::PyAny, AsPyPointer};
6
7#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
9pub enum Dimension {
10 Static(usize),
11 Dynamic,
12}
13
14#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
16pub struct Shape(Dimension, Dimension);
17
18#[derive(Clone, Eq, PartialEq, Debug)]
20pub enum Error {
21 WrongObjectType(WrongObjectTypeError),
23
24 IncompatibleArray(IncompatibleArrayError),
26
27 UnalignedArray(UnalignedArrayError),
29}
30
31#[derive(Clone, Eq, PartialEq, Debug)]
33pub struct WrongObjectTypeError {
34 pub actual: String,
35}
36
37#[derive(Clone, Eq, PartialEq, Debug)]
39pub struct IncompatibleArrayError {
40 pub expected_shape: Shape,
41 pub actual_shape: Vec<usize>,
42 pub expected_dtype: numpy::DataType,
43 pub actual_dtype: String,
44}
45
46#[derive(Clone, Eq, PartialEq, Debug)]
48pub struct UnalignedArrayError;
49
50#[allow(clippy::needless_lifetimes)]
60pub unsafe fn matrix_slice_from_numpy<'a, N, R, C>(
61 _py: pyo3::Python,
62 input: &'a PyAny,
63) -> Result<nalgebra::MatrixSlice<'a, N, R, C, Dynamic, Dynamic>, Error>
64where
65 N: nalgebra::Scalar + numpy::Element,
66 R: nalgebra::Dim,
67 C: nalgebra::Dim,
68{
69 matrix_slice_from_numpy_ptr(input.as_ptr())
70}
71
72#[allow(clippy::needless_lifetimes)]
82pub unsafe fn matrix_slice_mut_from_numpy<'a, N, R, C>(
83 _py: pyo3::Python,
84 input: &'a PyAny,
85) -> Result<nalgebra::MatrixSliceMut<'a, N, R, C, Dynamic, Dynamic>, Error>
86where
87 N: nalgebra::Scalar + numpy::Element,
88 R: nalgebra::Dim,
89 C: nalgebra::Dim,
90{
91 matrix_slice_mut_from_numpy_ptr(input.as_ptr())
92}
93
94pub fn matrix_from_numpy<N, R, C>(py: pyo3::Python, input: &PyAny) -> Result<nalgebra::MatrixMN<N, R, C>, Error>
102where
103 N: nalgebra::Scalar + numpy::Element,
104 R: nalgebra::Dim,
105 C: nalgebra::Dim,
106 nalgebra::base::default_allocator::DefaultAllocator: nalgebra::base::allocator::Allocator<N, R, C>,
107{
108 Ok(unsafe { matrix_slice_from_numpy::<N, R, C>(py, input) }?.into_owned())
109}
110
111#[allow(clippy::missing_safety_doc)]
113pub unsafe fn matrix_slice_from_numpy_ptr<'a, N, R, C>(
114 array: *mut pyo3::ffi::PyObject,
115) -> Result<nalgebra::MatrixSlice<'a, N, R, C, Dynamic, Dynamic>, Error>
116where
117 N: nalgebra::Scalar + numpy::Element,
118 R: nalgebra::Dim,
119 C: nalgebra::Dim,
120{
121 let array = cast_to_py_array(array)?;
122 let shape = check_array_compatible::<N, R, C>(array)?;
123 check_array_alignment(array)?;
124
125 let row_stride = Dynamic::new(*(*array).strides.add(0) as usize / std::mem::size_of::<N>());
126 let col_stride = Dynamic::new(*(*array).strides.add(1) as usize / std::mem::size_of::<N>());
127 let storage = SliceStorage::<N, R, C, Dynamic, Dynamic>::from_raw_parts((*array).data as *const N, shape, (row_stride, col_stride));
128
129 Ok(Matrix::from_data(storage))
130}
131
132#[allow(clippy::missing_safety_doc)]
134pub unsafe fn matrix_slice_mut_from_numpy_ptr<'a, N, R, C>(
135 array: *mut pyo3::ffi::PyObject,
136) -> Result<nalgebra::MatrixSliceMut<'a, N, R, C, Dynamic, Dynamic>, Error>
137where
138 N: nalgebra::Scalar + numpy::Element,
139 R: nalgebra::Dim,
140 C: nalgebra::Dim,
141{
142 let array = cast_to_py_array(array)?;
143 let shape = check_array_compatible::<N, R, C>(array)?;
144 check_array_alignment(array)?;
145
146 let row_stride = Dynamic::new(*(*array).strides.add(0) as usize / std::mem::size_of::<N>());
147 let col_stride = Dynamic::new(*(*array).strides.add(1) as usize / std::mem::size_of::<N>());
148 let storage = SliceStorageMut::<N, R, C, Dynamic, Dynamic>::from_raw_parts((*array).data as *mut N, shape, (row_stride, col_stride));
149
150 Ok(Matrix::from_data(storage))
151}
152
153unsafe fn cast_to_py_array(object: *mut pyo3::ffi::PyObject) -> Result<*mut PyArrayObject, WrongObjectTypeError> {
155 if npyffi::array::PyArray_Check(object) == 1 {
156 Ok(&mut *(object as *mut npyffi::objects::PyArrayObject))
157 } else {
158 Err(WrongObjectTypeError {
159 actual: object_type_string(object),
160 })
161 }
162}
163
164unsafe fn check_array_compatible<N, R, C>(array: *mut PyArrayObject) -> Result<(R, C), IncompatibleArrayError>
166where
167 N: numpy::Element,
168 R: nalgebra::Dim,
169 C: nalgebra::Dim,
170{
171 let make_error = || {
173 let expected_shape = Shape(
174 R::try_to_usize().map(Dimension::Static).unwrap_or(Dimension::Dynamic),
175 C::try_to_usize().map(Dimension::Static).unwrap_or(Dimension::Dynamic),
176 );
177 IncompatibleArrayError {
178 expected_shape,
179 actual_shape: shape(array),
180 expected_dtype: N::DATA_TYPE,
181 actual_dtype: data_type_string(array),
182 }
183 };
184
185 if (*array).nd != 2 {
187 return Err(make_error());
188 }
189
190 let input_rows = *(*array).dimensions.add(0) as usize;
191 let input_cols = *(*array).dimensions.add(1) as usize;
192
193 if R::try_to_usize().map(|expected| input_rows == expected) == Some(false) {
195 return Err(make_error());
196 }
197
198 if C::try_to_usize().map(|expected| input_cols == expected) == Some(false) {
200 return Err(make_error());
201 }
202
203 if npyffi::array::PY_ARRAY_API.PyArray_EquivTypenums((*(*array).descr).type_num, N::ffi_dtype() as u32 as i32) != 1 {
205 return Err(make_error());
206 }
207
208 Ok((R::from_usize(input_rows), C::from_usize(input_cols)))
210}
211
212unsafe fn check_array_alignment(array: *mut PyArrayObject) -> Result<(), UnalignedArrayError> {
213 if (*array).flags & npyffi::flags::NPY_ARRAY_ALIGNED != 0 {
214 Ok(())
215 } else {
216 Err(UnalignedArrayError)
217 }
218}
219
220unsafe fn object_type_string(object: *mut pyo3::ffi::PyObject) -> String {
222 let py_type = (*object).ob_type;
223 let name = (*py_type).tp_name;
224 let name = std::ffi::CStr::from_ptr(name).to_bytes();
225 String::from_utf8_lossy(name).into_owned()
226}
227
228unsafe fn data_type_string(array: *mut PyArrayObject) -> String {
230 let py_name = pyo3::ffi::PyObject_Str((*array).descr as *mut pyo3::ffi::PyObject);
233 if py_name.is_null() {
234 return String::from("<error converting dtype to string>");
235 }
236
237 let mut size = 0isize;
238 let data = pyo3::ffi::PyUnicode_AsUTF8AndSize(py_name, &mut size as *mut isize);
239 if data.is_null() {
240 pyo3::ffi::Py_DecRef(py_name);
241 return String::from("<invalid UTF-8 in dtype>");
242 }
243
244 let name = std::slice::from_raw_parts(data as *mut u8, size as usize);
245 let name = String::from_utf8_unchecked(name.to_vec());
246 pyo3::ffi::Py_DecRef(py_name);
247 name
248}
249
250unsafe fn shape(object: *mut numpy::npyffi::objects::PyArrayObject) -> Vec<usize> {
252 let num_dims = (*object).nd;
253 let dimensions = std::slice::from_raw_parts((*object).dimensions as *const usize, num_dims as usize);
254 dimensions.to_vec()
255}
256
257impl From<WrongObjectTypeError> for Error {
258 fn from(other: WrongObjectTypeError) -> Self {
259 Self::WrongObjectType(other)
260 }
261}
262
263impl From<IncompatibleArrayError> for Error {
264 fn from(other: IncompatibleArrayError) -> Self {
265 Self::IncompatibleArray(other)
266 }
267}
268
269impl From<UnalignedArrayError> for Error {
270 fn from(other: UnalignedArrayError) -> Self {
271 Self::UnalignedArray(other)
272 }
273}
274
275impl std::fmt::Display for Dimension {
276 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
277 match self {
278 Self::Dynamic => write!(f, "Dynamic"),
279 Self::Static(x) => write!(f, "{}", x),
280 }
281 }
282}
283
284impl std::fmt::Display for Shape {
285 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
286 let Self(rows, cols) = self;
287 write!(f, "[{}, {}]", rows, cols)
288 }
289}
290
291impl std::fmt::Display for Error {
292 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
293 match self {
294 Self::WrongObjectType(e) => write!(f, "{}", e),
295 Self::IncompatibleArray(e) => write!(f, "{}", e),
296 Self::UnalignedArray(e) => write!(f, "{}", e),
297 }
298 }
299}
300
301impl std::fmt::Display for WrongObjectTypeError {
302 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
303 write!(f, "wrong object type: expected a numpy.ndarray, found {}", self.actual)
304 }
305}
306
307impl std::fmt::Display for IncompatibleArrayError {
308 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
309 write!(
310 f,
311 "incompatible array: expected ndarray(shape={}, dtype='{}'), found ndarray(shape={:?}, dtype={:?})",
312 self.expected_shape,
313 FormatDataType(&self.expected_dtype),
314 self.actual_shape,
315 self.actual_dtype,
316 )
317 }
318}
319
320impl std::fmt::Display for UnalignedArrayError {
321 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
322 write!(f, "the input array is not properly aligned for this platform")
323 }
324}
325
326struct FormatDataType<'a>(&'a numpy::DataType);
328
329impl std::fmt::Display for FormatDataType<'_> {
330 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
331 let Self(dtype) = self;
332 match dtype {
333 numpy::DataType::Bool => write!(f, "bool"),
334 numpy::DataType::Complex32 => write!(f, "complex32"),
335 numpy::DataType::Complex64 => write!(f, "complex64"),
336 numpy::DataType::Float32 => write!(f, "float32"),
337 numpy::DataType::Float64 => write!(f, "float64"),
338 numpy::DataType::Int8 => write!(f, "int8"),
339 numpy::DataType::Int16 => write!(f, "int16"),
340 numpy::DataType::Int32 => write!(f, "int32"),
341 numpy::DataType::Int64 => write!(f, "int64"),
342 numpy::DataType::Object => write!(f, "object"),
343 numpy::DataType::Uint8 => write!(f, "uint8"),
344 numpy::DataType::Uint16 => write!(f, "uint16"),
345 numpy::DataType::Uint32 => write!(f, "uint32"),
346 numpy::DataType::Uint64 => write!(f, "uint64"),
347 }
348 }
349}
350
351impl std::error::Error for Error {}
352impl std::error::Error for WrongObjectTypeError {}
353impl std::error::Error for IncompatibleArrayError {}