use crate::delta::Delta;
use crate::error::Result;
use crate::plasticity_engine::{NeuromodState, PlasticityEngine, PlasticityEngineState};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedSNNConfig {
pub num_neurons: usize,
pub input_dim: usize,
pub top_k: usize,
pub stdp_lr: f32,
pub homeostasis_target: f32,
pub competition_strength: f32,
pub decay_rate: f32,
pub trace_decay: f32,
pub spike_threshold: f32,
}
impl Default for EmbeddedSNNConfig {
fn default() -> Self {
Self {
num_neurons: 100,
input_dim: 2048,
top_k: 10,
stdp_lr: 0.01,
homeostasis_target: 0.1,
competition_strength: 0.5,
decay_rate: 0.001,
trace_decay: 0.1,
spike_threshold: 0.5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SNNState {
prototypes: Vec<Vec<f32>>,
weights: Vec<Vec<(usize, f32)>>,
traces: Vec<f32>,
firing_rates: Vec<f32>,
last_activations: Vec<f32>,
}
pub struct EmbeddedSNN {
config: EmbeddedSNNConfig,
state: SNNState,
neuromod: NeuromodState,
tick_count: usize,
}
impl EmbeddedSNN {
pub fn new(config: EmbeddedSNNConfig) -> Self {
let mut prototypes = Vec::with_capacity(config.num_neurons);
for _ in 0..config.num_neurons {
let mut proto = vec![0.0; config.input_dim];
for val in &mut proto {
*val = (rand::random::<f32>() - 0.5) * 0.1;
}
prototypes.push(proto);
}
let mut weights = Vec::with_capacity(config.num_neurons);
for i in 0..config.num_neurons {
let mut neighbors = Vec::new();
for k in 1..=config.top_k {
let neighbor = (i + k) % config.num_neurons;
neighbors.push((neighbor, 0.1)); }
weights.push(neighbors);
}
let num_neurons = config.num_neurons;
Self {
config,
state: SNNState {
prototypes,
weights,
traces: vec![0.0; num_neurons],
firing_rates: vec![0.0; num_neurons],
last_activations: vec![0.0; num_neurons],
},
neuromod: NeuromodState::baseline(),
tick_count: 0,
}
}
fn compute_activations(&self, input: &[f32]) -> Vec<f32> {
let mut activations = vec![0.0; self.config.num_neurons];
for (i, proto) in self.state.prototypes.iter().enumerate() {
let mut dot = 0.0;
let mut norm_input = 0.0;
let mut norm_proto = 0.0;
for (inp, p) in input.iter().zip(proto.iter()) {
dot += inp * p;
norm_input += inp * inp;
norm_proto += p * p;
}
if norm_input > 0.0 && norm_proto > 0.0 {
activations[i] = dot / (norm_input.sqrt() * norm_proto.sqrt());
activations[i] = activations[i].max(0.0); }
}
activations
}
fn apply_competition(&self, activations: &mut [f32]) {
let strength = self.config.competition_strength * self.neuromod.norepinephrine;
let mut sorted_indices: Vec<usize> = (0..activations.len()).collect();
sorted_indices.sort_by(|&a, &b| {
activations[b].partial_cmp(&activations[a]).unwrap()
});
for (rank, &idx) in sorted_indices.iter().enumerate() {
let suppression = (rank as f32 / activations.len() as f32) * strength;
activations[idx] *= 1.0 - suppression;
}
}
fn spread_activation(&self, activations: &mut [f32]) {
let mut spread = vec![0.0; self.config.num_neurons];
for (i, neighbors) in self.state.weights.iter().enumerate() {
for &(neighbor_id, weight) in neighbors {
spread[neighbor_id] += activations[i] * weight;
}
}
for (act, spr) in activations.iter_mut().zip(spread.iter()) {
*act += spr * self.neuromod.acetylcholine; }
}
fn detect_spikes(&self, activations: &[f32]) -> Vec<usize> {
activations
.iter()
.enumerate()
.filter(|(_, &act)| act > self.config.spike_threshold)
.map(|(i, _)| i)
.collect()
}
fn apply_stdp(&mut self, spiking: &[usize]) -> Vec<Delta> {
let mut deltas = Vec::new();
let lr = self.config.stdp_lr * self.neuromod.dopamine;
for &i in spiking {
for &j in spiking {
if i == j {
continue;
}
if let Some(conn) = self.state.weights[i].iter_mut().find(|(n, _)| *n == j) {
conn.1 = (conn.1 + lr).clamp(0.0, 1.0);
deltas.push(Delta::merge(
format!("weight_{}_{}", i, j),
conn.1.to_le_bytes().to_vec(),
"snn_stdp",
conn.1, None, ));
}
}
}
deltas
}
fn apply_homeostasis(&mut self) {
let target = self.config.homeostasis_target;
let rate = 0.01 * self.neuromod.serotonin;
for (i, firing_rate) in self.state.firing_rates.iter_mut().enumerate() {
let error = target - *firing_rate;
for (_, weight) in &mut self.state.weights[i] {
*weight *= 1.0 + error * rate;
*weight = weight.clamp(0.0, 1.0);
}
}
}
fn apply_decay(&mut self) -> Vec<Delta> {
let mut deltas = Vec::new();
let decay = self.config.decay_rate * (1.0 - self.neuromod.serotonin);
for (i, neighbors) in self.state.weights.iter_mut().enumerate() {
for (j, weight) in neighbors.iter_mut() {
*weight *= 1.0 - decay;
if *weight < 0.01 {
deltas.push(Delta::delete(
format!("weight_{}_{}", i, j),
"snn_decay",
None,
));
*weight = 0.0;
}
}
}
deltas
}
fn update_traces(&mut self, activations: &[f32]) {
for (trace, &act) in self.state.traces.iter_mut().zip(activations.iter()) {
*trace = *trace * (1.0 - self.config.trace_decay) + act;
}
}
fn update_firing_rates(&mut self, spiking: &[usize]) {
let alpha = 0.1;
for i in 0..self.config.num_neurons {
let spike = if spiking.contains(&i) { 1.0 } else { 0.0 };
self.state.firing_rates[i] = self.state.firing_rates[i] * (1.0 - alpha) + spike * alpha;
}
}
}
impl PlasticityEngine for EmbeddedSNN {
fn process(&mut self, activation: &[f32], neuromod: &NeuromodState) -> Result<Vec<Delta>> {
self.neuromod = neuromod.clone();
self.tick_count += 1;
let mut activations = self.compute_activations(activation);
self.apply_competition(&mut activations);
self.spread_activation(&mut activations);
let spiking = self.detect_spikes(&activations);
let mut deltas = self.apply_stdp(&spiking);
if self.tick_count % 100 == 0 {
self.apply_homeostasis();
}
if self.tick_count % 10 == 0 {
let decay_deltas = self.apply_decay();
deltas.extend(decay_deltas);
}
self.update_traces(&activations);
self.update_firing_rates(&spiking);
self.state.last_activations = activations;
Ok(deltas)
}
fn sync_neuromod(&mut self, neuromod: &NeuromodState) {
self.neuromod = neuromod.clone();
}
fn state(&self) -> PlasticityEngineState {
PlasticityEngineState {
engine_type: "EmbeddedSNN".to_string(),
neuromod: self.neuromod.clone(),
custom_state: bincode::serialize(&self.state).unwrap_or_default(),
}
}
fn restore(&mut self, state: &PlasticityEngineState) -> Result<()> {
if state.engine_type != "EmbeddedSNN" {
return Err(crate::error::Error::InvalidState(
"Wrong engine type".to_string(),
));
}
self.neuromod = state.neuromod.clone();
self.state = bincode::deserialize(&state.custom_state)
.map_err(|e| crate::error::Error::Deserialization(e.to_string()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_snn() {
let config = EmbeddedSNNConfig::default();
let snn = EmbeddedSNN::new(config);
assert_eq!(snn.state.prototypes.len(), 100);
assert_eq!(snn.state.weights.len(), 100);
}
#[test]
fn test_process_activation() {
let config = EmbeddedSNNConfig::default();
let mut snn = EmbeddedSNN::new(config);
let input = vec![0.5; 2048];
let neuromod = NeuromodState::baseline();
let deltas = snn.process(&input, &neuromod).unwrap();
assert!(!deltas.is_empty() || snn.tick_count < 10);
}
#[test]
fn test_stdp_strengthening() {
let config = EmbeddedSNNConfig {
num_neurons: 10,
spike_threshold: 0.1, stdp_lr: 0.1, ..Default::default()
};
let mut snn = EmbeddedSNN::new(config);
for proto in &mut snn.state.prototypes[0..3] {
for val in proto.iter_mut() {
*val = 0.01; }
}
let input = vec![1.0; 2048];
let mut neuromod = NeuromodState::baseline();
neuromod.dopamine = 1.0;
let deltas = snn.process(&input, &neuromod).unwrap();
let weight_updates = deltas
.iter()
.filter(|d| d.key.starts_with("weight_"))
.count();
assert!(weight_updates > 0, "Expected STDP weight updates but got none. Total deltas: {}", deltas.len());
}
}