1use crate::{DataRepresentation, Padding};
7use ndarray::*;
8use num_traits::Float;
9
10pub struct ConvolutionLayer<F: Float> {
15 pub(in crate) kernel: Array4<F>,
17 pub(in crate) bias: Option<Array1<F>>,
18 pub(in crate) stride: usize,
19 pub(in crate) padding: Padding,
20}
21
22impl<F: 'static + Float + std::ops::AddAssign> ConvolutionLayer<F> {
23 pub fn new(
28 weights: Array4<F>,
29 bias_array: Option<Array1<F>>,
30 stride: usize,
31 padding: Padding,
32 ) -> ConvolutionLayer<F> {
33 assert!(stride > 0, "Stride of 0 passed");
34 ConvolutionLayer {
35 kernel: weights,
36 bias: bias_array,
37 stride,
38 padding,
39 }
40 }
41
42 pub fn new_tf(
46 weights: Array4<F>,
47 bias_array: Option<Array1<F>>,
48 stride: usize,
49 padding: Padding,
50 ) -> ConvolutionLayer<F> {
51 let permuted_view = weights.view().permuted_axes([3, 2, 0, 1]);
52 let permuted_array: Array4<F> =
55 Array::from_shape_vec(permuted_view.dim(), permuted_view.iter().copied().collect())
56 .unwrap();
57 ConvolutionLayer::new(permuted_array, bias_array, stride, padding)
58 }
59
60 pub fn convolve(&self, image: &DataRepresentation<F>) -> DataRepresentation<F> {
62 conv2d(
63 &self.kernel,
64 self.bias.as_ref(),
65 image,
66 self.padding,
67 self.stride,
68 )
69 }
70}
71
72pub(in crate) fn get_padding_size(
73 input_h: usize,
74 input_w: usize,
75 stride: usize,
76 kernel_h: usize,
77 kernel_w: usize,
78) -> (usize, usize, usize, usize, usize, usize) {
79 let pad_along_height: usize;
80 let pad_along_width: usize;
81 let idx_0: usize = 0;
82
83 if input_h % stride == idx_0 {
84 pad_along_height = (kernel_h - stride).max(idx_0);
85 } else {
86 pad_along_height = (kernel_h - (input_h % stride)).max(idx_0);
87 };
88 if input_w % stride == idx_0 {
89 pad_along_width = (kernel_w - stride).max(idx_0);
90 } else {
91 pad_along_width = (kernel_w - (input_w % stride)).max(idx_0);
92 };
93
94 let pad_top = pad_along_height / 2;
95 let pad_bottom = pad_along_height - pad_top;
96 let pad_left = pad_along_width / 2;
97 let pad_right = pad_along_width - pad_left;
98
99 (
102 pad_along_height,
103 pad_along_width,
104 pad_bottom,
105 pad_top,
106 pad_right,
107 pad_left,
108 )
109}
110
111pub(in crate) fn im2col_ref<'a, T, F: 'a + Float>(
112 im_arr: T,
113 ker_height: usize,
114 ker_width: usize,
115 im_height: usize,
116 im_width: usize,
117 im_channel: usize,
118 stride: usize,
119) -> Array2<F>
120where
121 T: AsArray<'a, F, Ix3>,
132{
133 let im2d_arr: ArrayView3<F> = im_arr.into();
134 let new_h = (im_height - ker_height) / stride + 1;
135 let new_w = (im_width - ker_width) / stride + 1;
136 let mut cols_img: Array2<F> =
137 Array::zeros((new_h * new_w, im_channel * ker_height * ker_width));
138 let mut cont = 0_usize;
139 for i in 1..new_h + 1 {
140 for j in 1..new_w + 1 {
141 let patch = im2d_arr.slice(s![
142 ..,
143 (i - 1) * stride..((i - 1) * stride + ker_height),
144 (j - 1) * stride..((j - 1) * stride + ker_width),
145 ]);
146 let patchrow_unwrap: Array1<F> = Array::from_iter(patch.map(|a| *a));
147
148 cols_img.row_mut(cont).assign(&patchrow_unwrap);
149 cont += 1;
150 }
151 }
152 cols_img
153}
154
155pub fn conv2d<'a, T, V, F: 'static + Float + std::ops::AddAssign>(
175 kernel_weights: T,
176 bias: Option<&Array1<F>>,
177 im2d: V,
178 padding: Padding,
179 stride: usize,
180) -> DataRepresentation<F>
181where
182 V: AsArray<'a, F, Ix3>,
186 T: AsArray<'a, F, Ix4>,
187{
188 let im2d_arr: ArrayView3<F> = im2d.into();
190 let kernel_weights_arr: ArrayView4<F> = kernel_weights.into();
191 let im_col: Array2<F>; let new_im_height: usize;
193 let new_im_width: usize;
194 let weight_shape = kernel_weights_arr.shape();
195 let num_filters = weight_shape[0] as usize;
196 let num_channels_out = weight_shape[1] as usize;
197 let kernel_height = weight_shape[2] as usize;
198 let kernel_width = weight_shape[3] as usize;
199
200 let im_channel = im2d_arr.len_of(Axis(0));
202 let im_height = im2d_arr.len_of(Axis(1));
203 let im_width = im2d_arr.len_of(Axis(2));
204
205 if padding == Padding::Same {
207 let h_float = im_height as f32;
212 let w_float = im_width as f32;
213 let stride_float = stride as f32;
214
215 let new_im_height_float = (h_float / stride_float).ceil();
216 let new_im_width_float = (w_float / stride_float).ceil();
217
218 new_im_height = new_im_height_float as usize;
219 new_im_width = new_im_width_float as usize;
220 } else {
221 new_im_height = ((im_height - kernel_height) / stride) + 1;
224 new_im_width = ((im_width - kernel_width) / stride) + 1;
225 };
226
227 let filter_col = kernel_weights_arr
229 .into_shape((num_filters, kernel_height * kernel_width * num_channels_out))
230 .unwrap();
231
232 if padding == Padding::Same {
234 let (pad_num_h, pad_num_w, pad_top, pad_bottom, pad_left, pad_right) =
236 get_padding_size(im_height, im_width, stride, kernel_height, kernel_width);
237 let mut im2d_arr_pad: Array3<F> = Array::zeros((
238 num_channels_out,
239 im_height + pad_num_h,
240 im_width + pad_num_w,
241 ));
242 let pad_bottom_int = (im_height + pad_num_h) - pad_bottom;
243 let pad_right_int = (im_width + pad_num_w) - pad_right;
244 im2d_arr_pad
246 .slice_mut(s![.., pad_top..pad_bottom_int, pad_left..pad_right_int])
247 .assign(&im2d_arr);
248
249 let im_height_pad = im2d_arr_pad.len_of(Axis(1));
250 let im_width_pad = im2d_arr_pad.len_of(Axis(2));
251
252 im_col = im2col_ref(
253 im2d_arr_pad.view(),
254 kernel_height,
255 kernel_width,
256 im_height_pad,
257 im_width_pad,
258 im_channel,
259 stride,
260 );
261 } else {
262 im_col = im2col_ref(
263 im2d_arr,
264 kernel_height,
265 kernel_width,
266 im_height,
267 im_width,
268 im_channel,
269 stride,
270 );
271 };
272 let filter_transpose = filter_col.t();
273 let mul = im_col.dot(&filter_transpose);
274 let output = mul
275 .into_shape((new_im_height, new_im_width, num_filters))
276 .unwrap()
277 .permuted_axes([2, 0, 1]);
278
279 add_bias(&output, bias)
280}
281
282pub(in crate) fn add_bias<F>(x: &Array3<F>, bias: Option<&Array1<F>>) -> Array3<F>
283where
284 F: 'static + Float + std::ops::AddAssign,
285{
286 if let Some(bias_array) = bias {
287 assert!(
288 bias_array.shape()[0] == x.shape()[0],
289 "Bias array has the wrong shape {:?} for vec of shape {:?}",
290 bias_array.shape(),
291 x.shape()
292 );
293 (x + &bias_array
298 .clone()
299 .insert_axis(Axis(1))
300 .insert_axis(Axis(2))
301 .broadcast(x.shape())
302 .unwrap())
303 .into_dimensionality()
304 .unwrap()
305 } else {
306 x.clone()
307 }
308}