neurons/
feedback.rs

1// Copyright (C) 2024 Hallvard Høyland Lavik
2
3use std::collections::HashMap;
4
5use crate::{activation, assert_eq_shape, network, optimizer, tensor};
6
7#[derive(Clone)]
8pub enum Accumulation {
9    Add,
10    Subtract,
11    Multiply,
12    Overwrite,
13    Mean,
14    // TODO: Expand?
15}
16
17impl std::fmt::Display for Accumulation {
18    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
19        match self {
20            Accumulation::Add => write!(f, "additive"),
21            Accumulation::Subtract => write!(f, "subtractive"),
22            Accumulation::Multiply => write!(f, "multiplicative"),
23            Accumulation::Overwrite => write!(f, "overwrite"),
24            Accumulation::Mean => write!(f, "mean"),
25            #[allow(unreachable_patterns)]
26            _ => unimplemented!("Accumulation method not implemented."),
27        }
28    }
29}
30
31/// A simplified layer definition used for defining feedback blocks.
32///
33/// # Dense
34///
35/// * `nodes` - The number of nodes in the layer.
36/// * `activation` - The activation function of the layer.
37/// * `bias` - Whether the layer should include a bias.
38/// * `dropout` - The dropout rate of the layer.
39///
40/// # Convolution
41///
42/// * `filters` - The number of filters in the layer.
43/// * `activation` - The activation function of the layer.
44/// * `kernel` - The kernel size of the layer.
45/// * `stride` - The stride of the layer.
46/// * `padding` - The padding of the layer.
47/// * `dilation` - The dilation of the layer.
48/// * `dropout` - The dropout rate of the layer.
49///
50/// # Maxpool
51///
52/// * `kernel` - The pool size of the layer.
53/// * `stride` - The stride of the layer.
54pub enum Layer {
55    Dense(usize, activation::Activation, bool, Option<f32>),
56    Convolution(
57        usize,
58        activation::Activation,
59        (usize, usize),
60        (usize, usize),
61        (usize, usize),
62        (usize, usize),
63        Option<f32>,
64    ),
65    Deconvolution(
66        usize,
67        activation::Activation,
68        (usize, usize),
69        (usize, usize),
70        (usize, usize),
71        Option<f32>,
72    ),
73    Maxpool((usize, usize), (usize, usize)),
74}
75
76/// A feedback block.
77///
78/// # Attributes
79///
80/// * `inputs` - The number of inputs to the block.
81/// * `outputs` - The number of outputs from the block.
82/// * `optimizer` - The optimizer used for training the block.
83/// * `flatten` - Whether the block should flatten the output.
84/// * `layers` - The layers of the block.
85/// * `connect` - The (skip) connections between layers.
86/// * `coupled` - The coupled layers of the block.
87///
88/// # Notes
89///
90/// * The `inputs` should match the `outputs`, to allow for feedback looping.
91/// * TODO: Add support for differing input and output shapes, projecting differences internally.
92#[derive(Clone)]
93pub struct Feedback {
94    pub(crate) inputs: tensor::Shape,
95    pub(crate) outputs: tensor::Shape,
96    pub(crate) optimizer: optimizer::Optimizer,
97    pub(crate) flatten: bool,
98    pub layers: Vec<network::Layer>,
99    connect: HashMap<usize, Vec<usize>>,
100    pub(crate) accumulation: Accumulation,
101    coupled: Vec<Vec<usize>>,
102}
103
104impl std::fmt::Display for Feedback {
105    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
106        write!(f, "Feedback (\n")?;
107        write!(f, "\t\t\t{} -> {}\n", self.inputs, self.outputs)?;
108
109        // let optimizer: String = self
110        //     .optimizer
111        //     .to_string()
112        //     .lines()
113        //     .map(|line| format!("\t\t{}", line))
114        //     .collect::<Vec<String>>()
115        //     .join("\n");
116        // write!(f, "\t\t\toptimizer: (\n{}\n", optimizer)?;
117
118        write!(f, "\t\t\tlayers: (\n")?;
119        for (i, layer) in self.layers.iter().enumerate() {
120            match layer {
121                network::Layer::Dense(layer) => {
122                    write!(
123                        f,
124                        "\t\t\t\t{}: Dense{} ({} -> {})\n",
125                        i, layer.activation, layer.inputs, layer.outputs
126                    )?;
127                }
128                network::Layer::Convolution(layer) => {
129                    write!(
130                        f,
131                        "\t\t\t\t{}: Convolution{} ({} -> {})\n",
132                        i, layer.activation, layer.inputs, layer.outputs
133                    )?;
134                }
135                network::Layer::Deconvolution(layer) => {
136                    write!(
137                        f,
138                        "\t\t\t\t{}: Decovolution{} ({} -> {})\n",
139                        i, layer.activation, layer.inputs, layer.outputs
140                    )?;
141                }
142                network::Layer::Maxpool(layer) => {
143                    write!(
144                        f,
145                        "\t\t\t\t{}: Maxpool ({} -> {})\n",
146                        i, layer.inputs, layer.outputs
147                    )?;
148                }
149                network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
150            }
151        }
152        write!(f, "\t\t\t)\n")?;
153        if !self.coupled.is_empty() {
154            write!(f, "\t\t\tcoupled: (\n")?;
155            for coupled in self.coupled.iter() {
156                write!(f, "\t\t\t\t{:?}\n", coupled)?;
157            }
158            write!(f, "\t\t\t\taccumulation: {}\n", self.accumulation)?;
159            write!(f, "\t\t\t)\n")?;
160        }
161        if !self.connect.is_empty() {
162            write!(f, "\t\t\tconnections: (\n")?;
163            write!(f, "\t\t\t\taccumulation: {}\n", self.accumulation)?;
164
165            let mut entries: Vec<(&usize, &Vec<usize>)> = self.connect.iter().collect();
166            entries.sort_by_key(|&(to, _)| to);
167            for (to, from) in entries.iter() {
168                write!(f, "\t\t\t\t{:?}.input -> {}.input\n", from, to)?;
169            }
170            write!(f, "\t\t\t)\n")?;
171        }
172        write!(f, "\t\t\tflatten: {}\n", self.flatten)?;
173        write!(f, "\t\t)")?;
174        Ok(())
175    }
176}
177
178impl Feedback {
179    /// Create a new feedback block.
180    ///
181    /// # Arguments
182    ///
183    /// * `layers` - The layers of the block.
184    /// * `loops` - The number of loops the block should perform.
185    /// * `inskips` - Whether the block should include input-to-input skip connections.
186    /// * `outskips` - Whether the block should include output-to-output skip connections.
187    /// * `accumulation` - The accumulation method of the block.
188    pub fn create(
189        mut layers: Vec<network::Layer>,
190        loops: usize,
191        inskips: bool,
192        outskips: bool,
193        accumulation: Accumulation,
194    ) -> Self {
195        assert!(loops > 0, "Feedback block should loop at least once.");
196        let inputs = match layers.first().unwrap() {
197            network::Layer::Dense(dense) => dense.inputs.clone(),
198            network::Layer::Convolution(convolution) => convolution.inputs.clone(),
199            network::Layer::Deconvolution(deconvolution) => deconvolution.inputs.clone(),
200            network::Layer::Maxpool(maxpool) => maxpool.inputs.clone(),
201            network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
202        };
203        let outputs = match layers.last().unwrap() {
204            network::Layer::Dense(dense) => dense.outputs.clone(),
205            network::Layer::Convolution(convolution) => convolution.outputs.clone(),
206            network::Layer::Deconvolution(deconvolution) => deconvolution.outputs.clone(),
207            network::Layer::Maxpool(maxpool) => maxpool.outputs.clone(),
208            network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
209        };
210        assert_eq_shape!(inputs, outputs);
211
212        let length = layers.len();
213
214        // Extend the layers `loops` times.
215        let _layers = layers.clone();
216        for _ in 1..loops {
217            layers.extend(_layers.clone());
218        }
219
220        // Define the coupled layers.
221        let mut coupled: Vec<Vec<usize>> = Vec::new();
222        for layer in 0..length {
223            let mut coupling = Vec::new();
224            for i in 0..loops {
225                coupling.push(layer + i * length);
226            }
227            coupled.push(coupling);
228        }
229
230        // Define the skip connections.
231        let mut connect: HashMap<usize, Vec<usize>> = HashMap::new();
232        if inskips || outskips {
233            let mut outputs = Vec::new();
234            for i in 1..loops {
235                if inskips {
236                    // {to: from}
237                    connect.insert(i * length, vec![0]);
238                }
239                if outskips {
240                    outputs.push(i * length);
241                }
242            }
243            if outskips {
244                // {to: from}
245                connect.insert(loops * length, outputs);
246            }
247        }
248
249        Feedback {
250            inputs,
251            outputs,
252            optimizer: optimizer::SGD::create(0.1, None),
253            flatten: false,
254            layers,
255            connect,
256            accumulation,
257            coupled,
258        }
259    }
260
261    /// Set the `optimizer::Optimizer` function of the network.
262    ///
263    /// # Arguments
264    ///
265    /// * `optimizer` - The reference to the network optimizer, to copy the values from.
266    pub fn copy_optimizer(&mut self, mut optimizer: optimizer::Optimizer) {
267        let mut vectors: Vec<Vec<Vec<tensor::Tensor>>> = Vec::new();
268        for layer in self.layers.iter().rev() {
269            match layer {
270                network::Layer::Dense(layer) => {
271                    let (output, input) = match &layer.weights.shape {
272                        tensor::Shape::Double(output, input) => (*output, *input),
273                        _ => panic!("Expected Dense shape"),
274                    };
275                    vectors.push(vec![vec![
276                        tensor::Tensor::double(vec![vec![0.0; input]; output]),
277                        if layer.bias.is_some() {
278                            tensor::Tensor::single(vec![0.0; output])
279                        } else {
280                            tensor::Tensor::single(vec![])
281                        },
282                    ]]);
283                }
284                network::Layer::Convolution(layer) => {
285                    let (ch, kh, kw) = match layer.kernels[0].shape {
286                        tensor::Shape::Triple(ch, he, wi) => (ch, he, wi),
287                        _ => panic!("Expected Convolution shape"),
288                    };
289                    vectors.push(vec![
290                        vec![
291                            tensor::Tensor::triple(vec![vec![vec![0.0; kw]; kh]; ch]),
292                            // TODO: Add bias term here.
293                        ];
294                        layer.kernels.len()
295                    ]);
296                }
297                network::Layer::Deconvolution(layer) => {
298                    let (ch, kh, kw) = match layer.kernels[0].shape {
299                        tensor::Shape::Triple(ch, he, wi) => (ch, he, wi),
300                        _ => panic!("Expected Convolution shape"),
301                    };
302                    vectors.push(vec![
303                        vec![
304                            tensor::Tensor::triple(vec![vec![vec![0.0; kw]; kh]; ch]),
305                            // TODO: Add bias term here.
306                        ];
307                        layer.kernels.len()
308                    ]);
309                }
310                network::Layer::Maxpool(_) => {
311                    vectors.push(vec![vec![tensor::Tensor::single(vec![0.0; 0])]])
312                }
313                _ => unimplemented!("Feedback blocks not yet implemented."),
314            }
315        }
316
317        // Validate the optimizers' parameters.
318        // Override to default values if wrongly set.
319        optimizer.validate(vectors);
320
321        self.optimizer = optimizer;
322    }
323
324    /// Count the number of parameters.
325    /// Only counts the parameters of the first loop, as the rest are identical (coupled).
326    pub fn parameters(&self) -> usize {
327        let mut parameters = 0;
328        for idx in 0..self.coupled.len() {
329            parameters += match &self.layers[idx] {
330                network::Layer::Dense(dense) => dense.parameters(),
331                network::Layer::Convolution(convolution) => convolution.parameters(),
332                network::Layer::Deconvolution(deconvolution) => deconvolution.parameters(),
333                network::Layer::Maxpool(_) => 0,
334                network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
335            };
336        }
337        parameters
338    }
339
340    pub fn training(&mut self, train: bool) {
341        self.layers.iter_mut().for_each(|layer| match layer {
342            network::Layer::Dense(layer) => layer.training = train,
343            network::Layer::Convolution(layer) => layer.training = train,
344            network::Layer::Deconvolution(layer) => layer.training = train,
345            network::Layer::Maxpool(_) => {}
346            network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
347        });
348    }
349
350    /// Compute the forward pass of the feedback block for the given input, including all
351    /// intermediate pre- and post-activation values.
352    ///
353    /// # Arguments
354    ///
355    /// * `input` - The input data (x).
356    ///
357    /// # Returns
358    ///
359    /// * Unactivated tensor to be used for neighbouring layers when backpropagating.
360    /// * Activated tensor to be used for neighbouring layers when backpropagating.
361    /// * Maxpool tensor to be used for neighbouring layers when backpropagating.
362    /// * Intermediate unactivated tensors (nested).
363    /// * Intermediate activated tensors (nested).
364    pub fn forward(
365        &self,
366        input: &tensor::Tensor,
367    ) -> (
368        tensor::Tensor,
369        tensor::Tensor,
370        tensor::Tensor,
371        tensor::Tensor,
372        tensor::Tensor,
373    ) {
374        let mut unactivated = Vec::with_capacity(self.layers.len());
375        let mut activated = Vec::with_capacity(self.layers.len() + 1);
376        let mut maxpools = Vec::with_capacity(self.layers.len());
377
378        activated.push(input.clone());
379
380        for (i, layer) in self.layers.iter().enumerate() {
381            let mut x = activated.last().unwrap().clone();
382
383            // Check if the layer should account for a skip connection.
384            if self.connect.contains_key(&i) {
385                match self.accumulation {
386                    Accumulation::Add => {
387                        for idx in self.connect.get(&i).unwrap() {
388                            x.add_inplace(&activated[*idx]);
389                        }
390                    }
391                    Accumulation::Subtract => {
392                        for idx in self.connect.get(&i).unwrap() {
393                            x.sub_inplace(&activated[*idx]);
394                        }
395                    }
396                    Accumulation::Multiply => {
397                        for idx in self.connect.get(&i).unwrap() {
398                            x.mul_inplace(&activated[*idx]);
399                        }
400                    }
401                    Accumulation::Overwrite => {
402                        x = activated[*self.connect.get(&i).unwrap().last().unwrap()].clone();
403                    }
404                    Accumulation::Mean => {
405                        let mut _x: Vec<&tensor::Tensor> = Vec::new();
406                        for idx in self.connect.get(&i).unwrap() {
407                            _x.push(&activated[*idx]);
408                        }
409                        x.mean_inplace(&_x);
410                    }
411                    #[allow(unreachable_patterns)]
412                    _ => unimplemented!("Accumulation method not implemented."),
413                }
414            }
415
416            let (pre, post, max) = match layer {
417                network::Layer::Dense(layer) => {
418                    assert_eq_shape!(layer.inputs, x.shape);
419                    let (pre, post) = layer.forward(&x);
420                    (pre, post, None)
421                }
422                network::Layer::Convolution(layer) => {
423                    assert_eq_shape!(layer.inputs, x.shape);
424                    let (pre, post) = layer.forward(&x);
425                    (pre, post, None)
426                }
427                network::Layer::Deconvolution(layer) => {
428                    assert_eq_shape!(layer.inputs, x.shape);
429                    let (pre, post) = layer.forward(&x);
430                    (pre, post, None)
431                }
432                network::Layer::Maxpool(layer) => {
433                    assert_eq_shape!(layer.inputs, x.shape);
434                    let (pre, post, max) = layer.forward(&x);
435                    (pre, post, Some(max))
436                }
437                network::Layer::Feedback(_) => panic!("Nested feedback blocks are not supported."),
438            };
439
440            unactivated.push(pre);
441            activated.push(post);
442            maxpools.push(max);
443        }
444
445        let mut last = activated.pop().unwrap();
446
447        // Check if the last layer should account for a skip connection.
448        if self.connect.contains_key(&self.layers.len()) {
449            let i = self.layers.len();
450            match self.accumulation {
451                Accumulation::Add => {
452                    for idx in self.connect.get(&i).unwrap() {
453                        last.add_inplace(&activated[*idx]);
454                    }
455                }
456                Accumulation::Subtract => {
457                    for idx in self.connect.get(&i).unwrap() {
458                        last.sub_inplace(&activated[*idx]);
459                    }
460                }
461                Accumulation::Multiply => {
462                    for idx in self.connect.get(&i).unwrap() {
463                        last.mul_inplace(&activated[*idx]);
464                    }
465                }
466                Accumulation::Overwrite => {
467                    last = activated[*self.connect.get(&i).unwrap().last().unwrap()].clone();
468                }
469                Accumulation::Mean => {
470                    let mut _x: Vec<&tensor::Tensor> = Vec::new();
471                    for idx in self.connect.get(&i).unwrap() {
472                        _x.push(&activated[*idx]);
473                    }
474                    last.mean_inplace(&_x);
475                }
476                #[allow(unreachable_patterns)]
477                _ => unimplemented!("Accumulation method not implemented."),
478            }
479        }
480
481        // Flattening the last output if specified.
482        if self.flatten {
483            activated.push(last.flatten());
484        } else {
485            activated.push(last);
486        }
487
488        (
489            unactivated[0].clone(),
490            activated[activated.len() - 1].clone(),
491            tensor::Tensor::nestedoptional(maxpools),
492            tensor::Tensor::nested(unactivated),
493            tensor::Tensor::nested(activated),
494        )
495    }
496
497    /// Applies the backward pass of the layer to the gradient vector.
498    ///
499    /// # Arguments
500    ///
501    /// * `gradient` - The gradient tensor::Tensor to the layer.
502    /// * `inbetween` - The intermediate tensors of the forward pass.
503    ///
504    /// # Returns
505    ///
506    /// The input-, weight- and bias gradient of the layer.
507    pub fn backward(
508        &self,
509        gradient: &tensor::Tensor,
510        inbetween: &Vec<tensor::Tensor>,
511    ) -> (tensor::Tensor, tensor::Tensor, Option<tensor::Tensor>) {
512        // We need to un-nest the input and output tensors (see `forward`).
513        let unactivated = inbetween[0].unnested();
514        let activated = inbetween[1].unnested();
515
516        let mut gradients: Vec<tensor::Tensor> = vec![gradient.clone()];
517        let mut weight_gradients: Vec<tensor::Tensor> = Vec::new();
518        let mut bias_gradients: Vec<Option<tensor::Tensor>> = Vec::new();
519
520        let mut connect: HashMap<usize, Vec<usize>> = HashMap::new();
521        for (key, value) in self.connect.iter() {
522            for idx in value.iter() {
523                // {to: from} -> {from: [to1, ...]}
524                if connect.contains_key(idx) {
525                    connect.get_mut(idx).unwrap().push(*key);
526                } else {
527                    connect.insert(*idx, vec![*key]);
528                }
529            }
530        }
531
532        self.layers.iter().rev().enumerate().for_each(|(i, layer)| {
533            let idx = self.layers.len() - i - 1;
534
535            let input: &tensor::Tensor = &activated[idx];
536            let output: &tensor::Tensor = &unactivated[idx];
537
538            // Check for skip connections.
539            // Add the gradient of the skip connection to the current gradient.
540            if connect.contains_key(&idx) {
541                for j in connect[&idx].iter() {
542                    let mut idx = *j;
543                    if j == &self.layers.len() {
544                        // If the skip connection is the last layer (i.e., output);
545                        // * Account for this by using the output gradient (i.e., gradients[0]).
546                        // * Equivalent to `layers.len() - (layers.len() - 1) - 1 = 0` (below).
547                        idx = idx - 1;
548                    }
549                    let gradient = gradients[self.layers.len() - idx - 1].clone();
550                    gradients.last_mut().unwrap().add_inplace(&gradient);
551                }
552                // TODO: Handle accumulation methods.
553            }
554
555            let (gradient, wg, bg) = match layer {
556                network::Layer::Dense(layer) => {
557                    layer.backward(&gradients.last().unwrap(), input, output)
558                }
559                network::Layer::Convolution(layer) => {
560                    layer.backward(&gradients.last().unwrap(), input, output)
561                }
562                network::Layer::Deconvolution(layer) => {
563                    layer.backward(&gradients.last().unwrap(), input, output)
564                }
565                _ => panic!("Unsupported layer type."),
566            };
567
568            gradients.push(gradient);
569            weight_gradients.push(wg);
570            bias_gradients.push(bg);
571        });
572
573        return (
574            gradients.last().unwrap().clone(),
575            tensor::Tensor::nested(weight_gradients),
576            Some(tensor::Tensor::nestedoptional(bias_gradients)),
577        );
578    }
579
580    pub fn update(
581        &mut self,
582        stepnr: i32,
583        weight_gradients: &mut tensor::Tensor,
584        bias_gradients: &mut tensor::Tensor,
585    ) {
586        let mut weight_gradients = weight_gradients.unnested();
587        let mut bias_gradients = bias_gradients.unnestedoptional();
588
589        // Update the weights and biases of the layers.
590        self.layers
591            .iter_mut()
592            .rev()
593            .enumerate()
594            .for_each(|(i, layer)| match layer {
595                network::Layer::Dense(layer) => {
596                    self.optimizer.update(
597                        i,
598                        0,
599                        false,
600                        stepnr,
601                        &mut layer.weights,
602                        &mut weight_gradients[i],
603                    );
604
605                    if let Some(bias) = &mut layer.bias {
606                        self.optimizer.update(
607                            i,
608                            0,
609                            true,
610                            stepnr,
611                            bias,
612                            &mut bias_gradients[i].as_mut().unwrap(),
613                        )
614                    }
615                }
616                network::Layer::Convolution(layer) => {
617                    for (f, (filter, gradient)) in layer
618                        .kernels
619                        .iter_mut()
620                        .zip(weight_gradients[i].quadruple_to_vec_triple().iter_mut())
621                        .enumerate()
622                    {
623                        self.optimizer.update(i, f, false, stepnr, filter, gradient);
624                        // TODO: Add bias term here.
625                    }
626                }
627                network::Layer::Deconvolution(layer) => {
628                    for (f, (filter, gradient)) in layer
629                        .kernels
630                        .iter_mut()
631                        .zip(weight_gradients[i].quadruple_to_vec_triple().iter_mut())
632                        .enumerate()
633                    {
634                        self.optimizer.update(i, f, false, stepnr, filter, gradient);
635                        // TODO: Add bias term here.
636                    }
637                }
638                network::Layer::Maxpool(_) => {}
639                network::Layer::Feedback(_) => panic!("Feedback layers are not supported."),
640            });
641
642        // Couple respective layers.
643        // Iterates through `self.coupled` and updates the weights and biases to match.
644        for couple in self.coupled.iter() {
645            let mut count: f32 = 0.0;
646            let mut weights: Vec<tensor::Tensor> = Vec::new();
647            let mut biases: Vec<tensor::Tensor> = Vec::new();
648
649            // Add the weights and biases of the coupled layers.
650            for idx in couple.iter() {
651                match &self.layers[*idx] {
652                    network::Layer::Dense(layer) => {
653                        weights.push(layer.weights.clone());
654                        if let Some(bias) = &layer.bias {
655                            biases.push(bias.clone());
656                        }
657                    }
658                    network::Layer::Convolution(layer) => {
659                        weights.push(tensor::Tensor::nested(layer.kernels.clone()));
660                    }
661                    network::Layer::Deconvolution(layer) => {
662                        weights.push(tensor::Tensor::nested(layer.kernels.clone()));
663                    }
664                    _ => continue,
665                }
666                count += 1.0;
667            }
668
669            let mut weight: tensor::Tensor = weights.remove(0);
670            let mut bias: Option<tensor::Tensor> = if biases.is_empty() {
671                None
672            } else {
673                Some(biases.remove(0))
674            };
675            match self.accumulation {
676                Accumulation::Add => {
677                    for w in weights.iter() {
678                        weight.add_inplace(w);
679                    }
680                    if let Some(bias) = &mut bias {
681                        for b in biases.iter() {
682                            bias.add_inplace(b);
683                        }
684                    }
685                }
686                Accumulation::Multiply => {
687                    for w in weights.iter() {
688                        weight.mul_inplace(w);
689                    }
690                    if let Some(bias) = &mut bias {
691                        for b in biases.iter() {
692                            bias.mul_inplace(b);
693                        }
694                    }
695                }
696                Accumulation::Subtract => {
697                    for w in weights.iter() {
698                        weight.sub_inplace(w);
699                    }
700                    if let Some(bias) = &mut bias {
701                        for b in biases.iter() {
702                            bias.sub_inplace(b);
703                        }
704                    }
705                }
706                Accumulation::Mean => {
707                    for w in weights.iter() {
708                        weight.add_inplace(w);
709                    }
710                    if let Some(bias) = &mut bias {
711                        for b in biases.iter() {
712                            bias.add_inplace(b);
713                        }
714                    }
715                    weight.div_scalar_inplace(count);
716                    if let Some(b) = &mut bias {
717                        b.div_scalar_inplace(count);
718                    }
719                }
720                Accumulation::Overwrite => {
721                    // Do nothing?
722                    unimplemented!("Overwrite accumulation is not implemented.")
723                }
724            }
725
726            // Update the weights and biases of the coupled layers.
727            for i in couple.iter() {
728                match &mut self.layers[*i] {
729                    network::Layer::Dense(layer) => {
730                        layer.weights = weight.clone();
731                        if let Some(b) = &mut layer.bias {
732                            *b = bias.clone().unwrap();
733                        }
734                    }
735                    network::Layer::Convolution(layer) => {
736                        layer.kernels = weight.unnested();
737                    }
738                    network::Layer::Deconvolution(layer) => {
739                        layer.kernels = weight.unnested();
740                    }
741                    _ => continue,
742                }
743            }
744        }
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751    use crate::{activation, assert_eq_data, assert_eq_shape, dense, network, tensor};
752
753    #[test]
754    fn test_feedback_create() {
755        let layers = vec![
756            network::Layer::Dense(dense::Dense::create(
757                tensor::Shape::Single(2),
758                tensor::Shape::Single(2),
759                &activation::Activation::ReLU,
760                false,
761                None,
762            )),
763            network::Layer::Dense(dense::Dense::create(
764                tensor::Shape::Single(2),
765                tensor::Shape::Single(2),
766                &activation::Activation::ReLU,
767                false,
768                None,
769            )),
770        ];
771        let feedback = Feedback::create(layers.clone(), 2, true, false, Accumulation::Add);
772
773        assert_eq!(feedback.inputs, tensor::Shape::Single(2));
774        assert_eq!(feedback.outputs, tensor::Shape::Single(2));
775        assert_eq!(feedback.layers.len(), 4); // 2 loops of 2 layers each
776        assert_eq!(feedback.coupled.len(), 2);
777        assert_eq!(feedback.connect.len(), 1);
778    }
779
780    // #[test]
781    // fn test_feedback_copy_optimizer() {
782    //     let layers = vec![network::Layer::Dense(dense::Dense::create(
783    //         tensor::Shape::Single(2),
784    //         tensor::Shape::Single(2),
785    //         &activation::Activation::ReLU,
786    //         false,
787    //         None,
788    //     ))];
789    //     let mut feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
790    //     let optimizer = optimizer::SGD::create(0.1, None);
791    //     feedback.copy_optimizer(optimizer.clone());
792
793    //     assert_eq!(feedback.optimizer, optimizer);
794    // }
795
796    #[test]
797    fn test_feedback_parameters() {
798        let layers = vec![network::Layer::Dense(dense::Dense::create(
799            tensor::Shape::Single(3),
800            tensor::Shape::Single(3),
801            &activation::Activation::ReLU,
802            true,
803            None,
804        ))];
805        let feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
806
807        assert_eq!(feedback.parameters(), 12); // 9 weights + 3 biases
808    }
809
810    #[test]
811    fn test_feedback_training() {
812        let layers = vec![network::Layer::Dense(dense::Dense::create(
813            tensor::Shape::Single(3),
814            tensor::Shape::Single(3),
815            &activation::Activation::ReLU,
816            true,
817            None,
818        ))];
819        let mut feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
820        feedback.training(true);
821
822        for layer in feedback.layers.iter() {
823            if let network::Layer::Dense(layer) = layer {
824                assert!(layer.training);
825            }
826        }
827    }
828
829    #[test]
830    fn test_feedback_forward() {
831        let mut layer = dense::Dense::create(
832            tensor::Shape::Single(3),
833            tensor::Shape::Single(3),
834            &activation::Activation::ReLU,
835            true,
836            None,
837        );
838        layer.weights = tensor::Tensor::double(vec![vec![1.0; 3]; 3]);
839        layer.bias = Some(tensor::Tensor::single(vec![0.0; 3]));
840        let layers = vec![network::Layer::Dense(layer)];
841        let feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
842        let input = tensor::Tensor::single(vec![-1.0, 2.0, 3.0]);
843
844        let (unactivated, activated, maxpool, intermediate_unactivated, intermediate_activated) =
845            feedback.forward(&input);
846
847        assert_eq_shape!(unactivated.shape, tensor::Shape::Single(3));
848        assert_eq_shape!(activated.shape, tensor::Shape::Single(3));
849        assert_eq_shape!(maxpool.shape, tensor::Shape::Nested(1));
850        assert_eq_shape!(
851            intermediate_unactivated.shape,
852            tensor::Tensor::nested(vec![tensor::Tensor::single(vec![1.0; 3]),]).shape
853        );
854        assert_eq_shape!(
855            intermediate_activated.shape,
856            tensor::Tensor::nested(vec![
857                tensor::Tensor::single(vec![1.0; 3]),
858                tensor::Tensor::single(vec![1.0; 3]),
859            ])
860            .shape
861        );
862
863        // Check actual values
864        let expected_unactivated = tensor::Tensor::single(vec![4.0; 3]);
865        let expected_activated = tensor::Tensor::single(vec![4.0; 3]);
866        assert_eq_data!(unactivated.data, expected_unactivated.data);
867        assert_eq_data!(activated.data, expected_activated.data);
868    }
869
870    #[test]
871    fn test_feedback_backward() {
872        let mut layer = dense::Dense::create(
873            tensor::Shape::Single(3),
874            tensor::Shape::Single(3),
875            &activation::Activation::ReLU,
876            true,
877            None,
878        );
879        layer.weights = tensor::Tensor::double(vec![vec![1.0; 3]; 3]);
880        layer.bias = Some(tensor::Tensor::single(vec![0.0; 3]));
881        let layers = vec![network::Layer::Dense(layer)];
882        let feedback = Feedback::create(layers.clone(), 1, false, false, Accumulation::Add);
883        let input = tensor::Tensor::single(vec![1.0, 2.0, 3.0]);
884        let (_, _, _, intermediate_unactivated, intermediate_activated) = feedback.forward(&input);
885        let gradient = tensor::Tensor::single(vec![0.1, 0.2, 0.3]);
886
887        let (input_gradient, weight_gradient, bias_gradient) = feedback.backward(
888            &gradient,
889            &vec![intermediate_unactivated, intermediate_activated],
890        );
891
892        assert_eq_shape!(input_gradient.shape, tensor::Shape::Single(3));
893        assert_eq!(
894            weight_gradient.shape,
895            tensor::Tensor::nested(vec![tensor::Tensor::double(vec![vec![1.0; 3]; 2]),]).shape
896        );
897        assert_eq!(
898            bias_gradient.clone().unwrap().shape,
899            tensor::Tensor::nested(vec![tensor::Tensor::single(vec![1.0; 3]),]).shape
900        );
901
902        // Check actual values
903        let expected_input_gradient = tensor::Tensor::single(vec![0.6, 0.6, 0.6]);
904        let expected_weight_gradient = tensor::Tensor::nested(vec![tensor::Tensor::double(vec![
905            vec![0.1 * 1.0, 0.1 * 2.0, 0.1 * 3.0],
906            vec![0.2 * 1.0, 0.2 * 2.0, 0.2 * 3.0],
907            vec![0.3 * 1.0, 0.3 * 2.0, 0.3 * 3.0],
908        ])]);
909        let expected_bias_gradient = tensor::Tensor::single(vec![0.1, 0.2, 0.3]);
910
911        assert_eq_data!(input_gradient.data, expected_input_gradient.data);
912        assert_eq_data!(
913            weight_gradient.unnested()[0].data,
914            expected_weight_gradient.unnested()[0].data
915        );
916        assert_eq_data!(
917            bias_gradient.clone().unwrap().unnestedoptional()[0]
918                .clone()
919                .unwrap()
920                .data,
921            expected_bias_gradient.data
922        );
923    }
924
925    #[test]
926    fn test_feedback_update() {
927        let layers = vec![network::Layer::Dense(dense::Dense::create(
928            tensor::Shape::Single(3),
929            tensor::Shape::Single(3),
930            &activation::Activation::ReLU,
931            true,
932            None,
933        ))];
934        let mut weight_gradient = tensor::Tensor::nested(vec![
935            tensor::Tensor::double(vec![
936                vec![0.1, 0.2, 0.3],
937                vec![0.4, 0.5, 0.6],
938                vec![0.7, 0.8, 0.9],
939            ]),
940            tensor::Tensor::double(vec![
941                vec![0.1, 0.2, 0.3],
942                vec![0.7, 0.8, 0.9],
943                vec![0.4, 0.5, 0.6],
944            ]),
945            tensor::Tensor::double(vec![
946                vec![0.7, 0.8, 0.9],
947                vec![0.1, 0.2, 0.3],
948                vec![0.4, 0.5, 0.6],
949            ]),
950        ]);
951        let mut bias_gradient = tensor::Tensor::nestedoptional(vec![
952            Some(tensor::Tensor::single(vec![0.1, 0.2, 0.3])),
953            Some(tensor::Tensor::single(vec![0.5, 0.7, 1.0])),
954            Some(tensor::Tensor::single(vec![1.1, 1.2, 0.3])),
955        ]);
956
957        for accumulation in vec![
958            Accumulation::Add,
959            Accumulation::Subtract,
960            Accumulation::Multiply,
961            // Accumulation::Overwrite,
962            Accumulation::Mean,
963        ] {
964            let mut feedback =
965                Feedback::create(layers.clone(), 3, false, false, accumulation.clone());
966            feedback.update(1, &mut weight_gradient, &mut bias_gradient);
967
968            let (weight, bias) = match &feedback.layers[0] {
969                network::Layer::Dense(layer) => (layer.weights.clone(), layer.bias.clone()),
970                _ => panic!("Invalid layer type"),
971            };
972
973            // Check if weights and biases have been updated
974            for i in 0..3 {
975                match &feedback.layers[i] {
976                    network::Layer::Dense(layer) => {
977                        assert_eq_data!(layer.weights.data, weight.data);
978                        if let Some(bias) = &bias {
979                            assert_eq_data!(layer.bias.clone().unwrap().data, bias.data);
980                        } else {
981                            panic!("Should have bias!");
982                        }
983                    }
984                    _ => panic!("Invalid layer type"),
985                }
986            }
987        }
988    }
989}