1use crate::network::{Initializer, Layer, LayerDims, Optimizer, XavierUniform};
2use crate::{ConvGeometryIsValid, Float, tensor::Tensor};
3use rand::{Rng, SeedableRng, rngs::StdRng};
4use std::array;
5
6#[doc(hidden)]
7pub const fn conv_out_dim(input: usize, pad: usize, kernel: usize, stride: usize) -> usize {
8 if stride == 0 {
9 return 0;
10 }
11 let padded = input + 2 * pad;
12 if padded < kernel {
13 return 0;
14 }
15 let numer = padded - kernel;
16 if !numer.is_multiple_of(stride) {
17 return 0;
18 }
19 numer / stride + 1
20}
21
22#[derive(Debug, Clone)]
23pub struct Filter<const H: usize, const W: usize, const D: usize> {
24 weights: Tensor<crate::shape!(H, W, D)>,
25 grads: Box<[Float]>,
26}
27
28impl<const H: usize, const W: usize, const D: usize> Filter<H, W, D> {
29 fn zeroed() -> Self {
30 Self {
31 weights: Tensor::<crate::shape!(H, W, D)>::from_boxed(
32 vec![0.0 as Float; H * W * D].into_boxed_slice(),
33 ),
34 grads: vec![0.0 as Float; H * W * D].into_boxed_slice(),
35 }
36 }
37
38 fn weights(&self) -> &[Float] {
39 self.weights.raw_slice()
40 }
41
42 fn grads_mut(&mut self) -> &mut [Float] {
43 &mut self.grads[..]
44 }
45}
46
47#[derive(Debug)]
56pub struct Conv<
57 const IW: usize,
58 const IH: usize,
59 const IC: usize,
60 const FH: usize,
61 const FW: usize,
62 const OC: usize,
63 const S: usize,
64 const P: usize,
65> {
66 filters: [Filter<FH, FW, IC>; OC],
67 biases: Box<[Float; OC]>,
68 bias_grads: Box<[Float; OC]>,
69}
70impl<
71 const IW: usize,
72 const IH: usize,
73 const IC: usize,
74 const FH: usize,
75 const FW: usize,
76 const OC: usize,
77 const S: usize,
78 const P: usize,
79> Conv<IW, IH, IC, FH, FW, OC, S, P>
80where
81 [(); IC * IH * IW]:,
82 [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
83 (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
84{
85 pub fn init() -> Self {
86 Self::with_initializer(XavierUniform)
87 }
88
89 pub fn seeded(seed: u64) -> Self {
90 Self::with_initializer_and_seed(XavierUniform, seed)
91 }
92
93 pub fn with_initializer<I: Initializer>(initializer: I) -> Self {
94 let mut rng = rand::rng();
95 Self::with_initializer_and_rng(initializer, &mut rng)
96 }
97
98 pub fn with_initializer_and_seed<I: Initializer>(initializer: I, seed: u64) -> Self {
99 let mut rng = StdRng::seed_from_u64(seed);
100 Self::with_initializer_and_rng(initializer, &mut rng)
101 }
102
103 pub fn with_initializer_and_rng<I: Initializer, R: Rng + ?Sized>(
104 initializer: I,
105 rng: &mut R,
106 ) -> Self {
107 let mut conv = Conv {
108 filters: array::from_fn(|_| Filter::zeroed()),
109 biases: Box::new([0.0 as Float; OC]),
110 bias_grads: Box::new([0.0 as Float; OC]),
111 };
112 let fan_in = FH * FW * IC;
113 let fan_out = FH * FW * OC;
114 for filter in &mut conv.filters {
115 initializer.fill(filter.weights.raw_mut_slice(), fan_in, fan_out, rng);
116 }
117 conv
118 }
119
120 pub fn create_output_space(&self) -> <Self as ConvIO>::Output {
121 Tensor::<crate::shape!(
122 OC,
123 conv_out_dim(IH, P, FH, S),
124 conv_out_dim(IW, P, FW, S)
125 )>::from_boxed(
126 vec![
127 0.0 as Float;
128 OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)
129 ]
130 .into_boxed_slice(),
131 )
132 }
133
134 pub fn input_from_data(&self, data: [Float; IC * IH * IW]) -> <Self as ConvIO>::Input {
135 Tensor::<crate::shape!(IC, IH, IW)>::from_boxed(Vec::from(data).into_boxed_slice())
136 }
137
138 pub fn forward(
139 &self,
140 input: &Tensor<crate::shape!(IC, IH, IW)>,
141 output: &mut Tensor<
142 crate::shape!(OC, conv_out_dim(IH, P, FH, S), conv_out_dim(IW, P, FW, S)),
143 >,
144 ) {
145 let input_arr: &[Float; IC * IH * IW] = input.raw_slice().try_into().expect("bad input");
146 let output_arr: &mut [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)] =
147 output.raw_mut_slice().try_into().expect("bad output");
148 self.forward_flat(input_arr, output_arr);
149 }
150
151 pub fn forward_flat(
152 &self,
153 input: &[Float; IC * IH * IW],
154 output: &mut [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
155 ) {
156 let out_h = conv_out_dim(IH, P, FH, S);
157 let out_w = conv_out_dim(IW, P, FW, S);
158
159 for oc in 0..OC {
160 let filter_data = self.filters[oc].weights();
161
162 for y in 0..out_h {
163 for x in 0..out_w {
164 let mut sum = self.biases[oc];
165
166 for ky in 0..FH {
167 for kx in 0..FW {
168 for ic in 0..IC {
169 let in_y = y * S + ky;
170 let in_x = x * S + kx;
171 let in_y = in_y as isize - P as isize;
172 let in_x = in_x as isize - P as isize;
173
174 if in_y >= 0
175 && in_y < IH as isize
176 && in_x >= 0
177 && in_x < IW as isize
178 {
179 let in_y = in_y as usize;
180 let in_x = in_x as usize;
181 let input_idx = ic * IH * IW + in_y * IW + in_x;
182 let filter_idx = (ky * FW + kx) * IC + ic;
183 sum += filter_data[filter_idx] * input[input_idx];
184 }
185 }
186 }
187 }
188
189 let output_idx = oc * out_h * out_w + y * out_w + x;
190 output[output_idx] = sum;
191 }
192 }
193 }
194 }
195
196 pub fn backward_flat(
197 &mut self,
198 input: &[Float; IC * IH * IW],
199 output_grad: &[Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
200 input_grad: &mut [Float; IC * IH * IW],
201 ) {
202 let out_h = conv_out_dim(IH, P, FH, S);
203 let out_w = conv_out_dim(IW, P, FW, S);
204
205 input_grad.fill(0.0);
206
207 for oc in 0..OC {
208 let Filter { weights, grads } = &mut self.filters[oc];
209 let filter_weights = weights.raw_slice();
210 let filter_grads = &mut grads[..];
211
212 for y in 0..out_h {
213 for x in 0..out_w {
214 let output_idx = oc * out_h * out_w + y * out_w + x;
215 let grad = output_grad[output_idx];
216 self.bias_grads[oc] += grad;
217
218 for ky in 0..FH {
219 for kx in 0..FW {
220 for ic in 0..IC {
221 let in_y = y * S + ky;
222 let in_x = x * S + kx;
223 let in_y = in_y as isize - P as isize;
224 let in_x = in_x as isize - P as isize;
225
226 if in_y >= 0
227 && in_y < IH as isize
228 && in_x >= 0
229 && in_x < IW as isize
230 {
231 let in_y = in_y as usize;
232 let in_x = in_x as usize;
233 let input_idx = ic * IH * IW + in_y * IW + in_x;
234 let filter_idx = (ky * FW + kx) * IC + ic;
235
236 filter_grads[filter_idx] += grad * input[input_idx];
237 input_grad[input_idx] += grad * filter_weights[filter_idx];
238 }
239 }
240 }
241 }
242 }
243 }
244 }
245 }
246}
247
248impl<
249 const IW: usize,
250 const IH: usize,
251 const IC: usize,
252 const FH: usize,
253 const FW: usize,
254 const OC: usize,
255 const S: usize,
256 const P: usize,
257> LayerDims for Conv<IW, IH, IC, FH, FW, OC, S, P>
258where
259 [(); IC * IH * IW]:,
260 [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
261 (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
262{
263 const INPUT: usize = IC * IH * IW;
264 const OUTPUT: usize = OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S);
265}
266
267impl<
268 const IW: usize,
269 const IH: usize,
270 const IC: usize,
271 const FH: usize,
272 const FW: usize,
273 const OC: usize,
274 const S: usize,
275 const P: usize,
276> Layer<{ IC * IH * IW }, { OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S) }>
277 for Conv<IW, IH, IC, FH, FW, OC, S, P>
278where
279 [(); IC * IH * IW]:,
280 [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
281 (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
282{
283 fn forward(
284 &self,
285 input: &[Float; IC * IH * IW],
286 output: &mut [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
287 ) {
288 self.forward_flat(input, output);
289 }
290
291 fn backward(
292 &mut self,
293 input: &[Float; IC * IH * IW],
294 _output: &[Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
295 output_grad: &[Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)],
296 input_grad: &mut [Float; IC * IH * IW],
297 ) {
298 self.backward_flat(input, output_grad, input_grad);
299 }
300
301 fn zero_grad(&mut self) {
302 self.bias_grads.fill(0.0);
303 for filter in &mut self.filters {
304 filter.grads_mut().fill(0.0);
305 }
306 }
307
308 fn apply_gradients(&mut self, optimizer: &mut dyn Optimizer, slot: &mut usize, scale: Float) {
309 for filter in &mut self.filters {
310 optimizer.update_parameter(
311 *slot,
312 filter.weights.raw_mut_slice(),
313 filter.grads.as_ref(),
314 scale,
315 );
316 *slot += 1;
317 filter.grads_mut().fill(0.0);
318 }
319 optimizer.update_parameter(
320 *slot,
321 self.biases.as_mut_slice(),
322 self.bias_grads.as_slice(),
323 scale,
324 );
325 *slot += 1;
326 self.bias_grads.fill(0.0);
327 }
328}
329
330#[allow(dead_code)]
331pub trait ConvIO {
333 type Output;
334 type Input;
335 type OutputShape;
336 type InputShape;
337 type FilterShape;
338 const N: usize;
339}
340
341impl<
342 const IW: usize,
343 const IH: usize,
344 const IC: usize,
345 const FH: usize,
346 const FW: usize,
347 const OC: usize,
348 const S: usize,
349 const P: usize,
350> ConvIO for Conv<IW, IH, IC, FH, FW, OC, S, P>
351where
352 [(); IC * IH * IW]:,
353 [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
354{
355 const N: usize = IC * IH * IW;
356 type Input = Tensor<crate::shape!(IC, IH, IW)>;
357 type Output = Tensor<Self::OutputShape>;
358 type InputShape = crate::shape!(IC, IH, IW);
359 type OutputShape = crate::shape!(OC, conv_out_dim(IH, P, FH, S), conv_out_dim(IW, P, FW, S));
360 type FilterShape = crate::shape!(FH, FW, IC);
361}
362
363#[allow(dead_code)]
364pub trait ConvOps: ConvIO {
366 type InputArray;
367 type OutputArray;
368 type FilterArray;
369
370 const INPUT_SIZE: usize;
371 const OUTPUT_SIZE: usize;
372 const FILTER_SIZE: usize;
373
374 fn init() -> Self;
375 fn forward_flat(&self, input: &Self::InputArray, output: &mut Self::OutputArray);
376 fn input_from_fn<F: FnMut(usize) -> Float>(f: F) -> Self::InputArray;
377 fn output_zeroed() -> Self::OutputArray;
378}
379
380impl<
381 const IW: usize,
382 const IH: usize,
383 const IC: usize,
384 const FH: usize,
385 const FW: usize,
386 const OC: usize,
387 const S: usize,
388 const P: usize,
389> ConvOps for Conv<IW, IH, IC, FH, FW, OC, S, P>
390where
391 [(); FH * FW * IC]:,
392 [(); IC * IH * IW]:,
393 [(); OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)]:,
394 (): ConvGeometryIsValid<IH, IW, FH, FW, S, P>,
395{
396 type InputArray = [Float; IC * IH * IW];
397 type OutputArray = [Float; OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S)];
398 type FilterArray = [Float; FH * FW * IC];
399
400 const INPUT_SIZE: usize = IC * IH * IW;
401 const OUTPUT_SIZE: usize = OC * conv_out_dim(IH, P, FH, S) * conv_out_dim(IW, P, FW, S);
402 const FILTER_SIZE: usize = FH * FW * IC;
403
404 fn init() -> Self {
405 Conv::<IW, IH, IC, FH, FW, OC, S, P>::init()
406 }
407
408 fn forward_flat(&self, input: &Self::InputArray, output: &mut Self::OutputArray) {
409 Conv::<IW, IH, IC, FH, FW, OC, S, P>::forward_flat(self, input, output);
410 }
411
412 fn input_from_fn<F: FnMut(usize) -> Float>(f: F) -> Self::InputArray {
413 array::from_fn(f)
414 }
415
416 fn output_zeroed() -> Self::OutputArray {
417 array::from_fn(|_| 0.0 as Float)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 type ConvCase = Conv<3, 3, 1, 2, 2, 1, 1, 0>;
426 const IN_SIZE: usize = 3 * 3;
427 const OUT_SIZE: usize = 4;
428
429 fn approx_eq(a: Float, b: Float, eps: Float) {
430 let diff = (a - b).abs();
431 assert!(diff <= eps, "expected {a} ~= {b} (diff={diff}, eps={eps})");
432 }
433
434 fn configured_conv() -> ConvCase {
435 let mut conv = ConvCase::init();
436 for (i, w) in conv.filters[0]
437 .weights
438 .raw_mut_slice()
439 .iter_mut()
440 .enumerate()
441 {
442 *w = 0.1 * (i as Float + 1.0);
443 }
444 conv.biases[0] = 0.05;
445 conv
446 }
447
448 fn objective(
449 conv: &ConvCase,
450 input: &[Float; IN_SIZE],
451 output_grad: &[Float; OUT_SIZE],
452 ) -> Float {
453 let mut output = [0.0; OUT_SIZE];
454 conv.forward_flat(input, &mut output);
455 output
456 .iter()
457 .zip(output_grad.iter())
458 .map(|(o, g)| o * g)
459 .sum()
460 }
461
462 #[test]
463 fn input_gradient_matches_finite_difference() {
464 let mut conv = configured_conv();
465 let input = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
466 let output_grad = [0.3, -0.2, 0.1, 0.4];
467 let mut input_grad = [0.0; IN_SIZE];
468
469 conv.zero_grad();
470 conv.backward_flat(&input, &output_grad, &mut input_grad);
471
472 let eps = 1e-7;
473 for i in 0..IN_SIZE {
474 let mut plus = input;
475 let mut minus = input;
476 plus[i] += eps;
477 minus[i] -= eps;
478 let f_plus = objective(&conv, &plus, &output_grad);
479 let f_minus = objective(&conv, &minus, &output_grad);
480 let numeric = (f_plus - f_minus) / (2.0 * eps);
481 approx_eq(input_grad[i], numeric, 1e-6);
482 }
483 }
484
485 #[test]
486 fn weight_update_matches_finite_difference_gradient() {
487 let mut conv = configured_conv();
488 let input = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
489 let output_grad = [0.3, -0.2, 0.1, 0.4];
490 let mut input_grad = [0.0; IN_SIZE];
491 let weight_idx = 2;
492
493 let eps = 1e-7;
494 let mut conv_plus = configured_conv();
495 conv_plus.filters[0].weights.raw_mut_slice()[weight_idx] += eps;
496 let mut conv_minus = configured_conv();
497 conv_minus.filters[0].weights.raw_mut_slice()[weight_idx] -= eps;
498 let numeric = (objective(&conv_plus, &input, &output_grad)
499 - objective(&conv_minus, &input, &output_grad))
500 / (2.0 * eps);
501
502 conv.zero_grad();
503 conv.backward_flat(&input, &output_grad, &mut input_grad);
504 let analytic = conv.filters[0].grads[weight_idx];
505
506 approx_eq(analytic, numeric, 1e-6);
507 }
508}