use ndarray::ShapeBuilder;
use crate::dimension::Dimension;
use crate::dtype::Element;
use crate::error::{FerrayError, FerrayResult};
use crate::layout::MemoryLayout;
pub struct Array<T: Element, D: Dimension> {
pub(crate) inner: ndarray::Array<T, D::NdarrayDim>,
pub(crate) dim: D,
}
impl<T: Element, D: Dimension> Array<T, D> {
pub(crate) fn from_ndarray(inner: ndarray::Array<T, D::NdarrayDim>) -> Self {
let dim = D::from_ndarray_dim(&inner.raw_dim());
Self { inner, dim }
}
pub(crate) fn into_ndarray(self) -> ndarray::Array<T, D::NdarrayDim> {
self.inner
}
pub fn from_elem(dim: D, elem: T) -> FerrayResult<Self> {
let nd_dim = dim.to_ndarray_dim();
let inner = ndarray::Array::from_elem(nd_dim, elem);
Ok(Self { inner, dim })
}
pub fn zeros(dim: D) -> FerrayResult<Self> {
Self::from_elem(dim, T::zero())
}
pub fn ones(dim: D) -> FerrayResult<Self> {
Self::from_elem(dim, T::one())
}
pub fn from_vec(dim: D, data: Vec<T>) -> FerrayResult<Self> {
let expected = dim.size();
if data.len() != expected {
return Err(FerrayError::shape_mismatch(format!(
"data length {} does not match shape {:?} (expected {})",
data.len(),
dim.as_slice(),
expected,
)));
}
let nd_dim = dim.to_ndarray_dim();
let inner = ndarray::Array::from_shape_vec(nd_dim, data)
.map_err(|e| FerrayError::shape_mismatch(format!("ndarray shape error: {e}")))?;
Ok(Self { inner, dim })
}
pub fn from_vec_f(dim: D, data: Vec<T>) -> FerrayResult<Self> {
let expected = dim.size();
if data.len() != expected {
return Err(FerrayError::shape_mismatch(format!(
"data length {} does not match shape {:?} (expected {})",
data.len(),
dim.as_slice(),
expected,
)));
}
let nd_dim = dim.to_ndarray_dim();
let inner = ndarray::Array::from_shape_vec(nd_dim.f(), data)
.map_err(|e| FerrayError::shape_mismatch(format!("ndarray shape error: {e}")))?;
let dim = D::from_ndarray_dim(&inner.raw_dim());
Ok(Self { inner, dim })
}
pub fn from_iter_1d(iter: impl IntoIterator<Item = T>) -> FerrayResult<Self>
where
D: Dimension<NdarrayDim = ndarray::Ix1>,
{
let inner = ndarray::Array::from_iter(iter);
let dim = D::from_ndarray_dim(&inner.raw_dim());
Ok(Self { inner, dim })
}
pub fn layout(&self) -> MemoryLayout {
if self.inner.is_standard_layout() {
MemoryLayout::C
} else {
let shape = self.dim.as_slice();
let strides = self.strides_isize();
crate::layout::detect_layout(shape, &strides)
}
}
pub(crate) fn strides_isize(&self) -> Vec<isize> {
self.inner.strides().to_vec()
}
#[inline]
pub fn ndim(&self) -> usize {
self.dim.ndim()
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.inner.shape()
}
#[inline]
pub fn strides(&self) -> &[isize] {
self.inner.strides()
}
#[inline]
pub fn size(&self) -> usize {
self.inner.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.inner.as_ptr()
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.inner.as_mut_ptr()
}
pub fn as_slice(&self) -> Option<&[T]> {
self.inner.as_slice()
}
pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
self.inner.as_slice_mut()
}
#[inline]
pub fn dim(&self) -> &D {
&self.dim
}
}
impl<T: Element, D: Dimension> From<ndarray::Array<T, D::NdarrayDim>> for Array<T, D> {
fn from(inner: ndarray::Array<T, D::NdarrayDim>) -> Self {
Self::from_ndarray(inner)
}
}
impl<T: Element, D: Dimension> From<Array<T, D>> for ndarray::Array<T, D::NdarrayDim> {
fn from(arr: Array<T, D>) -> Self {
arr.into_ndarray()
}
}
impl<T: Element, D: Dimension> Clone for Array<T, D> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
dim: self.dim.clone(),
}
}
}
impl<T: Element + PartialEq, D: Dimension> PartialEq for Array<T, D> {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl<T: Element + Eq, D: Dimension> Eq for Array<T, D> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::{Ix1, Ix2, IxDyn};
#[test]
fn create_zeros() {
let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
assert_eq!(arr.shape(), &[3, 4]);
assert_eq!(arr.size(), 12);
assert_eq!(arr.ndim(), 2);
assert!(!arr.is_empty());
}
#[test]
fn create_from_vec() {
let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
assert_eq!(arr.shape(), &[4]);
assert_eq!(arr.as_slice().unwrap(), &[1, 2, 3, 4]);
}
#[test]
fn create_from_vec_shape_mismatch() {
let res = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0]);
assert!(res.is_err());
}
#[test]
fn from_iter_1d() {
let arr = Array::<f64, Ix1>::from_iter_1d((0..5).map(|x| x as f64)).unwrap();
assert_eq!(arr.shape(), &[5]);
}
#[test]
fn layout_c_contiguous() {
let arr = Array::<f64, Ix2>::zeros(Ix2::new([3, 4])).unwrap();
assert_eq!(arr.layout(), MemoryLayout::C);
}
#[test]
fn from_vec_f_order() {
let arr =
Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
assert_eq!(arr.shape(), &[2, 3]);
assert_eq!(arr.layout(), MemoryLayout::Fortran);
}
#[test]
fn clone_array() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = a.clone();
assert_eq!(a, b);
}
#[test]
fn ndarray_roundtrip() {
let original = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), original.clone()).unwrap();
let nd: ndarray::Array<f64, ndarray::Ix2> = arr.into();
let arr2: Array<f64, Ix2> = nd.into();
assert_eq!(arr2.as_slice().unwrap(), &original[..]);
}
#[test]
fn dynamic_rank() {
let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0; 6]).unwrap();
assert_eq!(arr.ndim(), 2);
assert_eq!(arr.shape(), &[2, 3]);
}
}