renplex/cvnn/layer/
conv.rs1use crate::{act::ComplexActFunc, err::{GradientError, LayerForwardError, LayerInitError}, init::InitMethod, input::{IOShape, IOType}, math::{matrix::Matrix, BasicOperations, Complex}};
2use crate::math::matrix::SliceOps;
3use super::{CLayer, ComplexDerivatives};
4
5#[derive(Debug)]
7pub struct ConvCLayer<T> {
8 input_features_len: usize,
9 kernels: Matrix<Matrix<T>>,
11 biases: Vec<T>,
13 func: ComplexActFunc
14}
15
16impl<T: Complex + BasicOperations<T>> ConvCLayer<T> {
17 pub fn is_empty(&self) -> bool {
23 if self.kernels.get_shape() == &[0, 0] { true }
24 else { false }
25 }
26
27 pub fn propagates(&self) -> bool {
29 true
30 }
31
32 pub fn params_len(&self) -> (usize, usize) {
34 let mut kernel_params: usize = 0;
35 for kernel in self.kernels.get_body().iter() {
36 let kernel_shape = kernel.get_shape();
37 kernel_params += kernel_shape[0] * kernel_shape[1];
38 }
39
40 let bias_params = self.biases.len();
41
42 (kernel_params, bias_params)
43 }
44
45 pub fn get_input_shape(&self) -> IOShape {
47 IOShape::Matrix(self.input_features_len)
48 }
49
50 pub fn get_output_shape(&self) -> IOShape {
52 IOShape::Matrix(self.kernels.get_shape()[0])
53 }
54
55 pub fn init(
67 input_shape: IOShape,
68 filters_len: usize,
69 kernel_size: [usize; 2],
70 func: ComplexActFunc,
71 kernel_method: InitMethod,
72 seed: &mut u128
73 ) -> Result<Self, LayerInitError> {
74
75 match input_shape {
76 IOShape::Matrix(input_features_len) => {
77 let mut kernels = Vec::new();
78 let mut biases = Vec::new();
79
80 for _filter in 0..filters_len {
81 for _channel in 0..input_features_len {
82 let mut kernel = Vec::new();
83 for _ in 0..kernel_size[0] {
84 for _ in 0..kernel_size[1] {
85 kernel.push(kernel_method.gen(seed));
86 }
87 }
88 kernels.push(Matrix::from_body(kernel, kernel_size));
90 }
91 }
92
93 for _ in 0..filters_len {
94 biases.push(T::default());
95 }
96
97 let kernels = Matrix::from_body(
98 kernels,
99 [filters_len, input_features_len]
100 );
101
102 Ok(Self { input_features_len, kernels, biases, func })
103 },
104 _ => { Err(LayerInitError::InvalidInputShape) }
105 }
106 }
107
108 pub fn forward(&self, input_type: &IOType<T>) -> Result<IOType<T>, LayerForwardError> {
114 match input_type {
115 IOType::Matrix(input) => {
116 let kernels_shape = self.kernels.get_shape();
117 let channels = kernels_shape[1];
118 let func = &self.func;
119 let filters_shape = self.kernels.elm(0, 0).unwrap().get_shape();
121 let input_shape = input[0].get_shape();
122 let output_shape = [input_shape[0]-(filters_shape[0]-1), input_shape[1]-(filters_shape[1]-1)];
124
125 let input_features_len = input.len();
126 if channels != input_features_len { return Err(LayerForwardError::InvalidInput) }
127
128 let output_features = self.kernels
129 .rows_as_iter()
131 .zip(self.biases.iter())
132 .map(|(filter, bias)| {
133 let init_matrix = Matrix::from_body(
134 vec![T::default(); output_shape[0]*output_shape[1]],
135 output_shape
136 );
137
138 let mut output_feature = input
139 .iter()
140 .zip(filter.iter())
141 .fold(init_matrix, |mut acc, (feature, kernel)| {
142 let convolved_feature = feature.convolution(kernel).unwrap();
145 acc.add_mut(&convolved_feature).unwrap();
146
147 acc
148 });
149
150 output_feature.add_mut_scalar(*bias).unwrap();
151 let output_body = output_feature.get_body_as_mut();
152 T::activate_mut(output_body, func);
153
154 output_feature
155 }).collect::<Vec<_>>();
156
157 Ok(IOType::Matrix(output_features))
158 },
159 _ => { Err(LayerForwardError::InvalidInput) }
160 }
161 }
162
163 pub fn compute_derivatives(&self, previous_act: &IOType<T>, dlda: Vec<T>, dlda_conj: Vec<T>) -> Result<ComplexDerivatives<T>, GradientError> {
170 match previous_act {
171 IOType::Matrix(input) => {
172 let n_input_features = input.len();
173 let input_shape = input[0].get_shape();
176 let kernels_shape = self.kernels.get_shape();
177 let kernel_shape = self.kernels.elm(0, 0).unwrap().get_shape();
178 let padx = kernel_shape[0] - 1;
179 let pady = kernel_shape[1] - 1;
180 let output_shape = [input_shape[0] - padx, input_shape[1] - pady];
181
182 let act_func = &self.func;
183
184 let q = self.compute_q(input);
185
186 let mut dlda_dadq = q
189 .iter()
190 .zip(dlda.into_iter())
191 .map(|(elm, dlda_val)| {
192 elm.d_activate(act_func) * dlda_val
194 }).collect::<Vec<_>>();
195
196 let mut dlda_conj_da_conj_dq = q
197 .iter()
198 .zip(dlda_conj.into_iter())
199 .map(|(elm, dlda_conj_val)| {
200 elm.d_conj_activate(act_func).conj() * dlda_conj_val
202 }).collect::<Vec<_>>();
203
204 drop(q);
205
206 let out_feat_size = dlda_dadq.len() / kernels_shape[0];
208
209 let mut dldk = Vec::new();
214 let mut dldb = Vec::new();
215 let mut new_dlda = vec![T::default(); input_shape[0] * input_shape[1] * n_input_features];
216 let mut new_dlda_conj = vec![T::default(); input_shape[0] * input_shape[1] * n_input_features];
217 let filters = self.kernels.rows_as_iter();
218 filters.for_each(|filter| {
219 let dlda_dadq_feat = Matrix::from_body(
222 dlda_dadq.drain(0..out_feat_size).collect::<Vec<_>>(),
223 output_shape
224 );
225 let dlda_conj_da_conj_dq_feat = Matrix::from_body(
226 dlda_conj_da_conj_dq.drain(0..out_feat_size).collect::<Vec<_>>(),
227 output_shape
228 );
229
230 let oper_a = || {
231 let dldk_per_filter = input
232 .into_iter()
233 .flat_map(|feature| {
234 let mut dldk_term1 = feature.convolution(&dlda_dadq_feat).unwrap();
238 let dldk_term2 = feature.convolution(&dlda_conj_da_conj_dq_feat).unwrap();
239
240 dldk_term1.add_mut(&dldk_term2).unwrap();
241 dldk_term1.export_body()
242 }).collect::<Vec<_>>();
243
244 let dldb_per_filter = dlda_dadq_feat
246 .get_body()
247 .add_slice(dlda_conj_da_conj_dq_feat.get_body())
248 .unwrap();
249
250 (dldk_per_filter, dldb_per_filter)
251 };
252
253 let fliped_filter = filter
255 .iter()
256 .map(|kernel| {
257 kernel.flip().unwrap()
258 })
259 .collect::<Vec<_>>();
260 let dlda_padded = dlda_dadq_feat.clone().pad((padx, pady));
262 let dlda_conj_padded = dlda_conj_da_conj_dq_feat.clone().pad((padx, pady));
263
264 let oper_b = || {
265 let new_dlda_acc = fliped_filter
266 .iter()
267 .flat_map(|flip_kernel| {
268 dlda_padded
270 .convolution(flip_kernel)
271 .unwrap()
272 .export_body()
273 }).collect::<Vec<_>>();
274
275 let new_dlda_conj_acc = fliped_filter
276 .iter()
277 .flat_map(|flip_kernel| {
278 dlda_conj_padded
280 .convolution(flip_kernel)
281 .unwrap()
282 .export_body()
283 }).collect::<Vec<_>>();
284
285 (new_dlda_acc, new_dlda_conj_acc)
286 };
287
288 let ((dldk_filter, dldb_filter), (new_dlda_acc, new_dlda_conj_acc)) = (oper_a(), oper_b());
291 dldk.extend(dldk_filter);
292 dldb.push(dldb_filter.into_iter().reduce(|acc, elm| { acc + elm }).unwrap());
293 new_dlda.add_slice_mut(&new_dlda_acc).unwrap();
294 new_dlda_conj.add_slice_mut(&new_dlda_conj_acc).unwrap();
295 });
296
297 Ok((dldk, dldb, new_dlda, new_dlda_conj))
298 },
299 _ => { panic!("Something went terribily wrong.") }
300 }
301 }
302
303 pub fn neg_conj_adjustment(&mut self, dldw: Vec<T>, dldb: Vec<T>) -> Result<(), GradientError> {
310 let dldw_size = dldw.len();
311 let dldb_size = dldb.len();
312
313 let (weights, biases) = self.params_len();
315
316 if dldb_size != biases {
317 return Err(GradientError::InconsistentShape)
318 }
319 if dldw_size != weights {
320 return Err(GradientError::InconsistentShape)
321 }
322
323 self.kernels
324 .get_body_as_mut()
325 .iter_mut()
326 .flat_map(|elm| elm.get_body_as_mut())
327 .zip(dldw.into_iter())
328 .for_each(|(elm, dk)| { *elm -= dk.conj(); });
329
330 self.biases.iter_mut().zip(dldb).for_each(|(bias, db)| {
331 *bias -= db.conj();
332 });
333
334 Ok(())
335 }
336
337 pub fn wrap(self) -> CLayer<T> {
339 CLayer::Convolutional(self)
340 }
341
342 fn compute_q(&self, input: &Vec<Matrix<T>>) -> Vec<T> {
343 let filters_shape = self.kernels.elm(0, 0).unwrap().get_shape();
345 let input_shape = input[0].get_shape();
346 let output_shape = [input_shape[0]-(filters_shape[0]-1), input_shape[1]-(filters_shape[1]-1)];
348
349 let output_features_flat = self.kernels
350 .rows_as_iter()
351 .zip(self.biases.iter())
352 .flat_map(|(filter, bias)| {
353 let init_matrix = Matrix::from_body(
354 vec![T::default(); output_shape[0]*output_shape[1]],
355 output_shape
356 );
357
358 let mut output_feature = input
359 .iter()
360 .zip(filter.iter())
361 .fold(init_matrix, |mut acc, (feature, kernel)| {
362 let convolved_feature = feature.convolution(kernel).unwrap();
365 acc.add_mut(&convolved_feature).unwrap();
366
367 acc
368 });
369
370 output_feature.add_mut_scalar(*bias).unwrap();
371
372 output_feature.export_body()
373 }).collect::<Vec<_>>();
374
375 output_features_flat
376 }
377}