nevermind_neu/layers/
abstract_layer.rs

1use std::fmt;
2
3use crate::cpu_params::{CpuParams, ParamsBlob, TypeBuffer};
4use crate::util::{Array2D, Metrics, WithParams};
5
6#[derive(Debug)]
7pub enum LayerError {
8    InvalidSize,
9    OtherError,
10    NotImpl,
11}
12
13impl fmt::Display for LayerError {
14    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15        match self {
16            LayerError::InvalidSize => {
17                write!(f, "{}", "Invalid size")
18            },
19            LayerError::OtherError => {
20                write!(f, "{}", "OtherError")
21            }
22            _ => {
23                write!(f, "{}", "Other")
24            },
25        }
26    }
27}
28
29
30pub type LayerForwardResult = Result<ParamsBlob, LayerError>;
31pub type LayerBackwardResult = Result<ParamsBlob, LayerError>;
32pub type TrainableBufsIds<'a> = (&'a[i32], &'a[i32]);
33
34pub trait AbstractLayer: WithParams {
35    // for signature for input layers
36    fn forward_input(&mut self, _input_data: Array2D) -> LayerForwardResult {
37        Err(LayerError::NotImpl)
38    }
39
40    fn forward(&mut self, _input: ParamsBlob) -> LayerForwardResult {
41        Err(LayerError::NotImpl)
42    }
43
44    /// returns out_values and array of weights
45    fn backward(&mut self, _prev_input: ParamsBlob, _input: ParamsBlob) -> LayerBackwardResult {
46        Err(LayerError::NotImpl)
47    }
48
49    fn backward_output(
50        &mut self,
51        _prev_input: ParamsBlob,
52        _expected: Array2D,
53    ) -> LayerBackwardResult {
54        Err(LayerError::NotImpl)
55    }
56
57    fn layer_type(&self) -> &str;
58
59    fn size(&self) -> usize;
60
61    fn set_batch_size(&mut self, batch_size: usize) {
62        let mut lr = self.cpu_params().unwrap();
63        lr.fit_to_batch_size(batch_size);
64    }
65
66    fn metrics(&self) -> Option<&Metrics> {
67        None
68    }
69
70    fn serializable_bufs(&self) -> &[i32] {
71        return &[TypeBuffer::Weights as i32, TypeBuffer::Bias as i32];
72    }
73
74    fn trainable_bufs(&self) -> TrainableBufsIds {
75        (
76            &[TypeBuffer::Weights as i32, TypeBuffer::Bias as i32],
77            &[TypeBuffer::WeightsGrad as i32, TypeBuffer::BiasGrad as i32],
78        )
79    }
80
81    fn cpu_params(&self) -> Option<CpuParams>;
82    fn set_cpu_params(&mut self, lp: CpuParams);
83
84    fn set_input_shape(&mut self, sh: &[usize]);
85
86    // Do copy layer memory(ws, output, ...)
87    fn copy_layer(&self) -> Box<dyn AbstractLayer>;
88
89    // Do copy only Rc
90    fn clone_layer(&self) -> Box<dyn AbstractLayer>;
91}