use std::borrow::Cow;
use std::ffi::{CStr, CString};
use std::ptr::null_mut;
use ndarray::{Dimension, IxDyn};
use pyo3::types::PyAnyMethods;
use pyo3::{Borrowed, Bound, FromPyObject, PyResult};
use crate::array::PyArray;
use crate::dtype::Element;
use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};
pub trait ArrayOrScalar<'a, 'py, T>: FromPyObject<'a, 'py> {}
impl<'a, 'py, T, D> ArrayOrScalar<'a, 'py, T> for Bound<'py, PyArray<T, D>>
where
T: Element + 'a,
D: Dimension + 'a,
{
}
impl<'a, 'py, T> ArrayOrScalar<'a, 'py, T> for T where T: Element + FromPyObject<'a, 'py> {}
pub fn inner<'py, T, DIN1, DIN2, OUT>(
array1: &Bound<'py, PyArray<T, DIN1>>,
array2: &Bound<'py, PyArray<T, DIN2>>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: for<'a> ArrayOrScalar<'a, 'py, T>,
{
let py = array1.py();
let obj = unsafe {
let result = PY_ARRAY_API.PyArray_InnerProduct(py, array1.as_ptr(), array2.as_ptr());
Bound::from_owned_ptr_or_err(py, result)?
};
obj.extract().map_err(Into::into)
}
pub fn dot<'py, T, DIN1, DIN2, OUT>(
array1: &Bound<'py, PyArray<T, DIN1>>,
array2: &Bound<'py, PyArray<T, DIN2>>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: for<'a> ArrayOrScalar<'a, 'py, T>,
{
let py = array1.py();
let obj = unsafe {
let result = PY_ARRAY_API.PyArray_MatrixProduct(py, array1.as_ptr(), array2.as_ptr());
Bound::from_owned_ptr_or_err(py, result)?
};
obj.extract().map_err(Into::into)
}
pub fn einsum<'py, T, OUT>(
subscripts: &str,
arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
) -> PyResult<OUT>
where
T: Element,
OUT: for<'a> ArrayOrScalar<'a, 'py, T>,
{
let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
Ok(subscripts) => Cow::Borrowed(subscripts),
Err(_) => Cow::Owned(CString::new(subscripts).expect("Operation failed")),
};
let py = arrays[0].py();
let obj = unsafe {
let result = PY_ARRAY_API.PyArray_EinsteinSum(
py,
subscripts.as_ptr() as _,
arrays.len() as _,
arrays.as_ptr() as _,
null_mut(),
NPY_ORDER::NPY_KEEPORDER,
NPY_CASTING::NPY_NO_CASTING,
null_mut(),
);
Bound::from_owned_ptr_or_err(py, result)?
};
obj.extract().map_err(Into::into)
}
#[macro_export]
macro_rules! einsum {
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
let arrays = [$($array.to_dyn().as_borrowed(),)+];
$crate::einsum(concat!($subscripts, "\0"), &arrays)
}};
}