use trustformers_core::{
errors::{Result, TrustformersError},
layers::{LayerNorm, Linear},
tensor::{DType, Tensor},
traits::Layer,
};
use super::{
config::{BiologicalConfig, NeuronModel, PlasticityType},
model::BiologicalModelOutput,
};
#[derive(Debug, Clone)]
pub struct NeuronState {
pub v_mem: Tensor,
pub u_recovery: Option<Tensor>,
pub adaptation: Option<Tensor>,
pub refractory_time: Tensor,
pub spikes: Tensor,
}
#[derive(Debug, Clone)]
pub struct SynapticState {
pub weights: Tensor,
pub pre_traces: Tensor,
pub post_traces: Tensor,
pub eligibility: Tensor,
}
#[derive(Debug)]
pub struct SpikingLayer {
pub config: BiologicalConfig,
pub input_projection: Linear,
pub recurrent_projection: Linear,
pub neuron_states: Option<NeuronState>,
pub synaptic_states: Option<SynapticState>,
pub layer_norm: LayerNorm,
}
impl SpikingLayer {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let input_projection =
Linear::new(config.d_model, config.neurons_per_layer, config.use_bias);
let recurrent_projection =
Linear::new(config.neurons_per_layer, config.neurons_per_layer, false);
let layer_norm = LayerNorm::new(vec![config.neurons_per_layer], 1e-12)?;
Ok(Self {
config: config.clone(),
input_projection,
recurrent_projection,
neuron_states: None,
synaptic_states: None,
layer_norm,
})
}
pub fn init_states(&mut self, batch_size: usize) -> Result<()> {
let neurons_per_layer = self.config.neurons_per_layer;
let v_mem = Tensor::zeros(&[batch_size, neurons_per_layer])?;
let u_recovery = if matches!(self.config.neuron_model, NeuronModel::Izhikevich) {
Some(Tensor::zeros(&[batch_size, neurons_per_layer])?)
} else {
None
};
let adaptation = if matches!(self.config.neuron_model, NeuronModel::AdaptiveExponentialIF) {
Some(Tensor::zeros(&[batch_size, neurons_per_layer])?)
} else {
None
};
let refractory_time = Tensor::zeros(&[batch_size, neurons_per_layer])?;
let spikes = Tensor::zeros(&[batch_size, neurons_per_layer])?;
self.neuron_states = Some(NeuronState {
v_mem,
u_recovery,
adaptation,
refractory_time,
spikes,
});
let weights = Tensor::randn(&[neurons_per_layer, neurons_per_layer])?
.scalar_mul(self.config.initializer_range)?;
let pre_traces = Tensor::zeros(&[batch_size, neurons_per_layer])?;
let post_traces = Tensor::zeros(&[batch_size, neurons_per_layer])?;
let eligibility = Tensor::zeros(&[neurons_per_layer, neurons_per_layer])?;
self.synaptic_states = Some(SynapticState {
weights,
pre_traces,
post_traces,
eligibility,
});
Ok(())
}
pub fn forward(&mut self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
if self.neuron_states.is_none() {
self.init_states(batch_size)?;
}
let mut outputs = Vec::new();
for t in 0..seq_len {
let input_t = input.slice(1, t, t + 1)?.squeeze(1)?;
let output_t = self.forward_timestep(&input_t)?;
outputs.push(output_t);
}
let mut output = outputs[0].clone();
for i in 1..outputs.len() {
output = Tensor::concat(&[output, outputs[i].clone()], 1)?;
}
Ok(output)
}
fn forward_timestep(&mut self, input: &Tensor) -> Result<Tensor> {
let input_current = self.input_projection.forward(input.clone())?;
let recurrent_current = {
let neuron_states = self.neuron_states.as_ref().ok_or_else(|| {
TrustformersError::runtime_error(
"Neuron states not initialized in forward_timestep".to_string(),
)
})?;
self.recurrent_projection.forward(neuron_states.spikes.clone())?
};
let total_current = input_current.add(&recurrent_current)?;
{
let neuron_states = self.neuron_states.as_mut().ok_or_else(|| {
TrustformersError::runtime_error(
"Neuron states not initialized in forward_timestep".to_string(),
)
})?;
let dt = 0.001; let tau = 0.02;
let decay = neuron_states.v_mem.mul_scalar(1.0 - dt / tau)?;
let input_term = total_current.mul_scalar(dt / tau)?;
neuron_states.v_mem = decay.add(&input_term)?;
let threshold = Tensor::full(1.0, neuron_states.v_mem.shape())?;
let spike_mask = neuron_states.v_mem.greater(&threshold)?;
neuron_states.spikes = spike_mask.clone();
let reset_mask = spike_mask.mul_scalar(-1.0)?; neuron_states.v_mem = neuron_states.v_mem.add(&reset_mask)?;
}
{
let neuron_states = self.neuron_states.as_mut().ok_or_else(|| {
TrustformersError::runtime_error(
"Neuron states not initialized in forward_timestep".to_string(),
)
})?;
let synaptic_states = self.synaptic_states.as_mut().ok_or_else(|| {
TrustformersError::runtime_error(
"Synaptic states not initialized in forward_timestep".to_string(),
)
})?;
let learning_rate = 0.001;
let pre_spike_trace = neuron_states.spikes.mul_scalar(0.1)?;
let post_spike_trace = neuron_states.spikes.mul_scalar(0.1)?;
let weight_update = pre_spike_trace
.matmul(&post_spike_trace.transpose(0, 1)?)?
.mul_scalar(learning_rate)?;
synaptic_states.weights = synaptic_states.weights.add(&weight_update)?;
synaptic_states.pre_traces = synaptic_states.pre_traces.mul_scalar(0.99)?;
synaptic_states.post_traces = synaptic_states.post_traces.mul_scalar(0.99)?;
}
let output = {
let neuron_states = self.neuron_states.as_ref().ok_or_else(|| {
TrustformersError::runtime_error(
"Neuron states not initialized in forward_timestep".to_string(),
)
})?;
self.layer_norm.forward(neuron_states.spikes.clone())?
};
Ok(output)
}
#[allow(dead_code)]
fn update_lif_dynamics(&self, current: &Tensor, states: &mut NeuronState) -> Result<()> {
let dt = self.config.dt;
let tau_mem = self.config.tau_mem;
let v_threshold = self.config.v_threshold;
let v_reset = self.config.v_reset;
let decay = (-dt / tau_mem).exp();
states.v_mem = states.v_mem.mul_scalar(decay)?.add(¤t.mul_scalar(dt / tau_mem)?)?;
let noise = Tensor::randn_like(&states.v_mem)?.scalar_mul(self.config.noise_variance)?;
states.v_mem = states.v_mem.add(&noise)?;
let threshold_tensor = Tensor::full(v_threshold, states.v_mem.shape())?;
let spike_mask = states.v_mem.greater(&threshold_tensor)?;
states.spikes = spike_mask.to_dtype(DType::F32)?;
let reset_tensor = Tensor::full(v_reset, states.v_mem.shape())?;
let one_tensor = Tensor::ones_like(&states.spikes)?;
let inverse_spikes = one_tensor.sub(&states.spikes)?;
states.v_mem =
states.v_mem.mul(&inverse_spikes)?.add(&reset_tensor.mul(&states.spikes)?)?;
states.refractory_time = states.refractory_time.sub_scalar(dt)?;
states.refractory_time = states.refractory_time.clamp(0.0, f32::INFINITY)?;
let new_refractory = states.spikes.mul_scalar(self.config.refractory_period)?;
states.refractory_time = states.refractory_time.add(&new_refractory)?;
Ok(())
}
#[allow(dead_code)]
fn update_izhikevich_dynamics(&self, current: &Tensor, states: &mut NeuronState) -> Result<()> {
let dt = self.config.dt;
let v_threshold = self.config.v_threshold;
let _v_reset = self.config.v_reset;
let a = 0.02; let b = 0.2; let c = -65.0; let d = 2.0;
let u_recovery = states.u_recovery.as_mut().ok_or_else(|| {
TrustformersError::runtime_error(
"Recovery variable not initialized for Izhikevich model".to_string(),
)
})?;
let v_squared = states.v_mem.pow_scalar(2.0)?;
let dv_dt = v_squared
.mul_scalar(0.04)?
.add(&states.v_mem.mul_scalar(5.0)?)?
.add_scalar(140.0)?
.sub(u_recovery)?
.add(current)?;
states.v_mem = states.v_mem.add(&dv_dt.mul_scalar(dt)?)?;
let du_dt = states.v_mem.mul_scalar(b)?.sub(u_recovery)?.mul_scalar(a)?;
*u_recovery = u_recovery.add(&du_dt.mul_scalar(dt)?)?;
let threshold_tensor = Tensor::full(v_threshold, states.v_mem.shape())?;
let spike_mask = states.v_mem.greater(&threshold_tensor)?;
states.spikes = spike_mask.to_dtype(DType::F32)?;
let reset_tensor = Tensor::full(c, states.v_mem.shape())?;
let one_tensor = Tensor::ones_like(&states.spikes)?;
let inverse_spikes = one_tensor.sub(&states.spikes)?;
states.v_mem =
states.v_mem.mul(&inverse_spikes)?.add(&reset_tensor.mul(&states.spikes)?)?;
let u_increment = u_recovery.add_scalar(d)?;
*u_recovery = u_recovery.mul(&inverse_spikes)?.add(&u_increment.mul(&states.spikes)?)?;
Ok(())
}
#[allow(dead_code)]
fn update_hh_dynamics(&self, current: &Tensor, states: &mut NeuronState) -> Result<()> {
self.update_lif_dynamics(current, states)
}
#[allow(dead_code)]
fn update_adexp_dynamics(&self, current: &Tensor, states: &mut NeuronState) -> Result<()> {
let dt = self.config.dt;
let tau_mem = self.config.tau_mem;
let v_threshold = self.config.v_threshold;
let v_reset = self.config.v_reset;
let delta_t = 2.0; let v_t = -50.0; let tau_w = 30.0; let a = 2.0; let b = 0.0;
let adaptation = states.adaptation.as_mut().ok_or_else(|| {
TrustformersError::runtime_error(
"Adaptation current not initialized for AdExp model".to_string(),
)
})?;
let exp_term = states.v_mem.sub_scalar(v_t)?.div_scalar(delta_t)?.exp()?;
let exp_current = exp_term.mul_scalar(delta_t)?;
let dv_dt = states
.v_mem
.mul_scalar(-1.0 / tau_mem)?
.add(&exp_current.mul_scalar(1.0 / tau_mem)?)?
.sub(&adaptation.mul_scalar(1.0 / tau_mem)?)?
.add(¤t.mul_scalar(1.0 / tau_mem)?)?;
states.v_mem = states.v_mem.add(&dv_dt.mul_scalar(dt)?)?;
let dw_dt = states.v_mem.mul_scalar(a)?.sub(adaptation)?.mul_scalar(1.0 / tau_w)?;
*adaptation = adaptation.add(&dw_dt.mul_scalar(dt)?)?;
let threshold_tensor = Tensor::full(v_threshold, states.v_mem.shape())?;
let spike_mask = states.v_mem.greater(&threshold_tensor)?;
states.spikes = spike_mask.to_dtype(DType::F32)?;
let one_tensor = Tensor::ones_like(&states.spikes)?;
let inverse_spikes = one_tensor.sub(&states.spikes)?;
let reset_tensor = Tensor::full(v_reset, states.v_mem.shape())?;
states.v_mem =
states.v_mem.mul(&inverse_spikes)?.add(&reset_tensor.mul(&states.spikes)?)?;
let adaptation_increment = adaptation.add_scalar(b)?;
*adaptation = adaptation
.mul(&inverse_spikes)?
.add(&adaptation_increment.mul(&states.spikes)?)?;
Ok(())
}
#[allow(dead_code)]
fn update_srm_dynamics(&self, current: &Tensor, states: &mut NeuronState) -> Result<()> {
self.update_lif_dynamics(current, states)
}
#[allow(dead_code)]
fn update_plasticity(
&self,
neuron_states: &mut NeuronState,
synaptic_states: &mut SynapticState,
) -> Result<()> {
let dt = self.config.dt;
let learning_rate = self.config.learning_rate;
let tau_trace = self.config.tau_syn;
let trace_decay = (-dt / tau_trace).exp();
synaptic_states.pre_traces = synaptic_states.pre_traces.mul_scalar(trace_decay)?;
synaptic_states.post_traces = synaptic_states.post_traces.mul_scalar(trace_decay)?;
synaptic_states.pre_traces = synaptic_states.pre_traces.add(&neuron_states.spikes)?;
synaptic_states.post_traces = synaptic_states.post_traces.add(&neuron_states.spikes)?;
match self.config.plasticity_type {
PlasticityType::STDP => {
self.update_stdp_weights(neuron_states, synaptic_states, learning_rate)?;
},
PlasticityType::Hebbian => {
self.update_hebbian_weights(neuron_states, synaptic_states, learning_rate)?;
},
PlasticityType::AntiHebbian => {
self.update_anti_hebbian_weights(neuron_states, synaptic_states, learning_rate)?;
},
PlasticityType::Homeostatic => {
self.update_homeostatic_weights(neuron_states, synaptic_states, learning_rate)?;
},
PlasticityType::Metaplasticity => {
self.update_metaplasticity_weights(neuron_states, synaptic_states, learning_rate)?;
},
}
Ok(())
}
#[allow(dead_code)]
fn update_stdp_weights(
&self,
neuron_states: &NeuronState,
synaptic_states: &mut SynapticState,
lr: f32,
) -> Result<()> {
let ltp = neuron_states
.spikes
.unsqueeze(1)?
.matmul(&synaptic_states.pre_traces.unsqueeze(0)?)?;
let ltd = synaptic_states
.post_traces
.unsqueeze(1)?
.matmul(&neuron_states.spikes.unsqueeze(0)?)?;
let weight_update = ltp.sub(<d)?.mul_scalar(lr)?;
synaptic_states.weights = synaptic_states.weights.add(&weight_update)?;
Ok(())
}
#[allow(dead_code)]
fn update_hebbian_weights(
&self,
neuron_states: &NeuronState,
synaptic_states: &mut SynapticState,
lr: f32,
) -> Result<()> {
let hebbian_update =
neuron_states.spikes.unsqueeze(1)?.matmul(&neuron_states.spikes.unsqueeze(0)?)?;
let weight_update = hebbian_update.mul_scalar(lr)?;
synaptic_states.weights = synaptic_states.weights.add(&weight_update)?;
Ok(())
}
#[allow(dead_code)]
fn update_anti_hebbian_weights(
&self,
neuron_states: &NeuronState,
synaptic_states: &mut SynapticState,
lr: f32,
) -> Result<()> {
let anti_hebbian_update =
neuron_states.spikes.unsqueeze(1)?.matmul(&neuron_states.spikes.unsqueeze(0)?)?;
let weight_update = anti_hebbian_update.mul_scalar(-lr)?;
synaptic_states.weights = synaptic_states.weights.add(&weight_update)?;
Ok(())
}
#[allow(dead_code)]
fn update_homeostatic_weights(
&self,
neuron_states: &NeuronState,
synaptic_states: &mut SynapticState,
lr: f32,
) -> Result<()> {
let target_rate = self.config.target_rate;
let current_rate = neuron_states.spikes.mean()?;
let rate_error = current_rate.sub_scalar(target_rate)?;
let homeostatic_update = rate_error.mul_scalar(-lr)?;
let homeostatic_scalar = homeostatic_update.to_scalar()?;
synaptic_states.weights = synaptic_states.weights.add_scalar(homeostatic_scalar)?;
Ok(())
}
#[allow(dead_code)]
fn update_metaplasticity_weights(
&self,
neuron_states: &NeuronState,
synaptic_states: &mut SynapticState,
lr: f32,
) -> Result<()> {
self.update_stdp_weights(neuron_states, synaptic_states, lr * 0.8)?;
self.update_homeostatic_weights(neuron_states, synaptic_states, lr * 0.2)?;
Ok(())
}
pub fn reset_states(&mut self) -> Result<()> {
if let Some(states) = &mut self.neuron_states {
states.v_mem = Tensor::zeros_like(&states.v_mem)?;
states.refractory_time = Tensor::zeros_like(&states.refractory_time)?;
states.spikes = Tensor::zeros_like(&states.spikes)?;
if let Some(u_recovery) = &mut states.u_recovery {
*u_recovery = Tensor::zeros_like(u_recovery)?;
}
if let Some(adaptation) = &mut states.adaptation {
*adaptation = Tensor::zeros_like(adaptation)?;
}
}
if let Some(synaptic_states) = &mut self.synaptic_states {
synaptic_states.pre_traces = Tensor::zeros_like(&synaptic_states.pre_traces)?;
synaptic_states.post_traces = Tensor::zeros_like(&synaptic_states.post_traces)?;
synaptic_states.eligibility = Tensor::zeros_like(&synaptic_states.eligibility)?;
}
Ok(())
}
pub fn parameter_count(&self) -> usize {
self.input_projection.parameter_count()
+ self.recurrent_projection.parameter_count()
+ self.layer_norm.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
let param_memory = self.parameter_count() as f32 * 4.0 / 1_000_000.0; let state_memory = if self.neuron_states.is_some() {
self.config.neurons_per_layer as f32 * 4.0 * 4.0 / 1_000_000.0 } else {
0.0
};
param_memory + state_memory
}
}
#[derive(Debug)]
pub struct SpikingNeuralNetwork {
pub config: BiologicalConfig,
pub layers: Vec<SpikingLayer>,
pub output_projection: Linear,
}
impl SpikingNeuralNetwork {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.n_layer {
layers.push(SpikingLayer::new(config)?);
}
let output_projection =
Linear::new(config.neurons_per_layer, config.d_model, config.use_bias);
Ok(Self {
config: config.clone(),
layers,
output_projection,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<BiologicalModelOutput> {
let mut hidden_states = input.clone();
let mut all_spike_trains = Vec::new();
for layer in &mut self.layers {
hidden_states = layer.forward(&hidden_states)?;
if let Some(states) = &layer.neuron_states {
all_spike_trains.push(states.spikes.clone());
}
}
let output = self.output_projection.forward(hidden_states)?;
let spike_trains = if !all_spike_trains.is_empty() {
let mut stacked = all_spike_trains[0].clone();
for i in 1..all_spike_trains.len() {
stacked = Tensor::concat(&[stacked, all_spike_trains[i].clone()], 2)?;
}
Some(stacked)
} else {
None
};
Ok(BiologicalModelOutput {
hidden_states: output,
spike_trains,
memory_states: None,
attention_weights: None,
capsule_outputs: None,
dendritic_activations: None,
plasticity_traces: None,
})
}
pub fn update_plasticity(&mut self, _targets: &Tensor) -> Result<()> {
Ok(())
}
pub fn reset_states(&mut self) -> Result<()> {
for layer in &mut self.layers {
layer.reset_states()?;
}
Ok(())
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum::<usize>()
+ self.output_projection.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.layers.iter().map(|l| l.memory_usage()).sum::<f32>()
+ (self.output_projection.parameter_count() as f32 * 4.0 / 1_000_000.0)
}
}