mod shared;
use std::any::type_name;
use std::fmt;
use std::ops::Deref;
use ndarray::{
ArrayView, ArrayViewMut, Dimension, IntoDimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn,
};
use pyo3::{Borrowed, Bound, CastError, FromPyObject, PyAny, PyResult};
use crate::array::{PyArray, PyArrayMethods};
use crate::convert::NpyIndex;
use crate::dtype::Element;
use crate::error::{BorrowError, NotContiguousError};
use crate::npyffi::flags;
use crate::untyped_array::PyUntypedArrayMethods;
use shared::{acquire, acquire_mut, release, release_mut};
#[repr(transparent)]
pub struct PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
array: Bound<'py, PyArray<T, D>>,
}
pub type PyReadonlyArray0<'py, T> = PyReadonlyArray<'py, T, Ix0>;
pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>;
pub type PyReadonlyArray2<'py, T> = PyReadonlyArray<'py, T, Ix2>;
pub type PyReadonlyArray3<'py, T> = PyReadonlyArray<'py, T, Ix3>;
pub type PyReadonlyArray4<'py, T> = PyReadonlyArray<'py, T, Ix4>;
pub type PyReadonlyArray5<'py, T> = PyReadonlyArray<'py, T, Ix5>;
pub type PyReadonlyArray6<'py, T> = PyReadonlyArray<'py, T, Ix6>;
pub type PyReadonlyArrayDyn<'py, T> = PyReadonlyArray<'py, T, IxDyn>;
impl<'py, T, D> Deref for PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
type Target = Bound<'py, PyArray<T, D>>;
fn deref(&self) -> &Self::Target {
&self.array
}
}
impl<'a, 'py, T: Element + 'a, D: Dimension + 'a> FromPyObject<'a, 'py>
for PyReadonlyArray<'py, T, D>
{
type Error = CastError<'a, 'py>;
fn extract(obj: Borrowed<'a, 'py, PyAny>) -> Result<Self, Self::Error> {
let array = obj.cast::<PyArray<T, D>>()?;
Ok(array.readonly())
}
}
impl<'py, T, D> PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
pub(crate) fn try_new(array: Bound<'py, PyArray<T, D>>) -> Result<Self, BorrowError> {
acquire(array.py(), array.as_array_ptr())?;
Ok(Self { array })
}
#[inline(always)]
pub fn as_array(&self) -> ArrayView<'_, T, D> {
unsafe { self.array.as_array() }
}
#[inline(always)]
pub fn as_slice(&self) -> Result<&[T], NotContiguousError> {
unsafe { self.array.as_slice() }
}
#[inline(always)]
pub fn get<I>(&self, index: I) -> Option<&T>
where
I: NpyIndex<Dim = D>,
{
unsafe { self.array.get(index) }
}
}
#[cfg(feature = "nalgebra")]
impl<'py, N, D> PyReadonlyArray<'py, N, D>
where
N: nalgebra::Scalar + Element,
D: Dimension,
{
#[doc(alias = "nalgebra")]
pub fn try_as_matrix<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixView<'_, N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
RStride: nalgebra::Dim,
CStride: nalgebra::Dim,
{
unsafe { self.array.try_as_matrix() }
}
}
#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadonlyArray<'py, N, Ix1>
where
N: nalgebra::Scalar + Element,
{
#[doc(alias = "nalgebra")]
pub fn as_matrix(&self) -> nalgebra::DMatrixView<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix().expect("Operation failed")
}
}
#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadonlyArray<'py, N, Ix2>
where
N: nalgebra::Scalar + Element,
{
#[doc(alias = "nalgebra")]
pub fn as_matrix(&self) -> nalgebra::DMatrixView<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix().expect("Operation failed")
}
}
impl<'py, T, D> Clone for PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
fn clone(&self) -> Self {
acquire(self.array.py(), self.array.as_array_ptr()).expect("Operation failed");
Self {
array: self.array.clone(),
}
}
}
impl<'py, T, D> Drop for PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
fn drop(&mut self) {
release(self.array.py(), self.array.as_array_ptr());
}
}
impl<'py, T, D> fmt::Debug for PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = format!(
"PyReadonlyArray<{}, {}>",
type_name::<T>(),
type_name::<D>()
);
f.debug_struct(&name).finish()
}
}
#[repr(transparent)]
pub struct PyReadwriteArray<'py, T, D>
where
T: Element,
D: Dimension,
{
array: Bound<'py, PyArray<T, D>>,
}
pub type PyReadwriteArray0<'py, T> = PyReadwriteArray<'py, T, Ix0>;
pub type PyReadwriteArray1<'py, T> = PyReadwriteArray<'py, T, Ix1>;
pub type PyReadwriteArray2<'py, T> = PyReadwriteArray<'py, T, Ix2>;
pub type PyReadwriteArray3<'py, T> = PyReadwriteArray<'py, T, Ix3>;
pub type PyReadwriteArray4<'py, T> = PyReadwriteArray<'py, T, Ix4>;
pub type PyReadwriteArray5<'py, T> = PyReadwriteArray<'py, T, Ix5>;
pub type PyReadwriteArray6<'py, T> = PyReadwriteArray<'py, T, Ix6>;
pub type PyReadwriteArrayDyn<'py, T> = PyReadwriteArray<'py, T, IxDyn>;
impl<'py, T, D> Deref for PyReadwriteArray<'py, T, D>
where
T: Element,
D: Dimension,
{
type Target = PyReadonlyArray<'py, T, D>;
fn deref(&self) -> &Self::Target {
unsafe { &*(self as *const Self as *const Self::Target) }
}
}
impl<'py, T, D> From<PyReadwriteArray<'py, T, D>> for PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
fn from(value: PyReadwriteArray<'py, T, D>) -> Self {
let array = value.array.clone();
::std::mem::drop(value);
Self::try_new(array)
.expect("releasing an exclusive reference should immediately permit a shared reference")
}
}
impl<'a, 'py, T: Element + 'a, D: Dimension + 'a> FromPyObject<'a, 'py>
for PyReadwriteArray<'py, T, D>
{
type Error = CastError<'a, 'py>;
fn extract(obj: Borrowed<'a, 'py, PyAny>) -> Result<Self, Self::Error> {
let array = obj.cast::<PyArray<T, D>>()?;
Ok(array.readwrite())
}
}
impl<'py, T, D> PyReadwriteArray<'py, T, D>
where
T: Element,
D: Dimension,
{
pub(crate) fn try_new(array: Bound<'py, PyArray<T, D>>) -> Result<Self, BorrowError> {
acquire_mut(array.py(), array.as_array_ptr())?;
Ok(Self { array })
}
#[inline(always)]
pub fn as_array_mut(&mut self) -> ArrayViewMut<'_, T, D> {
unsafe { self.array.as_array_mut() }
}
#[inline(always)]
pub fn as_slice_mut(&mut self) -> Result<&mut [T], NotContiguousError> {
unsafe { self.array.as_slice_mut() }
}
#[inline(always)]
pub fn get_mut<I>(&mut self, index: I) -> Option<&mut T>
where
I: NpyIndex<Dim = D>,
{
unsafe { self.array.get_mut(index) }
}
pub fn make_nonwriteable(self) -> PyReadonlyArray<'py, T, D> {
unsafe {
(*self.as_array_ptr()).flags &= !flags::NPY_ARRAY_WRITEABLE;
}
self.into()
}
}
#[cfg(feature = "nalgebra")]
impl<'py, N, D> PyReadwriteArray<'py, N, D>
where
N: nalgebra::Scalar + Element,
D: Dimension,
{
#[doc(alias = "nalgebra")]
pub fn try_as_matrix_mut<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixViewMut<'_, N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
RStride: nalgebra::Dim,
CStride: nalgebra::Dim,
{
unsafe { self.array.try_as_matrix_mut() }
}
}
#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadwriteArray<'py, N, Ix1>
where
N: nalgebra::Scalar + Element,
{
#[doc(alias = "nalgebra")]
pub fn as_matrix_mut(&self) -> nalgebra::DMatrixViewMut<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix_mut().expect("Operation failed")
}
}
#[cfg(feature = "nalgebra")]
impl<'py, N> PyReadwriteArray<'py, N, Ix2>
where
N: nalgebra::Scalar + Element,
{
#[doc(alias = "nalgebra")]
pub fn as_matrix_mut(&self) -> nalgebra::DMatrixViewMut<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix_mut().expect("Operation failed")
}
}
impl<'py, T> PyReadwriteArray<'py, T, Ix1>
where
T: Element,
{
pub fn resize<ID: IntoDimension>(self, dims: ID) -> PyResult<Self> {
unsafe {
self.array.resize(dims)?;
}
let py = self.array.py();
let ptr = self.array.as_array_ptr();
release_mut(py, ptr);
acquire_mut(py, ptr).expect("Operation failed");
Ok(self)
}
}
impl<'py, T, D> Drop for PyReadwriteArray<'py, T, D>
where
T: Element,
D: Dimension,
{
fn drop(&mut self) {
release_mut(self.array.py(), self.array.as_array_ptr());
}
}
impl<'py, T, D> fmt::Debug for PyReadwriteArray<'py, T, D>
where
T: Element,
D: Dimension,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = format!(
"PyReadwriteArray<{}, {}>",
type_name::<T>(),
type_name::<D>()
);
f.debug_struct(&name).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use pyo3::{types::IntoPyDict, Python};
use crate::array::PyArray1;
use pyo3::ffi::c_str;
#[test]
fn test_debug_formatting() {
Python::attach(|py| {
let array = PyArray::<f64, _>::zeros(py, (1, 2, 3), false);
{
let shared = array.readonly();
assert_eq!(
format!("{shared:?}"),
"PyReadonlyArray<f64, ndarray::dimension::dim::Dim<[usize; 3]>>"
);
}
{
let exclusive = array.readwrite();
assert_eq!(
format!("{exclusive:?}"),
"PyReadwriteArray<f64, ndarray::dimension::dim::Dim<[usize; 3]>>"
);
}
});
}
#[test]
#[should_panic(expected = "AlreadyBorrowed")]
fn cannot_clone_exclusive_borrow_via_deref() {
Python::attach(|py| {
let array = PyArray::<f64, _>::zeros(py, (3, 2, 1), false);
let exclusive = array.readwrite();
let _shared = exclusive.clone();
});
}
#[test]
fn failed_resize_does_not_double_release() {
Python::attach(|py| {
let array = PyArray::<f64, _>::zeros(py, 10, false);
let locals = [("array", &array)]
.into_py_dict(py)
.expect("Operation failed");
let _view = py
.eval(c_str!("array[:]"), None, Some(&locals))
.expect("Operation failed")
.cast_into::<PyArray1<f64>>()
.expect("Operation failed");
let exclusive = array.readwrite();
assert!(exclusive.resize(100).is_err());
});
}
#[test]
fn ineffective_resize_does_not_conflict() {
Python::attach(|py| {
let array = PyArray::<f64, _>::zeros(py, 10, false);
let exclusive = array.readwrite();
assert!(exclusive.resize(10).is_ok());
});
}
}