rsnn-eta 0.1.0

Biologically-inspired ETA estimator using a Recurrent Spiking Neural Network with STDP learning
Documentation
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use serde::{Deserialize, Serialize};

use crate::config::NetworkConfig;
use crate::lif::LifNeuron;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Synapse {
    pub target: usize,
    pub weight: f64,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SnnNetwork {
    pub neurons: Vec<LifNeuron>,
    pub input_synapses: Vec<Vec<(usize, f64)>>,
    pub recurrent_synapses: Vec<Vec<Synapse>>,
    pub is_excitatory: Vec<bool>,
    pub input_dim: usize,
    pub num_neurons: usize,
}

impl SnnNetwork {
    pub fn new(input_dim: usize, config: &NetworkConfig, seed: u64) -> Self {
        let mut rng = SmallRng::seed_from_u64(seed);
        let n = config.num_neurons;

        let neurons: Vec<LifNeuron> = (0..n)
            .map(|_| {
                let log_tau = rng.gen_range(config.tau_min.ln()..config.tau_max.ln());
                LifNeuron::new(log_tau.exp())
            })
            .collect();

        let n_exc = (n as f64 * config.excitatory_frac) as usize;
        let is_excitatory: Vec<bool> = (0..n).map(|i| i < n_exc).collect();

        let input_synapses = Self::init_input_synapses(
            input_dim, n, config.input_sparsity, config.init_input_scale, &mut rng,
        );
        let recurrent_synapses = Self::init_recurrent_synapses(
            n, &is_excitatory, config.recurrent_sparsity, config.init_recurrent_scale, &mut rng,
        );

        Self { neurons, input_synapses, recurrent_synapses, is_excitatory, input_dim, num_neurons: n }
    }

    fn init_input_synapses(
        input_dim: usize, num_neurons: usize, sparsity: f64, scale: f64, rng: &mut SmallRng,
    ) -> Vec<Vec<(usize, f64)>> {
        (0..num_neurons)
            .map(|_| {
                let mut syns = Vec::new();
                for feat in 0..input_dim {
                    if rng.gen::<f64>() < sparsity {
                        syns.push((feat, rng.gen_range(-scale..scale)));
                    }
                }
                syns
            })
            .collect()
    }

    fn init_recurrent_synapses(
        num_neurons: usize, is_excitatory: &[bool], sparsity: f64, scale: f64, rng: &mut SmallRng,
    ) -> Vec<Vec<Synapse>> {
        (0..num_neurons)
            .map(|pre| {
                let mut syns = Vec::new();
                for post in 0..num_neurons {
                    if post != pre && rng.gen::<f64>() < sparsity {
                        let w = rng.gen_range(0.0..scale);
                        let weight = if is_excitatory[pre] { w } else { -w };
                        syns.push(Synapse { target: post, weight });
                    }
                }
                syns
            })
            .collect()
    }

    pub fn output_neuron(&self) -> usize {
        self.num_neurons - 1
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn default_config() -> NetworkConfig {
        NetworkConfig::default()
    }

    #[test]
    fn test_network_dimensions() {
        let cfg = default_config();
        let net = SnnNetwork::new(4, &cfg, 42);
        assert_eq!(net.neurons.len(), cfg.num_neurons);
        assert_eq!(net.input_synapses.len(), cfg.num_neurons);
        assert_eq!(net.recurrent_synapses.len(), cfg.num_neurons);
        assert_eq!(net.is_excitatory.len(), cfg.num_neurons);
        assert_eq!(net.input_dim, 4);
    }

    #[test]
    fn test_excitatory_inhibitory_ratio() {
        let cfg = default_config();
        let net = SnnNetwork::new(4, &cfg, 42);
        let n_exc = net.is_excitatory.iter().filter(|&&e| e).count();
        let expected = (cfg.num_neurons as f64 * cfg.excitatory_frac) as usize;
        assert_eq!(n_exc, expected);
    }

    #[test]
    fn test_no_self_connections() {
        let cfg = default_config();
        let net = SnnNetwork::new(4, &cfg, 42);
        for (pre, syns) in net.recurrent_synapses.iter().enumerate() {
            for syn in syns {
                assert_ne!(syn.target, pre, "self-connection found at neuron {pre}");
            }
        }
    }

    #[test]
    fn test_weight_sign_constraints() {
        let cfg = default_config();
        let net = SnnNetwork::new(4, &cfg, 42);
        for (pre, syns) in net.recurrent_synapses.iter().enumerate() {
            for syn in syns {
                if net.is_excitatory[pre] {
                    assert!(syn.weight >= 0.0, "excitatory neuron {pre} has negative weight");
                } else {
                    assert!(syn.weight <= 0.0, "inhibitory neuron {pre} has positive weight");
                }
            }
        }
    }

    #[test]
    fn test_input_sparsity_approximate() {
        let cfg = default_config();
        let net = SnnNetwork::new(4, &cfg, 42);
        let total_possible = cfg.num_neurons * 4;
        let actual: usize = net.input_synapses.iter().map(|s| s.len()).sum();
        let ratio = actual as f64 / total_possible as f64;
        assert!((ratio - cfg.input_sparsity).abs() < 0.15,
            "input sparsity {ratio} far from target {}", cfg.input_sparsity);
    }

    #[test]
    fn test_output_neuron_index() {
        let cfg = default_config();
        let net = SnnNetwork::new(4, &cfg, 42);
        assert_eq!(net.output_neuron(), cfg.num_neurons - 1);
    }
}