use super::*;
use std::ops::{Index, IndexMut};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct SDynamic<const DIM: usize>(pub [usize; DIM]);
impl<const DIM: usize> Index<usize> for SDynamic<DIM> {
type Output = usize;
fn index(&self, idx: usize) -> &usize { self.0.index(idx) }
}
impl<const DIM: usize> IndexMut<usize> for SDynamic<DIM> {
fn index_mut(&mut self, idx: usize) -> &mut usize { self.0.index_mut(idx) }
}
impl<const DIM: usize> Shape for SDynamic<DIM> {
type Index = [usize; DIM];
type IndexIter = multi_product::MultiProduct<DIM>;
const DIM: usize = DIM;
fn contains(&self, ix: [usize; DIM]) -> bool {
ix.iter().zip(self.0.iter()).all(|(l, r)| l < r)
}
fn cardinality(&self) -> usize { self.0.iter().product() }
fn indices(&self) -> Self::IndexIter { multi_product::MultiProduct::new(self.0) }
}
impl<const DIM: usize> Broadcast for SDynamic<DIM> {
type Shape = SDynamic<DIM>;
#[inline]
fn broadcast(self, rhs: Self) -> Result<Self, IncompatibleShapes<Self>> {
if self == rhs {
Ok(self)
} else {
Err(IncompatibleShapes {
left: self,
right: rhs,
})
}
}
}
impl<const DIM: usize> Broadcast<S0> for SDynamic<DIM> {
type Shape = SDynamic<DIM>;
#[inline]
fn broadcast(self, _: S0) -> Result<Self, IncompatibleShapes<Self, S0>> { Ok(self) }
}
impl<const DIM: usize> Broadcast<SDynamic<DIM>> for S0 {
type Shape = SDynamic<DIM>;
#[inline]
fn broadcast(self, rhs: SDynamic<DIM>) -> Result<Self::Shape, IncompatibleShapes<S0, SDynamic<DIM>>> {
Ok(rhs)
}
}
impl<const DIM: usize> std::fmt::Display for SDynamic<DIM> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SDynamic({:?})", self.0)
}
}
macro_rules! impl_add_dim {
($n:literal + $m:literal) => {
impl Concat<SDynamic<$m>> for SDynamic<$n> {
type Shape = SDynamic<{$n + $m}>;
fn concat(self: SDynamic<$n>, rhs: SDynamic<$m>) -> Self::Shape {
SDynamic(concat_arrays!(self.0, rhs.0))
}
fn concat_indices(left: [usize; $n], rhs: [usize; $m]) -> IndexOf<Self::Shape> {
concat_arrays!(left, rhs)
}
}
};
([$($n:literal),*] + $m:literal) => {
$(impl_add_dim!($n + $m);)*
};
($ns:tt + [$($m:literal),*]) => {
$(impl_add_dim!($ns + $m);)*
}
}
impl_add_dim!([0, 1, 2, 3, 4, 5, 6] + [0, 1, 2, 3, 4, 5, 6]);