nevermind_neu/layers/
abstract_layer.rs1use std::fmt;
2
3use crate::cpu_params::{CpuParams, ParamsBlob, TypeBuffer};
4use crate::util::{Array2D, Metrics, WithParams};
5
6#[derive(Debug)]
7pub enum LayerError {
8 InvalidSize,
9 OtherError,
10 NotImpl,
11}
12
13impl fmt::Display for LayerError {
14 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15 match self {
16 LayerError::InvalidSize => {
17 write!(f, "{}", "Invalid size")
18 },
19 LayerError::OtherError => {
20 write!(f, "{}", "OtherError")
21 }
22 _ => {
23 write!(f, "{}", "Other")
24 },
25 }
26 }
27}
28
29
30pub type LayerForwardResult = Result<ParamsBlob, LayerError>;
31pub type LayerBackwardResult = Result<ParamsBlob, LayerError>;
32pub type TrainableBufsIds<'a> = (&'a[i32], &'a[i32]);
33
34pub trait AbstractLayer: WithParams {
35 fn forward_input(&mut self, _input_data: Array2D) -> LayerForwardResult {
37 Err(LayerError::NotImpl)
38 }
39
40 fn forward(&mut self, _input: ParamsBlob) -> LayerForwardResult {
41 Err(LayerError::NotImpl)
42 }
43
44 fn backward(&mut self, _prev_input: ParamsBlob, _input: ParamsBlob) -> LayerBackwardResult {
46 Err(LayerError::NotImpl)
47 }
48
49 fn backward_output(
50 &mut self,
51 _prev_input: ParamsBlob,
52 _expected: Array2D,
53 ) -> LayerBackwardResult {
54 Err(LayerError::NotImpl)
55 }
56
57 fn layer_type(&self) -> &str;
58
59 fn size(&self) -> usize;
60
61 fn set_batch_size(&mut self, batch_size: usize) {
62 let mut lr = self.cpu_params().unwrap();
63 lr.fit_to_batch_size(batch_size);
64 }
65
66 fn metrics(&self) -> Option<&Metrics> {
67 None
68 }
69
70 fn serializable_bufs(&self) -> &[i32] {
71 return &[TypeBuffer::Weights as i32, TypeBuffer::Bias as i32];
72 }
73
74 fn trainable_bufs(&self) -> TrainableBufsIds {
75 (
76 &[TypeBuffer::Weights as i32, TypeBuffer::Bias as i32],
77 &[TypeBuffer::WeightsGrad as i32, TypeBuffer::BiasGrad as i32],
78 )
79 }
80
81 fn cpu_params(&self) -> Option<CpuParams>;
82 fn set_cpu_params(&mut self, lp: CpuParams);
83
84 fn set_input_shape(&mut self, sh: &[usize]);
85
86 fn copy_layer(&self) -> Box<dyn AbstractLayer>;
88
89 fn clone_layer(&self) -> Box<dyn AbstractLayer>;
91}