use super::{FromGenome, Network};
use crate::{Connection, Genome};
use core::ops::Range;
#[derive(Debug)]
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
derive(serde::Serialize, serde::Deserialize),
serde(bound(
serialize = "C: serde::Serialize",
deserialize = "C: serde::Deserialize<'de>"
))
)]
pub struct Simple<C: Connection> {
pub(crate) connections: Vec<C>,
pub(crate) bias: Vec<f64>,
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(skip_serializing)
)]
pub(crate) state: Vec<f64>,
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(skip_serializing)
)]
pub(crate) sensory: Range<usize>,
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(skip_serializing)
)]
pub(crate) action: Range<usize>,
}
impl<C: Connection> Network for Simple<C> {
fn step<F: Fn(f64) -> f64>(&mut self, prec: usize, input: &[f64], σ: F) {
debug_assert!(input.len() == self.sensory.len());
self.state[self.sensory.start..self.sensory.end].copy_from_slice(input);
if !self.connections.is_empty() {
for _ in 0..prec {
for c in self.connections.iter() {
self.state[c.to()] +=
σ((self.bias[c.from()] + self.state[c.from()]) * c.weight())
}
}
}
}
fn flush(&mut self) {
self.state = vec![0.; self.state.len()];
}
fn output(&self) -> &[f64] {
&self.state[self.action.start..self.action.end]
}
}
impl<C: Connection, G: Genome<C>> FromGenome<C, G> for Simple<C> {
fn from_genome(genome: &G) -> Self {
Simple {
connections: genome.connections().to_owned(),
bias: vec![0.; genome.node_count()],
state: vec![0.; genome.node_count()],
sensory: genome.sensory(),
action: genome.action(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::genome::{self, connection::BWConnection, WConnection};
use eevee_macros::fn_matrix;
fn_matrix! {
C: WConnection | BWConnection,
G: genome::Recurrent<C>,
#[test]
fn test_preserves_all_connections() {
let (genome, _) = G::new(2, 2);
let nn = Simple::from_genome(&genome);
assert_eq!(nn.connections.len(), genome.connections().len());
}
#[test]
fn test_preserves_enabled_status() {
let (mut genome, _) = G::new(2, 2);
if let Some(c) = genome.connections_mut().first_mut() {
c.disable();
}
let nn = Simple::from_genome(&genome);
assert!(!nn.connections[0].enabled());
assert!(nn.connections[1].enabled());
}
#[test]
fn test_state_size() {
let (genome, _) = G::new(3, 2);
let nn = Simple::from_genome(&genome);
assert_eq!(nn.state.len(), genome.node_count());
assert_eq!(nn.bias.len(), genome.node_count());
}
#[test]
fn test_input_range() {
let (genome, _) = G::new(3, 2);
let nn = Simple::from_genome(&genome);
assert_eq!(nn.sensory.start, genome.sensory().start);
assert_eq!(nn.sensory.end, genome.sensory().end);
assert_eq!(nn.sensory.len(), genome.sensory().len());
}
#[test]
fn test_initial_state() {
let (genome, _) = G::new(2, 2);
let nn = Simple::from_genome(&genome);
assert!(nn.state.iter().all(|&x| x == 0.0));
}
#[test]
fn test_flush_clears_state() {
let (genome, _) = G::new(2, 2);
let mut nn = Simple::from_genome(&genome);
let input = vec![1.0, 0.5];
nn.step(2, &input, |x| x);
let state_before = nn.state.clone();
nn.flush();
let state_after = nn.state.clone();
assert!(state_before.iter().any(|&x| x != 0.0));
assert!(state_after.iter().all(|&x| x == 0.0));
}
}
}