pub mod builtin;
use bitflags::bitflags;
use builtin::*;
use genetic_rs::prelude::{rand, RngExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use lazy_static::lazy_static;
use std::{
collections::HashMap,
fmt,
sync::{Arc, RwLock},
};
use crate::NeuronLocation;
#[macro_export]
macro_rules! activation_fn {
($F: path) => {
$crate::activation::ActivationFn::new(std::sync::Arc::new($F), $crate::activation::NeuronScope::default(), stringify!($F))
};
($F: path, $S: expr) => {
$crate::activation::ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F))
};
{$($F: path),*} => {
[$(activation_fn!($F)),*]
};
{$($F: path => $S: expr),*} => {
[$(activation_fn!($F, $S)),*]
}
}
lazy_static! {
pub(crate) static ref ACTIVATION_REGISTRY: Arc<RwLock<ActivationRegistry>> = Arc::new(RwLock::new(ActivationRegistry::default()));
}
pub fn register_activation(act: ActivationFn) {
let mut reg = ACTIVATION_REGISTRY.write().unwrap();
reg.register(act);
}
pub fn batch_register_activation(acts: impl IntoIterator<Item = ActivationFn>) {
let mut reg = ACTIVATION_REGISTRY.write().unwrap();
reg.batch_register(acts);
}
pub struct ActivationRegistry {
pub fns: HashMap<&'static str, ActivationFn>,
}
impl ActivationRegistry {
pub fn register(&mut self, activation: ActivationFn) {
self.fns.insert(activation.name, activation);
}
pub fn batch_register(&mut self, activations: impl IntoIterator<Item = ActivationFn>) {
for act in activations {
self.register(act);
}
}
pub fn activations(&self) -> Vec<ActivationFn> {
self.fns.values().cloned().collect()
}
pub fn activations_in_scope(&self, scope: NeuronScope) -> Vec<ActivationFn> {
if scope == NeuronScope::NONE {
return Vec::new();
}
let acts = self.activations();
acts.into_iter()
.filter(|a| a.scope.contains(scope))
.collect()
}
pub fn clear(&mut self) {
self.fns.clear();
}
pub fn random_activation_in_scope(
&self,
scope: NeuronScope,
rng: &mut impl rand::Rng,
) -> ActivationFn {
let activations = self.activations_in_scope(scope);
assert!(
!activations.is_empty(),
"no activation functions registered for scope {:?}",
scope
);
activations[rng.random_range(0..activations.len())].clone()
}
}
impl Default for ActivationRegistry {
fn default() -> Self {
let mut s = Self {
fns: HashMap::new(),
};
s.batch_register(activation_fn! {
sigmoid => NeuronScope::HIDDEN | NeuronScope::OUTPUT,
relu => NeuronScope::HIDDEN | NeuronScope::OUTPUT,
linear_activation => NeuronScope::INPUT | NeuronScope::HIDDEN | NeuronScope::OUTPUT,
f32::tanh => NeuronScope::HIDDEN | NeuronScope::OUTPUT
});
s
}
}
pub trait Activation {
fn activate(&self, n: f32) -> f32;
}
impl<F: Fn(f32) -> f32> Activation for F {
fn activate(&self, n: f32) -> f32 {
(self)(n)
}
}
#[derive(Clone)]
pub struct ActivationFn {
pub func: Arc<dyn Activation + Send + Sync>,
pub scope: NeuronScope,
pub name: &'static str,
}
impl ActivationFn {
pub fn new(
func: Arc<dyn Activation + Send + Sync>,
scope: NeuronScope,
name: &'static str,
) -> Self {
Self { func, name, scope }
}
}
impl fmt::Debug for ActivationFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}
impl PartialEq for ActivationFn {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
#[cfg(feature = "serde")]
impl Serialize for ActivationFn {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(self.name)
}
}
#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for ActivationFn {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'a>,
{
let name = String::deserialize(deserializer)?;
let reg = ACTIVATION_REGISTRY.read().unwrap();
let f = reg.fns.get(name.as_str()).ok_or_else(|| {
serde::de::Error::custom(format!("Activation function {name} not found"))
})?;
Ok(f.clone())
}
}
bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct NeuronScope: u8 {
const INPUT = 0b001;
const HIDDEN = 0b010;
const OUTPUT = 0b100;
const NONE = 0b000;
}
}
impl Default for NeuronScope {
fn default() -> Self {
Self::HIDDEN
}
}
impl<L: AsRef<NeuronLocation>> From<L> for NeuronScope {
fn from(value: L) -> Self {
match value.as_ref() {
NeuronLocation::Input(_) => Self::INPUT,
NeuronLocation::Hidden(_) => Self::HIDDEN,
NeuronLocation::Output(_) => Self::OUTPUT,
}
}
}