use crate::dimutils;
pub trait TensorDimension {
const DIM : usize;
const ISSCALAR : bool = Self::DIM == 0;
fn is_scalar(&self) -> bool {Self::ISSCALAR}
}
pub trait TensorSize<const DIM:usize> : TensorDimension {
fn size(&self) -> [usize;DIM];
fn inbounds(&self,index : [usize;DIM]) -> bool {
index.iter().zip(self.size().iter()).all(|(i,s)| i < s)
}
const VALIDDIMS : bool = DIM == Self::DIM;
}
pub trait Broadcastable<const DIM : usize> : TensorSize<DIM> + Sized {
type Element;
fn bget(&self, index:[usize;DIM]) -> Option<Self::Element>;
fn mod_bget(&self,index:[isize;DIM]) -> Self::Element {
self.bget(dimutils::modular_index(index,self.size())).expect("Broken contract in Broadcast impl. Modular access out of bounds.")
}
fn feedto(&self,receiver :&mut impl BroadcastReceiver<DIM,Element=Self::Element>) -> Option<()>{
receiver.receive(self)
}
fn lazy_updim<const NEWDIM : usize>(&self, size : [usize;NEWDIM] ) -> LazyUpdim<Self,DIM,NEWDIM>
{
assert!(Self::VALIDDIMS, "Invalid dimensions entered in Trait implementation.");
assert!(NEWDIM >= DIM, "Updimmed tensor cannot have fewer indices than the initial one.");
LazyUpdim {size,reference:&self}
}
fn reshaped<const NEWDIM : usize>(self,size: [usize;NEWDIM]) -> ReShaped<Self,NEWDIM,DIM> {
ReShaped {underlying : self,size}
}
fn mapindex<F:Fn([usize;M],[usize;M]) -> [usize;DIM],const M : usize>(&self,indexclosure:F,sizeclosure : impl Fn([usize;DIM]) -> [usize;M] )
-> MapIndex<Self,F,DIM,M>
{
MapIndex {reference: self,indexclosure,size:sizeclosure(self.size())}
}
fn bmap<T,F :Fn(Self::Element) -> T>(&self,foo : F) -> BMap<T,Self,F,DIM>{
assert!(Self::VALIDDIMS, "Invalid dimensions entered in Trait implementation.");
BMap {reference:self,closure : foo}
}
fn bcloned(self) -> BCloned<Self,DIM> {
BCloned {underlying:self}
}
fn broadcast2<'a,'b,T : Broadcastable<N>,const N : usize >(&'a self,foo: &'b T)
-> Broadcast2<LazyUpdim<'a,Self,DIM,{dimutils::cmax(DIM,N)}>,LazyUpdim<'b,T,N,{dimutils::cmax(DIM,N)}>,{dimutils::cmax(DIM,N)}>
{
let commondims : [usize;dimutils::cmax(DIM,N)] = dimutils::commondims(self.size(),foo.size()).expect("F error handling");
Broadcast2 {first : self.lazy_updim(commondims),second : foo.lazy_updim(commondims)}
}
fn bc_iter(&self) -> BroadcastIterator<Self,DIM> {
BroadcastIterator {reference:self,state : dimutils::ndim_iterator(self.size())}
}
}
pub trait BroadcastReceiver<const DIM : usize> : TensorSize<DIM> + Sized {
type Element;
fn bget_mut<'a>(&'a mut self, index: [usize;DIM]) -> Option<&'a mut Self::Element>;
fn receive(&mut self, broadcast: &impl Broadcastable<DIM,Element=Self::Element>) -> Option<()>;
unsafe fn bc_iter_mut<'a>(&'a mut self) -> BroadcastIterMut<'a,Self,DIM> {
BroadcastIterMut {reference : self.into() ,state: dimutils::ndim_iterator(self.size()),_marker: std::marker::PhantomData}
}
}
pub struct BroadcastIterator<'a,T : Broadcastable<DIM>,const DIM : usize> {
reference : &'a T,
state : dimutils::NDimIteratorHelper<DIM>,
}
impl<'a,T : Broadcastable<DIM>,const DIM : usize> Iterator for BroadcastIterator<'a,T,DIM> {
type Item = T::Element;
fn next(&mut self) -> Option<T::Element> {
let index = self.state.next()?;
self.reference.bget(index)
}
fn size_hint(&self) -> (usize,Option<usize>){
self.state.size_hint()
}
}
pub struct BroadcastIterMut<'a,T : BroadcastReceiver<DIM>,const DIM : usize> {
reference : std::ptr::NonNull<T>,
state :dimutils::NDimIteratorHelper<DIM>,
_marker: std::marker::PhantomData<&'a mut T>,
}
impl<'a,T : BroadcastReceiver<DIM>,const DIM : usize> Iterator for BroadcastIterMut<'a,T,DIM> {
type Item = &'a mut T::Element;
fn next(&mut self) -> Option< &'a mut T::Element> {
let index = self.state.next()?;
let castreference = unsafe {self.reference.as_mut::<'a>()};
castreference.bget_mut(index)
}
fn size_hint(&self) -> (usize,Option<usize>){
self.state.size_hint()
}
}
pub struct ReShaped<T:Broadcastable<M>,const N : usize,const M : usize> {
underlying : T,
size : [usize;N]
}
impl<T:Broadcastable<M>,const N : usize,const M : usize> TensorDimension for ReShaped<T,N,M> {
const DIM : usize = T::DIM;
}
impl<T:Broadcastable<M>,const N : usize,const M : usize> TensorSize<N> for ReShaped<T,N,M> {
fn size(&self) -> [usize;N] {self.size}
}
impl<T:Broadcastable<M>,const N : usize,const M : usize> Broadcastable<N> for ReShaped<T,N,M> {
type Element = T::Element;
fn bget(&self,index:[usize;N])-> Option<T::Element> {
if !self.inbounds(index) {return None}
let linearindex = dimutils::linearize_index(index,self.size());
let innerindex = dimutils::delinearize_index(linearindex,self.underlying.size());
self.underlying.bget(innerindex)
}
}
impl<'a,T:Broadcastable<M>,const N : usize,const M : usize> IntoIterator for &'a ReShaped<T,N,M> {
type Item = T::Element;
type IntoIter = BroadcastIterator<'a,ReShaped<T,N,M>,N>;
fn into_iter(self) -> Self::IntoIter {
self.bc_iter()
}
}
pub struct BCloned<T:Broadcastable<N>,const N : usize> {
underlying : T,
}
impl<'a,T:Broadcastable<N>,const N : usize> TensorDimension for BCloned<T,N> {
const DIM : usize = T::DIM;
}
impl<'a,T:Broadcastable<N>,const N : usize> TensorSize<N> for BCloned<T,N> {
fn size(&self) -> [usize;N] {self.underlying.size()}
}
impl<'b,T:Broadcastable<N,Element=&'b E>,E : 'b + Clone,const N : usize> Broadcastable<N> for BCloned<T,N>
{
type Element = E;
fn bget(&self,index:[usize;N]) -> Option<Self::Element> {
self.underlying.bget(index).cloned()
}
}
impl<'a,'b,T:Broadcastable<N,Element=&'b E>,E : Clone + 'a + 'b,const N : usize> IntoIterator for &'a BCloned<T,N> {
type Item = E;
type IntoIter = BroadcastIterator<'a,BCloned<T,N>,N>;
fn into_iter(self) -> Self::IntoIter {
self.bc_iter()
}
}
pub struct MapIndex<'a,T:Broadcastable<N>, F : Fn([usize;M],[usize;M]) -> [usize;N] ,const N : usize,const M : usize> {
reference : &'a T,
indexclosure : F,
size : [usize;M],
}
impl<'a,T:Broadcastable<N>, F : Fn([usize;M],[usize;M]) -> [usize;N] ,const N : usize,const M : usize> TensorDimension for MapIndex<'a,T,F,N,M> {
const DIM : usize = M;
}
impl<'a,T:Broadcastable<N>, F : Fn([usize;M],[usize;M]) -> [usize;N] ,const N : usize,const M : usize> TensorSize<M> for MapIndex<'a,T,F,N,M> {
fn size(&self) -> [usize;M] {self.size}
}
impl<'a,T:Broadcastable<N>, F : Fn([usize;M],[usize;M]) -> [usize;N] ,const N : usize,const M : usize> Broadcastable<M> for MapIndex<'a,T,F,N,M> {
type Element = T::Element;
fn bget(&self,index : [usize;M]) -> Option<T::Element> {
if !self.inbounds(index) {return None};
let size = self.size;
let indexclosure = &self.indexclosure;
let newindex : [usize;N] = indexclosure(index,size);
self.reference.bget(newindex)
}
}
impl<'b,'a,T:Broadcastable<N>, F : Fn([usize;M],[usize;M]) -> [usize;N] ,const N : usize,const M : usize> IntoIterator for &'b MapIndex<'a,T,F,N,M> {
type Item = T::Element;
type IntoIter = BroadcastIterator<'b,MapIndex<'a,T,F,N,M>,M>;
fn into_iter(self) -> Self::IntoIter {
self.bc_iter()
}
}
impl<'a,T:Broadcastable<N>, F : Fn([usize;M],[usize;M]) -> [usize;N] ,const N : usize,const M : usize> MapIndex<'a,T,F,N,M> {
pub fn iter(&self) -> BroadcastIterator<MapIndex<'a,T,F,N,M>,M> {
self.into_iter()
}
}
pub struct Broadcast2<A : Broadcastable<N>,B : Broadcastable<N>, const N : usize> {
first: A,
second: B
}
impl<A : Broadcastable<N>,B : Broadcastable<N>, const N : usize> TensorDimension for Broadcast2<A,B,N> {
const DIM : usize = N;
}
impl<A : Broadcastable<N>,B : Broadcastable<N>, const N : usize> TensorSize<N> for Broadcast2<A,B,N> {
fn size(&self) -> [usize;N] {
self.first.size()
}
}
impl<A : Broadcastable<N>,B : Broadcastable<N>, const N : usize> Broadcastable<N> for Broadcast2<A,B,N> {
type Element = (A::Element,B::Element);
fn bget(&self, index:[usize;N]) -> Option<Self::Element> {
Some((self.first.bget(index)?,self.second.bget(index)?))
}
}
impl<'b,A : Broadcastable<N>,B : Broadcastable<N>, const N : usize> IntoIterator for &'b Broadcast2<A,B,N> {
type Item = (A::Element,B::Element);
type IntoIter = BroadcastIterator<'b, Broadcast2<A,B,N> ,N>;
fn into_iter(self) -> Self::IntoIter {
self.bc_iter()
}
}
impl<A : Broadcastable<N>,B : Broadcastable<N>, const N : usize> Broadcast2<A,B,N> {
pub fn iter(&self) -> BroadcastIterator<Broadcast2<A,B,N> ,N> {
self.bc_iter()
}
}
pub struct LazyUpdim<'a,T : Broadcastable<OLDDIM>,const OLDDIM : usize, const DIM : usize> {
size : [usize;DIM],
reference : &'a T
}
impl<'a,T : Broadcastable<OLDDIM>,const OLDDIM : usize,const DIM : usize> TensorDimension for LazyUpdim<'a,T,OLDDIM,DIM> {
const DIM : usize = DIM;
}
impl<'a,T : Broadcastable<OLDDIM>,const OLDDIM : usize,const DIM : usize> TensorSize<DIM> for LazyUpdim<'a,T,OLDDIM,DIM> {
fn size(&self) -> [usize;DIM] {self.size}
}
impl<'a,T : Broadcastable<OLDDIM>,const OLDDIM : usize,const DIM : usize> Broadcastable<DIM> for LazyUpdim<'a,T,OLDDIM,DIM> {
type Element = T::Element;
fn bget(&self,index:[usize;DIM]) -> Option<Self::Element> {
assert!(DIM >= OLDDIM);
if !self.inbounds(index) {return None}
let size = self.size();
let newindex : [usize;OLDDIM] = array_init::array_init(|i| if size[i] > 1 {index[i]} else {0});
self.reference.bget(newindex)
}
}
impl<'b,'a,T : Broadcastable<OLDDIM>,const OLDDIM : usize,const DIM : usize> IntoIterator for &'b LazyUpdim<'a,T,OLDDIM,DIM> {
type Item = T::Element;
type IntoIter = BroadcastIterator<'b,LazyUpdim<'a,T,OLDDIM,DIM>,DIM>;
fn into_iter(self) -> Self::IntoIter {
self.bc_iter()
}
}
impl<'a,T : Broadcastable<OLDDIM>,const OLDDIM : usize,const DIM : usize> LazyUpdim<'a,T,OLDDIM,DIM> {
pub fn iter(&self) -> BroadcastIterator<LazyUpdim<'a,T,OLDDIM,DIM>,DIM> {
self.bc_iter()
}
}
pub struct BMap<'a,R, T : Broadcastable<DIM>, F : Fn(T::Element) -> R , const DIM: usize> {
reference : &'a T,
closure : F
}
impl<'a,R, T : Broadcastable<DIM>, F : Fn(T::Element) -> R , const DIM: usize> TensorDimension for BMap<'a,R,T,F,DIM> {
const DIM : usize = DIM;
}
impl<'a,R, T : Broadcastable<DIM>, F : Fn(T::Element) -> R , const DIM: usize> TensorSize<DIM> for BMap<'a,R,T,F,DIM> {
fn size(&self) -> [usize;DIM] {self.reference.size()}
}
impl<'a,R, T : Broadcastable<DIM>, F : Fn(T::Element) -> R , const DIM: usize> Broadcastable<DIM> for BMap<'a,R,T,F,DIM> {
type Element = R;
fn bget(&self,index:[usize;DIM]) -> Option<Self::Element> {
self.reference.bget(index).map(&self.closure)
}
}
impl<'b,'a,R, T : Broadcastable<DIM>, F : Fn(T::Element) -> R , const DIM: usize> IntoIterator for &'b BMap<'a,R,T,F,DIM> {
type Item = R;
type IntoIter = BroadcastIterator<'b,BMap<'a,R,T,F,DIM>,DIM>;
fn into_iter(self) -> Self::IntoIter {
self.bc_iter()
}
}
impl<T> TensorDimension for Vec<T> {
const DIM : usize = 1;
}
impl<T> TensorSize<1> for Vec<T> {
fn size(&self) -> [usize;1] {[self.len()]}
}
impl<'a,T> TensorDimension for &'a Vec<T> {
const DIM : usize = 1;
}
impl<'a, T> TensorSize<1> for &'a Vec<T> {
fn size(&self) -> [usize;1] {[self.len()]}
}
impl<'a,T> TensorDimension for &'a mut Vec<T> {
const DIM : usize = 1;
}
impl<'a, T> TensorSize<1> for &'a mut Vec<T> {
fn size(&self) -> [usize;1] {[self.len()]}
}
impl<'a,T> Broadcastable<1> for &'a Vec<T> {
type Element = &'a T;
fn bget(&self,index : [usize;1]) -> Option<&'a T> {
self.get(index[0])
}
}
impl<T: Copy> Broadcastable<1> for Vec<T> {
type Element = T;
fn bget(&self,index : [usize;1]) -> Option<T> {
self.get(index[0]).cloned()
}
}
impl<'a,T> BroadcastReceiver<1> for &'a mut Vec<T> {
type Element = T;
fn bget_mut(&mut self,[index] : [usize;1]) -> Option<&mut T> {
self.get_mut(index)
}
fn receive(&mut self, broadcast: &impl Broadcastable<1,Element=Self::Element>) -> Option<()> {
self.clear();
for i in 0..broadcast.size()[0] {
self.push(broadcast.bget([i])?);
}
Some(())
}
}