1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
use ndarray::Array2; use ndarray_parallel::prelude::*; use ndarray_rand::{RandomExt, F32}; use rand::distributions::Normal; use serde_derive::{Deserialize, Serialize}; use crate::activations::{self, Activation}; use crate::layers::Layer; #[derive(Default, Serialize, Deserialize)] pub struct Dense { pub weights: Array2<f32>, pub n_input: usize, pub n_output: usize, pub output: Option<Array2<f32>>, pub input: Option<Array2<f32>>, pub activation: Activation, } #[typetag::serde] impl Layer for Dense { fn new(n_input: usize, n_output: usize, activation: Activation) -> Self { let weights = Array2::<f32>::random((n_input, n_output), F32(Normal::new(-1., 1.))); Dense { weights, n_input, n_output, output: None, input: None, activation, } } fn forward(&mut self, x: Array2<f32>) -> Array2<f32> { self.input = Some(x.clone()); self.output = match self.activation { Activation::Linear => Some(x.dot(&self.weights)), Activation::Sigmoid => Some(activations::sigmoid(&x.dot(&self.weights), false)), Activation::Tanh => Some(activations::tanh(&x.dot(&self.weights), false)), Activation::Softmax => Some(activations::softmax(&x.dot(&self.weights), false)), }; self.output.clone().unwrap() } fn n_input(&self) -> usize { self.n_input } fn n_output(&self) -> usize { self.n_output } fn output(&self) -> Array2<f32> { self.output.clone().unwrap() } fn input(&self) -> Array2<f32> { self.input.clone().unwrap() } fn weights(&self) -> Array2<f32> { self.weights.clone() } fn set_weights(&mut self, weights: Array2<f32>) -> () { self.weights = weights; } fn backward(&mut self, error: Array2<f32>, lr: f32) -> Array2<f32> { let delta = match self.activation { Activation::Sigmoid => activations::sigmoid(&self.output(), true) * error.t(), Activation::Linear => self.output() * error.t(), Activation::Tanh => activations::tanh(&self.output(), true) * error.t(), Activation::Softmax => activations::softmax(&self.output(), true) * error.t(), }; let mut updates = self.input().t().dot(&delta); updates.par_mapv_inplace(|v| v * lr); let error_out = self.weights().dot(&delta.t()); self.weights += &updates; error_out } }