Skip to main content

meuron/layer/
mod.rs

1pub mod dense_layer;
2pub mod empty_layer;
3
4pub use dense_layer::DenseLayer;
5pub use empty_layer::EmptyLayer;
6
7use crate::backend::Backend;
8use crate::optimizer::Optimizer;
9use ndarray::Dimension;
10use serde::{Deserialize, Serialize};
11
12pub trait Layer<B: Backend> {
13    type Input: Dimension;
14    type Output: Dimension;
15
16    fn forward(&mut self, input: &B::Tensor<Self::Input>) -> B::Tensor<Self::Output>;
17    fn backward(&mut self, grad_output: &B::Tensor<Self::Output>) -> B::Tensor<Self::Input>;
18    fn update<O: Optimizer<B>>(&mut self, _optimizer: &mut O) {}
19}
20
21#[derive(Serialize, Deserialize)]
22#[serde(bound(
23    serialize = "L1: Serialize, L2: Serialize",
24    deserialize = "L1: Deserialize<'de>, L2: Deserialize<'de>"
25))]
26pub struct Sequential<L1, L2> {
27    pub layer1: L1,
28    pub layer2: L2,
29}
30
31impl<L1, L2, B, D1, D2, D3> Layer<B> for Sequential<L1, L2>
32where
33    B: Backend,
34    L1: Layer<B, Input = D1, Output = D2>,
35    L2: Layer<B, Input = D2, Output = D3>,
36    D1: Dimension,
37    D2: Dimension,
38    D3: Dimension,
39{
40    type Input = D1;
41    type Output = D3;
42
43    fn forward(&mut self, input: &B::Tensor<D1>) -> B::Tensor<D3> {
44        let out = self.layer1.forward(input);
45        self.layer2.forward(&out)
46    }
47
48    fn backward(&mut self, grad_output: &B::Tensor<D3>) -> B::Tensor<D1> {
49        let grad = self.layer2.backward(grad_output);
50        self.layer1.backward(&grad)
51    }
52
53    fn update<O: Optimizer<B>>(&mut self, optimizer: &mut O) {
54        self.layer1.update(optimizer);
55        self.layer2.update(optimizer);
56    }
57}
58
59pub fn seq<L1, L2>(layer1: L1, layer2: L2) -> Sequential<L1, L2> {
60    Sequential { layer1, layer2 }
61}
62
63#[macro_export]
64macro_rules! Layers {
65    ($layer:expr) => { $layer };
66    ($layer1:expr, $layer2:expr) => { $crate::layer::seq($layer1, $layer2) };
67    ($layer1:expr, $layer2:expr, $($rest:expr),+) => {
68        $crate::layer::seq($layer1, $crate::Layers!($layer2, $($rest),+))
69    };
70}