use super::{FromGenome, Recurrent, Stateful};
use crate::{Connection, Genome, Network};
use nalgebra as na;
#[derive(Debug)]
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
derive(serde::Serialize, serde::Deserialize)
)]
pub struct Continuous {
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(with = "crate::serialize::dmatrix")
)]
pub y: na::DMatrix<f64>,
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(with = "crate::serialize::dmatrix")
)]
pub θ: na::DMatrix<f64>,
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(with = "crate::serialize::dmatrix")
)]
pub τ: na::DMatrix<f64>,
#[cfg_attr(
all(feature = "serialize", not(feature = "serialize_json")),
serde(with = "crate::serialize::dmatrix")
)]
pub w: na::DMatrix<f64>,
pub sensory: (usize, usize),
pub action: (usize, usize),
}
impl Network for Continuous {
fn step<F: Fn(f64) -> f64>(&mut self, prec: usize, input: &[f64], σ: F) {
let mut m_input = na::DMatrix::zeros(1, self.y.ncols());
m_input.as_mut_slice()[self.sensory.0..self.sensory.1].copy_from_slice(input);
let inv = 1. / (prec as f64);
let mut temp1 = na::DMatrix::zeros(1, self.y.ncols());
let mut temp2 = na::DMatrix::zeros(1, self.y.ncols());
for _ in 0..prec {
temp1.copy_from(&self.y);
temp1 += &self.θ;
for val in temp1.iter_mut() {
*val = σ(*val);
}
temp2.gemm(1.0, &temp1, &self.w, 0.0);
temp2 -= &self.y;
temp2 += &m_input;
temp2.component_mul_assign(&self.τ);
temp2 *= inv;
self.y += &temp2;
}
}
fn flush(&mut self) {
self.y = na::DMatrix::zeros(1, self.y.ncols());
}
fn output(&self) -> &[f64] {
&self.y.as_slice()[self.action.0..self.action.1]
}
}
impl Recurrent for Continuous {}
impl Stateful for Continuous {}
impl<C: Connection, G: Genome<C>> FromGenome<C, G> for Continuous {
fn from_genome(genome: &G) -> Self {
let cols = genome.node_count();
Self {
y: na::DMatrix::zeros(1, cols),
θ: na::DMatrix::zeros(1, cols),
τ: na::DMatrix::from_element(1, cols, 1.0),
w: {
let mut w = vec![0.; cols * cols];
for c in genome.connections().iter().filter(|c| c.enabled()) {
w[c.from() * cols + c.to()] = c.weight();
}
na::DMatrix::from_row_slice(cols, cols, &w)
},
sensory: (genome.sensory().start, genome.sensory().end),
action: (genome.action().start, genome.action().end),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
activate, assert_f64_approx, assert_matrix_approx,
genome::{self, InnoGen, WConnection},
random::default_rng,
};
use rand_distr::{num_traits::Float, Distribution, Uniform};
#[test]
#[cfg(feature = "serialize_json")]
fn test_ctrnn_serialization_deserialization() {
use crate::serialize::SerializeFile;
let n_neurons = 10;
let mut rng = default_rng();
let dist = Uniform::new(-10., 10.).unwrap();
let mut y_data = vec![0.0; n_neurons];
let mut theta_data = vec![0.0; n_neurons];
let mut tau_data = vec![0.0; n_neurons];
let mut w_data = vec![0.0; n_neurons * n_neurons];
for i in 0..n_neurons {
y_data[i] = dist.sample(&mut rng);
theta_data[i] = dist.sample(&mut rng);
tau_data[i] = dist.sample(&mut rng).abs() + 0.1;
for j in 0..n_neurons {
w_data[i * n_neurons + j] = dist.sample(&mut rng);
}
}
let original = Continuous {
y: na::DMatrix::from_row_slice(1, n_neurons, &y_data),
θ: na::DMatrix::from_row_slice(1, n_neurons, &theta_data),
τ: na::DMatrix::from_row_slice(1, n_neurons, &tau_data),
w: na::DMatrix::from_row_slice(n_neurons, n_neurons, &w_data),
sensory: (0, 2),
action: (3, 5),
};
let serialized = original.to_str().expect("Failed to serialize");
let deserialized = Continuous::from_str(&serialized).expect("Failed to deserialize");
assert_matrix_approx!(original.y.as_slice(), deserialized.y.as_slice());
assert_matrix_approx!(original.θ.as_slice(), deserialized.θ.as_slice());
assert_matrix_approx!(original.τ.as_slice(), deserialized.τ.as_slice());
assert_matrix_approx!(original.w.as_slice(), deserialized.w.as_slice());
assert_eq!(original.sensory, deserialized.sensory);
assert_eq!(original.action, deserialized.action);
}
#[test]
#[cfg(feature = "serialize_json")]
fn test_ctrnn_behavioral_equivalence() {
use crate::serialize::SerializeFile;
let n_neurons = 10;
let mut rng = default_rng();
let dist = Uniform::new(-10., 10.).unwrap();
let mut y_data = vec![0.0; n_neurons];
let mut θ_data = vec![0.0; n_neurons];
let mut τ_data = vec![0.0; n_neurons];
let mut w_data = vec![0.0; n_neurons * n_neurons];
for i in 0..n_neurons {
y_data[i] = dist.sample(&mut rng);
θ_data[i] = dist.sample(&mut rng);
τ_data[i] = dist.sample(&mut rng).abs() + 0.1;
for j in 0..n_neurons {
w_data[i * n_neurons + j] = dist.sample(&mut rng);
}
}
let mut original = Continuous {
y: na::DMatrix::from_row_slice(1, n_neurons, &y_data),
θ: na::DMatrix::from_row_slice(1, n_neurons, &θ_data),
τ: na::DMatrix::from_row_slice(1, n_neurons, &τ_data),
w: na::DMatrix::from_row_slice(n_neurons, n_neurons, &w_data),
sensory: (0, 2),
action: (3, 5),
};
let mut deserialized =
Continuous::from_str(&original.to_str().expect("Failed to serialize"))
.expect("Failed to deserialize");
let precision = 10;
let n_steps = 500;
for __ in 0..n_steps {
let input: Vec<f64> = (0..2).map(|_| dist.sample(&mut rng)).collect();
original.step(precision, &input, activate::steep_sigmoid);
deserialized.step(precision, &input, activate::steep_sigmoid);
let original_output = original.output();
let deserialized_output = deserialized.output();
assert_matrix_approx!(original_output, deserialized_output);
}
}
#[test]
fn test_from_genome() {
type C = WConnection;
let mut inno = InnoGen::new(0);
let (mut genome, _) = genome::Recurrent::<C>::new(2, 2);
genome.push_connection(C::new(0, 3, &mut inno));
genome.push_connection(C::new(0, 1, &mut inno));
genome.push_connection(C::new(0, 1, &mut inno));
let nn = Continuous::from_genome(&genome);
for c in genome.connections() {
if c.enabled() {
assert_f64_approx!(nn.w[(c.from(), c.to())], c.weight());
}
}
for i in 0..genome.node_count() {
assert_f64_approx!(nn.θ[(0, i)], 0.)
}
assert_eq!(
(nn.sensory.0, nn.sensory.1),
(genome.sensory().start, genome.sensory().end)
);
assert_eq!(
(nn.action.0, nn.action.1),
(genome.action().start, genome.action().end)
);
}
}