mod error;
mod evaluate;
mod utils;
pub use error::{
Error, IndexOutOfBoundsError, MismatchedLengthsError, MutationError, NotEnoughInputsError,
};
use num_traits::Float;
#[cfg(all(feature = "serde", feature = "json"))]
use serde::de::DeserializeOwned;
#[cfg(all(feature = "serde", feature = "json"))]
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::iter;
use std::ops::{Index, Range};
#[cfg(all(feature = "serde", feature = "json"))]
use std::path::Path;
use crate::activation::Activation;
#[cfg(all(feature = "serde", feature = "json"))]
use crate::encoding::{self, CommonMetadata, Extra};
#[cfg(feature = "serde")]
use crate::encoding::{EncodingVersion, MetadataVersion, PortableCGE, WithRecurrentState};
use crate::gene::*;
use crate::stack::Stack;
use evaluate::Inputs;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct NeuronInfo {
subgenome_range: Range<usize>,
depth: usize,
}
impl NeuronInfo {
fn new(subgenome_range: Range<usize>, depth: usize) -> Self {
Self {
subgenome_range,
depth,
}
}
pub fn subgenome_range(&self) -> Range<usize> {
self.subgenome_range.clone()
}
pub fn depth(&self) -> usize {
self.depth
}
}
#[derive(Clone, Debug)]
#[cfg_attr(test, derive(PartialEq))]
pub struct Network<T: Float> {
genome: Vec<Gene<T>>,
activation: Activation,
next_neuron_id: usize,
neuron_info: HashMap<NeuronId, NeuronInfo>,
gene_parents: Vec<Option<NeuronId>>,
recurrent_state_ids: Vec<NeuronId>,
num_inputs: usize,
num_outputs: usize,
stack: Stack<T>,
}
impl<T: Float> Network<T> {
pub fn new(genome: Vec<Gene<T>>, activation: Activation) -> Result<Self, Error> {
let mut network = Self {
genome,
activation,
next_neuron_id: 0,
neuron_info: HashMap::new(),
gene_parents: Vec::new(),
recurrent_state_ids: Vec::new(),
num_inputs: 0,
num_outputs: 0,
stack: Stack::new(),
};
network.rebuild_metadata()?;
Ok(network)
}
#[cfg(all(feature = "serde", feature = "json"))]
pub fn load_str<'a, E>(
s: &'a str,
with_state: WithRecurrentState,
) -> Result<(Self, CommonMetadata, Extra<E>), encoding::Error>
where
T: Deserialize<'a>,
E: Deserialize<'a>,
{
encoding::load_str(s, with_state)
}
#[cfg(all(feature = "serde", feature = "json"))]
pub fn load_file<E, P>(
path: P,
with_state: WithRecurrentState,
) -> Result<(Self, CommonMetadata, Extra<E>), encoding::Error>
where
T: DeserializeOwned,
E: DeserializeOwned,
P: AsRef<Path>,
{
encoding::load_file(path, with_state)
}
#[cfg(all(feature = "serde", feature = "json"))]
pub fn to_string<E, M>(
&self,
metadata: M,
extra: E,
with_state: WithRecurrentState,
) -> Result<String, encoding::Error>
where
T: Serialize,
E: Serialize,
M: MetadataVersion<T, E>,
{
encoding::to_string(self.to_serializable(metadata, extra, with_state))
}
#[cfg(all(feature = "serde", feature = "json"))]
pub fn to_file<E, M, P>(
&self,
metadata: M,
extra: E,
with_state: WithRecurrentState,
path: P,
create_dirs: bool,
) -> Result<(), encoding::Error>
where
T: Serialize,
E: Serialize,
M: MetadataVersion<T, E>,
P: AsRef<Path>,
{
encoding::to_file(
self.to_serializable(metadata, extra, with_state),
path,
create_dirs,
)
}
#[cfg(feature = "serde")]
pub fn to_serializable<E, M>(
&self,
metadata: M,
extra: E,
with_state: WithRecurrentState,
) -> PortableCGE<T, E>
where
M: MetadataVersion<T, E>,
{
M::Data::new(self, metadata, extra, with_state)
}
#[deny(clippy::integer_arithmetic, clippy::as_conversions)]
fn rebuild_metadata(&mut self) -> Result<(), Error> {
struct StoppingPoint {
counter: isize,
id: NeuronId,
start_index: usize,
depth: usize,
}
struct ForwardJumperCheck {
jumper_index: usize,
parent_depth: usize,
source_id: NeuronId,
}
struct RecurrentJumperCheck {
jumper_index: usize,
source_id: NeuronId,
}
if self.genome.is_empty() {
return Err(Error::EmptyGenome);
}
let mut counter = 0isize;
let neuron_info = &mut self.neuron_info;
neuron_info.clear();
let gene_parents = &mut self.gene_parents;
gene_parents.clear();
let mut stopping_points: Vec<StoppingPoint> = Vec::new();
let mut forward_jumper_checks: Vec<ForwardJumperCheck> = Vec::new();
let mut recurrent_jumper_checks: Vec<RecurrentJumperCheck> = Vec::new();
let mut max_input_id = None;
let mut max_neuron_id = None;
for (i, gene) in self.genome.iter().enumerate() {
let parent = stopping_points.last().map(|p| p.id);
let depth = stopping_points.len();
counter = counter.checked_add(1).ok_or(Error::Arithmetic)?;
gene_parents.push(parent);
if let Gene::Neuron(neuron) = gene {
stopping_points.push(StoppingPoint {
counter,
id: neuron.id(),
start_index: i,
depth,
});
if neuron.num_inputs() == 0 {
return Err(Error::InvalidInputCount(i, neuron.id()));
}
let num_inputs = isize::try_from(neuron.num_inputs())?;
counter = counter.checked_sub(num_inputs).ok_or(Error::Arithmetic)?;
max_neuron_id = max_neuron_id
.or(Some(0))
.map(|max_id| max_id.max(neuron.id().as_usize()));
} else {
if parent.is_none() {
return Err(Error::NonNeuronOutput(i));
}
match gene {
Gene::ForwardJumper(forward) => {
let parent_depth = depth.checked_sub(1).unwrap();
forward_jumper_checks.push(ForwardJumperCheck {
jumper_index: i,
parent_depth,
source_id: forward.source_id(),
});
}
Gene::RecurrentJumper(recurrent) => {
recurrent_jumper_checks.push(RecurrentJumperCheck {
jumper_index: i,
source_id: recurrent.source_id(),
});
}
Gene::Input(input) => {
max_input_id = max_input_id
.or(Some(0))
.map(|max_id| max_id.max(input.id().as_usize()));
}
_ => {}
}
while !stopping_points.is_empty()
&& stopping_points.last().unwrap().counter == counter
{
let stop = stopping_points.pop().unwrap();
if let Some(existing) = neuron_info.get(&stop.id) {
let existing_index = existing.subgenome_range().start;
return Err(Error::DuplicateNeuronId(
existing_index,
stop.start_index,
stop.id,
));
}
let end_index = i.checked_add(1).unwrap();
let subgenome_range = stop.start_index..end_index;
neuron_info.insert(stop.id, NeuronInfo::new(subgenome_range, stop.depth));
}
}
}
if let Some(stop) = stopping_points.last() {
return Err(Error::NotEnoughInputs(stop.start_index, stop.id));
}
for check in forward_jumper_checks {
if let Some(source_info) = neuron_info.get(&check.source_id) {
if check.parent_depth >= source_info.depth() {
return Err(Error::InvalidForwardJumper(check.jumper_index));
}
} else {
return Err(Error::InvalidJumperSource(
check.jumper_index,
check.source_id,
));
}
}
for check in recurrent_jumper_checks {
if !neuron_info.contains_key(&check.source_id) {
return Err(Error::InvalidJumperSource(
check.jumper_index,
check.source_id,
));
}
}
self.update_recurrent_state_ids();
self.next_neuron_id = max_neuron_id
.unwrap()
.checked_add(1)
.ok_or(Error::Arithmetic)?;
self.num_inputs = match max_input_id {
Some(id) => id.checked_add(1).ok_or(Error::Arithmetic)?,
None => 0,
};
self.num_outputs = usize::try_from(counter).unwrap();
Ok(())
}
pub fn evaluate(&mut self, inputs: &[T]) -> Result<&[T], NotEnoughInputsError> {
if inputs.len() < self.num_inputs {
return Err(NotEnoughInputsError::new(self.num_inputs(), inputs.len()));
}
self.stack.clear();
let inputs = Inputs(inputs);
let length = self.len();
evaluate::evaluate_slice(
&mut self.genome,
0..length,
inputs,
&mut self.stack,
false,
&self.neuron_info,
self.activation,
);
update_stored_values(&mut self.genome);
Ok(self.stack.as_slice())
}
pub fn clear_state(&mut self) {
for gene in &mut self.genome {
if let Gene::Neuron(neuron) = gene {
*neuron.mut_previous_value() = T::zero();
}
}
}
fn update_recurrent_state_ids(&mut self) {
let state_ids = &mut self.recurrent_state_ids;
state_ids.clear();
let mut unique_state_ids = HashSet::new();
for gene in &self.genome {
if let Gene::RecurrentJumper(recurrent) = gene {
let source_id = recurrent.source_id();
if unique_state_ids.insert(source_id) {
state_ids.push(source_id);
}
}
}
}
pub fn recurrent_state_len(&self) -> usize {
self.recurrent_state_ids.len()
}
pub fn recurrent_state(&self) -> impl Iterator<Item = T> + '_ {
self.recurrent_state_ids
.iter()
.map(move |id| self.get_neuron(*id).unwrap().previous_value())
}
pub fn recurrent_state_ids(&self) -> &[NeuronId] {
&self.recurrent_state_ids
}
pub fn map_recurrent_state<F: FnMut(usize, &mut T)>(&mut self, mut f: F) {
for (i, id) in self.recurrent_state_ids.iter().enumerate() {
let source = utils::get_mut_neuron(*id, &self.neuron_info, &mut self.genome).unwrap();
f(i, source.mut_previous_value());
}
}
pub fn set_recurrent_state(&mut self, state: &[T]) -> Result<(), MismatchedLengthsError> {
if state.len() != self.recurrent_state_ids.len() {
Err(MismatchedLengthsError)
} else {
self.map_recurrent_state(|i, val| *val = state[i]);
Ok(())
}
}
pub fn set_recurrent_state_at(
&mut self,
index: usize,
value: T,
) -> Result<(), IndexOutOfBoundsError> {
self.recurrent_state_ids
.get(index)
.cloned()
.map(|id| {
let source =
utils::get_mut_neuron(id, &self.neuron_info, &mut self.genome).unwrap();
*source.mut_previous_value() = value;
})
.ok_or(IndexOutOfBoundsError)
}
pub fn genome(&self) -> &[Gene<T>] {
&self.genome
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.genome.len()
}
pub fn get(&self, index: usize) -> Option<&Gene<T>> {
self.genome.get(index)
}
pub fn activation(&self) -> Activation {
self.activation
}
pub fn set_activation(&mut self, new: Activation) {
self.clear_state();
self.activation = new;
}
pub fn num_inputs(&self) -> usize {
self.num_inputs
}
pub fn num_outputs(&self) -> usize {
self.num_outputs
}
pub fn num_neurons(&self) -> usize {
self.neuron_info.len()
}
pub fn contains(&self, id: NeuronId) -> bool {
self.neuron_info.contains_key(&id)
}
pub fn get_neuron(&self, id: NeuronId) -> Option<&Neuron<T>> {
utils::get_neuron(id, &self.neuron_info, &self.genome)
}
pub(crate) fn get_mut_neuron(&mut self, id: NeuronId) -> Option<&mut Neuron<T>> {
utils::get_mut_neuron(id, &self.neuron_info, &mut self.genome)
}
pub fn neuron_ids(&self) -> impl Iterator<Item = NeuronId> + '_ {
self.neuron_info.keys().cloned()
}
pub fn neuron_info(&self, id: NeuronId) -> Option<&NeuronInfo> {
self.neuron_info.get(&id)
}
pub fn neuron_info_map(&self) -> &HashMap<NeuronId, NeuronInfo> {
&self.neuron_info
}
pub fn parent_of(&self, index: usize) -> Option<Option<NeuronId>> {
self.gene_parents.get(index).cloned()
}
pub fn parents(&self) -> &[Option<NeuronId>] {
&self.gene_parents
}
pub fn next_neuron_id(&self) -> NeuronId {
NeuronId::new(self.next_neuron_id)
}
pub fn weights(&self) -> impl Iterator<Item = T> + '_ {
self.genome.iter().map(Gene::weight)
}
pub fn mut_weights(&mut self) -> impl Iterator<Item = &mut T> {
self.clear_state();
self.genome.iter_mut().map(Gene::mut_weight)
}
pub fn set_weights(&mut self, weights: &[T]) -> Result<(), MismatchedLengthsError> {
if weights.len() != self.len() {
Err(MismatchedLengthsError)
} else {
for (old, new) in self.mut_weights().zip(weights) {
*old = *new;
}
Ok(())
}
}
pub fn add_non_neuron<G: Into<NonNeuronGene<T>>>(
&mut self,
parent: NeuronId,
gene: G,
) -> Result<(), MutationError> {
self.add_genes(parent, None, vec![gene.into()]).map(|_| ())
}
pub fn add_non_neurons(
&mut self,
parent: NeuronId,
genes: Vec<NonNeuronGene<T>>,
) -> Result<(), MutationError> {
self.add_genes(parent, None, genes).map(|_| ())
}
pub fn add_subnetwork(
&mut self,
parent: NeuronId,
weight: T,
inputs: Vec<NonNeuronGene<T>>,
) -> Result<NeuronId, MutationError> {
self.add_genes(parent, Some(weight), inputs)
.map(Option::unwrap)
}
fn add_genes(
&mut self,
parent: NeuronId,
subnetwork_weight: Option<T>,
genes: Vec<NonNeuronGene<T>>,
) -> Result<Option<NeuronId>, MutationError> {
if genes.is_empty() {
return Err(MutationError::Empty);
}
let parent_info = self
.neuron_info
.get(&parent)
.ok_or(MutationError::InvalidParent)?;
let parent_index = parent_info.subgenome_range().start;
let new_sequence_index = parent_index.checked_add(1).unwrap();
let new_neuron_id = subnetwork_weight.map(|_| NeuronId::new(self.next_neuron_id));
let parent_of_new_inputs = if let Some(id) = new_neuron_id {
id
} else {
parent
};
let new_neuron_depth = parent_info.depth().checked_add(1).unwrap();
let added_len = if new_neuron_id.is_some() {
genes.len().checked_add(1).unwrap()
} else {
genes.len()
};
let parent_neuron = self[parent_index].as_neuron().unwrap();
let new_parent_num_inputs = if new_neuron_id.is_some() {
parent_neuron.num_inputs().checked_add(1).unwrap()
} else {
parent_neuron.num_inputs().checked_add(genes.len()).unwrap()
};
let new_next_neuron_id = if let Some(id) = new_neuron_id {
id.as_usize()
.checked_add(1)
.ok_or(MutationError::Arithmetic)?
} else {
self.next_neuron_id
};
let mut new_num_inputs = self.num_inputs;
let mut added_recurrent_jumper = false;
{
let ref_self = &*self;
for gene in &genes {
match gene {
NonNeuronGene::Input(input) => {
new_num_inputs = new_num_inputs.max(
input
.id()
.as_usize()
.checked_add(1)
.ok_or(MutationError::Arithmetic)?,
);
}
NonNeuronGene::ForwardJumper(forward) => {
let points_to_new_neuron = if let Some(id) = new_neuron_id {
forward.source_id() == id
} else {
false
};
if points_to_new_neuron {
return Err(MutationError::InvalidForwardJumper);
}
if let Some(info) = ref_self.neuron_info.get(&forward.source_id()) {
let mut parent_depth = ref_self[parent].depth();
if new_neuron_id.is_some() {
parent_depth = parent_depth.checked_add(1).unwrap();
}
if parent_depth >= info.depth() {
return Err(MutationError::InvalidForwardJumper);
}
} else {
return Err(MutationError::InvalidJumperSource);
}
}
NonNeuronGene::RecurrentJumper(recurrent) => {
let points_to_new_neuron = if let Some(id) = new_neuron_id {
recurrent.source_id() == id
} else {
false
};
if !(points_to_new_neuron
|| ref_self.neuron_info.contains_key(&recurrent.source_id()))
{
return Err(MutationError::InvalidJumperSource);
}
added_recurrent_jumper = true;
}
NonNeuronGene::Bias(_) => {}
}
}
}
for info in self.neuron_info.values_mut() {
if info.subgenome_range.start >= new_sequence_index {
info.subgenome_range.start += added_len;
info.subgenome_range.end += added_len;
} else if info.subgenome_range.contains(&new_sequence_index) {
info.subgenome_range.end += added_len;
}
}
if let Some(id) = new_neuron_id {
let new_info = NeuronInfo::new(
new_sequence_index..new_sequence_index + added_len,
new_neuron_depth,
);
self.neuron_info.insert(id, new_info);
}
self.genome[parent_index]
.as_mut_neuron()
.unwrap()
.set_num_inputs(new_parent_num_inputs);
let genes_len = genes.len();
self.genome.splice(
new_sequence_index..new_sequence_index,
genes.into_iter().map(Into::into),
);
self.gene_parents.splice(
new_sequence_index..new_sequence_index,
iter::repeat(Some(parent_of_new_inputs)).take(genes_len),
);
if added_recurrent_jumper {
self.update_recurrent_state_ids();
}
if let Some(weight) = subnetwork_weight {
let num_inputs = genes_len;
self.genome.insert(
new_sequence_index,
Neuron::new(new_neuron_id.unwrap(), num_inputs, weight).into(),
);
self.gene_parents.insert(new_sequence_index, Some(parent));
}
self.num_inputs = new_num_inputs;
self.next_neuron_id = new_next_neuron_id;
self.clear_state();
Ok(new_neuron_id)
}
pub fn remove_non_neuron(&mut self, index: usize) -> Result<Gene<T>, MutationError> {
if let Some(removed_gene) = self.genome.get(index) {
if removed_gene.is_neuron() {
return Err(MutationError::RemoveNeuron);
}
let parent_id = self.gene_parents[index].unwrap();
let parent = self.get_mut_neuron(parent_id).unwrap();
let num_inputs = parent.num_inputs();
if num_inputs == 1 {
return Err(MutationError::RemoveOnlyInput);
}
parent.set_num_inputs(num_inputs.checked_sub(1).unwrap());
for info in self.neuron_info.values_mut() {
if info.subgenome_range.start > index {
info.subgenome_range.start = info.subgenome_range.start.checked_sub(1).unwrap();
info.subgenome_range.end = info.subgenome_range.end.checked_sub(1).unwrap();
} else if info.subgenome_range.contains(&index) {
info.subgenome_range.end = info.subgenome_range.end.checked_sub(1).unwrap();
}
}
let mut new_max_input_id = None;
for (i, gene) in self.genome.iter().enumerate() {
if let Gene::Input(input) = gene {
if i != index {
new_max_input_id = new_max_input_id
.or(Some(0))
.map(|max_id| max_id.max(input.id().as_usize()));
}
}
}
self.num_inputs = new_max_input_id
.map(|id| id.checked_add(1).unwrap())
.unwrap_or(0);
self.clear_state();
self.gene_parents.remove(index);
let removed = self.genome.remove(index);
if removed.is_recurrent_jumper() {
self.update_recurrent_state_ids();
}
Ok(removed)
} else {
Err(MutationError::RemoveInvalidIndex)
}
}
pub fn get_valid_removals(&self) -> impl Iterator<Item = usize> + '_ {
self.genome
.iter()
.zip(&self.gene_parents)
.enumerate()
.filter_map(move |(i, (gene, parent))| {
if gene.is_neuron() {
None
} else {
let num_inputs = self.get_neuron(parent.unwrap()).unwrap().num_inputs();
if num_inputs > 1 {
Some(i)
} else {
None
}
}
})
}
pub fn get_valid_forward_jumper_sources(
&self,
parent_depth: usize,
) -> impl Iterator<Item = NeuronId> + '_ {
self.neuron_info.iter().filter_map(move |(&id, info)| {
if info.depth() > parent_depth {
Some(id)
} else {
None
}
})
}
}
impl<T: Float> Index<usize> for Network<T> {
type Output = Gene<T>;
fn index(&self, idx: usize) -> &Self::Output {
&self.genome[idx]
}
}
impl<T: Float> Index<NeuronId> for Network<T> {
type Output = NeuronInfo;
fn index(&self, idx: NeuronId) -> &Self::Output {
&self.neuron_info[&idx]
}
}
fn update_stored_values<T: Float>(genome: &mut [Gene<T>]) {
for gene in genome {
if let Gene::Neuron(neuron) = gene {
*neuron.mut_previous_value() = neuron
.current_value()
.expect("neuron's current value is not set");
neuron.set_current_value(None);
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use rand::prelude::*;
use super::*;
use crate::encoding::Metadata;
fn get_file_path(folder: &str, file_name: &str) -> String {
format!("{}/{}/{}", env!("CARGO_MANIFEST_DIR"), folder, file_name)
}
fn bias<G: From<Bias<f64>>>() -> G {
Bias::new(1.0).into()
}
fn input<G: From<Input<f64>>>(id: usize) -> G {
Input::new(InputId::new(id), 1.0).into()
}
fn neuron<G: From<Neuron<f64>>>(id: usize, num_inputs: usize) -> G {
Neuron::new(NeuronId::new(id), num_inputs, 1.0).into()
}
fn forward<G: From<ForwardJumper<f64>>>(source_id: usize) -> G {
ForwardJumper::new(NeuronId::new(source_id), 1.0).into()
}
fn recurrent<G: From<RecurrentJumper<f64>>>(source_id: usize) -> G {
RecurrentJumper::new(NeuronId::new(source_id), 1.0).into()
}
fn check_num_outputs(network: &Network<f64>) {
assert_eq!(
network.num_outputs(),
network
.neuron_info
.iter()
.filter(|(_, info)| info.depth == 0)
.count()
);
}
fn check_state_is_cleared(network: &Network<f64>) {
for g in network.genome().iter().filter(|g| g.is_neuron()) {
assert_eq!(0.0, g.as_neuron().unwrap().previous_value());
}
}
#[test]
fn test_inputs_outputs() {
let genome = vec![neuron(0, 1), bias()];
let net = Network::new(genome, Activation::Linear).unwrap();
assert_eq!(0, net.num_inputs());
assert_eq!(1, net.num_outputs());
check_num_outputs(&net);
let genome2 = vec![neuron(0, 2), input(0), bias()];
let net2 = Network::new(genome2, Activation::Linear).unwrap();
assert_eq!(1, net2.num_inputs());
assert_eq!(1, net2.num_outputs());
check_num_outputs(&net2);
let genome3 = vec![neuron(0, 3), input(0), bias(), input(2)];
let net3 = Network::new(genome3, Activation::Linear).unwrap();
assert_eq!(3, net3.num_inputs());
assert_eq!(1, net3.num_outputs());
check_num_outputs(&net3);
let genome4 = vec![neuron(0, 2), input(0), bias(), neuron(1, 1), input(1)];
let net4 = Network::new(genome4, Activation::Linear).unwrap();
assert_eq!(2, net4.num_inputs());
assert_eq!(2, net4.num_outputs());
check_num_outputs(&net4);
}
#[test]
fn test_set_activation() {
let genome = vec![neuron(0, 2), bias(), input(0)];
let mut net = Network::new(genome, Activation::Linear).unwrap();
let _ = net.evaluate(&[1.0; 1]);
net.set_activation(Activation::Relu);
check_state_is_cleared(&net);
}
#[test]
fn test_set_weights() {
let genome = vec![neuron(0, 2), bias(), input(0)];
let mut net = Network::new(genome, Activation::Linear).unwrap();
let _ = net.evaluate(&[1.0; 1]);
assert!(net.set_weights(&[]).is_err());
assert!(net.set_weights(&[1.0, 2.0, 3.0, 4.0]).is_err());
assert_eq!(&[1.0; 3][..], net.weights().collect::<Vec<_>>().as_slice());
net.set_weights(&[5.0, 6.0, 7.0]).unwrap();
assert_eq!(
&[5.0, 6.0, 7.0][..],
net.weights().collect::<Vec<_>>().as_slice()
);
check_state_is_cleared(&net);
}
#[test]
fn test_mut_weights() {
let genome = vec![neuron(0, 2), bias(), input(0)];
let mut net = Network::new(genome, Activation::Linear).unwrap();
let _ = net.evaluate(&[1.0; 1]);
for _ in net.mut_weights() {}
check_state_is_cleared(&net);
}
#[test]
fn test_recurrent_state() {
let genome = vec![
neuron(0, 2),
recurrent(0),
neuron(1, 3),
neuron(2, 1),
recurrent(2),
recurrent(0),
recurrent(2),
];
let mut net = Network::new(genome, Activation::Linear).unwrap();
assert_eq!(2, net.recurrent_state_len());
assert_eq!(
&[0.0, 0.0][..],
net.recurrent_state().collect::<Vec<_>>().as_slice()
);
assert!(net.set_recurrent_state(&[]).is_err());
assert!(net.set_recurrent_state(&[1.0, 2.0, 3.0]).is_err());
assert_eq!(
&[0.0, 0.0][..],
net.recurrent_state().collect::<Vec<_>>().as_slice()
);
net.set_recurrent_state(&[2.0, 3.0]).unwrap();
assert_eq!(
&[2.0, 3.0][..],
net.recurrent_state().collect::<Vec<_>>().as_slice()
);
assert_eq!(2.0, net[0].as_neuron().unwrap().previous_value());
assert_eq!(3.0, net[3].as_neuron().unwrap().previous_value());
assert!(net.set_recurrent_state_at(2, 1.0).is_err());
net.set_recurrent_state_at(1, 5.0).unwrap();
assert_eq!(5.0, net[3].as_neuron().unwrap().previous_value());
}
#[test]
fn test_save_load_recurrent_state() {
let (mut net, _, _) = Network::<f64>::load_file::<(), _>(
get_file_path("test_data", "test_network_recurrent.cge"),
WithRecurrentState(false),
)
.unwrap();
let _output = net.evaluate(&[]).unwrap();
let saved = net.recurrent_state().collect::<Vec<_>>();
let output2 = net.evaluate(&[]).unwrap().to_vec();
net.clear_state();
net.set_recurrent_state(&saved).unwrap();
let output3 = net.evaluate(&[]).unwrap().to_vec();
assert_eq!(output2, output3);
}
#[test]
fn test_rebuild_metadata() {
let (net, _, _) = Network::<f64>::load_file::<(), _>(
get_file_path("test_data", "test_network_multi_output.cge"),
WithRecurrentState(false),
)
.unwrap();
let mut expected_neuron_info = HashMap::new();
expected_neuron_info.insert(NeuronId::new(0), NeuronInfo::new(0..5, 0));
expected_neuron_info.insert(NeuronId::new(1), NeuronInfo::new(1..4, 1));
expected_neuron_info.insert(NeuronId::new(2), NeuronInfo::new(5..9, 0));
expected_neuron_info.insert(NeuronId::new(3), NeuronInfo::new(9..14, 0));
expected_neuron_info.insert(NeuronId::new(4), NeuronInfo::new(11..14, 1));
assert_eq!(expected_neuron_info, net.neuron_info);
let expected_parents = vec![
None,
Some(NeuronId::new(0)),
Some(NeuronId::new(1)),
Some(NeuronId::new(1)),
Some(NeuronId::new(0)),
None,
Some(NeuronId::new(2)),
Some(NeuronId::new(2)),
Some(NeuronId::new(2)),
None,
Some(NeuronId::new(3)),
Some(NeuronId::new(3)),
Some(NeuronId::new(4)),
Some(NeuronId::new(4)),
];
assert_eq!(expected_parents, net.gene_parents);
assert!(net.recurrent_state_ids.is_empty());
}
#[test]
fn test_clear_state() {
let (mut net, _, _) = Network::<f64>::load_file::<(), _>(
get_file_path("test_data", "test_network_recurrent.cge"),
WithRecurrentState(false),
)
.unwrap();
let output = net.evaluate(&[]).unwrap().to_vec();
let output2 = net.evaluate(&[]).unwrap().to_vec();
assert_ne!(output, output2);
net.clear_state();
let output3 = net.evaluate(&[]).unwrap().to_vec();
assert_eq!(output, output3);
}
#[test]
fn test_next_neuron_id() {
let genome = vec![neuron(0, 2), input(1), neuron(1, 1), bias()];
let net = Network::new(genome, Activation::Linear).unwrap();
assert_eq!(2, net.next_neuron_id);
let genome2 = vec![neuron(2, 1), input(1)];
let net2 = Network::new(genome2, Activation::Linear).unwrap();
assert_eq!(3, net2.next_neuron_id);
}
#[test]
fn test_validate_valid() {
let genome = vec![neuron(0, 2), input(0), bias()];
assert!(Network::new(genome, Activation::Linear).is_ok());
let genome2 = vec![
neuron(0, 5),
input(0),
bias(),
forward(1),
recurrent(1),
neuron(1, 1),
input(1),
];
assert!(Network::new(genome2, Activation::Linear).is_ok());
}
#[test]
fn test_validate_empty() {
let genome = vec![];
assert_eq!(
Network::<f64>::new(genome, Activation::Linear).unwrap_err(),
Error::EmptyGenome
);
}
#[test]
fn test_validate_invalid_input_count() {
let genome = vec![neuron(0, 1), neuron(2, 0), input(0)];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::InvalidInputCount(1, NeuronId::new(2))
);
}
#[test]
fn test_validate_not_enough_inputs() {
let genome = vec![neuron(0, 2), neuron(2, 1), input(0)];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::NotEnoughInputs(0, NeuronId::new(0))
);
let genome2 = vec![neuron(1, 1)];
assert_eq!(
Network::new(genome2, Activation::Linear).unwrap_err(),
Error::NotEnoughInputs(0, NeuronId::new(1))
);
let genome3 = vec![neuron(2, 3), bias(), input(0)];
assert_eq!(
Network::new(genome3, Activation::Linear).unwrap_err(),
Error::NotEnoughInputs(0, NeuronId::new(2))
);
}
#[test]
fn test_validate_duplicate_neuron_id() {
let genome = vec![neuron(1, 2), input(1), neuron(1, 1), bias()];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::DuplicateNeuronId(2, 0, NeuronId::new(1))
);
let genome2 = vec![
neuron(0, 2),
input(1),
neuron(1, 2),
bias(),
neuron(1, 1),
input(0),
];
assert_eq!(
Network::new(genome2, Activation::Linear).unwrap_err(),
Error::DuplicateNeuronId(4, 2, NeuronId::new(1))
);
}
#[test]
fn test_validate_non_neuron_output() {
for gene in vec![bias(), input(0), forward(1), recurrent(1)] {
let genome = vec![gene];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::NonNeuronOutput(0)
);
}
let genome2 = vec![neuron(0, 2), input(1), bias(), input(0)];
assert_eq!(
Network::new(genome2, Activation::Linear).unwrap_err(),
Error::NonNeuronOutput(3)
);
let genome3 = vec![bias(), neuron(0, 1), input(0)];
assert_eq!(
Network::new(genome3, Activation::Linear).unwrap_err(),
Error::NonNeuronOutput(0)
);
}
#[test]
fn test_validate_invalid_jumper_source() {
let genome = vec![neuron(0, 1), forward(3), neuron(1, 1), bias()];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::InvalidJumperSource(1, NeuronId::new(3))
);
let genome2 = vec![neuron(0, 1), recurrent(2)];
assert_eq!(
Network::new(genome2, Activation::Linear).unwrap_err(),
Error::InvalidJumperSource(1, NeuronId::new(2))
);
}
#[test]
fn test_validate_invalid_forward_jumper() {
let genome = vec![neuron(0, 1), forward(0)];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::InvalidForwardJumper(1)
);
let genome2 = vec![
neuron(0, 2),
neuron(1, 1),
input(0),
neuron(2, 1),
neuron(3, 1),
forward(1),
];
assert_eq!(
Network::new(genome2, Activation::Linear).unwrap_err(),
Error::InvalidForwardJumper(5)
);
}
#[test]
fn test_validate_extreme_neuron_input_count() {
let genome = vec![neuron(usize::MAX, (usize::MAX / 2) + 1)];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::Arithmetic
);
}
#[test]
fn test_validate_extreme_neuron_id() {
let genome = vec![neuron(usize::MAX, 1), bias()];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::Arithmetic
);
}
#[test]
fn test_validate_extreme_input_id() {
let genome = vec![neuron(0, 1), input(usize::MAX)];
assert_eq!(
Network::new(genome, Activation::Linear).unwrap_err(),
Error::Arithmetic
);
}
fn run_invalid_mutation_test<F: Fn(&mut Network<f64>) -> Result<(), MutationError>>(
genome: Vec<Gene<f64>>,
mutate: F,
expected: MutationError,
) {
let mut network = Network::new(genome, Activation::Linear).unwrap();
let old = network.clone();
assert_eq!(Err(expected), mutate(&mut network));
assert_eq!(old, network);
}
#[test]
fn test_mutate_invalid_parent() {
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| {
let new_gene: NonNeuronGene<_> = input(1);
net.add_non_neuron(NeuronId::new(1), new_gene)
},
MutationError::InvalidParent,
);
}
#[test]
fn test_mutate_invalid_jumper_source() {
let new_genes: [NonNeuronGene<_>; 2] = [forward(1), recurrent(1)];
for new_gene in &new_genes {
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| net.add_non_neuron(NeuronId::new(0), new_gene.clone()),
MutationError::InvalidJumperSource,
);
}
}
#[test]
fn test_mutate_invalid_forward_jumper() {
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| {
let new_gene: NonNeuronGene<_> = forward(0);
net.add_non_neuron(NeuronId::new(0), new_gene)
},
MutationError::InvalidForwardJumper,
);
run_invalid_mutation_test(
vec![neuron(0, 2), neuron(1, 1), input(1), input(0)],
|net| {
let new_gene: NonNeuronGene<_> = forward(0);
net.add_non_neuron(NeuronId::new(1), new_gene)
},
MutationError::InvalidForwardJumper,
);
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| {
let inputs = vec![forward(net.next_neuron_id().as_usize())];
net.add_subnetwork(NeuronId::new(0), 1.0, inputs)
.map(|_| ())
},
MutationError::InvalidForwardJumper,
);
run_invalid_mutation_test(
vec![neuron(0, 1), neuron(1, 1), input(0)],
|net| {
let inputs = vec![forward(1)];
net.add_subnetwork(NeuronId::new(0), 1.0, inputs)
.map(|_| ())
},
MutationError::InvalidForwardJumper,
);
}
#[test]
fn test_mutate_empty() {
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| {
net.add_subnetwork(NeuronId::new(0), 1.0, vec![])
.map(|_| ())
},
MutationError::Empty,
);
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| net.add_non_neurons(NeuronId::new(0), vec![]),
MutationError::Empty,
);
}
#[test]
fn test_mutate_remove_invalid_index() {
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| net.remove_non_neuron(2).map(|_| ()),
MutationError::RemoveInvalidIndex,
);
}
#[test]
fn test_mutate_remove_neuron() {
run_invalid_mutation_test(
vec![neuron(0, 1), neuron(1, 1), input(0)],
|net| net.remove_non_neuron(1).map(|_| ()),
MutationError::RemoveNeuron,
);
}
#[test]
fn test_mutate_remove_only_input() {
run_invalid_mutation_test(
vec![neuron(0, 1), bias()],
|net| net.remove_non_neuron(1).map(|_| ()),
MutationError::RemoveOnlyInput,
);
run_invalid_mutation_test(
vec![neuron(0, 2), neuron(1, 1), input(0), bias()],
|net| net.remove_non_neuron(2).map(|_| ()),
MutationError::RemoveOnlyInput,
);
}
#[test]
fn test_mutate_arithmetic() {
run_invalid_mutation_test(
vec![neuron(usize::MAX - 1, 1), input(0)],
|net| {
let inputs = vec![bias()];
net.add_subnetwork(NeuronId::new(usize::MAX - 1), 1.0, inputs)
.map(|_| ())
},
MutationError::Arithmetic,
);
run_invalid_mutation_test(
vec![neuron(0, 1), input(0)],
|net| {
let new_gene: NonNeuronGene<_> = input(usize::MAX);
net.add_non_neuron(NeuronId::new(0), new_gene).map(|_| ())
},
MutationError::Arithmetic,
);
}
fn check_mutated_metadata(net: &mut Network<f64>) {
let mutated_neuron_info = net.neuron_info.clone();
let mutated_gene_parents = net.gene_parents.clone();
let mutated_recurrent_state_ids = net.recurrent_state_ids.clone();
let mutated_num_inputs = net.num_inputs();
let num_outputs = net.num_outputs;
let mutated_next_neuron_id = net.next_neuron_id();
assert_eq!(Ok(()), net.rebuild_metadata());
assert_eq!(net.neuron_info, mutated_neuron_info);
assert_eq!(net.gene_parents, mutated_gene_parents);
assert_eq!(net.recurrent_state_ids, mutated_recurrent_state_ids);
assert_eq!(net.num_inputs(), mutated_num_inputs);
assert_eq!(num_outputs, net.num_outputs());
assert_eq!(net.next_neuron_id(), mutated_next_neuron_id);
}
fn run_mutation_test<O, F: Fn(&mut Network<f64>) -> O>(
start_genome: Vec<Gene<f64>>,
mutate: F,
end_genome: Vec<Gene<f64>>,
expected_num_inputs: usize,
expected_next_neuron_id: NeuronId,
) {
let mut network = Network::new(start_genome, Activation::Linear).unwrap();
let old_num_outputs = network.num_outputs();
let _ = network.evaluate(&[1.0; 10]).unwrap().to_vec();
let _ = mutate(&mut network);
check_state_is_cleared(&network);
assert_eq!(end_genome, network.genome());
assert_eq!(expected_num_inputs, network.num_inputs());
assert_eq!(old_num_outputs, network.num_outputs());
assert_eq!(expected_next_neuron_id, network.next_neuron_id());
assert!(network.evaluate(&[1.0; 10]).is_ok());
check_mutated_metadata(&mut network);
}
#[test]
fn test_add_non_neuron() {
run_mutation_test(
vec![neuron(0, 1), neuron(1, 1), input(0)],
|net| {
let new_gene: NonNeuronGene<_> = input(1);
net.add_non_neuron(NeuronId::new(0), new_gene).unwrap();
},
vec![neuron(0, 2), input(1), neuron(1, 1), input(0)],
2,
NeuronId::new(2),
);
}
#[test]
fn test_add_non_neurons() {
run_mutation_test(
vec![neuron(0, 1), neuron(1, 1), input(0)],
|net| {
let new_genes = vec![bias(), input(0), forward(1), recurrent(0)];
net.add_non_neurons(NeuronId::new(0), new_genes).unwrap();
},
vec![
neuron(0, 5),
bias(),
input(0),
forward(1),
recurrent(0),
neuron(1, 1),
input(0),
],
1,
NeuronId::new(2),
);
}
#[test]
fn test_add_subnetwork() {
run_mutation_test(
vec![neuron(0, 1), neuron(1, 1), neuron(2, 1), input(0)],
|net| {
let new_genes = vec![
bias(),
input(1),
forward(2),
recurrent(net.next_neuron_id().as_usize()),
];
net.add_subnetwork(NeuronId::new(0), 1.0, new_genes)
.unwrap();
},
vec![
neuron(0, 2),
neuron(3, 4),
bias(),
input(1),
forward(2),
recurrent(3),
neuron(1, 1),
neuron(2, 1),
input(0),
],
2,
NeuronId::new(4),
);
run_mutation_test(
vec![
neuron(0, 1),
neuron(1, 1),
neuron(2, 1),
neuron(3, 1),
input(0),
],
|net| {
let new_genes = vec![
bias(),
input(1),
forward(3),
recurrent(net.next_neuron_id().as_usize()),
];
net.add_subnetwork(NeuronId::new(0), 1.0, new_genes)
.unwrap();
},
vec![
neuron(0, 2),
neuron(4, 4),
bias(),
input(1),
forward(3),
recurrent(4),
neuron(1, 1),
neuron(2, 1),
neuron(3, 1),
input(0),
],
2,
NeuronId::new(5),
);
}
#[test]
fn test_remove_non_neuron() {
run_mutation_test(
vec![neuron(0, 2), input(3), input(0)],
|net| assert_eq!(input::<Gene<_>>(3), net.remove_non_neuron(1).unwrap()),
vec![neuron(0, 1), input(0)],
1,
NeuronId::new(1),
);
run_mutation_test(
vec![neuron(0, 1), neuron(1, 2), input(1), input(0)],
|net| assert_eq!(input::<Gene<_>>(0), net.remove_non_neuron(3).unwrap()),
vec![neuron(0, 1), neuron(1, 1), input(1)],
2,
NeuronId::new(2),
);
run_mutation_test(
vec![neuron(0, 2), neuron(1, 1), input(3), recurrent(0)],
|net| assert_eq!(recurrent::<Gene<_>>(0), net.remove_non_neuron(3).unwrap()),
vec![neuron(0, 1), neuron(1, 1), input(3)],
4,
NeuronId::new(2),
);
run_mutation_test(
vec![neuron(0, 3), input(0), neuron(1, 1), bias(), bias()],
|net| assert_eq!(input::<Gene<_>>(0), net.remove_non_neuron(1).unwrap()),
vec![neuron(0, 2), neuron(1, 1), bias(), bias()],
0,
NeuronId::new(2),
);
}
#[test]
fn test_multiple_mutations() {
run_mutation_test(
vec![neuron(0, 1), bias()],
|net| {
let id = |g: NonNeuronGene<_>| g;
net.add_non_neuron(NeuronId::new(0), id(input(0))).unwrap();
net.add_non_neuron(NeuronId::new(0), id(input(1))).unwrap();
let subnetwork_1 = net
.add_subnetwork(NeuronId::new(0), 1.0, vec![recurrent(0), bias()])
.unwrap();
net.add_non_neuron(subnetwork_1, id(input(0))).unwrap();
let subnetwork_2 = net
.add_subnetwork(subnetwork_1, 1.0, vec![input(0), input(2)])
.unwrap();
let index = net
.genome()
.iter()
.enumerate()
.find(|(_, gene)| {
if let Gene::Input(input) = gene {
input.id() == InputId::new(2)
} else {
false
}
})
.unwrap()
.0;
net.remove_non_neuron(index).unwrap();
net.add_non_neuron(subnetwork_2, id(recurrent(subnetwork_1.as_usize())))
.unwrap();
net.add_non_neuron(subnetwork_1, id(bias())).unwrap();
net.add_non_neuron(NeuronId::new(0), id(forward(subnetwork_1.as_usize())))
.unwrap();
net.add_non_neuron(NeuronId::new(0), id(forward(subnetwork_2.as_usize())))
.unwrap();
},
vec![
neuron(0, 6),
forward(2),
forward(1),
neuron(1, 5),
bias(),
neuron(2, 2),
recurrent(1),
input(0),
input(0),
recurrent(0),
bias(),
input(1),
input(0),
bias(),
],
2,
NeuronId::new(3),
);
}
#[test]
fn test_get_valid_removals() {
let genome = vec![
neuron(0, 3),
input(0),
neuron(1, 1),
neuron(2, 2),
input(1),
neuron(3, 1),
bias(),
forward(2),
];
let mut net = Network::new(genome, Activation::Linear).unwrap();
assert_eq!(
&[1, 4, 7][..],
net.get_valid_removals().collect::<Vec<_>>().as_slice()
);
loop {
let removals = net.get_valid_removals().collect::<Vec<_>>();
if removals.is_empty() {
check_mutated_metadata(&mut net);
break;
}
net.remove_non_neuron(removals[0]).unwrap();
let _ = net.evaluate(&[2.0, 3.0]);
}
}
#[test]
fn test_get_valid_forward_jumper_sources() {
let genome = vec![
neuron(0, 2),
neuron(1, 1),
bias(),
neuron(2, 2),
neuron(3, 1),
neuron(4, 1),
bias(),
neuron(5, 1),
bias(),
];
let mut net = Network::new(genome, Activation::Linear).unwrap();
let parent_id = NeuronId::new(1);
let parent_depth = net[parent_id].depth();
let valid_sources = net
.get_valid_forward_jumper_sources(parent_depth)
.collect::<Vec<_>>();
for id in &[NeuronId::new(3), NeuronId::new(4), NeuronId::new(5)] {
assert!(valid_sources.contains(id));
}
assert_eq!(3, valid_sources.len());
for id in &valid_sources {
let forward: NonNeuronGene<_> = forward(id.as_usize());
assert!(net.add_non_neuron(parent_id, forward).is_ok());
}
for id in 0..net.neuron_info.len() {
if !valid_sources.contains(&NeuronId::new(id)) {
let forward: NonNeuronGene<_> = forward(id);
assert!(net.add_non_neuron(parent_id, forward).is_err());
}
}
}
fn get_random_weight() -> f64 {
(1e3 * rand::thread_rng().gen_range(-1.0f64..=1.0)).round() / 1e3
}
fn get_random_non_neuron(net: &mut Network<f64>) -> NonNeuronGene<f64> {
let mut rng = rand::thread_rng();
let mut ids = net.neuron_ids().collect::<Vec<_>>();
ids.push(net.next_neuron_id());
let mut gene: NonNeuronGene<_> = match rng.gen_range(0i32..=3) {
0 => bias(),
1 => input(rng.gen_range(0..10)),
2 => {
let source = ids.choose(&mut rng).unwrap();
forward(source.as_usize())
}
3 => {
let source = ids.choose(&mut rng).unwrap();
recurrent(source.as_usize())
}
_ => unreachable!(),
};
let weight = get_random_weight();
match &mut gene {
NonNeuronGene::Bias(g) => *g.mut_value() = weight,
NonNeuronGene::Input(g) => *g.mut_weight() = weight,
NonNeuronGene::ForwardJumper(g) => *g.mut_weight() = weight,
NonNeuronGene::RecurrentJumper(g) => *g.mut_weight() = weight,
}
gene
}
fn add_random_non_neuron(net: &mut Network<f64>, parent: NeuronId) {
let new_gene = get_random_non_neuron(net);
let _result = net.add_non_neuron(parent, new_gene);
}
fn add_random_non_neurons(net: &mut Network<f64>, parent: NeuronId) {
let mut rng = rand::thread_rng();
let count = rng.gen_range(0..=2);
let new_genes = (0..count).map(|_| get_random_non_neuron(net)).collect();
let _result = net.add_non_neurons(parent, new_genes);
}
fn add_random_subnetwork(net: &mut Network<f64>, parent: NeuronId) {
let mut rng = rand::thread_rng();
let num_inputs = rng.gen_range(0..=3);
let inputs = (0..num_inputs)
.map(|_| get_random_non_neuron(net))
.collect();
let _result = net.add_subnetwork(parent, get_random_weight(), inputs);
}
fn remove_random_gene(net: &mut Network<f64>) {
let mut rng = rand::thread_rng();
let index = (0..=net.genome().len()).choose(&mut rng).unwrap();
let _result = net.remove_non_neuron(index);
}
fn build_random_network(initial: Vec<Gene<f64>>) {
const MUTATION_COUNT: usize = 200;
let mut network = Network::new(initial, Activation::Linear).unwrap();
let initial_outputs = network.num_outputs();
let mut rng = rand::thread_rng();
for _ in 0..MUTATION_COUNT {
let parent = (0..=network.next_neuron_id().as_usize())
.choose(&mut rng)
.unwrap();
let parent = NeuronId::new(parent);
match rng.gen_range(0..=3) {
0 => add_random_non_neuron(&mut network, parent),
1 => add_random_non_neurons(&mut network, parent),
2 => add_random_subnetwork(&mut network, parent),
3 => remove_random_gene(&mut network),
_ => unreachable!(),
}
check_mutated_metadata(&mut network);
assert!(network.evaluate(&[1.0; 10]).is_ok());
network.clear_state();
}
assert_eq!(initial_outputs, network.num_outputs());
let string = network
.to_string(Metadata::new(None), (), WithRecurrentState(true))
.unwrap();
let (converted_network, _, _) =
Network::<f64>::load_str::<()>(&string, WithRecurrentState(true)).unwrap();
network.stack.clear();
network.clear_state();
assert_eq!(converted_network, network);
network.evaluate(&[1.0; 10]).unwrap();
let path = get_file_path(
"test_output",
&format!("random_{}_output_network.cge", network.num_outputs()),
);
network
.to_file(
Metadata::new(Some("A randomly-generated network.".into())),
(),
WithRecurrentState(true),
path,
true,
)
.unwrap();
}
#[test]
fn test_random() {
for _ in 0..10 {
build_random_network(vec![neuron(0, 1), input(0)]);
build_random_network(vec![neuron(0, 1), input(0), neuron(1, 1), input(1)]);
build_random_network(vec![
neuron(0, 1),
input(0),
neuron(1, 1),
input(1),
neuron(2, 1),
input(2),
]);
}
}
}