use super::{FromGenome, Network, Recurrent, Stateful};
use crate::{Connection, Genome};
use nalgebra as na;
#[derive(Debug)]
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
derive(serde::Serialize, serde::Deserialize)
)]
pub struct NonBias {
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(with = "crate::serialize::dmatrix")
)]
pub y: na::DMatrix<f64>,
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(with = "crate::serialize::dmatrix")
)]
pub w: na::DMatrix<f64>,
pub sensory: (usize, usize),
pub action: (usize, usize),
}
impl Network for NonBias {
fn step<F: Fn(f64) -> f64>(&mut self, prec: usize, input: &[f64], σ: F) {
let mut m_input = na::DMatrix::zeros(1, self.y.ncols());
m_input.as_mut_slice()[self.sensory.0..self.sensory.1].copy_from_slice(input);
let inv = 1. / (prec as f64);
let mut temp1 = na::DMatrix::zeros(1, self.y.ncols());
let mut temp2 = na::DMatrix::zeros(1, self.y.ncols());
for _ in 0..prec {
temp1.copy_from(&self.y);
temp1 += &m_input;
for val in temp1.iter_mut() {
*val = σ(*val);
}
temp2.gemm(1.0, &temp1, &self.w, 0.0);
temp2 *= inv;
self.y.copy_from(&temp2);
}
}
fn flush(&mut self) {
self.y = na::DMatrix::zeros(1, self.y.ncols());
}
fn output(&self) -> &[f64] {
&self.y.as_slice()[self.action.0..self.action.1]
}
}
impl Recurrent for NonBias {}
impl Stateful for NonBias {}
impl<C: Connection, G: Genome<C>> FromGenome<C, G> for NonBias {
fn from_genome(genome: &G) -> Self {
let cols = genome.node_count();
Self {
y: na::DMatrix::zeros(1, cols),
w: {
let mut w = vec![0.; cols * cols];
for c in genome.connections().iter().filter(|c| c.enabled()) {
w[c.from() * cols + c.to()] = c.weight();
}
na::DMatrix::from_row_slice(cols, cols, &w)
},
sensory: (genome.sensory().start, genome.sensory().end),
action: (genome.action().start, genome.action().end),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::genome::{self, connection::BWConnection, WConnection};
use eevee_macros::fn_matrix;
fn_matrix! {
C: WConnection | BWConnection,
G: genome::Recurrent<C>,
#[test]
fn test_matrix_structure() {
let (genome, _) = G::new(3, 2);
let nn = NonBias::from_genome(&genome);
let cols = genome.node_count();
assert_eq!(nn.y.ncols(), cols);
assert_eq!(nn.y.nrows(), 1);
assert_eq!(nn.w.ncols(), cols);
assert_eq!(nn.w.nrows(), cols);
}
#[test]
fn test_bounds() {
let (genome, _) = G::new(4, 3);
let nn = NonBias::from_genome(&genome);
assert_eq!(nn.sensory.0, genome.sensory().start);
assert_eq!(nn.sensory.1, genome.sensory().end);
assert_eq!(nn.action.0, genome.action().start);
assert_eq!(nn.action.1, genome.action().end);
assert_eq!(nn.output().len(), genome.action().len());
}
#[test]
fn test_disabled_connections_excluded() {
let (mut genome, _) = G::new(2, 2);
if let Some(c) = genome.connections_mut().first_mut() {
c.disable();
}
let nn = NonBias::from_genome(&genome);
let cols = genome.node_count();
assert_eq!(nn.w.as_slice()[0 * cols + 2], 0.0);
}
}
}