#[cfg(feature = "std")]
pub mod broadcast;
#[cfg(all(feature = "const_shapes", feature = "std"))]
pub mod static_shape;
use core::fmt;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(feature = "std")]
use ndarray::Dimension as NdDimension;
pub trait Dimension: Clone + PartialEq + Eq + fmt::Debug + Send + Sync + 'static {
const NDIM: Option<usize>;
#[doc(hidden)]
#[cfg(feature = "std")]
type NdarrayDim: ndarray::Dimension;
type Smaller: Dimension;
type Larger: Dimension;
fn as_slice(&self) -> &[usize];
fn as_slice_mut(&mut self) -> &mut [usize];
fn ndim(&self) -> usize {
self.as_slice().len()
}
fn size(&self) -> usize {
self.as_slice().iter().product()
}
#[doc(hidden)]
#[cfg(feature = "std")]
fn to_ndarray_dim(&self) -> Self::NdarrayDim;
#[doc(hidden)]
#[cfg(feature = "std")]
fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self;
fn from_dim_slice(shape: &[usize]) -> Option<Self>;
}
macro_rules! impl_fixed_dimension {
($name:ident, $n:expr, $ndarray_ty:ty, $smaller:ty, $larger:ty) => {
#[doc = concat!(stringify!($n), " axes.")]
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct $name {
shape: [usize; $n],
}
impl $name {
#[inline]
pub const fn new(shape: [usize; $n]) -> Self {
Self { shape }
}
}
impl fmt::Debug for $name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", &self.shape[..])
}
}
impl From<[usize; $n]> for $name {
#[inline]
fn from(shape: [usize; $n]) -> Self {
Self::new(shape)
}
}
impl Dimension for $name {
const NDIM: Option<usize> = Some($n);
#[cfg(feature = "std")]
type NdarrayDim = $ndarray_ty;
type Smaller = $smaller;
type Larger = $larger;
#[inline]
fn as_slice(&self) -> &[usize] {
&self.shape
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [usize] {
&mut self.shape
}
#[cfg(feature = "std")]
fn to_ndarray_dim(&self) -> Self::NdarrayDim {
ndarray::Dim(self.shape)
}
#[cfg(feature = "std")]
fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
let view = dim.as_array_view();
let s = view.as_slice().expect("ndarray dim should be contiguous");
let mut shape = [0usize; $n];
shape.copy_from_slice(s);
Self { shape }
}
fn from_dim_slice(shape: &[usize]) -> Option<Self> {
if shape.len() != $n {
return None;
}
let mut arr = [0usize; $n];
arr.copy_from_slice(shape);
Some(Self { shape: arr })
}
}
};
}
impl_fixed_dimension!(Ix1, 1, ndarray::Ix1, Ix0, Ix2);
impl_fixed_dimension!(Ix2, 2, ndarray::Ix2, Ix1, Ix3);
impl_fixed_dimension!(Ix3, 3, ndarray::Ix3, Ix2, Ix4);
impl_fixed_dimension!(Ix4, 4, ndarray::Ix4, Ix3, Ix5);
impl_fixed_dimension!(Ix5, 5, ndarray::Ix5, Ix4, Ix6);
impl_fixed_dimension!(Ix6, 6, ndarray::Ix6, Ix5, IxDyn);
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Ix0;
impl fmt::Debug for Ix0 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[]")
}
}
impl Dimension for Ix0 {
const NDIM: Option<usize> = Some(0);
#[cfg(feature = "std")]
type NdarrayDim = ndarray::Ix0;
type Smaller = Ix0;
type Larger = Ix1;
#[inline]
fn as_slice(&self) -> &[usize] {
&[]
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [usize] {
&mut []
}
#[cfg(feature = "std")]
fn to_ndarray_dim(&self) -> Self::NdarrayDim {
ndarray::Dim(())
}
#[cfg(feature = "std")]
fn from_ndarray_dim(_dim: &Self::NdarrayDim) -> Self {
Self
}
fn from_dim_slice(shape: &[usize]) -> Option<Self> {
if shape.is_empty() { Some(Self) } else { None }
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct IxDyn {
shape: Vec<usize>,
}
impl IxDyn {
#[must_use]
pub fn new(shape: &[usize]) -> Self {
Self {
shape: shape.to_vec(),
}
}
}
impl fmt::Debug for IxDyn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", &self.shape[..])
}
}
impl From<Vec<usize>> for IxDyn {
fn from(shape: Vec<usize>) -> Self {
Self { shape }
}
}
impl From<&[usize]> for IxDyn {
fn from(shape: &[usize]) -> Self {
Self::new(shape)
}
}
impl Dimension for IxDyn {
const NDIM: Option<usize> = None;
#[cfg(feature = "std")]
type NdarrayDim = ndarray::IxDyn;
type Smaller = IxDyn;
type Larger = IxDyn;
#[inline]
fn as_slice(&self) -> &[usize] {
&self.shape
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [usize] {
&mut self.shape
}
#[cfg(feature = "std")]
fn to_ndarray_dim(&self) -> Self::NdarrayDim {
ndarray::IxDyn(&self.shape)
}
#[cfg(feature = "std")]
fn from_ndarray_dim(dim: &Self::NdarrayDim) -> Self {
let view = dim.as_array_view();
let s = view.as_slice().expect("ndarray IxDyn should be contiguous");
Self { shape: s.to_vec() }
}
fn from_dim_slice(shape: &[usize]) -> Option<Self> {
Some(Self {
shape: shape.to_vec(),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Axis(pub usize);
impl Axis {
#[inline]
#[must_use]
pub const fn index(self) -> usize {
self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ix1_basics() {
let d = Ix1::new([5]);
assert_eq!(d.ndim(), 1);
assert_eq!(d.size(), 5);
assert_eq!(d.as_slice(), &[5]);
}
#[test]
fn ix2_basics() {
let d = Ix2::new([3, 4]);
assert_eq!(d.ndim(), 2);
assert_eq!(d.size(), 12);
}
#[test]
fn ix0_basics() {
let d = Ix0;
assert_eq!(d.ndim(), 0);
assert_eq!(d.size(), 1);
}
#[test]
fn ixdyn_basics() {
let d = IxDyn::new(&[2, 3, 4]);
assert_eq!(d.ndim(), 3);
assert_eq!(d.size(), 24);
}
#[test]
fn roundtrip_ix2_ndarray() {
let d = Ix2::new([3, 7]);
let nd = d.to_ndarray_dim();
let d2 = Ix2::from_ndarray_dim(&nd);
assert_eq!(d, d2);
}
#[test]
fn roundtrip_ixdyn_ndarray() {
let d = IxDyn::new(&[2, 5, 3]);
let nd = d.to_ndarray_dim();
let d2 = IxDyn::from_ndarray_dim(&nd);
assert_eq!(d, d2);
}
#[test]
fn axis_index() {
let a = Axis(2);
assert_eq!(a.index(), 2);
}
}