tch 0.0.1

PyTorch wrappers for rust
use crate::tensor::Tensor;

pub struct BatchNorm2DConfig {
    cudnn_enabled: bool,
    eps: f64,
    momentum: f64,
}

impl Default for BatchNorm2DConfig {
    fn default() -> Self {
        BatchNorm2DConfig {
            cudnn_enabled: true,
            eps: 1e-5,
            momentum: 0.1,
        }
    }
}

pub struct BatchNorm2D {
    config: BatchNorm2DConfig,
    running_mean: Tensor,
    running_var: Tensor,
    ws: Tensor,
    bs: Tensor,
}

impl BatchNorm2D {
    pub fn new(
        vs: &super::var_store::Path,
        out_dim: i64,
        config: BatchNorm2DConfig,
    ) -> BatchNorm2D {
        BatchNorm2D {
            config,
            running_mean: vs.zeros_no_train("running_mean", &[out_dim]),
            running_var: vs.ones_no_train("running_var", &[out_dim]),
            ws: vs.uniform("weight", &[out_dim], 0.0, 1.0),
            bs: vs.zeros("bias", &[out_dim]),
        }
    }
}

impl super::module::ModuleT for BatchNorm2D {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
        Tensor::batch_norm(
            xs,
            Some(&self.ws),
            Some(&self.bs),
            Some(&self.running_mean),
            Some(&self.running_var),
            train,
            self.config.momentum,
            self.config.eps,
            self.config.cudnn_enabled,
        )
    }
}