convolutions_rs/
convolutions.rs

1//! Module that contains classical convolutions, as used f.e. in convolutional neural networks.
2//!
3//! More can be read here:
4//! - <https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53?gi=f4a37beea40b>
5
6use crate::{DataRepresentation, Padding};
7use ndarray::*;
8use num_traits::Float;
9
10/// Rust implementation of a convolutional layer.
11/// The weight matrix shall have dimension (in that order)
12/// (input channels, output channels, kernel width, kernel height),
13/// to comply with the order in which pytorch weights are saved.
14pub struct ConvolutionLayer<F: Float> {
15    /// Weight matrix of the kernel
16    pub(in crate) kernel: Array4<F>,
17    pub(in crate) bias: Option<Array1<F>>,
18    pub(in crate) stride: usize,
19    pub(in crate) padding: Padding,
20}
21
22impl<F: 'static + Float + std::ops::AddAssign> ConvolutionLayer<F> {
23    /// Creates new convolution layer.
24    /// The weights are given in Pytorch layout.
25    /// (out channels, in channels, kernel height, kernel width)
26    /// Bias: (output height * output width, 1)
27    pub fn new(
28        weights: Array4<F>,
29        bias_array: Option<Array1<F>>,
30        stride: usize,
31        padding: Padding,
32    ) -> ConvolutionLayer<F> {
33        assert!(stride > 0, "Stride of 0 passed");
34        ConvolutionLayer {
35            kernel: weights,
36            bias: bias_array,
37            stride,
38            padding,
39        }
40    }
41
42    /// Creates new convolution layer. The weights are given in
43    /// Tensorflow layout.
44    /// (kernel height, kernel width, in channels, out channels)
45    pub fn new_tf(
46        weights: Array4<F>,
47        bias_array: Option<Array1<F>>,
48        stride: usize,
49        padding: Padding,
50    ) -> ConvolutionLayer<F> {
51        let permuted_view = weights.view().permuted_axes([3, 2, 0, 1]);
52        // Hack to fix the memory layout, permuted axes makes a
53        // col major array / non-contiguous array from weights
54        let permuted_array: Array4<F> =
55            Array::from_shape_vec(permuted_view.dim(), permuted_view.iter().copied().collect())
56                .unwrap();
57        ConvolutionLayer::new(permuted_array, bias_array, stride, padding)
58    }
59
60    /// Analog to conv2d.
61    pub fn convolve(&self, image: &DataRepresentation<F>) -> DataRepresentation<F> {
62        conv2d(
63            &self.kernel,
64            self.bias.as_ref(),
65            image,
66            self.padding,
67            self.stride,
68        )
69    }
70}
71
72pub(in crate) fn get_padding_size(
73    input_h: usize,
74    input_w: usize,
75    stride: usize,
76    kernel_h: usize,
77    kernel_w: usize,
78) -> (usize, usize, usize, usize, usize, usize) {
79    let pad_along_height: usize;
80    let pad_along_width: usize;
81    let idx_0: usize = 0;
82
83    if input_h % stride == idx_0 {
84        pad_along_height = (kernel_h - stride).max(idx_0);
85    } else {
86        pad_along_height = (kernel_h - (input_h % stride)).max(idx_0);
87    };
88    if input_w % stride == idx_0 {
89        pad_along_width = (kernel_w - stride).max(idx_0);
90    } else {
91        pad_along_width = (kernel_w - (input_w % stride)).max(idx_0);
92    };
93
94    let pad_top = pad_along_height / 2;
95    let pad_bottom = pad_along_height - pad_top;
96    let pad_left = pad_along_width / 2;
97    let pad_right = pad_along_width - pad_left;
98
99    // yes top/bottom and right/left are swapped. No, I don't know
100    // why this change makes it conform to the pytorchn implementation.
101    (
102        pad_along_height,
103        pad_along_width,
104        pad_bottom,
105        pad_top,
106        pad_right,
107        pad_left,
108    )
109}
110
111pub(in crate) fn im2col_ref<'a, T, F: 'a + Float>(
112    im_arr: T,
113    ker_height: usize,
114    ker_width: usize,
115    im_height: usize,
116    im_width: usize,
117    im_channel: usize,
118    stride: usize,
119) -> Array2<F>
120where
121    // Args:
122    //   im_arr: image matrix to be translated into columns, (C,H,W)
123    //   ker_height: filter height (hh)
124    //   ker_width: filter width (ww)
125    //   im_height: image height
126    //   im_width: image width
127    //
128    // Returns:
129    //   col: (new_h*new_w,hh*ww*C) matrix, each column is a cube that will convolve with a filter
130    //         new_h = (H-hh) // stride + 1, new_w = (W-ww) // stride + 1
131    T: AsArray<'a, F, Ix3>,
132{
133    let im2d_arr: ArrayView3<F> = im_arr.into();
134    let new_h = (im_height - ker_height) / stride + 1;
135    let new_w = (im_width - ker_width) / stride + 1;
136    let mut cols_img: Array2<F> =
137        Array::zeros((new_h * new_w, im_channel * ker_height * ker_width));
138    let mut cont = 0_usize;
139    for i in 1..new_h + 1 {
140        for j in 1..new_w + 1 {
141            let patch = im2d_arr.slice(s![
142                ..,
143                (i - 1) * stride..((i - 1) * stride + ker_height),
144                (j - 1) * stride..((j - 1) * stride + ker_width),
145            ]);
146            let patchrow_unwrap: Array1<F> = Array::from_iter(patch.map(|a| *a));
147
148            cols_img.row_mut(cont).assign(&patchrow_unwrap);
149            cont += 1;
150        }
151    }
152    cols_img
153}
154
155/// Performs a convolution on the given image data using this layers parameters.
156/// We always convolve on flattened images and expect the input array in im2col
157/// style format.
158///
159/// Read more here:
160/// - <https://leonardoaraujosantos.gitbook.io/artificial-inteligence/machine_learning/deep_learning/convolution_layer/making_faster>
161///
162/// Input:
163/// -----------------------------------------------
164/// - kernel_weights: weights of shape (F, C, HH, WW)
165/// - im2d: Input data of shape (C, H, W)
166/// -----------------------------------------------
167/// - 'stride': The number of pixels between adjacent receptive fields in the
168///     horizontal and vertical directions, must be int
169/// - 'pad': "Same" or "Valid"
170
171/// Returns:
172/// -----------------------------------------------
173/// - out: Output data, of shape (F, H', W')
174pub fn conv2d<'a, T, V, F: 'static + Float + std::ops::AddAssign>(
175    kernel_weights: T,
176    bias: Option<&Array1<F>>,
177    im2d: V,
178    padding: Padding,
179    stride: usize,
180) -> DataRepresentation<F>
181where
182    // This trait bound ensures that kernel and im2d can be passed as owned array or view.
183    // AsArray just ensures that im2d can be converted to an array view via ".into()".
184    // Read more here: https://docs.rs/ndarray/0.12.1/ndarray/trait.AsArray.html
185    V: AsArray<'a, F, Ix3>,
186    T: AsArray<'a, F, Ix4>,
187{
188    // Initialisations
189    let im2d_arr: ArrayView3<F> = im2d.into();
190    let kernel_weights_arr: ArrayView4<F> = kernel_weights.into();
191    let im_col: Array2<F>; // output of fn: im2col_ref()
192    let new_im_height: usize;
193    let new_im_width: usize;
194    let weight_shape = kernel_weights_arr.shape();
195    let num_filters = weight_shape[0] as usize;
196    let num_channels_out = weight_shape[1] as usize;
197    let kernel_height = weight_shape[2] as usize;
198    let kernel_width = weight_shape[3] as usize;
199
200    // Dimensions: C, H, W
201    let im_channel = im2d_arr.len_of(Axis(0));
202    let im_height = im2d_arr.len_of(Axis(1));
203    let im_width = im2d_arr.len_of(Axis(2));
204
205    // Calculate output shapes H', W' for two types of Padding
206    if padding == Padding::Same {
207        // https://mmuratarat.github.io/2019-01-17/implementing-padding-schemes-of-tensorflow-in-python
208        // H' = H / stride
209        // W' = W / stride
210
211        let h_float = im_height as f32;
212        let w_float = im_width as f32;
213        let stride_float = stride as f32;
214
215        let new_im_height_float = (h_float / stride_float).ceil();
216        let new_im_width_float = (w_float / stride_float).ceil();
217
218        new_im_height = new_im_height_float as usize;
219        new_im_width = new_im_width_float as usize;
220    } else {
221        // H' =  ((H - HH) / stride ) + 1
222        // W' =  ((W - WW) / stride ) + 1
223        new_im_height = ((im_height - kernel_height) / stride) + 1;
224        new_im_width = ((im_width - kernel_width) / stride) + 1;
225    };
226
227    // weights.reshape(F, HH*WW*C)
228    let filter_col = kernel_weights_arr
229        .into_shape((num_filters, kernel_height * kernel_width * num_channels_out))
230        .unwrap();
231
232    // fn:im2col() for different Paddings
233    if padding == Padding::Same {
234        // https://mmuratarat.github.io/2019-01-17/implementing-padding-schemes-of-tensorflow-in-python
235        let (pad_num_h, pad_num_w, pad_top, pad_bottom, pad_left, pad_right) =
236            get_padding_size(im_height, im_width, stride, kernel_height, kernel_width);
237        let mut im2d_arr_pad: Array3<F> = Array::zeros((
238            num_channels_out,
239            im_height + pad_num_h,
240            im_width + pad_num_w,
241        ));
242        let pad_bottom_int = (im_height + pad_num_h) - pad_bottom;
243        let pad_right_int = (im_width + pad_num_w) - pad_right;
244        // https://github.com/rust-ndarray/ndarray/issues/823
245        im2d_arr_pad
246            .slice_mut(s![.., pad_top..pad_bottom_int, pad_left..pad_right_int])
247            .assign(&im2d_arr);
248
249        let im_height_pad = im2d_arr_pad.len_of(Axis(1));
250        let im_width_pad = im2d_arr_pad.len_of(Axis(2));
251
252        im_col = im2col_ref(
253            im2d_arr_pad.view(),
254            kernel_height,
255            kernel_width,
256            im_height_pad,
257            im_width_pad,
258            im_channel,
259            stride,
260        );
261    } else {
262        im_col = im2col_ref(
263            im2d_arr,
264            kernel_height,
265            kernel_width,
266            im_height,
267            im_width,
268            im_channel,
269            stride,
270        );
271    };
272    let filter_transpose = filter_col.t();
273    let mul = im_col.dot(&filter_transpose);
274    let output = mul
275        .into_shape((new_im_height, new_im_width, num_filters))
276        .unwrap()
277        .permuted_axes([2, 0, 1]);
278
279    add_bias(&output, bias)
280}
281
282pub(in crate) fn add_bias<F>(x: &Array3<F>, bias: Option<&Array1<F>>) -> Array3<F>
283where
284    F: 'static + Float + std::ops::AddAssign,
285{
286    if let Some(bias_array) = bias {
287        assert!(
288            bias_array.shape()[0] == x.shape()[0],
289            "Bias array has the wrong shape {:?} for vec of shape {:?}",
290            bias_array.shape(),
291            x.shape()
292        );
293        // Yes this is really necessary. Broadcasting with ndarray-rust
294        // starts at the right side of the shape, so we have to add
295        // the axes by hand (else it thinks that it should compare the
296        // output width and the bias channels).
297        (x + &bias_array
298            .clone()
299            .insert_axis(Axis(1))
300            .insert_axis(Axis(2))
301            .broadcast(x.shape())
302            .unwrap())
303            .into_dimensionality()
304            .unwrap()
305    } else {
306        x.clone()
307    }
308}