1pub mod dense_layer;
2pub mod empty_layer;
3
4pub use dense_layer::DenseLayer;
5pub use empty_layer::EmptyLayer;
6
7use crate::backend::Backend;
8use crate::optimizer::Optimizer;
9use ndarray::Dimension;
10use serde::{Deserialize, Serialize};
11
12pub trait Layer<B: Backend> {
13 type Input: Dimension;
14 type Output: Dimension;
15
16 fn forward(&mut self, input: &B::Tensor<Self::Input>) -> B::Tensor<Self::Output>;
17 fn backward(&mut self, grad_output: &B::Tensor<Self::Output>) -> B::Tensor<Self::Input>;
18 fn update<O: Optimizer<B>>(&mut self, _optimizer: &mut O) {}
19}
20
21#[derive(Serialize, Deserialize)]
22#[serde(bound(
23 serialize = "L1: Serialize, L2: Serialize",
24 deserialize = "L1: Deserialize<'de>, L2: Deserialize<'de>"
25))]
26pub struct Sequential<L1, L2> {
27 pub layer1: L1,
28 pub layer2: L2,
29}
30
31impl<L1, L2, B, D1, D2, D3> Layer<B> for Sequential<L1, L2>
32where
33 B: Backend,
34 L1: Layer<B, Input = D1, Output = D2>,
35 L2: Layer<B, Input = D2, Output = D3>,
36 D1: Dimension,
37 D2: Dimension,
38 D3: Dimension,
39{
40 type Input = D1;
41 type Output = D3;
42
43 fn forward(&mut self, input: &B::Tensor<D1>) -> B::Tensor<D3> {
44 let out = self.layer1.forward(input);
45 self.layer2.forward(&out)
46 }
47
48 fn backward(&mut self, grad_output: &B::Tensor<D3>) -> B::Tensor<D1> {
49 let grad = self.layer2.backward(grad_output);
50 self.layer1.backward(&grad)
51 }
52
53 fn update<O: Optimizer<B>>(&mut self, optimizer: &mut O) {
54 self.layer1.update(optimizer);
55 self.layer2.update(optimizer);
56 }
57}
58
59pub fn seq<L1, L2>(layer1: L1, layer2: L2) -> Sequential<L1, L2> {
60 Sequential { layer1, layer2 }
61}
62
63#[macro_export]
64macro_rules! Layers {
65 ($layer:expr) => { $layer };
66 ($layer1:expr, $layer2:expr) => { $crate::layer::seq($layer1, $layer2) };
67 ($layer1:expr, $layer2:expr, $($rest:expr),+) => {
68 $crate::layer::seq($layer1, $crate::Layers!($layer2, $($rest),+))
69 };
70}