tch 0.0.1

PyTorch wrappers for rust
use crate::tensor::Tensor;

pub struct Conv2DConfig {
    pub stride: i64,
    pub padding: i64,
    pub dilation: i64,
    pub groups: i64,
    pub bias: bool,
}

impl Default for Conv2DConfig {
    fn default() -> Self {
        Conv2DConfig {
            stride: 1,
            padding: 0,
            dilation: 1,
            groups: 1,
            bias: true,
        }
    }
}

pub struct Conv2D {
    ws: Tensor,
    bs: Tensor,
    stride: [i64; 2],
    padding: [i64; 2],
    dilation: [i64; 2],
    groups: i64,
}

impl Conv2D {
    pub fn new(
        vs: &super::var_store::Path,
        in_dim: i64,
        out_dim: i64,
        ksize: i64,
        config: Conv2DConfig,
    ) -> Conv2D {
        let Conv2DConfig {
            stride,
            padding,
            dilation,
            groups,
            bias,
        } = config;
        let bs = if bias {
            vs.zeros("bias", &[out_dim])
        } else {
            Tensor::zeros(&[out_dim], (crate::Kind::Float, vs.device()))
        };
        let ws = vs.kaiming_uniform("weight", &[out_dim, in_dim, ksize, ksize]);
        Conv2D {
            ws,
            bs,
            stride: [stride, stride],
            padding: [padding, padding],
            dilation: [dilation, dilation],
            groups,
        }
    }
}

impl super::module::Module for Conv2D {
    fn forward(&self, xs: &Tensor) -> Tensor {
        Tensor::conv2d(
            &xs,
            &self.ws,
            &self.bs,
            &self.stride,
            &self.padding,
            &self.dilation,
            self.groups,
        )
    }
}