use ghostflow_core::Tensor;
use crate::deep::layers::Dense;
pub struct LIFNeuron {
threshold: f32,
decay: f32,
reset_potential: f32,
membrane_potential: Vec<f32>,
}
impl LIFNeuron {
pub fn new(num_neurons: usize, threshold: f32, decay: f32) -> Self {
LIFNeuron {
threshold,
decay,
reset_potential: 0.0,
membrane_potential: vec![0.0f32; num_neurons],
}
}
pub fn forward(&mut self, input: &Tensor, dt: f32) -> Tensor {
let batch_size = input.dims()[0];
let num_neurons = input.dims()[1];
let input_data = input.data_f32();
let mut spikes = Vec::new();
for b in 0..batch_size {
for n in 0..num_neurons {
let input_current = input_data[b * num_neurons + n];
self.membrane_potential[n] = self.membrane_potential[n] * (1.0 - self.decay * dt)
+ input_current * dt;
if self.membrane_potential[n] >= self.threshold {
spikes.push(1.0f32);
self.membrane_potential[n] = self.reset_potential;
} else {
spikes.push(0.0f32);
}
}
}
Tensor::from_slice(&spikes, &[batch_size, num_neurons]).unwrap()
}
pub fn reset(&mut self) {
for v in &mut self.membrane_potential {
*v = self.reset_potential;
}
}
}
pub struct IzhikevichNeuron {
a: f32,
b: f32,
c: f32,
d: f32,
v: Vec<f32>, u: Vec<f32>, }
impl IzhikevichNeuron {
pub fn new(num_neurons: usize, a: f32, b: f32, c: f32, d: f32) -> Self {
IzhikevichNeuron {
a,
b,
c,
d,
v: vec![-65.0f32; num_neurons],
u: vec![b * -65.0; num_neurons],
}
}
pub fn forward(&mut self, input: &Tensor, dt: f32) -> Tensor {
let batch_size = input.dims()[0];
let num_neurons = input.dims()[1];
let input_data = input.data_f32();
let mut spikes = Vec::new();
for b in 0..batch_size {
for n in 0..num_neurons {
let i = input_data[b * num_neurons + n];
let dv = (0.04 * self.v[n] * self.v[n] + 5.0 * self.v[n] + 140.0 - self.u[n] + i) * dt;
let du = self.a * (self.b * self.v[n] - self.u[n]) * dt;
self.v[n] += dv;
self.u[n] += du;
if self.v[n] >= 30.0 {
spikes.push(1.0f32);
self.v[n] = self.c;
self.u[n] += self.d;
} else {
spikes.push(0.0f32);
}
}
}
Tensor::from_slice(&spikes, &[batch_size, num_neurons]).unwrap()
}
}
pub struct SpikingLayer {
weights: Dense,
neurons: LIFNeuron,
}
impl SpikingLayer {
pub fn new(input_size: usize, output_size: usize, threshold: f32, decay: f32) -> Self {
SpikingLayer {
weights: Dense::new(input_size, output_size),
neurons: LIFNeuron::new(output_size, threshold, decay),
}
}
pub fn forward(&mut self, spikes: &Tensor, dt: f32, training: bool) -> Tensor {
let weighted_input = self.weights.forward(spikes, training);
self.neurons.forward(&weighted_input, dt)
}
pub fn reset(&mut self) {
self.neurons.reset();
}
}
pub struct SpikingNeuralNetwork {
layers: Vec<SpikingLayer>,
num_timesteps: usize,
}
impl SpikingNeuralNetwork {
pub fn new(layer_sizes: Vec<usize>, threshold: f32, decay: f32, num_timesteps: usize) -> Self {
let mut layers = Vec::new();
for i in 0..layer_sizes.len() - 1 {
layers.push(SpikingLayer::new(layer_sizes[i], layer_sizes[i + 1], threshold, decay));
}
SpikingNeuralNetwork {
layers,
num_timesteps,
}
}
pub fn forward(&mut self, input: &Tensor, dt: f32, training: bool) -> Tensor {
for layer in &mut self.layers {
layer.reset();
}
let mut spike_counts = vec![0.0f32; input.dims()[0] * self.layers.last().unwrap().neurons.membrane_potential.len()];
for _ in 0..self.num_timesteps {
let mut spikes = input.clone();
for layer in &mut self.layers {
spikes = layer.forward(&spikes, dt, training);
}
let spike_data = spikes.data_f32();
for (i, &s) in spike_data.iter().enumerate() {
spike_counts[i] += s;
}
}
Tensor::from_slice(&spike_counts, spikes.dims()).unwrap()
}
}
pub struct STDPLayer {
weights: Vec<Vec<f32>>,
pre_spike_times: Vec<f32>,
post_spike_times: Vec<f32>,
a_plus: f32,
a_minus: f32,
tau_plus: f32,
tau_minus: f32,
input_size: usize,
output_size: usize,
}
impl STDPLayer {
pub fn new(input_size: usize, output_size: usize) -> Self {
use rand::prelude::*;
let mut rng = thread_rng();
let mut weights = Vec::new();
for _ in 0..output_size {
let row: Vec<f32> = (0..input_size)
.map(|_| rng.gen::<f32>() * 0.1)
.collect();
weights.push(row);
}
STDPLayer {
weights,
pre_spike_times: vec![0.0f32; input_size],
post_spike_times: vec![0.0f32; output_size],
a_plus: 0.01,
a_minus: 0.01,
tau_plus: 20.0,
tau_minus: 20.0,
input_size,
output_size,
}
}
pub fn forward(&self, input_spikes: &Tensor) -> Tensor {
let batch_size = input_spikes.dims()[0];
let spike_data = input_spikes.data_f32();
let mut output = vec![0.0f32; batch_size * self.output_size];
for b in 0..batch_size {
for j in 0..self.output_size {
let mut sum = 0.0f32;
for i in 0..self.input_size {
let spike = spike_data[b * self.input_size + i];
sum += self.weights[j][i] * spike;
}
output[b * self.output_size + j] = sum;
}
}
Tensor::from_slice(&output, &[batch_size, self.output_size]).unwrap()
}
pub fn update_weights(&mut self, pre_spikes: &[f32], post_spikes: &[f32], current_time: f32) {
for (i, &spike) in pre_spikes.iter().enumerate() {
if spike > 0.5 {
self.pre_spike_times[i] = current_time;
}
}
for (j, &spike) in post_spikes.iter().enumerate() {
if spike > 0.5 {
self.post_spike_times[j] = current_time;
}
}
for j in 0..self.output_size {
if post_spikes[j] > 0.5 {
for i in 0..self.input_size {
let dt = current_time - self.pre_spike_times[i];
if dt > 0.0 && dt < 100.0 {
let dw = self.a_plus * (-dt / self.tau_plus).exp();
self.weights[j][i] += dw;
self.weights[j][i] = self.weights[j][i].min(1.0);
}
}
}
}
for i in 0..self.input_size {
if pre_spikes[i] > 0.5 {
for j in 0..self.output_size {
let dt = current_time - self.post_spike_times[j];
if dt > 0.0 && dt < 100.0 {
let dw = -self.a_minus * (-dt / self.tau_minus).exp();
self.weights[j][i] += dw;
self.weights[j][i] = self.weights[j][i].max(0.0);
}
}
}
}
}
}
pub struct LiquidStateMachine {
reservoir: Vec<LIFNeuron>,
connections: Vec<Vec<f32>>,
readout: Dense,
num_neurons: usize,
}
impl LiquidStateMachine {
pub fn new(input_size: usize, reservoir_size: usize, output_size: usize, connectivity: f32) -> Self {
use rand::prelude::*;
let mut rng = thread_rng();
let mut reservoir = Vec::new();
for _ in 0..reservoir_size {
reservoir.push(LIFNeuron::new(1, 1.0, 0.1));
}
let mut connections = vec![vec![0.0f32; reservoir_size]; reservoir_size];
for i in 0..reservoir_size {
for j in 0..reservoir_size {
if rng.gen::<f32>() < connectivity {
connections[i][j] = rng.gen::<f32>() * 0.2 - 0.1;
}
}
}
LiquidStateMachine {
reservoir,
connections,
readout: Dense::new(reservoir_size, output_size),
num_neurons: reservoir_size,
}
}
pub fn forward(&mut self, input: &Tensor, num_steps: usize, dt: f32, training: bool) -> Tensor {
let batch_size = input.dims()[0];
let mut states = vec![0.0f32; batch_size * self.num_neurons];
for _ in 0..num_steps {
for n in 0..self.num_neurons {
let mut input_current = 0.0f32;
for m in 0..self.num_neurons {
input_current += self.connections[n][m] * states[m];
}
let input_tensor = Tensor::from_slice(&[input_current], &[1, 1]).unwrap();
let spike = self.reservoir[n].forward(&input_tensor, dt);
states[n] = spike.data_f32()[0];
}
}
let state_tensor = Tensor::from_slice(&states, &[batch_size, self.num_neurons]).unwrap();
self.readout.forward(&state_tensor, training)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lif_neuron() {
let mut neuron = LIFNeuron::new(10, 1.0, 0.1);
let input = Tensor::from_slice(&vec![0.5f32; 1 * 10], &[1, 10]).unwrap();
let spikes = neuron.forward(&input, 0.1);
assert_eq!(spikes.dims(), &[1, 10]);
}
#[test]
fn test_spiking_network() {
let mut snn = SpikingNeuralNetwork::new(vec![784, 256, 10], 1.0, 0.1, 100);
let input = Tensor::from_slice(&vec![0.5f32; 1 * 784], &[1, 784]).unwrap();
let output = snn.forward(&input, 0.1, false);
assert_eq!(output.dims()[1], 10);
}
#[test]
fn test_izhikevich_neuron() {
let mut neuron = IzhikevichNeuron::new(10, 0.02, 0.2, -65.0, 8.0);
let input = Tensor::from_slice(&vec![10.0f32; 1 * 10], &[1, 10]).unwrap();
let spikes = neuron.forward(&input, 0.5);
assert_eq!(spikes.dims(), &[1, 10]);
}
}