use rand_distr::{Distribution, Normal};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpropSynapse {
pub weight: f32,
pub eligibility_trace: f32,
pub filtered_trace: f32,
pub tau_e: f32,
pub tau_slow: f32,
}
impl EpropSynapse {
pub fn new(initial_weight: f32, tau_e: f32) -> Self {
Self {
weight: initial_weight,
eligibility_trace: 0.0,
filtered_trace: 0.0,
tau_e,
tau_slow: tau_e * 2.0, }
}
pub fn update(
&mut self,
pre_spike: bool,
pseudo_derivative: f32,
learning_signal: f32,
dt: f32,
lr: f32,
) {
let decay_fast = (-dt / self.tau_e).exp();
let decay_slow = (-dt / self.tau_slow).exp();
self.eligibility_trace *= decay_fast;
self.filtered_trace *= decay_slow;
if pre_spike {
let trace_increment = pseudo_derivative;
self.eligibility_trace += trace_increment;
self.filtered_trace += trace_increment;
}
let weight_delta = lr * self.filtered_trace * learning_signal;
self.weight += weight_delta;
self.weight = self.weight.clamp(-10.0, 10.0);
}
pub fn reset_traces(&mut self) {
self.eligibility_trace = 0.0;
self.filtered_trace = 0.0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpropLIF {
pub membrane: f32,
pub threshold: f32,
pub tau_mem: f32,
pub refractory: u32,
pub refractory_period: u32,
pub v_rest: f32,
pub v_reset: f32,
}
impl EpropLIF {
pub fn new(v_rest: f32, threshold: f32, tau_mem: f32) -> Self {
Self {
membrane: v_rest,
threshold,
tau_mem,
refractory: 0,
refractory_period: 2, v_rest,
v_reset: v_rest,
}
}
pub fn step(&mut self, input: f32, dt: f32) -> (bool, f32) {
let mut spike = false;
let mut pseudo_derivative = 0.0;
if self.refractory > 0 {
self.refractory -= 1;
self.membrane = self.v_reset;
return (false, 0.0);
}
let decay = (-dt / self.tau_mem).exp();
self.membrane = self.membrane * decay + input * (1.0 - decay);
let distance = (self.membrane - self.threshold).abs();
pseudo_derivative = (1.0 - distance).max(0.0);
if self.membrane >= self.threshold {
spike = true;
self.membrane = self.v_reset;
self.refractory = self.refractory_period;
}
(spike, pseudo_derivative)
}
pub fn reset(&mut self) {
self.membrane = self.v_rest;
self.refractory = 0;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LearningSignal {
Symmetric(f32),
Random { feedback: Vec<f32> },
Adaptive { buffer: Vec<f32> },
}
impl LearningSignal {
pub fn compute(&self, neuron_idx: usize, error: f32) -> f32 {
match self {
LearningSignal::Symmetric(scale) => error * scale,
LearningSignal::Random { feedback } => {
if neuron_idx < feedback.len() {
error * feedback[neuron_idx]
} else {
0.0
}
}
LearningSignal::Adaptive { buffer } => {
if neuron_idx < buffer.len() {
error * buffer[neuron_idx]
} else {
0.0
}
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpropNetwork {
pub input_size: usize,
pub hidden_size: usize,
pub output_size: usize,
pub neurons: Vec<EpropLIF>,
pub input_synapses: Vec<Vec<EpropSynapse>>,
pub recurrent_synapses: Vec<Vec<EpropSynapse>>,
pub readout: Vec<Vec<f32>>,
pub learning_signal: LearningSignal,
spike_buffer: Vec<bool>,
pseudo_derivatives: Vec<f32>,
}
impl EpropNetwork {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
let mut rng = rand::thread_rng();
let neurons = (0..hidden_size)
.map(|_| EpropLIF::new(-70.0, -55.0, 20.0))
.collect();
let input_scale = (2.0 / input_size as f32).sqrt();
let normal = Normal::new(0.0, input_scale as f64).unwrap();
let input_synapses = (0..input_size)
.map(|_| {
(0..hidden_size)
.map(|_| {
let weight = normal.sample(&mut rng) as f32;
EpropSynapse::new(weight, 20.0)
})
.collect()
})
.collect();
let recurrent_scale = (1.0 / hidden_size as f32).sqrt();
let recurrent_normal = Normal::new(0.0, recurrent_scale as f64).unwrap();
let recurrent_synapses = (0..hidden_size)
.map(|_| {
(0..hidden_size)
.map(|_| {
let weight = recurrent_normal.sample(&mut rng) as f32;
EpropSynapse::new(weight, 20.0)
})
.collect()
})
.collect();
let readout_scale = (1.0 / hidden_size as f32).sqrt();
let readout_normal = Normal::new(0.0, readout_scale as f64).unwrap();
let readout = (0..hidden_size)
.map(|_| {
(0..output_size)
.map(|_| readout_normal.sample(&mut rng) as f32)
.collect()
})
.collect();
Self {
input_size,
hidden_size,
output_size,
neurons,
input_synapses,
recurrent_synapses,
readout,
learning_signal: LearningSignal::Symmetric(1.0),
spike_buffer: vec![false; hidden_size],
pseudo_derivatives: vec![0.0; hidden_size],
}
}
pub fn forward(&mut self, input: &[f32], dt: f32) -> Vec<f32> {
assert_eq!(input.len(), self.input_size, "Input size mismatch");
let mut currents = vec![0.0; self.hidden_size];
for (i, &inp) in input.iter().enumerate() {
if inp > 0.5 {
for (j, synapse) in self.input_synapses[i].iter().enumerate() {
currents[j] += synapse.weight;
}
}
}
for (i, &spike) in self.spike_buffer.iter().enumerate() {
if spike {
for (j, synapse) in self.recurrent_synapses[i].iter().enumerate() {
currents[j] += synapse.weight;
}
}
}
for (i, neuron) in self.neurons.iter_mut().enumerate() {
let (spike, pseudo_deriv) = neuron.step(currents[i], dt);
self.spike_buffer[i] = spike;
self.pseudo_derivatives[i] = pseudo_deriv;
}
let mut output = vec![0.0; self.output_size];
for (i, &spike) in self.spike_buffer.iter().enumerate() {
if spike {
for (j, weight) in self.readout[i].iter().enumerate() {
output[j] += weight;
}
}
}
output
}
pub fn backward(&mut self, error: &[f32], learning_rate: f32, dt: f32) {
assert_eq!(error.len(), self.output_size, "Error size mismatch");
let mut learning_signals = vec![0.0; self.hidden_size];
for i in 0..self.hidden_size {
let mut signal = 0.0;
for j in 0..self.output_size {
signal += error[j] * self.readout[i][j];
}
learning_signals[i] = self.learning_signal.compute(i, signal);
}
for i in 0..self.input_size {
for j in 0..self.hidden_size {
let pre_spike = false; self.input_synapses[i][j].update(
pre_spike,
self.pseudo_derivatives[j],
learning_signals[j],
dt,
learning_rate,
);
}
}
for i in 0..self.hidden_size {
for j in 0..self.hidden_size {
let pre_spike = self.spike_buffer[i];
self.recurrent_synapses[i][j].update(
pre_spike,
self.pseudo_derivatives[j],
learning_signals[j],
dt,
learning_rate,
);
}
}
for i in 0..self.hidden_size {
if self.spike_buffer[i] {
for j in 0..self.output_size {
self.readout[i][j] += learning_rate * error[j];
}
}
}
}
pub fn online_step(&mut self, input: &[f32], target: &[f32], dt: f32, lr: f32) {
let output = self.forward(input, dt);
let error: Vec<f32> = target
.iter()
.zip(output.iter())
.map(|(t, o)| t - o)
.collect();
self.backward(&error, lr, dt);
}
pub fn reset(&mut self) {
for neuron in &mut self.neurons {
neuron.reset();
}
for synapses in &mut self.input_synapses {
for synapse in synapses {
synapse.reset_traces();
}
}
for synapses in &mut self.recurrent_synapses {
for synapse in synapses {
synapse.reset_traces();
}
}
self.spike_buffer.fill(false);
self.pseudo_derivatives.fill(0.0);
}
pub fn num_synapses(&self) -> usize {
let input_synapses = self.input_size * self.hidden_size;
let recurrent_synapses = self.hidden_size * self.hidden_size;
let readout_synapses = self.hidden_size * self.output_size;
input_synapses + recurrent_synapses + readout_synapses
}
pub fn memory_footprint(&self) -> usize {
let synapse_size = std::mem::size_of::<EpropSynapse>();
let neuron_size = std::mem::size_of::<EpropLIF>();
let readout_size = std::mem::size_of::<f32>();
let input_mem = self.input_size * self.hidden_size * synapse_size;
let recurrent_mem = self.hidden_size * self.hidden_size * synapse_size;
let readout_mem = self.hidden_size * self.output_size * readout_size;
let neuron_mem = self.hidden_size * neuron_size;
input_mem + recurrent_mem + readout_mem + neuron_mem
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_synapse_creation() {
let synapse = EpropSynapse::new(0.5, 20.0);
assert_eq!(synapse.weight, 0.5);
assert_eq!(synapse.eligibility_trace, 0.0);
assert_eq!(synapse.tau_e, 20.0);
}
#[test]
fn test_trace_decay() {
let mut synapse = EpropSynapse::new(0.5, 20.0);
synapse.eligibility_trace = 1.0;
synapse.update(false, 0.0, 0.0, 20.0, 0.0);
assert!((synapse.eligibility_trace - 0.368).abs() < 0.01);
}
#[test]
fn test_lif_spike_generation() {
let mut neuron = EpropLIF::new(-70.0, -55.0, 20.0);
for _ in 0..50 {
let (spike, _) = neuron.step(100.0, 1.0);
if spike {
assert_eq!(neuron.membrane, neuron.v_reset);
return;
}
}
panic!("Neuron did not spike with strong sustained input");
}
#[test]
fn test_lif_refractory_period() {
let mut neuron = EpropLIF::new(-70.0, -55.0, 20.0);
for _ in 0..50 {
let (spike, _) = neuron.step(100.0, 1.0);
if spike {
break;
}
}
let (spike2, _) = neuron.step(100.0, 1.0);
assert!(!spike2, "Should be in refractory period");
}
#[test]
fn test_pseudo_derivative() {
let mut neuron = EpropLIF::new(-70.0, -55.0, 20.0);
neuron.membrane = -55.5;
let (_, pseudo_deriv) = neuron.step(0.0, 1.0);
assert!(pseudo_deriv >= 0.0, "pseudo_deriv={}", pseudo_deriv);
}
#[test]
fn test_network_creation() {
let network = EpropNetwork::new(10, 100, 2);
assert_eq!(network.input_size, 10);
assert_eq!(network.hidden_size, 100);
assert_eq!(network.output_size, 2);
assert_eq!(network.neurons.len(), 100);
}
#[test]
fn test_network_forward() {
let mut network = EpropNetwork::new(10, 50, 2);
let input = vec![1.0; 10];
let output = network.forward(&input, 1.0);
assert_eq!(output.len(), 2);
}
#[test]
fn test_network_memory_footprint() {
let network = EpropNetwork::new(100, 500, 10);
let footprint = network.memory_footprint();
let num_synapses = network.num_synapses();
let bytes_per_synapse = footprint / num_synapses;
assert!(bytes_per_synapse >= 10 && bytes_per_synapse <= 20);
}
#[test]
fn test_online_learning() {
let mut network = EpropNetwork::new(10, 50, 2);
let input = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0];
let target = vec![1.0, 0.0];
for _ in 0..10 {
network.online_step(&input, &target, 1.0, 0.01);
}
}
#[test]
fn test_network_reset() {
let mut network = EpropNetwork::new(10, 50, 2);
let input = vec![1.0; 10];
network.forward(&input, 1.0);
network.reset();
for neuron in &network.neurons {
assert_eq!(neuron.membrane, neuron.v_rest);
assert_eq!(neuron.refractory, 0);
}
}
}