neat 1.0.1

Crate for working with NEAT in rust
Documentation
/// Contains some builtin activation functions ([`sigmoid`], [`relu`], etc.)
pub mod builtin;

use bitflags::bitflags;
use builtin::*;

use genetic_rs::prelude::{rand, RngExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use lazy_static::lazy_static;
use std::{
    collections::HashMap,
    fmt,
    sync::{Arc, RwLock},
};

use crate::NeuronLocation;

/// Creates an [`ActivationFn`] object from a function
#[macro_export]
macro_rules! activation_fn {
    ($F: path) => {
        $crate::activation::ActivationFn::new(std::sync::Arc::new($F), $crate::activation::NeuronScope::default(), stringify!($F))
    };

    ($F: path, $S: expr) => {
        $crate::activation::ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F))
    };

    {$($F: path),*} => {
        [$(activation_fn!($F)),*]
    };

    {$($F: path => $S: expr),*} => {
        [$(activation_fn!($F, $S)),*]
    }
}

lazy_static! {
    /// A static activation registry for use in deserialization.
    pub(crate) static ref ACTIVATION_REGISTRY: Arc<RwLock<ActivationRegistry>> = Arc::new(RwLock::new(ActivationRegistry::default()));
}

/// Register an activation function to the registry.
pub fn register_activation(act: ActivationFn) {
    let mut reg = ACTIVATION_REGISTRY.write().unwrap();
    reg.register(act);
}

/// Registers multiple activation functions to the registry at once.
pub fn batch_register_activation(acts: impl IntoIterator<Item = ActivationFn>) {
    let mut reg = ACTIVATION_REGISTRY.write().unwrap();
    reg.batch_register(acts);
}

/// A registry of the different possible activation functions.
pub struct ActivationRegistry {
    /// The currently-registered activation functions.
    pub fns: HashMap<&'static str, ActivationFn>,
}

impl ActivationRegistry {
    /// Registers an activation function.
    pub fn register(&mut self, activation: ActivationFn) {
        self.fns.insert(activation.name, activation);
    }

    /// Registers multiple activation functions at once.
    pub fn batch_register(&mut self, activations: impl IntoIterator<Item = ActivationFn>) {
        for act in activations {
            self.register(act);
        }
    }

    /// Gets a Vec of all the activation functions registered. Use [fns][ActivationRegistry::fns] if you only need an iterator.
    pub fn activations(&self) -> Vec<ActivationFn> {
        self.fns.values().cloned().collect()
    }

    /// Gets all activation functions that are valid for a scope.
    pub fn activations_in_scope(&self, scope: NeuronScope) -> Vec<ActivationFn> {
        if scope == NeuronScope::NONE {
            return Vec::new();
        }

        let acts = self.activations();

        acts.into_iter()
            .filter(|a| a.scope.contains(scope))
            .collect()
    }

    /// Clears all existing values in the activation registry.
    pub fn clear(&mut self) {
        self.fns.clear();
    }

    /// Fetches a random activation fn that applies to the provided scope.
    ///
    /// # Panics
    ///
    /// Panics if there are no activation functions registered for the given scope.
    pub fn random_activation_in_scope(
        &self,
        scope: NeuronScope,
        rng: &mut impl rand::Rng,
    ) -> ActivationFn {
        let activations = self.activations_in_scope(scope);

        assert!(
            !activations.is_empty(),
            "no activation functions registered for scope {:?}",
            scope
        );

        activations[rng.random_range(0..activations.len())].clone()
    }
}

impl Default for ActivationRegistry {
    fn default() -> Self {
        let mut s = Self {
            fns: HashMap::new(),
        };

        // TODO add a way to disable this
        s.batch_register(activation_fn! {
            sigmoid => NeuronScope::HIDDEN | NeuronScope::OUTPUT,
            relu => NeuronScope::HIDDEN | NeuronScope::OUTPUT,
            linear_activation => NeuronScope::INPUT | NeuronScope::HIDDEN | NeuronScope::OUTPUT,
            f32::tanh => NeuronScope::HIDDEN | NeuronScope::OUTPUT
        });

        s
    }
}

/// A trait that represents an activation method.
pub trait Activation {
    /// The activation function.
    fn activate(&self, n: f32) -> f32;
}

impl<F: Fn(f32) -> f32> Activation for F {
    fn activate(&self, n: f32) -> f32 {
        (self)(n)
    }
}

/// An activation function object that implements [`fmt::Debug`] and is [`Send`]
#[derive(Clone)]
pub struct ActivationFn {
    /// The actual activation function.
    pub func: Arc<dyn Activation + Send + Sync>,

    /// The scope defining where the activation function can appear.
    pub scope: NeuronScope,

    /// The name of the activation function, used for debugging and serialization.
    pub name: &'static str,
}

impl ActivationFn {
    /// Creates a new ActivationFn object.
    pub fn new(
        func: Arc<dyn Activation + Send + Sync>,
        scope: NeuronScope,
        name: &'static str,
    ) -> Self {
        Self { func, name, scope }
    }
}

impl fmt::Debug for ActivationFn {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.name)
    }
}

impl PartialEq for ActivationFn {
    fn eq(&self, other: &Self) -> bool {
        self.name == other.name
    }
}

#[cfg(feature = "serde")]
impl Serialize for ActivationFn {
    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
        serializer.serialize_str(self.name)
    }
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for ActivationFn {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'a>,
    {
        let name = String::deserialize(deserializer)?;

        let reg = ACTIVATION_REGISTRY.read().unwrap();

        let f = reg.fns.get(name.as_str()).ok_or_else(|| {
            serde::de::Error::custom(format!("Activation function {name} not found"))
        })?;

        Ok(f.clone())
    }
}

bitflags! {
    /// Specifies where an activation function can occur
    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
    pub struct NeuronScope: u8 {
        /// Whether the activation can be applied to the input layer.
        const INPUT = 0b001;

        /// Whether the activation can be applied to the hidden layer.
        const HIDDEN = 0b010;

        /// Whether the activation can be applied to the output layer.
        const OUTPUT = 0b100;

        /// The activation function will not be randomly placed anywhere
        const NONE = 0b000;
    }
}

impl Default for NeuronScope {
    fn default() -> Self {
        Self::HIDDEN
    }
}

impl<L: AsRef<NeuronLocation>> From<L> for NeuronScope {
    fn from(value: L) -> Self {
        match value.as_ref() {
            NeuronLocation::Input(_) => Self::INPUT,
            NeuronLocation::Hidden(_) => Self::HIDDEN,
            NeuronLocation::Output(_) => Self::OUTPUT,
        }
    }
}