use rand::{seq::SliceRandom, Rng};
use crate::{genes::Connection, genome::Genome};
use super::{MutationError, MutationResult, Mutations};
impl Mutations {
pub fn add_connection(genome: &mut Genome, rng: &mut impl Rng) -> MutationResult {
let mut possible_start_nodes = genome
.inputs
.iter()
.chain(genome.hidden.iter())
.collect::<Vec<_>>();
possible_start_nodes.shuffle(rng);
let mut possible_end_nodes = genome
.hidden
.iter()
.chain(genome.outputs.iter())
.collect::<Vec<_>>();
possible_end_nodes.shuffle(rng);
for start_node in possible_start_nodes {
if let Some(end_node) = possible_end_nodes.iter().cloned().find(|&end_node| {
end_node != start_node
&& !genome.feed_forward.contains(&Connection::new(
start_node.id,
0.0,
end_node.id,
))
&& !genome.would_form_cycle(start_node, end_node)
}) {
assert!(genome.feed_forward.insert(Connection::new(
start_node.id,
Connection::weight_perturbation(0.0, 0.1, rng),
end_node.id,
)));
return Ok(());
}
}
Err(MutationError::CouldNotAddFeedForwardConnection)
}
}
#[cfg(test)]
mod tests {
use rand::thread_rng;
use crate::{Genome, MutationError, Mutations, Parameters};
#[test]
fn add_random_connection() {
let mut genome = Genome::uninitialized(&Parameters::default());
assert!(Mutations::add_connection(&mut genome, &mut thread_rng()).is_ok());
assert_eq!(genome.feed_forward.len(), 1);
}
#[test]
fn dont_add_same_connection_twice() {
let mut genome = Genome::uninitialized(&Parameters::default());
Mutations::add_connection(&mut genome, &mut thread_rng()).expect("add_connection");
if let Err(error) = Mutations::add_connection(&mut genome, &mut thread_rng()) {
assert_eq!(error, MutationError::CouldNotAddFeedForwardConnection);
} else {
unreachable!()
}
assert_eq!(genome.feed_forward.len(), 1);
}
}