gtensor 1.0.0

Reverse-mode autodifferentiation of computational graphs with tensors and more for machine learning.
Documentation

use std::fmt::Display;

use serde::{Serialize, Deserialize};
use anyhow::{Result, anyhow};

#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct Shape([usize; 6]);

impl Shape {
    pub fn len(&self) -> usize {
        self[0]*self[1]*self[2]*self[3]*self[4]*self[5]
    }

    pub fn reshape(&mut self, new: impl ToShape) {
        self.try_reshape(new).unwrap()
    }

    pub fn add_batch(mut self, batch_size: usize) -> Self {
        for i in (1..6).rev() {
            self[i] = self[i-1]
        }

        self[0] = batch_size;
        self
    }

    pub fn try_reshape(&mut self, new: impl ToShape) -> Result<()> {
        let new = new.to_shape();

        if new.len() == self.len() {
            Ok(())
        } else {
            return Err(anyhow!("
                Length of Old Shape does not match length of New Shape! (Old: {}), (New: {})
            ", self, new))
        }
    }

    pub fn as_array2(&self) -> [usize; 2] {
        [self[0],self[1]]
    }

    pub fn as_array3(&self) -> [usize; 3] {
        [self[0],self[1],self[2]]
    }

    pub fn as_array4(&self) -> [usize; 4] {
        [self[0],self[1],self[2],self[3]]
    }

    pub fn as_array5(&self) -> [usize; 5] {
        [self[0],self[1],self[2],self[3],self[4]]
    }

    pub fn as_array6(&self) -> [usize; 6] {
        [self[0],self[1],self[2],self[3],self[4],self[5]]
    }
}

impl Display for Shape {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "[{}, {}, {}, {}, {}, {}]", self[0], self[1], self[2], self[3], self[4], self[5])
    }
}

impl std::ops::Index<usize> for Shape {
    type Output = usize;

    fn index(&self, index: usize) -> &Self::Output {
        &self.0[index]
    }
}

impl std::ops::IndexMut<usize> for Shape {
    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
        &mut self.0[index]
    }
}

pub trait ToShape {
    fn to_shape(self) -> Shape;
}

impl ToShape for Shape {
    fn to_shape(self) -> Shape {
        self
    }
}

macros::impl_shape!(1,2,3,4,5,6);

mod macros {
    macro_rules! impl_shape {
        ($($n:literal),*) => {
            $(
            impl ToShape for [usize; $n] {
                fn to_shape(self) -> Shape {
                    for i in self.iter() {
                        if *i == 0 {
                            panic!("Shape Axes must be non-zero!")
                        }
                    }

                    let mut arr = [1; 6];
                    arr.chunks_exact_mut($n).take(1).for_each(|slice| slice.copy_from_slice(&self));
                    Shape(arr)
                }
            }
            )*
        }
    }
    pub(super) use impl_shape;
}