zenu_layer/layers/
conv2d.rs

1use std::{cell::RefCell, collections::HashMap};
2
3use rand_distr::{Distribution, StandardNormal};
4use zenu_autograd::{
5    creator::{rand::normal, zeros::zeros},
6    nn::conv2d::{conv2d, Conv2dConfigs},
7    Variable,
8};
9use zenu_matrix::{device::Device, dim::DimTrait, nn::conv2d::conv2d_out_size, num::Num};
10
11use crate::{Module, Parameters};
12
13pub struct Conv2d<T: Num, D: Device> {
14    pub filter: Variable<T, D>,
15    pub bias: Option<Variable<T, D>>,
16    config: RefCell<Option<Conv2dConfigs<T>>>,
17    stride: (usize, usize),
18    padding: (usize, usize),
19}
20
21impl<T: Num, D: Device> Module<T, D> for Conv2d<T, D> {
22    type Input = Variable<T, D>;
23    type Output = Variable<T, D>;
24    fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
25        if self.config.borrow().is_none() {
26            let input_shape = input.get_data().shape();
27            let filter_shape = self.filter.get_data().shape();
28            let output_shape = conv2d_out_size(
29                input_shape.slice(),
30                filter_shape.slice(),
31                self.padding,
32                self.stride,
33            );
34            let config = Conv2dConfigs::new(
35                input_shape,
36                output_shape.into(),
37                filter_shape,
38                self.stride,
39                self.padding,
40                20,
41            );
42            *self.config.borrow_mut() = Some(config);
43        }
44        conv2d(
45            input,
46            self.filter.clone(),
47            self.stride,
48            self.padding,
49            self.bias.clone(),
50            Some(self.config.borrow().as_ref().unwrap().clone()),
51        )
52    }
53}
54
55impl<T: Num, D: Device> Parameters<T, D> for Conv2d<T, D> {
56    fn weights(&self) -> HashMap<String, Variable<T, D>> {
57        HashMap::new()
58            .into_iter()
59            .chain(std::iter::once((
60                String::from("conv2d.filter"),
61                self.filter.clone(),
62            )))
63            .collect()
64    }
65
66    fn biases(&self) -> HashMap<String, Variable<T, D>> {
67        self.bias
68            .as_ref()
69            .map(|bias| {
70                HashMap::new()
71                    .into_iter()
72                    .chain(std::iter::once((String::from("conv2d.bias"), bias.clone())))
73                    .collect()
74            })
75            .unwrap_or_default()
76    }
77}
78
79impl<T: Num, D: Device> Conv2d<T, D> {
80    #[must_use]
81    pub fn new(
82        input_channel: usize,
83        output_channel: usize,
84        kernel_size: (usize, usize),
85        stride: (usize, usize),
86        padding: (usize, usize),
87        bias: bool,
88    ) -> Self
89    where
90        StandardNormal: Distribution<T>,
91    {
92        let filter_shape = [output_channel, input_channel, kernel_size.0, kernel_size.1];
93        let bias = if bias {
94            let bias = zeros([1, output_channel, 1, 1]);
95            bias.set_is_train(true);
96            bias.set_name("conv2d.bias");
97            Some(bias)
98        } else {
99            None
100        };
101        let filter = normal(T::zero(), T::one(), None, filter_shape);
102
103        filter.set_is_train(true);
104        filter.set_name("conv2d.filter");
105
106        Conv2d {
107            filter,
108            bias,
109            config: RefCell::new(None),
110            stride,
111            padding,
112        }
113    }
114
115    pub fn to<Dout: Device>(self) -> Conv2d<T, Dout> {
116        Conv2d {
117            filter: self.filter.to(),
118            bias: self.bias.map(|b| b.to()),
119            config: RefCell::new(None),
120            stride: self.stride,
121            padding: self.padding,
122        }
123    }
124}