use std::marker::PhantomData;
use std::ops::Deref;
use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
use pyo3::{
intern,
sync::PyOnceLock,
types::{PyAnyMethods, PyDict},
Borrowed, FromPyObject, Py, PyAny, PyErr, PyResult,
};
use crate::array::PyArrayMethods;
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
pub trait Coerce: Sealed {
const ALLOW_TYPE_CHANGE: bool;
}
mod sealed {
pub trait Sealed {}
}
use sealed::Sealed;
#[derive(Debug)]
pub struct TypeMustMatch;
impl Sealed for TypeMustMatch {}
impl Coerce for TypeMustMatch {
const ALLOW_TYPE_CHANGE: bool = false;
}
#[derive(Debug)]
pub struct AllowTypeChange;
impl Sealed for AllowTypeChange {}
impl Coerce for AllowTypeChange {
const ALLOW_TYPE_CHANGE: bool = true;
}
#[derive(Debug)]
#[repr(transparent)]
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
where
T: Element,
D: Dimension,
C: Coerce;
impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
{
type Target = PyReadonlyArray<'py, T, D>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a, 'py, T, D, C> FromPyObject<'a, 'py> for PyArrayLike<'py, T, D, C>
where
T: Element + 'py,
D: Dimension + 'py,
C: Coerce,
Vec<T>: FromPyObject<'a, 'py>,
{
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
if let Ok(array) = ob.cast::<PyArray<T, D>>() {
return Ok(Self(array.readonly(), PhantomData));
}
let py = ob.py();
if (C::ALLOW_TYPE_CHANGE || ob.cast::<PyUntypedArray>().is_err())
&& matches!(D::NDIM, None | Some(1))
{
if let Ok(vec) = ob.extract::<Vec<T>>() {
let array = Array1::from(vec)
.into_dimensionality()
.expect("D being compatible to Ix1")
.into_pyarray(py)
.readonly();
return Ok(Self(array, PhantomData));
}
}
static AS_ARRAY: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
let as_array = AS_ARRAY
.get_or_try_init(py, || {
get_array_module(py)?.getattr("asarray").map(Into::into)
})?
.bind(py);
let kwargs = if C::ALLOW_TYPE_CHANGE {
let kwargs = PyDict::new(py);
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
Some(kwargs)
} else {
None
};
let array = as_array.call((ob,), kwargs.as_ref())?.extract()?;
Ok(Self(array, PhantomData))
}
}
pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;