gradients/layers/
conv2d.rs

1use crate::{GetParam, WithDevice};
2use custos::{cached, get_device, number::Float, Alloc, CDatatype, CacheBuf, Device, GraphReturn};
3use custos_math::{correlate_valid_mut, Matrix};
4use gradients_derive::NoParams;
5
6pub struct KernelBlock<'a, T> {
7    pub weights: Matrix<'a, T>,
8    bias: Matrix<'a, T>,
9    //dweights: Matrix<T>
10}
11
12impl<'a, T> KernelBlock<'a, T> {
13    pub fn new<D: Alloc<T> + GraphReturn>(
14        device: &'a D,
15        shape: (usize, usize),
16        bias_shape: (usize, usize),
17    ) -> Self
18    where
19        T: Float,
20    {
21        let mut weights = Matrix::new(device, shape);
22        weights.rand(T::one().neg(), T::one());
23
24        let mut bias = Matrix::new(device, bias_shape);
25        bias.rand(T::one().neg(), T::one());
26
27        KernelBlock { weights, bias }
28    }
29}
30
31#[doc(hidden)]
32#[derive(NoParams)]
33pub struct Conv2D<'a, T> {
34    pub kernel_shape: (usize, usize),
35    input_shape: (usize, usize),
36    output_shape: (usize, usize),
37    kernels: Vec<KernelBlock<'a, T>>,
38    inputs: Option<Matrix<'a, T>>,
39    device: Device,
40}
41
42impl<'a, T> Conv2D<'a, T>
43where
44    T: Float + CDatatype,
45{
46    pub fn new<D: Alloc<T> + GraphReturn>(
47        device: &'a D,
48        input_shape: (usize, usize),
49        kernel_shape: (usize, usize),
50        kernel_blocks: usize,
51    ) -> Conv2D<'a, T> {
52        let output_shape = (
53            input_shape.0 - kernel_shape.0 + 1,
54            input_shape.1 - kernel_shape.1 + 1,
55        );
56        let kernels = (0..kernel_blocks)
57            .into_iter()
58            .map(|_x| KernelBlock::new(device, kernel_shape, output_shape))
59            .collect();
60
61        Conv2D {
62            device: device.as_dev(),
63            kernel_shape,
64            output_shape,
65            input_shape,
66            kernels,
67            inputs: None,
68        }
69    }
70
71    pub fn forward(&mut self, inputs: &Matrix<'a, T>) -> Matrix<'a, T> {
72        let samples = inputs.rows();
73
74        self.inputs = Some(inputs.shallow_or_clone());
75        let (out_rows, out_cols) = self.output_shape;
76
77        let mut output = get_device!(self.device, CacheBuf<T>)
78            .cached(inputs.rows() * out_rows * out_cols * self.kernels.len());
79        output.clear();
80
81        //output.clear();
82
83        for row in 0..inputs.rows() {
84            let img_start = row * inputs.cols();
85            let single_image = &inputs[img_start..img_start + inputs.cols()];
86
87            for (idx, kernel_block) in self.kernels.iter().enumerate() {
88                let start = idx * out_rows * out_cols + img_start;
89                let output_slice = &mut output[start..start + out_rows * out_cols + img_start];
90                output_slice.copy_from_slice(&kernel_block.bias);
91                //assign_to_lhs(output_slice, &kernel_block.bias, |a, b| *a = b);
92
93                correlate_valid_mut(
94                    single_image,
95                    self.input_shape,
96                    &kernel_block.weights,
97                    kernel_block.weights.dims(),
98                    output_slice,
99                );
100            }
101        }
102
103        (output, samples, out_rows * out_cols * self.kernels.len()).into()
104    }
105    pub fn backward(&mut self, grad: &Matrix<'a, T>) -> Matrix<'a, T> {
106        let inputs = self.inputs.as_ref().unwrap();
107        let (out_rows, out_cols) = self.output_shape;
108        let (kernel_rows, kernel_cols) = self.kernel_shape;
109        let mut dkernel = cached::<T>(&grad.device, kernel_rows * kernel_cols * self.kernels.len());
110        dkernel.clear();
111
112        for row in 0..inputs.rows() {
113            let start = row * inputs.cols();
114            let single_image = &inputs[start..start + inputs.cols()];
115
116            for (idx, kernel) in self.kernels.iter_mut().enumerate() {
117                let start = idx * out_rows * out_cols;
118                let grad_slice = &grad[start..start + out_rows * out_cols];
119
120                let start = idx * kernel_rows * kernel_cols;
121                let dkernel_slice = &mut dkernel[start..start + kernel_rows * kernel_cols];
122
123                correlate_valid_mut(
124                    single_image,
125                    self.input_shape,
126                    grad_slice,
127                    (out_rows, out_cols),
128                    dkernel_slice,
129                );
130
131                // step
132                for (idx, value) in kernel.weights.iter_mut().enumerate() {
133                    *value -= dkernel_slice[idx] * T::one() / T::from_u64(1000);
134                }
135            }
136        }
137
138        // need to calculate w. r. t. inputs
139        grad.shallow_or_clone()
140    }
141}
142
143impl<'a, T> Default for Conv2D<'a, T> {
144    fn default() -> Self {
145        Self {
146            device: Default::default(),
147            inputs: Default::default(),
148            kernel_shape: Default::default(),
149            output_shape: Default::default(),
150            input_shape: Default::default(),
151            kernels: Default::default(),
152        }
153    }
154}