use std::{
collections::HashSet,
hash::{Hash, Hasher},
};
use crate::{
genes::{Connection, Genes, Id, Node},
parameters::Structure,
};
use rand::{rngs::SmallRng, seq::SliceRandom, thread_rng, Rng, SeedableRng};
use seahash::SeaHasher;
use serde::{Deserialize, Serialize};
mod compatibility_distance;
pub use compatibility_distance::CompatibilityDistance;
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Genome {
pub inputs: Genes<Node>,
pub hidden: Genes<Node>,
pub outputs: Genes<Node>,
pub feed_forward: Genes<Connection>,
pub recurrent: Genes<Connection>,
}
impl Genome {
pub fn new(structure: &Structure) -> Self {
let mut seed_hasher = SeaHasher::new();
structure.number_of_inputs.hash(&mut seed_hasher);
structure.number_of_outputs.hash(&mut seed_hasher);
structure.seed.hash(&mut seed_hasher);
let mut rng = SmallRng::seed_from_u64(seed_hasher.finish());
Genome {
inputs: (0..structure.number_of_inputs)
.map(|order| Node::input(Id(rng.gen::<u64>()), order))
.collect(),
outputs: (0..structure.number_of_outputs)
.map(|order| {
Node::output(Id(rng.gen::<u64>()), order, structure.outputs_activation)
})
.collect(),
..Default::default()
}
}
pub fn nodes(&self) -> impl Iterator<Item = &Node> {
self.inputs
.iter()
.chain(self.hidden.iter())
.chain(self.outputs.iter())
}
pub fn contains(&self, id: Id) -> bool {
let fake_node = &Node::input(id, 0);
self.inputs.contains(fake_node)
|| self.hidden.contains(fake_node)
|| self.outputs.contains(fake_node)
}
pub fn connections(&self) -> impl Iterator<Item = &Connection> {
self.feed_forward.iter().chain(self.recurrent.iter())
}
pub fn init(&mut self, structure: &Structure) {
let rng = &mut SmallRng::from_rng(thread_rng()).unwrap();
let mut possible_inputs = self.inputs.iter().collect::<Vec<_>>();
possible_inputs.shuffle(rng);
for input in possible_inputs.iter().take(
(structure.percent_of_connected_inputs * structure.number_of_inputs as f64).ceil()
as usize,
) {
for output in self.outputs.iter() {
assert!(self.feed_forward.insert(Connection::new(
input.id,
Connection::weight_perturbation(0.0, 0.1, rng),
output.id
)));
}
}
}
pub fn len(&self) -> usize {
self.feed_forward.len() + self.recurrent.len()
}
pub fn is_empty(&self) -> bool {
self.feed_forward.is_empty() && self.recurrent.is_empty()
}
pub fn cross_in(&self, other: &Self) -> Self {
let mut rng = SmallRng::from_rng(thread_rng()).unwrap();
let feed_forward = self.feed_forward.cross_in(&other.feed_forward, &mut rng);
let recurrent = self.recurrent.cross_in(&other.recurrent, &mut rng);
let hidden = self.hidden.cross_in(&other.hidden, &mut rng);
Genome {
feed_forward,
recurrent,
hidden,
inputs: self.inputs.clone(),
outputs: self.outputs.clone(),
..Default::default()
}
}
pub fn would_form_cycle(&self, start_node: &Node, end_node: &Node) -> bool {
let mut to_visit = vec![end_node.id];
let mut visited = HashSet::new();
while let Some(node) = to_visit.pop() {
if !visited.contains(&node) {
visited.insert(node);
for connection in self
.feed_forward
.iter()
.filter(|connection| connection.input == node)
{
if connection.output == start_node.id {
return true;
} else {
to_visit.push(connection.output)
}
}
}
}
false
}
pub fn has_alternative_input(&self, node: Id, exclude: Id) -> bool {
self.connections()
.filter(|connection| connection.output == node)
.any(|connection| connection.input != exclude)
}
pub fn has_alternative_output(&self, node: Id, exclude: Id) -> bool {
self.connections()
.filter(|connection| connection.input == node)
.any(|connection| connection.output != exclude)
}
pub fn dot(genome: &Self) -> String {
let mut dot = "digraph {\n".to_owned();
dot.push_str("\tgraph [splines=curved ranksep=8]\n");
dot.push_str("\tsubgraph cluster_inputs {\n");
dot.push_str("\t\tgraph [label=\"Inputs\"]\n");
dot.push_str("\t\tnode [color=\"#D6B656\", fillcolor=\"#FFF2CC\", style=\"filled\"]\n");
dot.push_str("\n");
for node in genome.inputs.iter() {
dot.push_str(&format!(
"\t\t{} [label={:?}];\n",
node.id.0, node.activation
));
}
dot.push_str("\t}\n");
dot.push_str("\tsubgraph hidden {\n");
dot.push_str("\t\tgraph [label=\"Hidden\" rank=\"same\"]\n");
dot.push_str("\t\tnode [color=\"#6C8EBF\", fillcolor=\"#DAE8FC\", style=\"filled\"]\n");
dot.push_str("\n");
for node in genome.hidden.iter() {
dot.push_str(&format!(
"\t\t{} [label={:?}];\n",
node.id.0, node.activation
));
}
dot.push_str("\t}\n");
dot.push_str("\tsubgraph cluster_outputs {\n");
dot.push_str("\t\tgraph [label=\"Outputs\" labelloc=\"b\"]\n");
dot.push_str("\t\tnode [color=\"#9673A6\", fillcolor=\"#E1D5E7\", style=\"filled\"]\n");
dot.push_str("\n");
for node in genome.outputs.iter() {
dot.push_str(&format!(
"\t\t{} [label={:?}];\n",
node.id.0, node.activation
));
}
dot.push_str("\t}\n");
dot.push_str("\n");
dot.push_str("\tsubgraph feedforward_connections {\n");
dot.push_str("\n");
for connection in genome.feed_forward.iter() {
dot.push_str(&format!(
"\t\t{0} -> {1} [label=\"\" arrowsize={3:?} penwidth={3:?} tooltip={2:?} labeltooltip={2:?}];\n",
connection.input.0,
connection.output.0,
connection.weight,
connection.weight.abs() * 0.95 + 0.05
));
}
dot.push_str("\t}\n");
dot.push_str("\tsubgraph recurrent_connections {\n");
dot.push_str("\t\tedge [color=\"#FF8000\"]\n");
dot.push_str("\n");
for connection in genome.recurrent.iter() {
dot.push_str(&format!(
"\t\t{0} -> {1} [label=\"\" arrowsize={3:?} penwidth={3:?} tooltip={2:?} labeltooltip={2:?}];\n",
connection.input.0,
connection.output.0,
connection.weight,
connection.weight.abs() * 0.95 + 0.05
));
}
dot.push_str("\t}\n");
dot.push_str("}\n");
dot
}
}
#[cfg(test)]
mod tests {
use std::hash::{Hash, Hasher};
use rand::thread_rng;
use seahash::SeaHasher;
use super::Genome;
use crate::{
genes::{Activation, Connection, Genes, Id, Node},
Mutations, Parameters, Structure,
};
#[test]
fn find_alternative_input() {
let genome = Genome {
inputs: Genes(
vec![Node::input(Id(0), 0), Node::input(Id(1), 1)]
.iter()
.cloned()
.collect(),
),
outputs: Genes(
vec![Node::output(Id(2), 0, Activation::Linear)]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![
Connection::new(Id(0), 1.0, Id(2)),
Connection::new(Id(1), 1.0, Id(2)),
]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
assert!(genome.has_alternative_input(Id(2), Id(1)))
}
#[test]
fn find_no_alternative_input() {
let genome = Genome {
inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()),
outputs: Genes(
vec![Node::output(Id(1), 0, Activation::Linear)]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![Connection::new(Id(0), 1.0, Id(1))]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
assert!(!genome.has_alternative_input(Id(1), Id(0)))
}
#[test]
fn find_alternative_output() {
let genome = Genome {
inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()),
outputs: Genes(
vec![
Node::output(Id(2), 0, Activation::Linear),
Node::output(Id(1), 0, Activation::Linear),
]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![
Connection::new(Id(0), 1.0, Id(1)),
Connection::new(Id(0), 1.0, Id(2)),
]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
assert!(genome.has_alternative_output(Id(0), Id(1)))
}
#[test]
fn find_no_alternative_output() {
let genome = Genome {
inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()),
outputs: Genes(
vec![Node::output(Id(1), 0, Activation::Linear)]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![Connection::new(Id(0), 1.0, Id(1))]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
assert!(!genome.has_alternative_output(Id(0), Id(1)))
}
#[test]
fn crossover() {
let parameters = Parameters::default();
let mut genome_0 = Genome::initialized(¶meters);
let mut genome_1 = Genome::initialized(¶meters);
let rng = &mut thread_rng();
Mutations::add_node(&Activation::all(), &mut genome_0, rng);
Mutations::add_node(&Activation::all(), &mut genome_1, rng);
Mutations::add_node(&Activation::all(), &mut genome_1, rng);
let offspring = genome_0.cross_in(&genome_1);
assert_eq!(offspring.hidden.len(), 1);
assert_eq!(offspring.feed_forward.len(), 3);
}
#[test]
fn detect_no_cycle() {
let parameters = Parameters::default();
let genome = Genome::initialized(¶meters);
let input = genome.inputs.iter().next().unwrap();
let output = genome.outputs.iter().next().unwrap();
assert!(!genome.would_form_cycle(&input, &output));
}
#[test]
fn detect_cycle() {
let parameters = Parameters::default();
let genome = Genome::initialized(¶meters);
let input = genome.inputs.iter().next().unwrap();
let output = genome.outputs.iter().next().unwrap();
assert!(genome.would_form_cycle(&output, &input));
}
#[test]
fn crossover_no_cycle() {
let mut genome_0 = Genome {
inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()),
outputs: Genes(
vec![Node::output(Id(1), 0, Activation::Linear)]
.iter()
.cloned()
.collect(),
),
hidden: Genes(
vec![
Node::hidden(Id(2), Activation::Tanh),
Node::hidden(Id(3), Activation::Tanh),
]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![
Connection::new(Id(0), 1.0, Id(2)),
Connection::new(Id(2), 1.0, Id(1)),
Connection::new(Id(0), 1.0, Id(3)),
Connection::new(Id(3), 1.0, Id(1)),
]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
let mut genome_1 = genome_0.clone();
genome_0
.feed_forward
.insert(Connection::new(Id(2), 1.0, Id(3)));
genome_1
.feed_forward
.insert(Connection::new(Id(3), 1.0, Id(2)));
let offspring = genome_0.cross_in(&genome_1);
for connection0 in offspring.feed_forward.iter() {
for connection1 in offspring.feed_forward.iter() {
assert!(
!(connection0.input == connection1.output
&& connection0.output == connection1.input)
)
}
}
}
#[test]
fn hash_genome() {
let genome_0 = Genome {
inputs: Genes(
vec![Node::input(Id(1), 0), Node::input(Id(0), 0)]
.iter()
.cloned()
.collect(),
),
outputs: Genes(
vec![Node::output(Id(2), 0, Activation::Linear)]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![Connection::new(Id(0), 1.0, Id(1))]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
let genome_1 = Genome {
inputs: Genes(
vec![Node::input(Id(0), 0), Node::input(Id(1), 0)]
.iter()
.cloned()
.collect(),
),
outputs: Genes(
vec![Node::output(Id(2), 0, Activation::Linear)]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![Connection::new(Id(0), 1.0, Id(1))]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
assert_eq!(genome_0, genome_1);
let mut hasher = SeaHasher::new();
genome_0.hash(&mut hasher);
let genome_0_hash = hasher.finish();
let mut hasher = SeaHasher::new();
genome_1.hash(&mut hasher);
let genome_1_hash = hasher.finish();
assert_eq!(genome_0_hash, genome_1_hash);
}
#[test]
fn create_dot_from_genome() {
let genome = Genome {
inputs: Genes(vec![Node::input(Id(0), 0)].iter().cloned().collect()),
outputs: Genes(
vec![Node::output(Id(1), 0, Activation::Linear)]
.iter()
.cloned()
.collect(),
),
hidden: Genes(
vec![Node::hidden(Id(2), Activation::Tanh)]
.iter()
.cloned()
.collect(),
),
feed_forward: Genes(
vec![
Connection::new(Id(0), 0.25795942718883524, Id(2)),
Connection::new(Id(2), -0.09736946507786626, Id(1)),
]
.iter()
.cloned()
.collect(),
),
recurrent: Genes(
vec![Connection::new(Id(1), 0.19777863112749228, Id(2))]
.iter()
.cloned()
.collect(),
),
..Default::default()
};
let dot = "digraph {
\tgraph [splines=curved ranksep=8]
\tsubgraph cluster_inputs {
\t\tgraph [label=\"Inputs\"]
\t\tnode [color=\"#D6B656\", fillcolor=\"#FFF2CC\", style=\"filled\"]
\t\t0 [label=Linear];
\t}
\tsubgraph hidden {
\t\tgraph [label=\"Hidden\" rank=\"same\"]
\t\tnode [color=\"#6C8EBF\", fillcolor=\"#DAE8FC\", style=\"filled\"]
\t\t2 [label=Tanh];
\t}
\tsubgraph cluster_outputs {
\t\tgraph [label=\"Outputs\" labelloc=\"b\"]
\t\tnode [color=\"#9673A6\", fillcolor=\"#E1D5E7\", style=\"filled\"]
\t\t1 [label=Linear];
\t}
\tsubgraph feedforward_connections {
\t\t0 -> 2 [label=\"\" arrowsize=0.29506145582939347 penwidth=0.29506145582939347 tooltip=0.25795942718883524 labeltooltip=0.25795942718883524];
\t\t2 -> 1 [label=\"\" arrowsize=0.14250099182397294 penwidth=0.14250099182397294 tooltip=-0.09736946507786626 labeltooltip=-0.09736946507786626];
\t}
\tsubgraph recurrent_connections {
\t\tedge [color=\"#FF8000\"]
\t\t1 -> 2 [label=\"\" arrowsize=0.23788969957111766 penwidth=0.23788969957111766 tooltip=0.19777863112749228 labeltooltip=0.19777863112749228];
\t}
}
";
assert_eq!(&Genome::dot(&genome), dot)
}
#[test]
fn print_big_dot() {
let parameters = Parameters {
structure: Structure {
number_of_inputs: 10,
number_of_outputs: 10,
percent_of_connected_inputs: 0.2,
..Default::default()
},
mutations: vec![
Mutations::ChangeWeights {
chance: 1.0,
percent_perturbed: 0.5,
standard_deviation: 0.1,
},
Mutations::ChangeActivation {
chance: 0.05,
activation_pool: vec![
Activation::Linear,
Activation::Sigmoid,
Activation::Tanh,
Activation::Gaussian,
Activation::Step,
Activation::Sine,
Activation::Cosine,
Activation::Inverse,
Activation::Absolute,
Activation::Relu,
],
},
Mutations::AddNode {
chance: 0.005,
activation_pool: vec![
Activation::Linear,
Activation::Sigmoid,
Activation::Tanh,
Activation::Gaussian,
Activation::Step,
Activation::Sine,
Activation::Cosine,
Activation::Inverse,
Activation::Absolute,
Activation::Relu,
],
},
Mutations::AddConnection { chance: 0.01 },
Mutations::AddRecurrentConnection { chance: 0.01 },
],
};
let mut genome = Genome::initialized(¶meters);
for _ in 0..1000 {
genome.mutate(¶meters);
}
print!("{}", Genome::dot(&genome));
}
}