use crate::error::{MastishkError, validate_dt};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuralPopulation {
pub name: String,
pub rate: f32,
pub resting_rate: f32,
pub tau: f32,
pub excitatory: bool,
}
impl NeuralPopulation {
#[must_use]
pub fn new(name: impl Into<String>, resting_rate: f32, tau: f32, excitatory: bool) -> Self {
Self {
name: name.into(),
rate: resting_rate,
resting_rate,
tau,
excitatory,
}
}
#[inline]
pub fn tick(&mut self, input: f32, dt: f32) -> Result<(), MastishkError> {
validate_dt(dt)?;
self.tick_unchecked(input, dt);
Ok(())
}
#[inline]
pub(crate) fn tick_unchecked(&mut self, input: f32, dt: f32) {
let target = (self.resting_rate + input).clamp(0.0, 1.0);
let alpha = 1.0 - (-dt / self.tau).exp();
self.rate += (target - self.rate) * alpha;
self.rate = self.rate.clamp(0.0, 1.0);
}
#[inline]
#[must_use]
pub fn activation(&self) -> f32 {
self.rate - self.resting_rate
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Synapse {
pub from: usize,
pub to: usize,
pub weight: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Circuit {
pub populations: Vec<NeuralPopulation>,
pub synapses: Vec<Synapse>,
}
impl Circuit {
#[must_use]
pub fn new() -> Self {
Self {
populations: Vec::new(),
synapses: Vec::new(),
}
}
pub fn add_population(&mut self, pop: NeuralPopulation) -> usize {
let idx = self.populations.len();
tracing::debug!(name = %pop.name, idx, excitatory = pop.excitatory, "population added");
self.populations.push(pop);
idx
}
pub fn add_synapse(
&mut self,
from: usize,
to: usize,
weight: f32,
) -> Result<(), MastishkError> {
let len = self.populations.len();
if from >= len || to >= len {
return Err(MastishkError::InvalidCircuit(format!(
"synapse {from}->{to} out of bounds (population count: {len})"
)));
}
tracing::debug!(from, to, weight, "synapse added");
self.synapses.push(Synapse { from, to, weight });
Ok(())
}
#[inline]
pub fn tick(&mut self, dt: f32) -> Result<(), MastishkError> {
validate_dt(dt)?;
tracing::trace!(
dt,
populations = self.populations.len(),
synapses = self.synapses.len(),
"ticking circuit"
);
let mut inputs = vec![0.0_f32; self.populations.len()];
for syn in &self.synapses {
if syn.from < self.populations.len() && syn.to < self.populations.len() {
inputs[syn.to] += self.populations[syn.from].rate * syn.weight;
}
}
for (i, pop) in self.populations.iter_mut().enumerate() {
pop.tick_unchecked(inputs[i], dt);
}
Ok(())
}
#[inline]
pub fn tick_with_gain(&mut self, gain: f32, dt: f32) -> Result<(), MastishkError> {
validate_dt(dt)?;
tracing::trace!(
dt,
gain,
populations = self.populations.len(),
synapses = self.synapses.len(),
"ticking circuit with gain"
);
let mut inputs = vec![0.0_f32; self.populations.len()];
for syn in &self.synapses {
if syn.from < self.populations.len() && syn.to < self.populations.len() {
inputs[syn.to] += self.populations[syn.from].rate * syn.weight * gain;
}
}
for (i, pop) in self.populations.iter_mut().enumerate() {
pop.tick_unchecked(inputs[i], dt);
}
Ok(())
}
#[inline]
pub fn apply_hebbian(&mut self, learning_rate: f32) {
for syn in &mut self.synapses {
if syn.from < self.populations.len() && syn.to < self.populations.len() {
let pre = self.populations[syn.from].rate;
let post = self.populations[syn.to].rate;
syn.weight += learning_rate * pre * post;
syn.weight = syn.weight.clamp(-1.0, 1.0);
}
}
tracing::trace!(learning_rate, "Hebbian learning applied");
}
}
impl Default for Circuit {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_population_tick() {
let mut pop = NeuralPopulation::new("test", 0.2, 0.5, true);
pop.tick(0.5, 1.0).unwrap();
assert!(pop.rate > 0.2);
}
#[test]
fn test_circuit_excitation() {
let mut c = Circuit::new();
let a = c.add_population(NeuralPopulation::new("A", 0.5, 0.1, true));
let b = c.add_population(NeuralPopulation::new("B", 0.1, 0.1, true));
c.add_synapse(a, b, 0.5).unwrap();
c.tick(0.5).unwrap();
assert!(c.populations[b].rate > 0.1);
}
#[test]
fn test_circuit_inhibition() {
let mut c = Circuit::new();
let a = c.add_population(NeuralPopulation::new("A", 0.8, 0.1, true));
let b = c.add_population(NeuralPopulation::new("B", 0.5, 0.1, false));
c.add_synapse(a, b, -0.5).unwrap();
c.tick(0.5).unwrap();
assert!(c.populations[b].rate < 0.5);
}
#[test]
fn test_serde_roundtrip() {
let c = Circuit::new();
let json = serde_json::to_string(&c).unwrap();
let c2: Circuit = serde_json::from_str(&json).unwrap();
assert_eq!(c2.populations.len(), 0);
}
#[test]
fn test_negative_dt_rejected() {
let mut pop = NeuralPopulation::new("test", 0.2, 0.5, true);
assert!(pop.tick(0.0, -1.0).is_err());
let mut c = Circuit::new();
c.add_population(NeuralPopulation::new("A", 0.5, 0.1, true));
assert!(c.tick(-0.5).is_err());
}
#[test]
fn test_activation() {
let mut pop = NeuralPopulation::new("test", 0.2, 0.5, true);
assert!((pop.activation() - 0.0).abs() < f32::EPSILON);
pop.tick(0.5, 1.0).unwrap();
assert!(pop.activation() > 0.0);
}
#[test]
fn test_empty_circuit_tick() {
let mut c = Circuit::new();
c.tick(1.0).unwrap(); }
#[test]
fn test_out_of_bounds_synapse_rejected() {
let mut c = Circuit::new();
c.add_population(NeuralPopulation::new("A", 0.5, 0.1, true));
assert!(c.add_synapse(0, 99, 0.5).is_err()); assert!(c.add_synapse(99, 0, 0.5).is_err()); }
#[test]
fn test_tick_with_gain_amplifies() {
let mut c1 = Circuit::new();
let a1 = c1.add_population(NeuralPopulation::new("A", 0.5, 0.1, true));
let b1 = c1.add_population(NeuralPopulation::new("B", 0.1, 0.1, true));
c1.add_synapse(a1, b1, 0.5).unwrap();
let mut c2 = c1.clone();
c1.tick(0.5).unwrap();
c2.tick_with_gain(2.0, 0.5).unwrap();
assert!(c2.populations[b1].rate > c1.populations[b1].rate);
}
#[test]
fn test_tick_with_gain_one_equals_tick() {
let mut c1 = Circuit::new();
let a = c1.add_population(NeuralPopulation::new("A", 0.5, 0.1, true));
let b = c1.add_population(NeuralPopulation::new("B", 0.1, 0.1, true));
c1.add_synapse(a, b, 0.5).unwrap();
let mut c2 = c1.clone();
c1.tick(0.5).unwrap();
c2.tick_with_gain(1.0, 0.5).unwrap();
assert!((c1.populations[b].rate - c2.populations[b].rate).abs() < f32::EPSILON);
}
#[test]
fn test_tick_with_gain_negative_dt_rejected() {
let mut c = Circuit::new();
assert!(c.tick_with_gain(1.0, -1.0).is_err());
}
}