renplex/cvnn/layer/
conv.rs

1use crate::{act::ComplexActFunc, err::{GradientError, LayerForwardError, LayerInitError}, init::InitMethod, input::{IOShape, IOType}, math::{matrix::Matrix, BasicOperations, Complex}};
2use crate::math::matrix::SliceOps;
3use super::{CLayer, ComplexDerivatives};
4
5/// Layer that computes pure convolution (no pad or strides).
6#[derive(Debug)]
7pub struct ConvCLayer<T> {
8  input_features_len: usize,
9  /// Matrix with shape [number of filters, number of channels]
10  kernels: Matrix<Matrix<T>>,
11  /// Each output feature map gets its own bias.
12  biases: Vec<T>,
13  func: ComplexActFunc
14}
15
16impl<T: Complex + BasicOperations<T>> ConvCLayer<T> {
17  /// Checks if the layer was not initialize. 
18  /// 
19  /// # Notes
20  /// 
21  /// This function will soon be deleted.
22  pub fn is_empty(&self) -> bool {
23    if self.kernels.get_shape() == &[0, 0] { true }
24    else { false }
25  }
26
27  /// Says if the layer propagates derivatives, returning a boolean.
28  pub fn propagates(&self) -> bool {
29    true
30  }
31
32  /// Calculates the number of parameters involved in the Layer
33  pub fn params_len(&self) -> (usize, usize) {
34    let mut kernel_params: usize = 0;
35    for kernel in self.kernels.get_body().iter() {
36      let kernel_shape = kernel.get_shape();
37      kernel_params += kernel_shape[0] * kernel_shape[1];
38    }
39
40    let bias_params = self.biases.len();
41
42    (kernel_params, bias_params)
43  }
44
45  /// Gives the input shape of the layer
46  pub fn get_input_shape(&self) -> IOShape {
47    IOShape::Matrix(self.input_features_len)
48  }
49
50  /// Gives the output shape of the layer
51  pub fn get_output_shape(&self) -> IOShape {
52    IOShape::Matrix(self.kernels.get_shape()[0])
53  }
54
55  /// Creates a convolutional layer and returns it initialized.
56  /// 
57  /// # Arguments
58  /// 
59  /// * `input_shape` - an [`IOShape`] related to input shape of the layer.
60  /// * `filters_len` - number of filters in the layer.
61  /// * `kernel_size` - two dimensional size of the kernels (filters). Depth is automatically calculated
62  /// based on the [`IOShape`].
63  /// * `func` - the [`ComplexActFunc`] to be used in the layer.
64  /// * `kernel_method` - method for intiating the kernel values.
65  /// * `seed` - seed for random number generation.
66  pub fn init(
67    input_shape: IOShape,
68    filters_len: usize,
69    kernel_size: [usize; 2],
70    func: ComplexActFunc,
71    kernel_method: InitMethod,
72    seed: &mut u128
73  ) -> Result<Self, LayerInitError> {
74
75    match input_shape {
76      IOShape::Matrix(input_features_len) => {
77        let mut kernels = Vec::new();
78        let mut biases = Vec::new();
79
80        for _filter in 0..filters_len {
81          for _channel in 0..input_features_len {
82            let mut kernel = Vec::new();
83            for _ in 0..kernel_size[0] {
84              for _ in 0..kernel_size[1] {
85                kernel.push(kernel_method.gen(seed));
86              }
87            }
88            /* add a channel */
89            kernels.push(Matrix::from_body(kernel, kernel_size));
90          }
91        }
92
93        for _ in 0..filters_len {
94          biases.push(T::default());
95        }
96
97        let kernels = Matrix::from_body(
98          kernels, 
99          [filters_len, input_features_len]
100        );
101
102        Ok(Self { input_features_len, kernels, biases, func })
103      },
104      _ => { Err(LayerInitError::InvalidInputShape) }
105    }
106  }
107
108  /// Returns a [`Result`] for the [`IOType<T>`] related to the prediction of the layer.
109  /// Error handling is not yet finished.
110  /// 
111  /// # Arguments
112  /// * `input_type` - a reference to a [`IOType<T>`] representing the input features of the layer.
113  pub fn forward(&self, input_type: &IOType<T>) -> Result<IOType<T>, LayerForwardError> {
114    match input_type {
115      IOType::Matrix(input) => {
116        let kernels_shape = self.kernels.get_shape();
117        let channels = kernels_shape[1];
118        let func = &self.func;
119        /* it has to be the same for all */
120        let filters_shape = self.kernels.elm(0, 0).unwrap().get_shape();
121        let input_shape = input[0].get_shape();
122        /* output of the convolutions */
123        let output_shape = [input_shape[0]-(filters_shape[0]-1), input_shape[1]-(filters_shape[1]-1)];
124
125        let input_features_len = input.len();
126        if channels != input_features_len { return Err(LayerForwardError::InvalidInput) }
127
128        let output_features = self.kernels
129          //.rows_as_par_chunks()
130          .rows_as_iter()
131          .zip(self.biases.iter())
132          .map(|(filter, bias)| {
133            let init_matrix = Matrix::from_body(
134              vec![T::default(); output_shape[0]*output_shape[1]], 
135              output_shape
136            );
137
138            let mut output_feature = input
139              .iter()
140              .zip(filter.iter())
141              .fold(init_matrix, |mut acc, (feature, kernel)| {
142                /* going through channels */
143                /* CHANGED HERE FOR COMPLEX CONVOLUTION OR NOT */
144                let convolved_feature = feature.convolution(kernel).unwrap();
145                acc.add_mut(&convolved_feature).unwrap();
146
147                acc
148              });
149
150            output_feature.add_mut_scalar(*bias).unwrap();
151            let output_body = output_feature.get_body_as_mut();
152            T::activate_mut(output_body, func);
153            
154            output_feature
155          }).collect::<Vec<_>>();
156
157        Ok(IOType::Matrix(output_features))
158      },
159      _ => { Err(LayerForwardError::InvalidInput) }
160    }
161  }
162
163  /// Return a [`Result`] for the derivatives and conjugate derivatives of the layer.
164  /// 
165  /// # Arguments
166  /// * `previous_act` - a reference to a [`IOType<T>`] representing the input features of the layer.
167  /// * `dlda` - gradients from an upper layer.
168  /// * `dlda_conj` - conjugate gradients from an upper layer.
169  pub fn compute_derivatives(&self, previous_act: &IOType<T>, dlda: Vec<T>, dlda_conj: Vec<T>) -> Result<ComplexDerivatives<T>, GradientError> {
170    match previous_act {
171      IOType::Matrix(input) => {
172        let n_input_features = input.len();
173        /* all of the shapes of the inputs should be the same */
174        /* because previous layers always decrease the size equally throughout features */
175        let input_shape = input[0].get_shape();
176        let kernels_shape = self.kernels.get_shape();
177        let kernel_shape = self.kernels.elm(0, 0).unwrap().get_shape();
178        let padx = kernel_shape[0] - 1;
179        let pady = kernel_shape[1] - 1;
180        let output_shape = [input_shape[0] - padx, input_shape[1] - pady];
181        
182        let act_func = &self.func;
183
184        let q = self.compute_q(input);
185
186        /*  CHECK IF CALCULATING THE DERIVATIVE OF THE ACTIVATION HERE IS CORRECT */
187        /* yap it seems to be correct */
188        let mut dlda_dadq = q
189          .iter()
190          .zip(dlda.into_iter())
191          .map(|(elm, dlda_val)| {
192            /* dadq * dlda values */
193            elm.d_activate(act_func) * dlda_val
194          }).collect::<Vec<_>>();
195
196        let mut dlda_conj_da_conj_dq = q
197          .iter()
198          .zip(dlda_conj.into_iter())
199          .map(|(elm, dlda_conj_val)| {
200            /* da_conj_dq * dlda_conj values */
201            elm.d_conj_activate(act_func).conj() * dlda_conj_val
202          }).collect::<Vec<_>>();
203
204        drop(q);
205
206        /* all output features have the same size since previous filters have all the same size */
207        let out_feat_size = dlda_dadq.len() / kernels_shape[0];
208
209        /* go through filters to update their derivatives */
210        /* each loss gradient chunk is related to a single filter */
211
212        /* can be optimized! */
213        let mut dldk = Vec::new();
214        let mut dldb = Vec::new();
215        let mut new_dlda = vec![T::default(); input_shape[0] * input_shape[1] * n_input_features];
216        let mut new_dlda_conj = vec![T::default(); input_shape[0] * input_shape[1] * n_input_features];
217        let filters = self.kernels.rows_as_iter();
218        filters.for_each(|filter| {
219          /* collecting loss derivatives to matrices */
220          /* maybe there is a better way */
221          let dlda_dadq_feat = Matrix::from_body(
222            dlda_dadq.drain(0..out_feat_size).collect::<Vec<_>>(), 
223            output_shape
224          );
225          let dlda_conj_da_conj_dq_feat = Matrix::from_body(
226            dlda_conj_da_conj_dq.drain(0..out_feat_size).collect::<Vec<_>>(), 
227            output_shape
228          );
229
230          let oper_a = || {
231            let dldk_per_filter = input
232              .into_iter()
233              .flat_map(|feature| {
234                /* CHANGED HERE FOR COMPLEX CONVOLUTION OR NOT */
235                /* YOU CAN CHECK HERE WHAT HAS GREATER PERFORMANCE: 
236                TWO CONVS -> ADD TERMS OR ADD TERMS -> ONE CONV */
237                let mut dldk_term1 = feature.convolution(&dlda_dadq_feat).unwrap();
238                let dldk_term2 = feature.convolution(&dlda_conj_da_conj_dq_feat).unwrap();
239                
240                dldk_term1.add_mut(&dldk_term2).unwrap();
241                dldk_term1.export_body()
242              }).collect::<Vec<_>>();
243            
244            /* calculate bias derivative */
245            let dldb_per_filter = dlda_dadq_feat
246              .get_body()
247              .add_slice(dlda_conj_da_conj_dq_feat.get_body())
248              .unwrap();
249
250            (dldk_per_filter, dldb_per_filter)
251          };
252
253          /* flipping kernels for derivatives */
254          let fliped_filter = filter
255            .iter()
256            .map(|kernel| {
257              kernel.flip().unwrap()
258            })
259            .collect::<Vec<_>>();
260          /* calculate loss derivative */
261          let dlda_padded = dlda_dadq_feat.clone().pad((padx, pady));
262          let dlda_conj_padded = dlda_conj_da_conj_dq_feat.clone().pad((padx, pady));
263
264          let oper_b = || {
265            let new_dlda_acc = fliped_filter
266              .iter()
267              .flat_map(|flip_kernel| {
268                /* CHANGED HERE FOR COMPLEX CONVOLUTION OR NOT */
269                dlda_padded
270                  .convolution(flip_kernel)
271                  .unwrap()
272                  .export_body()
273              }).collect::<Vec<_>>();
274            
275            let new_dlda_conj_acc = fliped_filter
276              .iter()
277              .flat_map(|flip_kernel| {
278                /* CHANGED HERE FOR COMPLEX CONVOLUTION OR NOT */
279                dlda_conj_padded
280                  .convolution(flip_kernel)
281                  .unwrap()
282                  .export_body()
283              }).collect::<Vec<_>>();
284            
285            (new_dlda_acc, new_dlda_conj_acc)
286          };
287
288          /* using rayon join might be a good addition in the future */
289          /* be careful because if it is not, you can change in the future for a more memory efficient approach */
290          let ((dldk_filter, dldb_filter), (new_dlda_acc, new_dlda_conj_acc)) = (oper_a(), oper_b());
291          dldk.extend(dldk_filter);
292          dldb.push(dldb_filter.into_iter().reduce(|acc, elm| { acc + elm }).unwrap());
293          new_dlda.add_slice_mut(&new_dlda_acc).unwrap();
294          new_dlda_conj.add_slice_mut(&new_dlda_conj_acc).unwrap();
295        });
296
297        Ok((dldk, dldb, new_dlda, new_dlda_conj))
298      },
299      _ => { panic!("Something went terribily wrong.") }
300    }
301  }
302
303  /// Adjusts the parameters of the layer with negative conjugate.
304  /// 
305  /// # Arguments
306  /// 
307  /// * `dldw` - adjustments on the weights.
308  /// * `dldb` - adjustments on the biases.
309  pub fn neg_conj_adjustment(&mut self, dldw: Vec<T>, dldb: Vec<T>) -> Result<(), GradientError> {
310    let dldw_size = dldw.len();
311    let dldb_size = dldb.len();
312    
313    /* if there is an error it can be here */
314    let (weights, biases) = self.params_len();
315
316    if dldb_size != biases {
317      return Err(GradientError::InconsistentShape)
318    } 
319    if dldw_size != weights {
320      return Err(GradientError::InconsistentShape)
321    }
322
323    self.kernels
324      .get_body_as_mut()
325      .iter_mut()
326      .flat_map(|elm| elm.get_body_as_mut())
327      .zip(dldw.into_iter())
328      .for_each(|(elm, dk)| { *elm -= dk.conj(); });
329
330    self.biases.iter_mut().zip(dldb).for_each(|(bias, db)| {
331      *bias -= db.conj();
332    });
333
334    Ok(())
335  }
336
337  /// Wraps the convolutional layer into the general [`CLayer`] interface.
338  pub fn wrap(self) -> CLayer<T> {
339    CLayer::Convolutional(self)
340  }
341
342  fn compute_q(&self, input: &Vec<Matrix<T>>) -> Vec<T> {
343    /* it has to be the same for all */
344    let filters_shape = self.kernels.elm(0, 0).unwrap().get_shape();
345    let input_shape = input[0].get_shape();
346    /* output of the convolutions */
347    let output_shape = [input_shape[0]-(filters_shape[0]-1), input_shape[1]-(filters_shape[1]-1)];
348
349    let output_features_flat = self.kernels
350      .rows_as_iter()
351      .zip(self.biases.iter())
352      .flat_map(|(filter, bias)| {
353        let init_matrix = Matrix::from_body(
354          vec![T::default(); output_shape[0]*output_shape[1]], 
355          output_shape
356        );
357
358        let mut output_feature = input
359          .iter()
360          .zip(filter.iter())
361          .fold(init_matrix, |mut acc, (feature, kernel)| {
362            /* going through channels */
363            /* CHANGED HERE FOR COMPLEX CONVOLUTION OR NOT */
364            let convolved_feature = feature.convolution(kernel).unwrap();
365            acc.add_mut(&convolved_feature).unwrap();
366
367            acc
368          });
369
370        output_feature.add_mut_scalar(*bias).unwrap();
371        
372        output_feature.export_body()
373      }).collect::<Vec<_>>();
374    
375    output_features_flat
376  }
377}