1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use zenu_autograd::{
    creator::{rand::normal, zeros::zeros},
    functions::conv2d::conv2d,
    Variable,
};
use zenu_matrix::{dim::DimTrait, matrix::MatrixBase, num::Num};

use crate::Layer;

pub struct Conv2d<T: Num> {
    in_channels: usize,
    out_channels: usize,
    kernel_size: (usize, usize),
    stride: (usize, usize),
    padding: (usize, usize),
    bias: Option<Variable<T>>,
    kernel: Option<Variable<T>>,
}

impl<T: Num> Conv2d<T> {
    #[must_use]
    pub fn new(
        in_channels: usize,
        out_channels: usize,
        kernel_size: (usize, usize),
        stride: (usize, usize),
        padding: (usize, usize),
        bias: bool,
    ) -> Self {
        let bias = if bias {
            Some(zeros([out_channels, 1, 1, 1]))
        } else {
            None
        };
        Self {
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            kernel: None,
            bias,
        }
    }

    #[must_use]
    pub fn kernel(&self) -> Option<Variable<T>> {
        self.kernel.clone()
    }
}

impl<T: Num> Layer<T> for Conv2d<T> {
    fn init_parameters(&mut self, seed: Option<u64>)
    where
        rand_distr::StandardNormal: rand::prelude::Distribution<T>,
    {
        let kernel = normal(
            T::zero(),
            T::one(),
            seed,
            [
                self.out_channels,
                self.in_channels,
                self.kernel_size.0,
                self.kernel_size.1,
            ],
        );
        self.kernel = Some(kernel);
    }

    fn call(&self, input: Variable<T>) -> Variable<T> {
        self.shape_check(&input);
        conv2d(
            input,
            self.kernel().unwrap(),
            self.bias.clone(),
            self.stride,
            self.padding,
        )
    }

    fn parameters(&self) -> Vec<Variable<T>> {
        vec![self.kernel().unwrap()]
    }

    fn shape_check(&self, input: &Variable<T>) {
        let input_shape = input.get_data().shape();
        assert_eq!(input_shape.len(), 4, "Input must be 4D tensor");
        assert_eq!(
            input_shape[1], self.in_channels,
            "Input channel must be equal to in_channels"
        );
    }

    fn load_parameters(&mut self, parameters: &[Variable<T>]) {
        self.kernel = Some(parameters[0].clone());
    }
}