#[macro_export]
macro_rules! new_t {
($t:ty, $($k:ident = $v:expr),+ $(,)?) => {{
let mut c = <$t>::default();
$(c.$k = $v;)+
c
}};
($($k:ident = $v:expr),+ $(,)?) => {new_t!(T, $($k = $v,)+)};
}
#[macro_export]
macro_rules! assert_f64_approx {
($l:expr, $r:expr) => {
assert!(
($l - $r).abs() < f64::EPSILON,
"assertion failed: {} !~ {}",
$l,
$r
)
};
($l:expr, $r:expr, $msg:expr) => {
assert!(
($l - $r).abs() < f64::EPSILON,
"assertion failed: {} !~ {}: {}",
$l,
$r,
$msg
)
};
}
#[macro_export]
macro_rules! assert_matrix_approx {
($a:expr, $b:expr) => {
assert_eq!($a.len(), $b.len(), "Matrices have different lengths");
for (i, (l, r)) in $a.iter().zip($b.iter()).enumerate() {
$crate::assert_f64_approx!(l, r, format!("differs at [{i}]"));
}
};
}
#[macro_export]
macro_rules! normalized {
($x:expr; $({.$($norm:tt)+})+) => {{
let mut x = $x;
$(x.$($norm)*;)+
x
}};
}
#[macro_export]
macro_rules! assert_some_normalized {
($l:expr, [$($r:expr),* $(,)?]; $({.$($norm:tt)+})+, $msg: expr) => {{
let l = $crate::normalized!($l.to_owned(); $({.$($norm)* })+);
assert!([$($r,)*].into_iter().any(|r| l == $crate::normalized!(r.to_owned(); $({.$($norm)* })+)), "{}", $msg)
}};
($l:expr, [$($r:expr),* $(,)?]; $({.$($norm:tt)+})+) => {$crate::assert_some_normalized!($l, [$($r,)*]; $({.$($norm)* })+, format!("{:?} not in {:?}", $l, [$($r,)*]))};
}
#[macro_export]
macro_rules! mutate_param {
([$($evt:ident),+]: [$($prob:expr),+]) => {
::paste::paste! {
fn mutate_param(&mut self, rng: &mut impl rand::RngCore) {
use $crate::random::EventKind;
use rand::Rng;
$crate::events!(Param[$($evt),*]);
const PARAM_PROBABILITIES: [u64; ParamEvent::COUNT] = [$($prob),*];
if let Some(evt) = ParamEvent::pick(rng, PARAM_PROBABILITIES) {
let replace = rng.next_u64() < Self::PARAM_REPLACE_PROBABILITY;
let v: f64 = rng.sample(rand_distr::Normal::new(0., Self::PARAM_STD).expect("PARAM_STD must be positive"));
match evt {
$(ParamEvent::[<$evt:camel>] => self.[<$evt:lower>] = if replace {
v
} else {
self.[<$evt:lower>] + ( Self::PARAM_PERTURB_FAC * v )
},)*
}
}
}
fn param_diff(&self, other: &Self) -> f64 {
[$((self.[<$evt:lower>] - other.[<$evt:lower>])),*].iter().sum()
}
}
};
}
#[macro_export]
macro_rules! count {
($_:ident) => {
1
};
($_:ident, $($remain:ident),+) => {
1+$crate::count!($($remain),+)
};
}
#[macro_export]
macro_rules! iota {
(@inner $t:ty, $name:ident, $value:expr, $($rest:ident, $new_value:expr),*) => {
const $name: $t = $value;
$crate::iota!(@inner $t, $($rest, $value + 1),*);
};
(@inner $t:ty, $name:ident, $value:expr) => {
const $name: $t = $value;
};
($t:ty, $($name:ident,)* $(,)?) => {
$crate::iota!(@inner $t, $($name, 0),*);
};
}
#[macro_export]
macro_rules! events {
($scope:ident[$($evt:ident),+]) => {
::paste::paste! {
#[derive(Debug, Clone, Copy)]
pub enum [<$scope Event>] {
$($evt,)*
}
impl $crate::random::EventKind for [<$scope Event>] {
const COUNT: usize = $crate::count!($($evt),+);
fn variants() -> [Self; Self::COUNT] {
[$(Self::$evt),*]
}
fn idx(&self) -> usize {
$crate::iota!(usize, $([<$evt:snake:upper _IDX>],)*);
match self {
$(Self::$evt => [<$evt:snake:upper _IDX>],)*
}
}
}
}
};
}
#[cfg(test)]
mod test {
use crate::genome::{connection::BWConnection, Genome, Recurrent, WConnection};
use crate::network::{Continuous, FromGenome, NonBias};
use eevee_macros::fn_matrix;
// Test 1: Simple non-parametrized types
fn_matrix! {
G: Recurrent<WConnection>,
NN: Continuous | NonBias,
fn test_func() -> NN {
let (g, _) = G::new(2, 2);
NN::from_genome(&g)
}
}
// Test 2: Parametrized generic with forward reference (C substituted into Recurrent<C>)
fn_matrix! {
C: WConnection | BWConnection,
G: Recurrent<C>,
NN: Continuous,
fn test_func_param() -> NN {
let (g, _) = G::new(2, 2);
NN::from_genome(&g)
}
}
#[test]
fn test_matrix_expansion() {
// Test 1: Simple non-parametrized types (G with concrete type, NN generic)
let _ = test_func_recurrent_continuous();
let _ = test_func_recurrent_nonbias();
// Test 2: Parametrized generic (C substituted into G: Recurrent<C>, NN concrete)
let _ = test_func_param_wconnection_recurrent_continuous();
let _ = test_func_param_bwconnection_recurrent_continuous();
}
}