1use crate::{GetParam, WithDevice};
2use custos::{cached, get_device, number::Float, Alloc, CDatatype, CacheBuf, Device, GraphReturn};
3use custos_math::{correlate_valid_mut, Matrix};
4use gradients_derive::NoParams;
5
6pub struct KernelBlock<'a, T> {
7 pub weights: Matrix<'a, T>,
8 bias: Matrix<'a, T>,
9 }
11
12impl<'a, T> KernelBlock<'a, T> {
13 pub fn new<D: Alloc<T> + GraphReturn>(
14 device: &'a D,
15 shape: (usize, usize),
16 bias_shape: (usize, usize),
17 ) -> Self
18 where
19 T: Float,
20 {
21 let mut weights = Matrix::new(device, shape);
22 weights.rand(T::one().neg(), T::one());
23
24 let mut bias = Matrix::new(device, bias_shape);
25 bias.rand(T::one().neg(), T::one());
26
27 KernelBlock { weights, bias }
28 }
29}
30
31#[doc(hidden)]
32#[derive(NoParams)]
33pub struct Conv2D<'a, T> {
34 pub kernel_shape: (usize, usize),
35 input_shape: (usize, usize),
36 output_shape: (usize, usize),
37 kernels: Vec<KernelBlock<'a, T>>,
38 inputs: Option<Matrix<'a, T>>,
39 device: Device,
40}
41
42impl<'a, T> Conv2D<'a, T>
43where
44 T: Float + CDatatype,
45{
46 pub fn new<D: Alloc<T> + GraphReturn>(
47 device: &'a D,
48 input_shape: (usize, usize),
49 kernel_shape: (usize, usize),
50 kernel_blocks: usize,
51 ) -> Conv2D<'a, T> {
52 let output_shape = (
53 input_shape.0 - kernel_shape.0 + 1,
54 input_shape.1 - kernel_shape.1 + 1,
55 );
56 let kernels = (0..kernel_blocks)
57 .into_iter()
58 .map(|_x| KernelBlock::new(device, kernel_shape, output_shape))
59 .collect();
60
61 Conv2D {
62 device: device.as_dev(),
63 kernel_shape,
64 output_shape,
65 input_shape,
66 kernels,
67 inputs: None,
68 }
69 }
70
71 pub fn forward(&mut self, inputs: &Matrix<'a, T>) -> Matrix<'a, T> {
72 let samples = inputs.rows();
73
74 self.inputs = Some(inputs.shallow_or_clone());
75 let (out_rows, out_cols) = self.output_shape;
76
77 let mut output = get_device!(self.device, CacheBuf<T>)
78 .cached(inputs.rows() * out_rows * out_cols * self.kernels.len());
79 output.clear();
80
81 for row in 0..inputs.rows() {
84 let img_start = row * inputs.cols();
85 let single_image = &inputs[img_start..img_start + inputs.cols()];
86
87 for (idx, kernel_block) in self.kernels.iter().enumerate() {
88 let start = idx * out_rows * out_cols + img_start;
89 let output_slice = &mut output[start..start + out_rows * out_cols + img_start];
90 output_slice.copy_from_slice(&kernel_block.bias);
91 correlate_valid_mut(
94 single_image,
95 self.input_shape,
96 &kernel_block.weights,
97 kernel_block.weights.dims(),
98 output_slice,
99 );
100 }
101 }
102
103 (output, samples, out_rows * out_cols * self.kernels.len()).into()
104 }
105 pub fn backward(&mut self, grad: &Matrix<'a, T>) -> Matrix<'a, T> {
106 let inputs = self.inputs.as_ref().unwrap();
107 let (out_rows, out_cols) = self.output_shape;
108 let (kernel_rows, kernel_cols) = self.kernel_shape;
109 let mut dkernel = cached::<T>(&grad.device, kernel_rows * kernel_cols * self.kernels.len());
110 dkernel.clear();
111
112 for row in 0..inputs.rows() {
113 let start = row * inputs.cols();
114 let single_image = &inputs[start..start + inputs.cols()];
115
116 for (idx, kernel) in self.kernels.iter_mut().enumerate() {
117 let start = idx * out_rows * out_cols;
118 let grad_slice = &grad[start..start + out_rows * out_cols];
119
120 let start = idx * kernel_rows * kernel_cols;
121 let dkernel_slice = &mut dkernel[start..start + kernel_rows * kernel_cols];
122
123 correlate_valid_mut(
124 single_image,
125 self.input_shape,
126 grad_slice,
127 (out_rows, out_cols),
128 dkernel_slice,
129 );
130
131 for (idx, value) in kernel.weights.iter_mut().enumerate() {
133 *value -= dkernel_slice[idx] * T::one() / T::from_u64(1000);
134 }
135 }
136 }
137
138 grad.shallow_or_clone()
140 }
141}
142
143impl<'a, T> Default for Conv2D<'a, T> {
144 fn default() -> Self {
145 Self {
146 device: Default::default(),
147 inputs: Default::default(),
148 kernel_shape: Default::default(),
149 output_shape: Default::default(),
150 input_shape: Default::default(),
151 kernels: Default::default(),
152 }
153 }
154}