Skip to main content

meuron/
lib.rs

1pub mod activation;
2pub mod backend;
3pub mod cost;
4pub mod initializer;
5pub mod layer;
6pub mod metric;
7pub mod optimizer;
8pub mod serialization;
9pub mod train;
10
11pub use activation::{ReLU, Sigmoid, Softmax, Tanh};
12pub use backend::DefaultBackend;
13pub use cost::{BinaryCrossEntropy, CrossEntropy, MSE};
14pub use initializer::{Constant, HeNormal, Initializer, XavierUniform, Zeros};
15pub use layer::DenseLayer;
16pub use metric::classification::accuracy;
17pub use optimizer::{SGD, SGDMomentum};
18pub use train::{PrintCallback, TrainCallback, TrainOptions};
19
20use crate::backend::Backend;
21use crate::layer::Layer;
22use ndarray::{Array, Dimension};
23use std::marker::PhantomData;
24
25pub struct NeuralNetwork<L, C, B: Backend = DefaultBackend>
26where
27    L: Layer<B>,
28{
29    pub layers: L,
30    pub cost: C,
31    pub(crate) _backend: PhantomData<B>,
32}
33
34impl<L, C, B> NeuralNetwork<L, C, B>
35where
36    B: Backend,
37    L: Layer<B>,
38{
39    pub fn new(layers: L, cost: C) -> Self {
40        NeuralNetwork {
41            layers,
42            cost,
43            _backend: PhantomData,
44        }
45    }
46
47    pub fn predict<D>(&mut self, input: Array<f32, D>) -> Array<f32, L::Output>
48    where
49        D: Dimension,
50        Array<f32, D>: Into<B::Tensor<L::Input>>,
51    {
52        B::to_array(&self.layers.forward(&input.into()))
53    }
54
55    #[inline]
56    pub fn forward<I>(&mut self, input: I) -> B::Tensor<L::Output>
57    where
58        I: Into<B::Tensor<L::Input>>,
59    {
60        self.layers.forward(&input.into())
61    }
62
63    #[inline]
64    pub fn backward<G>(&mut self, grad_output: G) -> B::Tensor<L::Input>
65    where
66        G: Into<B::Tensor<L::Output>>,
67    {
68        self.layers.backward(&grad_output.into())
69    }
70}
71
72#[macro_export]
73macro_rules! NetworkType {
74    ($l:ty) => { $l };
75    ($l1:ty, $l2:ty) => { $crate::layer::Sequential<$l1, $l2> };
76    ($l1:ty, $l2:ty, $($rest:ty),+) => {
77        $crate::layer::Sequential<$l1, $crate::NetworkType!($l2, $($rest),+)>
78    };
79}