gtensor 1.0.0

Reverse-mode autodifferentiation of computational graphs with tensors and more for machine learning.
Documentation

use super::dataset::Dataset;
use crate::tensor::shape::Shape;
use crate::tensor::slice::TensorSlice;

impl Dataset {  
    pub fn iter<'a>(&'a self) -> DatasetIter<'a> {
        self.iter_batched(1)
    }

    pub fn iter_batched<'a>(&'a self, batch_size: usize) -> DatasetIter<'a> {
        let mut feature_shape = self.feature_shape;
        feature_shape[0] = batch_size;

        let mut label_shape = self.label_shape;
        label_shape[0] = batch_size;

        if self.features.len() % feature_shape.len() != 0 {
            panic!("Cannot iterate dataset batched; provided batch size is not valid! 
                (Number of features must be evenly divisible by the batch size), feature shape: {}, feature vector len: {}",
                feature_shape, self.features.len())
        }

        DatasetIter {
            features: &self.features,
            labels: &self.labels,
            feature_shape,
            feature_len: feature_shape.len(),
            label_shape,
            label_len: label_shape.len(),
            feature_curr: 0,
            label_curr: 0,
        }
    }
}

pub struct DatasetIter<'a> {
    features: &'a Vec<f32>,
    labels: &'a Vec<f32>,
    feature_shape: Shape,
    feature_len: usize,
    label_shape: Shape,
    label_len: usize,
    feature_curr: usize,
    label_curr: usize,
}

impl<'a> Iterator for DatasetIter<'a> {
    type Item = (TensorSlice<'a>, TensorSlice<'a>);

    fn next(&mut self) -> Option<Self::Item> {
        if self.feature_curr >= self.features.len() {
            None
        } else {
            let out = Some((
                TensorSlice {
                    data: &self.features[self.feature_curr..self.feature_curr+self.feature_len],
                    shape: self.feature_shape,
                }, 
                TensorSlice {
                    data: &self.labels[self.label_curr..self.label_curr+self.label_len],
                    shape: self.label_shape,
                }
            ));

            self.feature_curr += self.feature_len;
            self.label_curr += self.label_len;

            out
        }
    }
}