use rand::{thread_rng, Rng};
use na::DMatrix as Matrix;
use crate::Activation;
#[derive(Debug, Clone)]
pub struct Layer {
pub(crate) input_len: usize,
pub(crate) output_len: usize,
pub(crate) gene_len: usize,
pub(crate) activation: Activation,
weights: Matrix<f64>,
biases: Matrix<f64>,
act: Matrix<f64>, act_func: fn(f64) -> f64,
net: Matrix<f64>,
}
impl Layer {
pub fn new(input_len: usize, output_len: usize, activation: Activation) -> Self {
let weights = Matrix::from_vec(
output_len,
input_len,
rand_vec_uniform(input_len * output_len),
);
let biases = Matrix::from_vec(output_len, 1, rand_vec_uniform(output_len));
let act_func = activation.get_func();
Self {
input_len,
output_len,
activation,
gene_len: output_len * input_len + output_len,
weights,
biases,
act: Matrix::from_vec(input_len, 1, vec![0.0; input_len]),
act_func,
net: Matrix::from_vec(output_len, 1, vec![0.0; output_len]),
}
}
pub fn genes(&self) -> Vec<f64> {
let mut out: Vec<f64> = Vec::new();
out.append(&mut self.weights.as_slice().into());
out.append(&mut self.biases.as_slice().into());
out
}
pub fn num_genes(&self) -> usize {
self.weights.len() + self.biases.len()
}
pub(crate) fn forward(&mut self, m: &Matrix<f64>) -> Matrix<f64> {
let net = &self.weights * m + &self.biases;
net.apply_into(&self.act_func)
}
pub(crate) fn set_weights(&mut self, w: Matrix<f64>) {
assert_eq!(self.weights.nrows(), w.nrows());
assert_eq!(self.weights.ncols(), w.ncols());
self.weights = w;
}
pub(crate) fn set_biases(&mut self, b: Matrix<f64>) {
assert_eq!(self.biases.nrows(), b.nrows());
assert_eq!(self.biases.ncols(), b.ncols());
self.biases = b;
}
pub fn set_genes(&mut self, genes: &Vec<f64>) {
let w_end = self.output_len * self.input_len;
let weights: Matrix<f64> =
Matrix::from_vec(self.output_len, self.input_len, genes[..w_end].to_vec());
self.set_weights(weights);
let biases: Matrix<f64> = Matrix::from_vec(self.output_len, 1, genes[w_end..].to_vec());
self.set_biases(biases);
}
}
fn rand_vec_uniform(length: usize) -> Vec<f64> {
let mut rng = thread_rng();
(0..length).map(|_| rng.gen::<f64>() * 2.0 - 1.0).collect()
}
#[cfg(test)]
mod tests {
extern crate round;
use super::*;
use round::*;
#[test]
fn layer_forward1() {
let mut l = Layer::new(3, 1, Activation::Relu);
let w = Matrix::from_vec(1, 3, vec![0.2, 0.4, 0.8]);
l.set_weights(w);
let b = Matrix::from_vec(1, 1, vec![0.0]);
l.set_biases(b);
let input = Matrix::from_vec(3, 1, vec![0.2, 0.4, 0.8]);
let output = l.forward(&input);
assert_eq!(output.ncols(), 1);
assert_eq!(output.nrows(), 1);
assert_eq!(round(output.row(0)[0], 2), 0.84);
}
#[test]
fn layer_forward2() {
let mut l = Layer::new(3, 3, Activation::Relu);
let w = Matrix::from_vec(3, 3, vec![1.0; 9]);
l.set_weights(w);
let b = Matrix::from_vec(3, 1, vec![0.0; 3]);
l.set_biases(b);
let input = Matrix::from_vec(3, 1, vec![1.0; 3]);
let output = l.forward(&input);
println!("\n--output: {:?}", output);
assert_eq!(output, Matrix::from_vec(3, 1, vec![3.0, 3.0, 3.0]));
}
#[test]
fn layer_set_weights1() {
let mut l = Layer::new(3, 1, Activation::Relu);
let w: Matrix<f64> = Matrix::from_vec(1, 3, vec![1.0; 3]);
l.set_weights(w);
}
#[test]
#[should_panic]
fn layer_set_weights2() {
let mut l = Layer::new(3, 1, Activation::Relu);
let w: Matrix<f64> = Matrix::from_vec(3, 4, vec![1.0; 3]);
l.set_weights(w);
}
#[test]
fn layer_set_biases() {
let mut l = Layer::new(3, 1, Activation::Relu);
let b: Matrix<f64> = Matrix::from_vec(1, 1, vec![0.0]);
l.set_biases(b);
}
#[test]
fn layer_enc_fit() {
}
#[test]
fn layer_genes() {
let l = Layer::new(3, 1, Activation::Relu);
let genes = l.genes();
assert_eq!(genes.len(), 4);
}
#[test]
fn layer_gene_len() {
let l = Layer::new(3, 1, Activation::Relu);
assert_eq!(l.gene_len, 4);
let l = Layer::new(3, 3, Activation::Relu);
assert_eq!(l.gene_len, 12);
}
#[test]
fn layer_set_genes() {
let mut l = Layer::new(3, 1, Activation::Relu);
let genes: Vec<f64> = vec![1.0; 4];
l.set_genes(&genes);
assert_eq!(l.genes(), genes);
let mut l = Layer::new(3, 3, Activation::Relu);
let genes: Vec<f64> = vec![1.0; 12];
l.set_genes(&genes);
assert_eq!(l.genes(), genes);
}
}