nalgebra_numpy/
from_numpy.rs

1use nalgebra::base::{SliceStorage, SliceStorageMut};
2use nalgebra::{Dynamic, Matrix};
3use numpy::npyffi;
4use numpy::npyffi::objects::PyArrayObject;
5use pyo3::{types::PyAny, AsPyPointer};
6
7/// Compile-time matrix dimension used in errors.
8#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
9pub enum Dimension {
10	Static(usize),
11	Dynamic,
12}
13
14/// Compile-time shape of a matrix used in errors.
15#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
16pub struct Shape(Dimension, Dimension);
17
18/// Error that can occur when converting from Python to a nalgebra matrix.
19#[derive(Clone, Eq, PartialEq, Debug)]
20pub enum Error {
21	/// The Python object is not a [`numpy.ndarray`](https://numpy.org/devdocs/reference/arrays.ndarray.html).
22	WrongObjectType(WrongObjectTypeError),
23
24	/// The input array is not compatible with the requested nalgebra matrix.
25	IncompatibleArray(IncompatibleArrayError),
26
27	/// The input array is not properly aligned.
28	UnalignedArray(UnalignedArrayError),
29}
30
31/// Error indicating that the Python object is not a [`numpy.ndarray`](https://numpy.org/devdocs/reference/arrays.ndarray.html).
32#[derive(Clone, Eq, PartialEq, Debug)]
33pub struct WrongObjectTypeError {
34	pub actual: String,
35}
36
37/// Error indicating that the input array is not compatible with the requested nalgebra matrix.
38#[derive(Clone, Eq, PartialEq, Debug)]
39pub struct IncompatibleArrayError {
40	pub expected_shape: Shape,
41	pub actual_shape: Vec<usize>,
42	pub expected_dtype: numpy::DataType,
43	pub actual_dtype: String,
44}
45
46/// Error indicating that the input array is not properly aligned.
47#[derive(Clone, Eq, PartialEq, Debug)]
48pub struct UnalignedArrayError;
49
50/// Create a nalgebra view from a numpy array.
51///
52/// The array dtype must match the output type exactly.
53/// If desired, you can convert the array to the desired type in Python
54/// using [`numpy.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html).
55///
56/// # Safety
57/// This function creates a const slice that references data owned by Python.
58/// The user must ensure that the data is not modified through other pointers or references.
59#[allow(clippy::needless_lifetimes)]
60pub unsafe fn matrix_slice_from_numpy<'a, N, R, C>(
61	_py: pyo3::Python,
62	input: &'a PyAny,
63) -> Result<nalgebra::MatrixSlice<'a, N, R, C, Dynamic, Dynamic>, Error>
64where
65	N: nalgebra::Scalar + numpy::Element,
66	R: nalgebra::Dim,
67	C: nalgebra::Dim,
68{
69	matrix_slice_from_numpy_ptr(input.as_ptr())
70}
71
72/// Create a mutable nalgebra view from a numpy array.
73///
74/// The array dtype must match the output type exactly.
75/// If desired, you can convert the array to the desired type in Python
76/// using [`numpy.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html).
77///
78/// # Safety
79/// This function creates a mutable slice that references data owned by Python.
80/// The user must ensure that no other Rust references to the same data exist.
81#[allow(clippy::needless_lifetimes)]
82pub unsafe fn matrix_slice_mut_from_numpy<'a, N, R, C>(
83	_py: pyo3::Python,
84	input: &'a PyAny,
85) -> Result<nalgebra::MatrixSliceMut<'a, N, R, C, Dynamic, Dynamic>, Error>
86where
87	N: nalgebra::Scalar + numpy::Element,
88	R: nalgebra::Dim,
89	C: nalgebra::Dim,
90{
91	matrix_slice_mut_from_numpy_ptr(input.as_ptr())
92}
93
94/// Create an owning nalgebra matrix from a numpy array.
95///
96/// The data is copied into the matrix.
97///
98/// The array dtype must match the output type exactly.
99/// If desired, you can convert the array to the desired type in Python
100/// using [`numpy.ndarray.astype`](https://numpy.org/devdocs/reference/generated/numpy.ndarray.astype.html).
101pub fn matrix_from_numpy<N, R, C>(py: pyo3::Python, input: &PyAny) -> Result<nalgebra::MatrixMN<N, R, C>, Error>
102where
103	N: nalgebra::Scalar + numpy::Element,
104	R: nalgebra::Dim,
105	C: nalgebra::Dim,
106	nalgebra::base::default_allocator::DefaultAllocator: nalgebra::base::allocator::Allocator<N, R, C>,
107{
108	Ok(unsafe { matrix_slice_from_numpy::<N, R, C>(py, input) }?.into_owned())
109}
110
111/// Same as [`matrix_slice_from_numpy`], but takes a raw [`PyObject`](pyo3::ffi::PyObject) pointer.
112#[allow(clippy::missing_safety_doc)]
113pub unsafe fn matrix_slice_from_numpy_ptr<'a, N, R, C>(
114	array: *mut pyo3::ffi::PyObject,
115) -> Result<nalgebra::MatrixSlice<'a, N, R, C, Dynamic, Dynamic>, Error>
116where
117	N: nalgebra::Scalar + numpy::Element,
118	R: nalgebra::Dim,
119	C: nalgebra::Dim,
120{
121	let array = cast_to_py_array(array)?;
122	let shape = check_array_compatible::<N, R, C>(array)?;
123	check_array_alignment(array)?;
124
125	let row_stride = Dynamic::new(*(*array).strides.add(0) as usize / std::mem::size_of::<N>());
126	let col_stride = Dynamic::new(*(*array).strides.add(1) as usize / std::mem::size_of::<N>());
127	let storage = SliceStorage::<N, R, C, Dynamic, Dynamic>::from_raw_parts((*array).data as *const N, shape, (row_stride, col_stride));
128
129	Ok(Matrix::from_data(storage))
130}
131
132/// Same as [`matrix_slice_mut_from_numpy`], but takes a raw [`PyObject`](pyo3::ffi::PyObject) pointer.
133#[allow(clippy::missing_safety_doc)]
134pub unsafe fn matrix_slice_mut_from_numpy_ptr<'a, N, R, C>(
135	array: *mut pyo3::ffi::PyObject,
136) -> Result<nalgebra::MatrixSliceMut<'a, N, R, C, Dynamic, Dynamic>, Error>
137where
138	N: nalgebra::Scalar + numpy::Element,
139	R: nalgebra::Dim,
140	C: nalgebra::Dim,
141{
142	let array = cast_to_py_array(array)?;
143	let shape = check_array_compatible::<N, R, C>(array)?;
144	check_array_alignment(array)?;
145
146	let row_stride = Dynamic::new(*(*array).strides.add(0) as usize / std::mem::size_of::<N>());
147	let col_stride = Dynamic::new(*(*array).strides.add(1) as usize / std::mem::size_of::<N>());
148	let storage = SliceStorageMut::<N, R, C, Dynamic, Dynamic>::from_raw_parts((*array).data as *mut N, shape, (row_stride, col_stride));
149
150	Ok(Matrix::from_data(storage))
151}
152
153/// Check if an object is numpy array and cast the pointer.
154unsafe fn cast_to_py_array(object: *mut pyo3::ffi::PyObject) -> Result<*mut PyArrayObject, WrongObjectTypeError> {
155	if npyffi::array::PyArray_Check(object) == 1 {
156		Ok(&mut *(object as *mut npyffi::objects::PyArrayObject))
157	} else {
158		Err(WrongObjectTypeError {
159			actual: object_type_string(object),
160		})
161	}
162}
163
164/// Check if a numpy array is compatible and return the runtime shape.
165unsafe fn check_array_compatible<N, R, C>(array: *mut PyArrayObject) -> Result<(R, C), IncompatibleArrayError>
166where
167	N: numpy::Element,
168	R: nalgebra::Dim,
169	C: nalgebra::Dim,
170{
171	// Delay semi-expensive construction of error object using a lambda.
172	let make_error = || {
173		let expected_shape = Shape(
174			R::try_to_usize().map(Dimension::Static).unwrap_or(Dimension::Dynamic),
175			C::try_to_usize().map(Dimension::Static).unwrap_or(Dimension::Dynamic),
176		);
177		IncompatibleArrayError {
178			expected_shape,
179			actual_shape: shape(array),
180			expected_dtype: N::DATA_TYPE,
181			actual_dtype: data_type_string(array),
182		}
183	};
184
185	// Input array must have two dimensions.
186	if (*array).nd != 2 {
187		return Err(make_error());
188	}
189
190	let input_rows = *(*array).dimensions.add(0) as usize;
191	let input_cols = *(*array).dimensions.add(1) as usize;
192
193	// Check number of rows in input array.
194	if R::try_to_usize().map(|expected| input_rows == expected) == Some(false) {
195		return Err(make_error());
196	}
197
198	// Check number of columns in input array.
199	if C::try_to_usize().map(|expected| input_cols == expected) == Some(false) {
200		return Err(make_error());
201	}
202
203	// Check the data type of the input array.
204	if npyffi::array::PY_ARRAY_API.PyArray_EquivTypenums((*(*array).descr).type_num, N::ffi_dtype() as u32 as i32) != 1 {
205		return Err(make_error());
206	}
207
208	// All good.
209	Ok((R::from_usize(input_rows), C::from_usize(input_cols)))
210}
211
212unsafe fn check_array_alignment(array: *mut PyArrayObject) -> Result<(), UnalignedArrayError> {
213	if (*array).flags & npyffi::flags::NPY_ARRAY_ALIGNED != 0 {
214		Ok(())
215	} else {
216		Err(UnalignedArrayError)
217	}
218}
219
220/// Get a string representing the type of a Python object.
221unsafe fn object_type_string(object: *mut pyo3::ffi::PyObject) -> String {
222	let py_type = (*object).ob_type;
223	let name = (*py_type).tp_name;
224	let name = std::ffi::CStr::from_ptr(name).to_bytes();
225	String::from_utf8_lossy(name).into_owned()
226}
227
228/// Get a string representing the data type of a numpy array.
229unsafe fn data_type_string(array: *mut PyArrayObject) -> String {
230	// Convert the dtype to string.
231	// Don't forget to call Py_DecRef in all paths if py_name isn't null.
232	let py_name = pyo3::ffi::PyObject_Str((*array).descr as *mut pyo3::ffi::PyObject);
233	if py_name.is_null() {
234		return String::from("<error converting dtype to string>");
235	}
236
237	let mut size = 0isize;
238	let data = pyo3::ffi::PyUnicode_AsUTF8AndSize(py_name, &mut size as *mut isize);
239	if data.is_null() {
240		pyo3::ffi::Py_DecRef(py_name);
241		return String::from("<invalid UTF-8 in dtype>");
242	}
243
244	let name = std::slice::from_raw_parts(data as *mut u8, size as usize);
245	let name = String::from_utf8_unchecked(name.to_vec());
246	pyo3::ffi::Py_DecRef(py_name);
247	name
248}
249
250/// Get the shape of a numpy array as [`Vec`].
251unsafe fn shape(object: *mut numpy::npyffi::objects::PyArrayObject) -> Vec<usize> {
252	let num_dims = (*object).nd;
253	let dimensions = std::slice::from_raw_parts((*object).dimensions as *const usize, num_dims as usize);
254	dimensions.to_vec()
255}
256
257impl From<WrongObjectTypeError> for Error {
258	fn from(other: WrongObjectTypeError) -> Self {
259		Self::WrongObjectType(other)
260	}
261}
262
263impl From<IncompatibleArrayError> for Error {
264	fn from(other: IncompatibleArrayError) -> Self {
265		Self::IncompatibleArray(other)
266	}
267}
268
269impl From<UnalignedArrayError> for Error {
270	fn from(other: UnalignedArrayError) -> Self {
271		Self::UnalignedArray(other)
272	}
273}
274
275impl std::fmt::Display for Dimension {
276	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
277		match self {
278			Self::Dynamic => write!(f, "Dynamic"),
279			Self::Static(x) => write!(f, "{}", x),
280		}
281	}
282}
283
284impl std::fmt::Display for Shape {
285	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
286		let Self(rows, cols) = self;
287		write!(f, "[{}, {}]", rows, cols)
288	}
289}
290
291impl std::fmt::Display for Error {
292	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
293		match self {
294			Self::WrongObjectType(e) => write!(f, "{}", e),
295			Self::IncompatibleArray(e) => write!(f, "{}", e),
296			Self::UnalignedArray(e) => write!(f, "{}", e),
297		}
298	}
299}
300
301impl std::fmt::Display for WrongObjectTypeError {
302	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
303		write!(f, "wrong object type: expected a numpy.ndarray, found {}", self.actual)
304	}
305}
306
307impl std::fmt::Display for IncompatibleArrayError {
308	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
309		write!(
310			f,
311			"incompatible array: expected ndarray(shape={}, dtype='{}'), found ndarray(shape={:?}, dtype={:?})",
312			self.expected_shape,
313			FormatDataType(&self.expected_dtype),
314			self.actual_shape,
315			self.actual_dtype,
316		)
317	}
318}
319
320impl std::fmt::Display for UnalignedArrayError {
321	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
322		write!(f, "the input array is not properly aligned for this platform")
323	}
324}
325
326/// Helper to format [`numpy::DataType`] more consistently.
327struct FormatDataType<'a>(&'a numpy::DataType);
328
329impl std::fmt::Display for FormatDataType<'_> {
330	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
331		let Self(dtype) = self;
332		match dtype {
333			numpy::DataType::Bool => write!(f, "bool"),
334			numpy::DataType::Complex32 => write!(f, "complex32"),
335			numpy::DataType::Complex64 => write!(f, "complex64"),
336			numpy::DataType::Float32 => write!(f, "float32"),
337			numpy::DataType::Float64 => write!(f, "float64"),
338			numpy::DataType::Int8 => write!(f, "int8"),
339			numpy::DataType::Int16 => write!(f, "int16"),
340			numpy::DataType::Int32 => write!(f, "int32"),
341			numpy::DataType::Int64 => write!(f, "int64"),
342			numpy::DataType::Object => write!(f, "object"),
343			numpy::DataType::Uint8 => write!(f, "uint8"),
344			numpy::DataType::Uint16 => write!(f, "uint16"),
345			numpy::DataType::Uint32 => write!(f, "uint32"),
346			numpy::DataType::Uint64 => write!(f, "uint64"),
347		}
348	}
349}
350
351impl std::error::Error for Error {}
352impl std::error::Error for WrongObjectTypeError {}
353impl std::error::Error for IncompatibleArrayError {}