gann/layers/
mod.rs

1mod activation_functions;
2mod basic_dense_layer;
3mod basic_loss;
4mod dense_layers;
5mod leaky_relu_layer;
6mod loss_functions;
7mod losses;
8mod softmax_layer;
9
10pub use activation_functions::ActivationFn;
11pub use basic_dense_layer::BasicDenseLayer;
12pub use basic_loss::BasicLoss;
13pub use dense_layers::{LinearLayer, SigmoidLayer, SwishLayer, TanhLayer};
14pub use leaky_relu_layer::LeakyReLULayer;
15pub use loss_functions::LossFn;
16pub use losses::{CrossEntropyLoss, MSELoss};
17use rand;
18pub use softmax_layer::SoftMaxLayer;
19
20use crate::la::Vector;
21
22
23pub trait FillLayerWith {
24  /// recursively fill weights and biases with random values
25  fn fill_with_rng<Rng>(&mut self, rng: &mut Rng)
26  where
27    Rng: rand::Rng;
28
29  /// recursively fill weights and biases with values from the functor,
30  /// for more info see implementations
31  fn fill_with<F>(&mut self, f: F)
32  where
33    F: FnMut() -> f32;
34
35  /// recursively fill weights and biases with the value
36  fn fill(&mut self, value: f32);
37
38  /// recursively fill various params with values from the functor,
39  /// for more info see implementations
40  fn fill_params_with<F>(&mut self, f: F)
41  where
42    F: FnMut() -> f32;
43
44  /// recursively fill various params with specific value,
45  /// for more info see implementations
46  fn fill_params(&mut self, value: f32);
47}
48
49
50pub trait RecursiveLayer<const INPUTS: usize, const OUTPUTS: usize> {
51  /// apply input data (or output of previous layer) by the layer and pass it to the next layer
52  fn infer(&mut self, input: &Vector<INPUTS>) -> Vector<OUTPUTS>;
53
54  /// recursive back propagation training
55  ///
56  /// # Arguments
57  /// - `input` - input data or output of previous layer
58  /// - `desired` - desired output of last layer (output of the ANN)
59  /// - `eta` - learning rate
60  ///
61  /// # Return values
62  /// downstream gradient and loss
63  fn train(
64    &mut self,
65    input: &Vector<INPUTS>,
66    desired: &Vector<OUTPUTS>,
67    eta: f32,
68  ) -> (Vector<INPUTS>, f32);
69}