1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
use perestroika::{
errors::PerestroikaError, genome::Genome, genome_builder::GenomeBuilder, node_gene::DepthType,
};
use rand_chacha::ChaCha8Rng;
fn main() -> Result<(), PerestroikaError> {
// Create a Genome of 1x1x2:
// An input layer of 1 nodes;
// A hidden layer of 1 nodes;
// and an output layer of 2 nodes.
let mut genome: Genome<ChaCha8Rng> =
GenomeBuilder::new().with_shape(&vec![1, 1, 2])?.build()?;
// Let's mutate the Genome to have a connection.
genome.create_connection(0, 1, 0.666, true, 1, false)?;
// It is not possible to create a second connection over an existing connection in the same
// manner.
assert!(matches!(
genome.create_connection(0, 1, 0.999, false, 9000, false),
Err(PerestroikaError::ConnectionGeneAlreadyExists)
));
// Connections must go from a shallow to a deeper node, and never from and to the same layer.
assert!(matches!(
genome.create_connection(1, 0, 0.0, false, 0, false),
Err(PerestroikaError::ConnectionGeneSourceDeeperThanTarget {
source: DepthType::Hidden(1),
target: DepthType::Input,
})
));
assert!(matches!(
genome.create_connection(2, 3, 0.0, false, 0, false),
Err(PerestroikaError::ConnectionGeneSourceAndTargetSameDepth {
depth: DepthType::Output
})
));
// Let's create the two connections to the output nodes:
genome.create_connection(1, 2, 0.5, true, 1, false)?;
genome.create_connection(1, 3, 0.1, true, 2, false)?;
// At this point we have the following structure:
// _-> o
// o -> o <
// --> o
println!("{genome:#?}");
// Let's propagate some inputs through it and see what happens:
let p = genome.propagate(&vec![0.5])?;
// All Nodes have an Identity activation function, meaning they are simply passing by the value
// they receive.
// Therefore it is the Connections that make a difference.
//
// The input node (0) is getting an input 0.5, activating it keeps it at 0.5.
// The hidden node (1) pulls the 0.5 and multiplies it by 0.666, resulting in a mass of 0.333.
// Activating it results in the same value.
//
// Each of the output nodes pulls 0.333 and multiplies it by 0.5 for node 2, and 0.1 for node
// 3:
// the output then is [0.1665, 0.0333].
println!("Propagation results: {p:?}.");
assert_eq!(p, vec![0.1665, 0.0333]);
Ok(())
}