use std::borrow::Cow;
use std::ffi::{CStr, CString};
use std::ptr::null_mut;
use ndarray::{Dimension, IxDyn};
use pyo3::{AsPyPointer, FromPyObject, FromPyPointer, PyAny, PyNativeType, PyResult};
use crate::array::PyArray;
use crate::dtype::Element;
use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};
pub trait ArrayOrScalar<'py, T>: FromPyObject<'py> {}
impl<'py, T, D> ArrayOrScalar<'py, T> for &'py PyArray<T, D>
where
T: Element,
D: Dimension,
{
}
impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
pub fn inner<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: ArrayOrScalar<'py, T>,
{
let py = array1.py();
let obj = unsafe {
let result = PY_ARRAY_API.PyArray_InnerProduct(py, array1.as_ptr(), array2.as_ptr());
PyAny::from_owned_ptr_or_err(py, result)?
};
obj.extract()
}
pub fn dot<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: ArrayOrScalar<'py, T>,
{
let py = array1.py();
let obj = unsafe {
let result = PY_ARRAY_API.PyArray_MatrixProduct(py, array1.as_ptr(), array2.as_ptr());
PyAny::from_owned_ptr_or_err(py, result)?
};
obj.extract()
}
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
where
T: Element,
OUT: ArrayOrScalar<'py, T>,
{
let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
Ok(subscripts) => Cow::Borrowed(subscripts),
Err(_) => Cow::Owned(CString::new(subscripts).unwrap()),
};
let py = arrays[0].py();
let obj = unsafe {
let result = PY_ARRAY_API.PyArray_EinsteinSum(
py,
subscripts.as_ptr() as _,
arrays.len() as _,
arrays.as_ptr() as _,
null_mut(),
NPY_ORDER::NPY_KEEPORDER,
NPY_CASTING::NPY_NO_CASTING,
null_mut(),
);
PyAny::from_owned_ptr_or_err(py, result)?
};
obj.extract()
}
#[macro_export]
macro_rules! einsum {
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
let arrays = [$($array.to_dyn(),)+];
$crate::einsum(concat!($subscripts, "\0"), &arrays)
}};
}