use crate::{NervousSystemError, Result};
use rand::Rng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BTSPSynapse {
weight: f32,
eligibility_trace: f32,
tau_btsp: f32,
min_weight: f32,
max_weight: f32,
ltp_rate: f32,
ltd_rate: f32,
}
impl BTSPSynapse {
pub fn new(initial_weight: f32, tau_btsp: f32) -> Result<Self> {
if !(0.0..=1.0).contains(&initial_weight) {
return Err(NervousSystemError::InvalidWeight(initial_weight));
}
if tau_btsp <= 0.0 {
return Err(NervousSystemError::InvalidTimeConstant(tau_btsp));
}
Ok(Self {
weight: initial_weight,
eligibility_trace: 0.0,
tau_btsp,
min_weight: 0.0,
max_weight: 1.0,
ltp_rate: 0.1, ltd_rate: 0.05, })
}
pub fn with_rates(
initial_weight: f32,
tau_btsp: f32,
ltp_rate: f32,
ltd_rate: f32,
) -> Result<Self> {
let mut synapse = Self::new(initial_weight, tau_btsp)?;
synapse.ltp_rate = ltp_rate;
synapse.ltd_rate = ltd_rate;
Ok(synapse)
}
#[inline]
pub fn update(&mut self, presynaptic_active: bool, plateau_signal: bool, dt: f32) {
self.eligibility_trace *= (-dt / self.tau_btsp).exp();
if presynaptic_active {
self.eligibility_trace += 1.0;
}
if plateau_signal && self.eligibility_trace > 0.01 {
let delta = if self.weight < 0.5 {
self.ltp_rate } else {
-self.ltd_rate };
self.weight += delta * self.eligibility_trace;
self.weight = self.weight.clamp(self.min_weight, self.max_weight);
}
}
#[inline]
pub fn weight(&self) -> f32 {
self.weight
}
#[inline]
pub fn eligibility_trace(&self) -> f32 {
self.eligibility_trace
}
#[inline]
pub fn forward(&self, input: f32) -> f32 {
self.weight * input
}
}
#[derive(Debug, Clone)]
pub struct PlateauDetector {
threshold: f32,
window: f32,
}
impl PlateauDetector {
pub fn new(threshold: f32, window: f32) -> Self {
Self { threshold, window }
}
#[inline]
pub fn detect(&self, postsynaptic_activity: f32) -> bool {
postsynaptic_activity > self.threshold
}
#[inline]
pub fn detect_error(&self, predicted: f32, actual: f32) -> bool {
(predicted - actual).abs() > self.threshold
}
}
#[derive(Debug, Clone)]
pub struct BTSPLayer {
synapses: Vec<BTSPSynapse>,
plateau_detector: PlateauDetector,
activity: f32,
}
impl BTSPLayer {
pub fn new(size: usize, tau: f32) -> Self {
let mut rng = rand::thread_rng();
let synapses = (0..size)
.map(|_| {
let weight = rng.gen_range(0.0..0.1); BTSPSynapse::new(weight, tau).unwrap()
})
.collect();
Self {
synapses,
plateau_detector: PlateauDetector::new(0.7, 100.0),
activity: 0.0,
}
}
#[inline]
pub fn forward(&self, input: &[f32]) -> f32 {
debug_assert_eq!(input.len(), self.synapses.len());
self.synapses
.iter()
.zip(input.iter())
.map(|(synapse, &x)| synapse.forward(x))
.sum()
}
pub fn learn(&mut self, input: &[bool], plateau: bool, dt: f32) {
debug_assert_eq!(input.len(), self.synapses.len());
for (synapse, &active) in self.synapses.iter_mut().zip(input.iter()) {
synapse.update(active, plateau, dt);
}
}
pub fn one_shot_associate(&mut self, pattern: &[f32], target: f32) {
debug_assert_eq!(pattern.len(), self.synapses.len());
let current = self.forward(pattern);
let error = target - current;
let sum_squared: f32 = pattern.iter().map(|&x| x * x).sum();
if sum_squared < 1e-8 {
return; }
for (synapse, &input_val) in self.synapses.iter_mut().zip(pattern.iter()) {
if input_val.abs() > 0.01 {
synapse.eligibility_trace = input_val;
let delta = error * input_val / sum_squared;
synapse.weight += delta;
synapse.weight = synapse.weight.clamp(0.0, 1.0);
}
}
self.activity = target;
}
pub fn size(&self) -> usize {
self.synapses.len()
}
pub fn weights(&self) -> Vec<f32> {
self.synapses.iter().map(|s| s.weight()).collect()
}
}
#[derive(Debug, Clone)]
pub struct BTSPAssociativeMemory {
layers: Vec<BTSPLayer>,
input_size: usize,
output_size: usize,
}
impl BTSPAssociativeMemory {
pub fn new(input_size: usize, output_size: usize) -> Self {
let tau = 2000.0; let layers = (0..output_size)
.map(|_| BTSPLayer::new(input_size, tau))
.collect();
Self {
layers,
input_size,
output_size,
}
}
pub fn store_one_shot(&mut self, key: &[f32], value: &[f32]) -> Result<()> {
if key.len() != self.input_size {
return Err(NervousSystemError::DimensionMismatch {
expected: self.input_size,
actual: key.len(),
});
}
if value.len() != self.output_size {
return Err(NervousSystemError::DimensionMismatch {
expected: self.output_size,
actual: value.len(),
});
}
for (layer, &target) in self.layers.iter_mut().zip(value.iter()) {
layer.one_shot_associate(key, target);
}
Ok(())
}
pub fn retrieve(&self, query: &[f32]) -> Result<Vec<f32>> {
if query.len() != self.input_size {
return Err(NervousSystemError::DimensionMismatch {
expected: self.input_size,
actual: query.len(),
});
}
Ok(self
.layers
.iter()
.map(|layer| layer.forward(query))
.collect())
}
pub fn store_batch(&mut self, pairs: &[(&[f32], &[f32])]) -> Result<()> {
for (key, value) in pairs {
self.store_one_shot(key, value)?;
}
Ok(())
}
pub fn dimensions(&self) -> (usize, usize) {
(self.input_size, self.output_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_synapse_creation() {
let synapse = BTSPSynapse::new(0.5, 2000.0).unwrap();
assert_eq!(synapse.weight(), 0.5);
assert_eq!(synapse.eligibility_trace(), 0.0);
assert!(BTSPSynapse::new(-0.1, 2000.0).is_err());
assert!(BTSPSynapse::new(1.1, 2000.0).is_err());
assert!(BTSPSynapse::new(0.5, -100.0).is_err());
}
#[test]
fn test_eligibility_trace_decay() {
let mut synapse = BTSPSynapse::new(0.5, 1000.0).unwrap();
synapse.update(true, false, 10.0);
let trace1 = synapse.eligibility_trace();
assert!(trace1 > 0.9);
for _ in 0..100 {
synapse.update(false, false, 10.0);
}
let trace2 = synapse.eligibility_trace();
assert!(trace2 < 0.4 && trace2 > 0.3);
}
#[test]
fn test_bidirectional_plasticity() {
let mut weak = BTSPSynapse::new(0.2, 2000.0).unwrap();
weak.eligibility_trace = 1.0; weak.update(false, true, 10.0); assert!(weak.weight() > 0.2);
let mut strong = BTSPSynapse::new(0.8, 2000.0).unwrap();
strong.eligibility_trace = 1.0;
strong.update(false, true, 10.0); assert!(strong.weight() < 0.8); }
#[test]
fn test_layer_forward() {
let layer = BTSPLayer::new(10, 2000.0);
let input = vec![0.5; 10];
let output = layer.forward(&input);
assert!(output >= 0.0); }
#[test]
fn test_one_shot_learning() {
let mut layer = BTSPLayer::new(100, 2000.0);
let pattern = vec![0.1; 100];
let target = 0.8;
layer.one_shot_associate(&pattern, target);
let output = layer.forward(&pattern);
let error = (output - target).abs();
assert!(
error < 0.6,
"One-shot learning failed: error = {}, output = {}",
error,
output
);
}
#[test]
fn test_one_shot_multiple_patterns() {
let mut layer = BTSPLayer::new(50, 2000.0);
let pattern1 = vec![1.0; 50];
let pattern2 = vec![0.5; 50];
layer.one_shot_associate(&pattern1, 1.0);
layer.one_shot_associate(&pattern2, 0.5);
let out1 = layer.forward(&pattern1);
let out2 = layer.forward(&pattern2);
assert!((out1 - 1.0).abs() < 0.5, "out1: {}", out1);
assert!((out2 - 0.5).abs() < 0.5, "out2: {}", out2);
}
#[test]
fn test_associative_memory() {
let mut memory = BTSPAssociativeMemory::new(10, 5);
let key = vec![0.5; 10];
let value = vec![0.1, 0.2, 0.3, 0.4, 0.5];
memory.store_one_shot(&key, &value).unwrap();
let retrieved = memory.retrieve(&key).unwrap();
for (expected, actual) in value.iter().zip(retrieved.iter()) {
assert!(
(expected - actual).abs() < 0.35,
"expected: {}, actual: {}",
expected,
actual
);
}
}
#[test]
fn test_associative_memory_batch() {
let mut memory = BTSPAssociativeMemory::new(8, 4);
let key1 = vec![1.0; 8];
let val1 = vec![0.1; 4];
let key2 = vec![0.5; 8];
let val2 = vec![0.9; 4];
memory
.store_batch(&[(&key1, &val1), (&key2, &val2)])
.unwrap();
let ret1 = memory.retrieve(&key1).unwrap();
let ret2 = memory.retrieve(&key2).unwrap();
assert_eq!(
ret1.len(),
4,
"Retrieved vector should have correct dimension"
);
assert_eq!(
ret2.len(),
4,
"Retrieved vector should have correct dimension"
);
for &v in &ret1 {
assert!(v.is_finite(), "value should be finite: {}", v);
}
for &v in &ret2 {
assert!(v.is_finite(), "value should be finite: {}", v);
}
}
#[test]
fn test_dimension_mismatch() {
let mut memory = BTSPAssociativeMemory::new(5, 3);
let wrong_key = vec![0.5; 10]; let value = vec![0.1; 3];
assert!(memory.store_one_shot(&wrong_key, &value).is_err());
let key = vec![0.5; 5];
let wrong_value = vec![0.1; 10];
assert!(memory.store_one_shot(&key, &wrong_value).is_err());
}
#[test]
fn test_plateau_detector() {
let detector = PlateauDetector::new(0.7, 100.0);
assert!(detector.detect(0.8));
assert!(!detector.detect(0.5));
assert!(detector.detect_error(0.0, 1.0));
assert!(!detector.detect_error(0.5, 0.6));
}
#[test]
fn test_retention_over_time() {
let mut layer = BTSPLayer::new(50, 2000.0);
let pattern = vec![0.7; 50];
layer.one_shot_associate(&pattern, 0.9);
let immediate = layer.forward(&pattern);
let input_inactive = vec![false; 50];
for _ in 0..100 {
layer.learn(&input_inactive, false, 10.0);
}
let after_delay = layer.forward(&pattern);
assert!((immediate - after_delay).abs() < 0.1);
}
#[test]
fn test_synapse_performance() {
let mut synapse = BTSPSynapse::new(0.5, 2000.0).unwrap();
for _ in 0..1000 {
synapse.update(true, false, 1.0);
}
let start = std::time::Instant::now();
for _ in 0..1_000_000 {
synapse.update(true, false, 1.0);
}
let elapsed = start.elapsed();
assert!(elapsed.as_millis() < 100);
}
}