1use std::collections::HashMap;
4
5use crate::{activation, assert_eq_shape, network, optimizer, tensor};
6
7#[derive(Clone)]
8pub enum Accumulation {
9 Add,
10 Subtract,
11 Multiply,
12 Overwrite,
13 Mean,
14 }
16
17impl std::fmt::Display for Accumulation {
18 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
19 match self {
20 Accumulation::Add => write!(f, "additive"),
21 Accumulation::Subtract => write!(f, "subtractive"),
22 Accumulation::Multiply => write!(f, "multiplicative"),
23 Accumulation::Overwrite => write!(f, "overwrite"),
24 Accumulation::Mean => write!(f, "mean"),
25 #[allow(unreachable_patterns)]
26 _ => unimplemented!("Accumulation method not implemented."),
27 }
28 }
29}
30
31pub enum Layer {
55 Dense(usize, activation::Activation, bool, Option<f32>),
56 Convolution(
57 usize,
58 activation::Activation,
59 (usize, usize),
60 (usize, usize),
61 (usize, usize),
62 (usize, usize),
63 Option<f32>,
64 ),
65 Deconvolution(
66 usize,
67 activation::Activation,
68 (usize, usize),
69 (usize, usize),
70 (usize, usize),
71 Option<f32>,
72 ),
73 Maxpool((usize, usize), (usize, usize)),
74}
75
76#[derive(Clone)]
93pub struct Feedback {
94 pub(crate) inputs: tensor::Shape,
95 pub(crate) outputs: tensor::Shape,
96 pub(crate) optimizer: optimizer::Optimizer,
97 pub(crate) flatten: bool,
98 pub layers: Vec<network::Layer>,
99 connect: HashMap<usize, Vec<usize>>,
100 pub(crate) accumulation: Accumulation,
101 coupled: Vec<Vec<usize>>,
102}
103
104impl std::fmt::Display for Feedback {
105 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
106 write!(f, "Feedback (\n")?;
107 write!(f, "\t\t\t{} -> {}\n", self.inputs, self.outputs)?;
108
109 write!(f, "\t\t\tlayers: (\n")?;
119 for (i, layer) in self.layers.iter().enumerate() {
120 match layer {
121 network::Layer::Dense(layer) => {
122 write!(
123 f,
124 "\t\t\t\t{}: Dense{} ({} -> {})\n",
125 i, layer.activation, layer.inputs, layer.outputs
126 )?;
127 }
128 network::Layer::Convolution(layer) => {
129 write!(
130 f,
131 "\t\t\t\t{}: Convolution{} ({} -> {})\n",
132 i, layer.activation, layer.inputs, layer.outputs
133 )?;
134 }
135 network::Layer::Deconvolution(layer) => {
136 write!(
137 f,
138 "\t\t\t\t{}: Decovolution{} ({} -> {})\n",
139 i, layer.activation, layer.inputs, layer.outputs
140 )?;
141 }
142 network::Layer::Maxpool(layer) => {
143 write!(
144 f,
145 "\t\t\t\t{}: Maxpool ({} -> {})\n",
146 i, layer.inputs, layer.outputs
147 )?;
148 }
149 network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
150 }
151 }
152 write!(f, "\t\t\t)\n")?;
153 if !self.coupled.is_empty() {
154 write!(f, "\t\t\tcoupled: (\n")?;
155 for coupled in self.coupled.iter() {
156 write!(f, "\t\t\t\t{:?}\n", coupled)?;
157 }
158 write!(f, "\t\t\t\taccumulation: {}\n", self.accumulation)?;
159 write!(f, "\t\t\t)\n")?;
160 }
161 if !self.connect.is_empty() {
162 write!(f, "\t\t\tconnections: (\n")?;
163 write!(f, "\t\t\t\taccumulation: {}\n", self.accumulation)?;
164
165 let mut entries: Vec<(&usize, &Vec<usize>)> = self.connect.iter().collect();
166 entries.sort_by_key(|&(to, _)| to);
167 for (to, from) in entries.iter() {
168 write!(f, "\t\t\t\t{:?}.input -> {}.input\n", from, to)?;
169 }
170 write!(f, "\t\t\t)\n")?;
171 }
172 write!(f, "\t\t\tflatten: {}\n", self.flatten)?;
173 write!(f, "\t\t)")?;
174 Ok(())
175 }
176}
177
178impl Feedback {
179 pub fn create(
189 mut layers: Vec<network::Layer>,
190 loops: usize,
191 inskips: bool,
192 outskips: bool,
193 accumulation: Accumulation,
194 ) -> Self {
195 assert!(loops > 0, "Feedback block should loop at least once.");
196 let inputs = match layers.first().unwrap() {
197 network::Layer::Dense(dense) => dense.inputs.clone(),
198 network::Layer::Convolution(convolution) => convolution.inputs.clone(),
199 network::Layer::Deconvolution(deconvolution) => deconvolution.inputs.clone(),
200 network::Layer::Maxpool(maxpool) => maxpool.inputs.clone(),
201 network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
202 };
203 let outputs = match layers.last().unwrap() {
204 network::Layer::Dense(dense) => dense.outputs.clone(),
205 network::Layer::Convolution(convolution) => convolution.outputs.clone(),
206 network::Layer::Deconvolution(deconvolution) => deconvolution.outputs.clone(),
207 network::Layer::Maxpool(maxpool) => maxpool.outputs.clone(),
208 network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
209 };
210 assert_eq_shape!(inputs, outputs);
211
212 let length = layers.len();
213
214 let _layers = layers.clone();
216 for _ in 1..loops {
217 layers.extend(_layers.clone());
218 }
219
220 let mut coupled: Vec<Vec<usize>> = Vec::new();
222 for layer in 0..length {
223 let mut coupling = Vec::new();
224 for i in 0..loops {
225 coupling.push(layer + i * length);
226 }
227 coupled.push(coupling);
228 }
229
230 let mut connect: HashMap<usize, Vec<usize>> = HashMap::new();
232 if inskips || outskips {
233 let mut outputs = Vec::new();
234 for i in 1..loops {
235 if inskips {
236 connect.insert(i * length, vec![0]);
238 }
239 if outskips {
240 outputs.push(i * length);
241 }
242 }
243 if outskips {
244 connect.insert(loops * length, outputs);
246 }
247 }
248
249 Feedback {
250 inputs,
251 outputs,
252 optimizer: optimizer::SGD::create(0.1, None),
253 flatten: false,
254 layers,
255 connect,
256 accumulation,
257 coupled,
258 }
259 }
260
261 pub fn copy_optimizer(&mut self, mut optimizer: optimizer::Optimizer) {
267 let mut vectors: Vec<Vec<Vec<tensor::Tensor>>> = Vec::new();
268 for layer in self.layers.iter().rev() {
269 match layer {
270 network::Layer::Dense(layer) => {
271 let (output, input) = match &layer.weights.shape {
272 tensor::Shape::Double(output, input) => (*output, *input),
273 _ => panic!("Expected Dense shape"),
274 };
275 vectors.push(vec![vec![
276 tensor::Tensor::double(vec![vec![0.0; input]; output]),
277 if layer.bias.is_some() {
278 tensor::Tensor::single(vec![0.0; output])
279 } else {
280 tensor::Tensor::single(vec![])
281 },
282 ]]);
283 }
284 network::Layer::Convolution(layer) => {
285 let (ch, kh, kw) = match layer.kernels[0].shape {
286 tensor::Shape::Triple(ch, he, wi) => (ch, he, wi),
287 _ => panic!("Expected Convolution shape"),
288 };
289 vectors.push(vec![
290 vec![
291 tensor::Tensor::triple(vec![vec![vec![0.0; kw]; kh]; ch]),
292 ];
294 layer.kernels.len()
295 ]);
296 }
297 network::Layer::Deconvolution(layer) => {
298 let (ch, kh, kw) = match layer.kernels[0].shape {
299 tensor::Shape::Triple(ch, he, wi) => (ch, he, wi),
300 _ => panic!("Expected Convolution shape"),
301 };
302 vectors.push(vec![
303 vec![
304 tensor::Tensor::triple(vec![vec![vec![0.0; kw]; kh]; ch]),
305 ];
307 layer.kernels.len()
308 ]);
309 }
310 network::Layer::Maxpool(_) => {
311 vectors.push(vec![vec![tensor::Tensor::single(vec![0.0; 0])]])
312 }
313 _ => unimplemented!("Feedback blocks not yet implemented."),
314 }
315 }
316
317 optimizer.validate(vectors);
320
321 self.optimizer = optimizer;
322 }
323
324 pub fn parameters(&self) -> usize {
327 let mut parameters = 0;
328 for idx in 0..self.coupled.len() {
329 parameters += match &self.layers[idx] {
330 network::Layer::Dense(dense) => dense.parameters(),
331 network::Layer::Convolution(convolution) => convolution.parameters(),
332 network::Layer::Deconvolution(deconvolution) => deconvolution.parameters(),
333 network::Layer::Maxpool(_) => 0,
334 network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
335 };
336 }
337 parameters
338 }
339
340 pub fn training(&mut self, train: bool) {
341 self.layers.iter_mut().for_each(|layer| match layer {
342 network::Layer::Dense(layer) => layer.training = train,
343 network::Layer::Convolution(layer) => layer.training = train,
344 network::Layer::Deconvolution(layer) => layer.training = train,
345 network::Layer::Maxpool(_) => {}
346 network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
347 });
348 }
349
350 pub fn forward(
365 &self,
366 input: &tensor::Tensor,
367 ) -> (
368 tensor::Tensor,
369 tensor::Tensor,
370 tensor::Tensor,
371 tensor::Tensor,
372 tensor::Tensor,
373 ) {
374 let mut unactivated = Vec::with_capacity(self.layers.len());
375 let mut activated = Vec::with_capacity(self.layers.len() + 1);
376 let mut maxpools = Vec::with_capacity(self.layers.len());
377
378 activated.push(input.clone());
379
380 for (i, layer) in self.layers.iter().enumerate() {
381 let mut x = activated.last().unwrap().clone();
382
383 if self.connect.contains_key(&i) {
385 match self.accumulation {
386 Accumulation::Add => {
387 for idx in self.connect.get(&i).unwrap() {
388 x.add_inplace(&activated[*idx]);
389 }
390 }
391 Accumulation::Subtract => {
392 for idx in self.connect.get(&i).unwrap() {
393 x.sub_inplace(&activated[*idx]);
394 }
395 }
396 Accumulation::Multiply => {
397 for idx in self.connect.get(&i).unwrap() {
398 x.mul_inplace(&activated[*idx]);
399 }
400 }
401 Accumulation::Overwrite => {
402 x = activated[*self.connect.get(&i).unwrap().last().unwrap()].clone();
403 }
404 Accumulation::Mean => {
405 let mut _x: Vec<&tensor::Tensor> = Vec::new();
406 for idx in self.connect.get(&i).unwrap() {
407 _x.push(&activated[*idx]);
408 }
409 x.mean_inplace(&_x);
410 }
411 #[allow(unreachable_patterns)]
412 _ => unimplemented!("Accumulation method not implemented."),
413 }
414 }
415
416 let (pre, post, max) = match layer {
417 network::Layer::Dense(layer) => {
418 assert_eq_shape!(layer.inputs, x.shape);
419 let (pre, post) = layer.forward(&x);
420 (pre, post, None)
421 }
422 network::Layer::Convolution(layer) => {
423 assert_eq_shape!(layer.inputs, x.shape);
424 let (pre, post) = layer.forward(&x);
425 (pre, post, None)
426 }
427 network::Layer::Deconvolution(layer) => {
428 assert_eq_shape!(layer.inputs, x.shape);
429 let (pre, post) = layer.forward(&x);
430 (pre, post, None)
431 }
432 network::Layer::Maxpool(layer) => {
433 assert_eq_shape!(layer.inputs, x.shape);
434 let (pre, post, max) = layer.forward(&x);
435 (pre, post, Some(max))
436 }
437 network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
438 };
439
440 unactivated.push(pre);
441 activated.push(post);
442 maxpools.push(max);
443 }
444
445 let mut last = activated.pop().unwrap();
446
447 if self.connect.contains_key(&self.layers.len()) {
449 let i = self.layers.len();
450 match self.accumulation {
451 Accumulation::Add => {
452 for idx in self.connect.get(&i).unwrap() {
453 last.add_inplace(&activated[*idx]);
454 }
455 }
456 Accumulation::Subtract => {
457 for idx in self.connect.get(&i).unwrap() {
458 last.sub_inplace(&activated[*idx]);
459 }
460 }
461 Accumulation::Multiply => {
462 for idx in self.connect.get(&i).unwrap() {
463 last.mul_inplace(&activated[*idx]);
464 }
465 }
466 Accumulation::Overwrite => {
467 last = activated[*self.connect.get(&i).unwrap().last().unwrap()].clone();
468 }
469 Accumulation::Mean => {
470 let mut _x: Vec<&tensor::Tensor> = Vec::new();
471 for idx in self.connect.get(&i).unwrap() {
472 _x.push(&activated[*idx]);
473 }
474 last.mean_inplace(&_x);
475 }
476 #[allow(unreachable_patterns)]
477 _ => unimplemented!("Accumulation method not implemented."),
478 }
479 }
480
481 if self.flatten {
483 activated.push(last.flatten());
484 } else {
485 activated.push(last);
486 }
487
488 (
489 unactivated[0].clone(),
490 activated[activated.len() - 1].clone(),
491 tensor::Tensor::nestedoptional(maxpools),
492 tensor::Tensor::nested(unactivated),
493 tensor::Tensor::nested(activated),
494 )
495 }
496
497 pub fn backward(
508 &self,
509 gradient: &tensor::Tensor,
510 inbetween: &Vec<tensor::Tensor>,
511 ) -> (tensor::Tensor, tensor::Tensor, Option<tensor::Tensor>) {
512 let unactivated = inbetween[0].unnested();
514 let activated = inbetween[1].unnested();
515
516 let mut gradients: Vec<tensor::Tensor> = vec![gradient.clone()];
517 let mut weight_gradients: Vec<tensor::Tensor> = Vec::new();
518 let mut bias_gradients: Vec<Option<tensor::Tensor>> = Vec::new();
519
520 let mut connect: HashMap<usize, Vec<usize>> = HashMap::new();
521 for (key, value) in self.connect.iter() {
522 for idx in value.iter() {
523 if connect.contains_key(idx) {
525 connect.get_mut(idx).unwrap().push(*key);
526 } else {
527 connect.insert(*idx, vec![*key]);
528 }
529 }
530 }
531
532 self.layers.iter().rev().enumerate().for_each(|(i, layer)| {
533 let idx = self.layers.len() - i - 1;
534
535 let input: &tensor::Tensor = &activated[idx];
536 let output: &tensor::Tensor = &unactivated[idx];
537
538 if connect.contains_key(&idx) {
541 for j in connect[&idx].iter() {
542 let mut idx = *j;
543 if j == &self.layers.len() {
544 idx = idx - 1;
548 }
549 let gradient = gradients[self.layers.len() - idx - 1].clone();
550 gradients.last_mut().unwrap().add_inplace(&gradient);
551 }
552 }
554
555 let (gradient, wg, bg) = match layer {
556 network::Layer::Dense(layer) => {
557 layer.backward(&gradients.last().unwrap(), input, output)
558 }
559 network::Layer::Convolution(layer) => {
560 layer.backward(&gradients.last().unwrap(), input, output)
561 }
562 network::Layer::Deconvolution(layer) => {
563 layer.backward(&gradients.last().unwrap(), input, output)
564 }
565 _ => panic!("Unsupported layer type."),
566 };
567
568 gradients.push(gradient);
569 weight_gradients.push(wg);
570 bias_gradients.push(bg);
571 });
572
573 return (
574 gradients.last().unwrap().clone(),
575 tensor::Tensor::nested(weight_gradients),
576 Some(tensor::Tensor::nestedoptional(bias_gradients)),
577 );
578 }
579
580 pub fn update(
581 &mut self,
582 stepnr: i32,
583 weight_gradients: &mut tensor::Tensor,
584 bias_gradients: &mut tensor::Tensor,
585 ) {
586 let mut weight_gradients = weight_gradients.unnested();
587 let mut bias_gradients = bias_gradients.unnestedoptional();
588
589 self.layers
591 .iter_mut()
592 .rev()
593 .enumerate()
594 .for_each(|(i, layer)| match layer {
595 network::Layer::Dense(layer) => {
596 self.optimizer.update(
597 i,
598 0,
599 false,
600 stepnr,
601 &mut layer.weights,
602 &mut weight_gradients[i],
603 );
604
605 if let Some(bias) = &mut layer.bias {
606 self.optimizer.update(
607 i,
608 0,
609 true,
610 stepnr,
611 bias,
612 &mut bias_gradients[i].as_mut().unwrap(),
613 )
614 }
615 }
616 network::Layer::Convolution(layer) => {
617 for (f, (filter, gradient)) in layer
618 .kernels
619 .iter_mut()
620 .zip(weight_gradients[i].quadruple_to_vec_triple().iter_mut())
621 .enumerate()
622 {
623 self.optimizer.update(i, f, false, stepnr, filter, gradient);
624 }
626 }
627 network::Layer::Deconvolution(layer) => {
628 for (f, (filter, gradient)) in layer
629 .kernels
630 .iter_mut()
631 .zip(weight_gradients[i].quadruple_to_vec_triple().iter_mut())
632 .enumerate()
633 {
634 self.optimizer.update(i, f, false, stepnr, filter, gradient);
635 }
637 }
638 network::Layer::Maxpool(_) => {}
639 network::Layer::Feedback(_) => panic!("Feedback layers are not supported."),
640 });
641
642 for couple in self.coupled.iter() {
645 let mut count: f32 = 0.0;
646 let mut weights: Vec<tensor::Tensor> = Vec::new();
647 let mut biases: Vec<tensor::Tensor> = Vec::new();
648
649 for idx in couple.iter() {
651 match &self.layers[*idx] {
652 network::Layer::Dense(layer) => {
653 weights.push(layer.weights.clone());
654 if let Some(bias) = &layer.bias {
655 biases.push(bias.clone());
656 }
657 }
658 network::Layer::Convolution(layer) => {
659 weights.push(tensor::Tensor::nested(layer.kernels.clone()));
660 }
661 network::Layer::Deconvolution(layer) => {
662 weights.push(tensor::Tensor::nested(layer.kernels.clone()));
663 }
664 _ => continue,
665 }
666 count += 1.0;
667 }
668
669 let mut weight: tensor::Tensor = weights.remove(0);
670 let mut bias: Option<tensor::Tensor> = if biases.is_empty() {
671 None
672 } else {
673 Some(biases.remove(0))
674 };
675 match self.accumulation {
676 Accumulation::Add => {
677 for w in weights.iter() {
678 weight.add_inplace(w);
679 }
680 if let Some(bias) = &mut bias {
681 for b in biases.iter() {
682 bias.add_inplace(b);
683 }
684 }
685 }
686 Accumulation::Multiply => {
687 for w in weights.iter() {
688 weight.mul_inplace(w);
689 }
690 if let Some(bias) = &mut bias {
691 for b in biases.iter() {
692 bias.mul_inplace(b);
693 }
694 }
695 }
696 Accumulation::Subtract => {
697 for w in weights.iter() {
698 weight.sub_inplace(w);
699 }
700 if let Some(bias) = &mut bias {
701 for b in biases.iter() {
702 bias.sub_inplace(b);
703 }
704 }
705 }
706 Accumulation::Mean => {
707 for w in weights.iter() {
708 weight.add_inplace(w);
709 }
710 if let Some(bias) = &mut bias {
711 for b in biases.iter() {
712 bias.add_inplace(b);
713 }
714 }
715 weight.div_scalar_inplace(count);
716 if let Some(b) = &mut bias {
717 b.div_scalar_inplace(count);
718 }
719 }
720 Accumulation::Overwrite => {
721 unimplemented!("Overwrite accumulation is not implemented.")
723 }
724 }
725
726 for i in couple.iter() {
728 match &mut self.layers[*i] {
729 network::Layer::Dense(layer) => {
730 layer.weights = weight.clone();
731 if let Some(b) = &mut layer.bias {
732 *b = bias.clone().unwrap();
733 }
734 }
735 network::Layer::Convolution(layer) => {
736 layer.kernels = weight.unnested();
737 }
738 network::Layer::Deconvolution(layer) => {
739 layer.kernels = weight.unnested();
740 }
741 _ => continue,
742 }
743 }
744 }
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751 use crate::{activation, assert_eq_data, assert_eq_shape, dense, network, tensor};
752
753 #[test]
754 fn test_feedback_create() {
755 let layers = vec![
756 network::Layer::Dense(dense::Dense::create(
757 tensor::Shape::Single(2),
758 tensor::Shape::Single(2),
759 &activation::Activation::ReLU,
760 false,
761 None,
762 )),
763 network::Layer::Dense(dense::Dense::create(
764 tensor::Shape::Single(2),
765 tensor::Shape::Single(2),
766 &activation::Activation::ReLU,
767 false,
768 None,
769 )),
770 ];
771 let feedback = Feedback::create(layers.clone(), 2, true, false, Accumulation::Add);
772
773 assert_eq!(feedback.inputs, tensor::Shape::Single(2));
774 assert_eq!(feedback.outputs, tensor::Shape::Single(2));
775 assert_eq!(feedback.layers.len(), 4); assert_eq!(feedback.coupled.len(), 2);
777 assert_eq!(feedback.connect.len(), 1);
778 }
779
780 #[test]
797 fn test_feedback_parameters() {
798 let layers = vec![network::Layer::Dense(dense::Dense::create(
799 tensor::Shape::Single(3),
800 tensor::Shape::Single(3),
801 &activation::Activation::ReLU,
802 true,
803 None,
804 ))];
805 let feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
806
807 assert_eq!(feedback.parameters(), 12); }
809
810 #[test]
811 fn test_feedback_training() {
812 let layers = vec![network::Layer::Dense(dense::Dense::create(
813 tensor::Shape::Single(3),
814 tensor::Shape::Single(3),
815 &activation::Activation::ReLU,
816 true,
817 None,
818 ))];
819 let mut feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
820 feedback.training(true);
821
822 for layer in feedback.layers.iter() {
823 if let network::Layer::Dense(layer) = layer {
824 assert!(layer.training);
825 }
826 }
827 }
828
829 #[test]
830 fn test_feedback_forward() {
831 let mut layer = dense::Dense::create(
832 tensor::Shape::Single(3),
833 tensor::Shape::Single(3),
834 &activation::Activation::ReLU,
835 true,
836 None,
837 );
838 layer.weights = tensor::Tensor::double(vec![vec![1.0; 3]; 3]);
839 layer.bias = Some(tensor::Tensor::single(vec![0.0; 3]));
840 let layers = vec![network::Layer::Dense(layer)];
841 let feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
842 let input = tensor::Tensor::single(vec![-1.0, 2.0, 3.0]);
843
844 let (unactivated, activated, maxpool, intermediate_unactivated, intermediate_activated) =
845 feedback.forward(&input);
846
847 assert_eq_shape!(unactivated.shape, tensor::Shape::Single(3));
848 assert_eq_shape!(activated.shape, tensor::Shape::Single(3));
849 assert_eq_shape!(maxpool.shape, tensor::Shape::Nested(1));
850 assert_eq_shape!(
851 intermediate_unactivated.shape,
852 tensor::Tensor::nested(vec![tensor::Tensor::single(vec![1.0; 3]),]).shape
853 );
854 assert_eq_shape!(
855 intermediate_activated.shape,
856 tensor::Tensor::nested(vec![
857 tensor::Tensor::single(vec![1.0; 3]),
858 tensor::Tensor::single(vec![1.0; 3]),
859 ])
860 .shape
861 );
862
863 let expected_unactivated = tensor::Tensor::single(vec![4.0; 3]);
865 let expected_activated = tensor::Tensor::single(vec![4.0; 3]);
866 assert_eq_data!(unactivated.data, expected_unactivated.data);
867 assert_eq_data!(activated.data, expected_activated.data);
868 }
869
870 #[test]
871 fn test_feedback_backward() {
872 let mut layer = dense::Dense::create(
873 tensor::Shape::Single(3),
874 tensor::Shape::Single(3),
875 &activation::Activation::ReLU,
876 true,
877 None,
878 );
879 layer.weights = tensor::Tensor::double(vec![vec![1.0; 3]; 3]);
880 layer.bias = Some(tensor::Tensor::single(vec![0.0; 3]));
881 let layers = vec![network::Layer::Dense(layer)];
882 let feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
883 let input = tensor::Tensor::single(vec![1.0, 2.0, 3.0]);
884 let (_, _, _, intermediate_unactivated, intermediate_activated) = feedback.forward(&input);
885 let gradient = tensor::Tensor::single(vec![0.1, 0.2, 0.3]);
886
887 let (input_gradient, weight_gradient, bias_gradient) = feedback.backward(
888 &gradient,
889 &vec![intermediate_unactivated, intermediate_activated],
890 );
891
892 assert_eq_shape!(input_gradient.shape, tensor::Shape::Single(3));
893 assert_eq!(
894 weight_gradient.shape,
895 tensor::Tensor::nested(vec![tensor::Tensor::double(vec![vec![1.0; 3]; 2]),]).shape
896 );
897 assert_eq!(
898 bias_gradient.clone().unwrap().shape,
899 tensor::Tensor::nested(vec![tensor::Tensor::single(vec![1.0; 3]),]).shape
900 );
901
902 let expected_input_gradient = tensor::Tensor::single(vec![0.6, 0.6, 0.6]);
904 let expected_weight_gradient = tensor::Tensor::nested(vec![tensor::Tensor::double(vec![
905 vec![0.1 * 1.0, 0.1 * 2.0, 0.1 * 3.0],
906 vec![0.2 * 1.0, 0.2 * 2.0, 0.2 * 3.0],
907 vec![0.3 * 1.0, 0.3 * 2.0, 0.3 * 3.0],
908 ])]);
909 let expected_bias_gradient = tensor::Tensor::single(vec![0.1, 0.2, 0.3]);
910
911 assert_eq_data!(input_gradient.data, expected_input_gradient.data);
912 assert_eq_data!(
913 weight_gradient.unnested()[0].data,
914 expected_weight_gradient.unnested()[0].data
915 );
916 assert_eq_data!(
917 bias_gradient.clone().unwrap().unnestedoptional()[0]
918 .clone()
919 .unwrap()
920 .data,
921 expected_bias_gradient.data
922 );
923 }
924
925 #[test]
926 fn test_feedback_update() {
927 let layers = vec![network::Layer::Dense(dense::Dense::create(
928 tensor::Shape::Single(3),
929 tensor::Shape::Single(3),
930 &activation::Activation::ReLU,
931 true,
932 None,
933 ))];
934 let mut weight_gradient = tensor::Tensor::nested(vec![
935 tensor::Tensor::double(vec![
936 vec![0.1, 0.2, 0.3],
937 vec![0.4, 0.5, 0.6],
938 vec![0.7, 0.8, 0.9],
939 ]),
940 tensor::Tensor::double(vec![
941 vec![0.1, 0.2, 0.3],
942 vec![0.7, 0.8, 0.9],
943 vec![0.4, 0.5, 0.6],
944 ]),
945 tensor::Tensor::double(vec![
946 vec![0.7, 0.8, 0.9],
947 vec![0.1, 0.2, 0.3],
948 vec![0.4, 0.5, 0.6],
949 ]),
950 ]);
951 let mut bias_gradient = tensor::Tensor::nestedoptional(vec![
952 Some(tensor::Tensor::single(vec![0.1, 0.2, 0.3])),
953 Some(tensor::Tensor::single(vec![0.5, 0.7, 1.0])),
954 Some(tensor::Tensor::single(vec![1.1, 1.2, 0.3])),
955 ]);
956
957 for accumulation in vec![
958 Accumulation::Add,
959 Accumulation::Subtract,
960 Accumulation::Multiply,
961 Accumulation::Mean,
963 ] {
964 let mut feedback =
965 Feedback::create(layers.clone(), 3, false, false, accumulation.clone());
966 feedback.update(1, &mut weight_gradient, &mut bias_gradient);
967
968 let (weight, bias) = match &feedback.layers[0] {
969 network::Layer::Dense(layer) => (layer.weights.clone(), layer.bias.clone()),
970 _ => panic!("Invalid layer type"),
971 };
972
973 for i in 0..3 {
975 match &feedback.layers[i] {
976 network::Layer::Dense(layer) => {
977 assert_eq_data!(layer.weights.data, weight.data);
978 if let Some(bias) = &bias {
979 assert_eq_data!(layer.bias.clone().unwrap().data, bias.data);
980 } else {
981 panic!("Should have bias!");
982 }
983 }
984 _ => panic!("Invalid layer type"),
985 }
986 }
987 }
988 }
989}