tch 0.0.1

PyTorch wrappers for rust
use crate::tensor::Tensor;

pub struct Sequential {
    layers: Vec<Box<super::module::Module>>,
}

impl Sequential {
    pub fn new() -> Sequential {
        Sequential { layers: vec![] }
    }
}

impl super::module::Module for Sequential {
    fn forward(&self, xs: &Tensor) -> Tensor {
        if self.layers.is_empty() {
            xs.shallow_clone()
        } else {
            let xs = self.layers[0].forward(xs);
            self.layers
                .iter()
                .skip(1)
                .fold(xs, |xs, layer| layer.forward(&xs))
        }
    }
}

impl Sequential {
    pub fn add<M: super::module::Module + 'static>(mut self, layer: M) -> Self {
        self.layers.push(Box::new(layer));
        self
    }

    pub fn add_fn<F>(self, f: F) -> Self
    where
        F: 'static,
        F: Fn(&Tensor) -> Tensor,
    {
        self.add(super::func::Func::new(f))
    }
}

pub struct SequentialT {
    layers: Vec<Box<super::module::ModuleT>>,
}

impl SequentialT {
    pub fn new() -> SequentialT {
        SequentialT { layers: vec![] }
    }
}

impl super::module::ModuleT for SequentialT {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
        if self.layers.is_empty() {
            xs.shallow_clone()
        } else {
            let xs = self.layers[0].forward_t(xs, train);
            self.layers
                .iter()
                .skip(1)
                .fold(xs, |xs, layer| layer.forward_t(&xs, train))
        }
    }
}

impl SequentialT {
    pub fn add<M: super::module::ModuleT + 'static>(mut self, layer: M) -> Self {
        self.layers.push(Box::new(layer));
        self
    }

    pub fn add_fn<F>(self, f: F) -> Self
    where
        F: 'static,
        F: Fn(&Tensor) -> Tensor,
    {
        self.add(super::func::Func::new(f))
    }

    pub fn add_fn_t<F>(self, f: F) -> Self
    where
        F: 'static,
        F: Fn(&Tensor, bool) -> Tensor,
    {
        self.add(super::func::FuncT::new(f))
    }
}