use crate::serialize::SerializeFile;
use nalgebra as na;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
const SERIALIZER_ID: &str = "json-1";
impl<T: Serialize + for<'de> Deserialize<'de>> SerializeFile for T {
const SERIALIZER_ID: &'static str = SERIALIZER_ID;
fn to_str(&self) -> Result<String, Box<dyn std::error::Error>> {
Ok(serde_json::to_string(self)?)
}
fn from_str(s: &str) -> Result<Self, Box<dyn std::error::Error>> {
serde_json::from_str(s).map_err(|e| e.into())
}
}
fn serialize_matrix<S: Serializer>(matrix: &na::DMatrix<f64>, ser: S) -> Result<S::Ok, S::Error> {
let bits: Vec<u64> = matrix.as_slice().iter().map(|&f| f64::to_bits(f)).collect();
bits.serialize(ser)
}
fn deserialize_matrix_flat<'de, D: Deserializer<'de>>(de: D) -> Result<na::DMatrix<f64>, D::Error> {
Vec::<u64>::deserialize(de).map(|v| {
let float_data: Vec<f64> = v.into_iter().map(f64::from_bits).collect();
na::DMatrix::from_vec(1, float_data.len(), float_data)
})
}
fn deserialize_matrix_square<'de, D: Deserializer<'de>>(
de: D,
) -> Result<na::DMatrix<f64>, D::Error> {
Vec::<u64>::deserialize(de).map(|v| {
let float_data: Vec<f64> = v.into_iter().map(f64::from_bits).collect();
let n = (float_data.len() as f64).sqrt() as usize;
debug_assert_eq!(n * n, float_data.len(), "non-square weight vec");
na::DMatrix::from_vec(n, n, float_data)
})
}
fn deserialize_connections<'de, C: crate::Connection + Deserialize<'de>, D: Deserializer<'de>>(
de: D,
) -> Result<Vec<C>, D::Error> {
Vec::<C>::deserialize(de)
}
macro_rules! json_impl {
(
$(#[$mod_attr:meta])*
use $use_path:path;
$Type:ident {
$($(#[$attr:meta])* $field:ident : $ftype:ty),* $(,)?
}
$($extra:item)*
) => {
::paste::paste! {
$(#[$mod_attr])*
mod [<$Type:snake>] {
use super::*;
use $use_path;
#[derive(Serialize)]
struct Ref<'a> {
$($(#[$attr])* $field: &'a $ftype,)*
}
#[derive(Deserialize)]
struct Data {
$($(#[$attr])* $field: $ftype,)*
}
impl Serialize for $Type {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
Ref { $($field: &self.$field,)* }.serialize(s)
}
}
impl<'de> Deserialize<'de> for $Type {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let v = Data::deserialize(d)?;
Ok($Type { $($field: v.$field,)* })
}
}
$($extra)*
}
}
};
(
$(#[$mod_attr:meta])*
use $use_path:path;
$Type:ident < $GP:ident : $Bound:path > {
$($(#[$attr:meta])* $field:ident : $ftype:ty),* $(,)?
}
$($extra:item)*
) => {
::paste::paste! {
$(#[$mod_attr])*
mod [<$Type:snake>] {
use super::*;
use $use_path;
#[derive(Serialize)]
struct Ref<'a, $GP: $Bound + Serialize> {
$($(#[$attr])* $field: &'a $ftype,)*
}
#[derive(Deserialize)]
struct Data<$GP: $Bound + for<'de2> Deserialize<'de2>> {
$($(#[$attr])* $field: $ftype,)*
}
impl<$GP: $Bound + Serialize> Serialize for $Type<$GP> {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
Ref { $($field: &self.$field,)* }.serialize(s)
}
}
impl<'de, $GP: $Bound + for<'de2> Deserialize<'de2>> Deserialize<'de>
for $Type<$GP>
{
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let v = Data::deserialize(d)?;
Ok($Type { $($field: v.$field,)* })
}
}
$($extra)*
}
}
};
}
json_impl! {
use crate::network::Continuous;
Continuous {
#[serde(serialize_with = "serialize_matrix", deserialize_with = "deserialize_matrix_flat")]
y: na::DMatrix<f64>,
#[serde(serialize_with = "serialize_matrix", deserialize_with = "deserialize_matrix_flat")]
θ: na::DMatrix<f64>,
#[serde(serialize_with = "serialize_matrix", deserialize_with = "deserialize_matrix_flat")]
τ: na::DMatrix<f64>,
#[serde(serialize_with = "serialize_matrix", deserialize_with = "deserialize_matrix_square")]
w: na::DMatrix<f64>,
sensory: (usize, usize),
action: (usize, usize),
}
}
json_impl! {
use crate::network::NonBias;
NonBias {
#[serde(serialize_with = "serialize_matrix", deserialize_with = "deserialize_matrix_flat")]
y: na::DMatrix<f64>,
#[serde(serialize_with = "serialize_matrix", deserialize_with = "deserialize_matrix_square")]
w: na::DMatrix<f64>,
sensory: (usize, usize),
action: (usize, usize),
}
}
mod simple {
use super::*;
use crate::network::Simple;
#[derive(Serialize)]
struct Ref<'a, C: crate::Connection + Serialize> {
connections: &'a Vec<C>,
bias: &'a Vec<f64>,
}
#[derive(Deserialize)]
struct Data<C: crate::Connection + for<'de2> Deserialize<'de2>> {
#[serde(deserialize_with = "deserialize_connections")]
connections: Vec<C>,
bias: Vec<f64>,
}
impl<C: crate::Connection + Serialize> Serialize for Simple<C> {
fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
Ref {
connections: &self.connections,
bias: &self.bias,
}
.serialize(s)
}
}
impl<'de, C: crate::Connection + for<'de2> Deserialize<'de2>> Deserialize<'de> for Simple<C> {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let v = Data::deserialize(d)?;
let n = v.bias.len();
Ok(Simple {
connections: v.connections,
bias: v.bias,
state: vec![0.; n],
sensory: 0..0,
action: 0..0,
})
}
}
}
json_impl! {
use crate::genome::connection::WConnection;
WConnection {
inno: usize,
from: usize,
to: usize,
weight: f64,
enabled: bool,
}
}
json_impl! {
use crate::genome::connection::BWConnection;
BWConnection {
inno: usize,
from: usize,
to: usize,
bias: f64,
weight: f64,
enabled: bool,
}
}
json_impl! {
use crate::genome::recurrent::Recurrent;
Recurrent<C: crate::Connection> {
sensory: usize,
action: usize,
node_count: usize,
#[serde(deserialize_with = "deserialize_connections")]
connections: Vec<C>,
}
}
#[cfg(test)]
mod test {
use crate::{
activate, assert_matrix_approx,
network::{Continuous, Network},
random::default_rng,
SerializeFile,
};
use nalgebra as na;
use rand_distr::{Distribution, Uniform};
#[test]
fn test_ctrnn_behavioral_equivalence() {
let n = 10;
let mut rng = default_rng();
let dist = Uniform::new(-10f64, 10.).unwrap();
let y_data: Vec<f64> = (0..n).map(|_| dist.sample(&mut rng)).collect();
let θ_data: Vec<f64> = (0..n).map(|_| dist.sample(&mut rng)).collect();
let τ_data: Vec<f64> = (0..n).map(|_| dist.sample(&mut rng).abs() + 0.1).collect();
let w_data: Vec<f64> = (0..n * n).map(|_| dist.sample(&mut rng)).collect();
let mut original = Continuous {
y: na::DMatrix::from_row_slice(1, n, &y_data),
θ: na::DMatrix::from_row_slice(1, n, &θ_data),
τ: na::DMatrix::from_row_slice(1, n, &τ_data),
w: na::DMatrix::from_row_slice(n, n, &w_data),
sensory: (0, 2),
action: (3, 5),
};
let serialized = original.to_str().expect("Failed to serialize");
let mut deserialized = Continuous::from_str(&serialized).expect("Failed to deserialize");
for _ in 0..500 {
let input: Vec<f64> = (0..2).map(|_| dist.sample(&mut rng)).collect();
original.step(10, &input, activate::steep_sigmoid);
deserialized.step(10, &input, activate::steep_sigmoid);
assert_matrix_approx!(original.output(), deserialized.output());
}
}
}