use super::{SpatialNeuron, SpatialSynapseStore};
#[derive(Clone, Copy, Debug)]
pub struct PruningConfig {
pub inactivity_decay: u8,
pub distance_decay_per_unit: f32,
pub dormancy_threshold: u8,
pub activity_threshold: u8,
pub retraction_rate: f32,
pub activity_boost: u8,
}
impl Default for PruningConfig {
fn default() -> Self {
Self {
inactivity_decay: 1,
distance_decay_per_unit: 0.1,
dormancy_threshold: 10,
activity_threshold: 5,
retraction_rate: 0.1,
activity_boost: 2,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct DormancyTracker {
dormancy_counts: Vec<u8>,
}
impl DormancyTracker {
pub fn new(synapse_count: usize) -> Self {
Self {
dormancy_counts: vec![0; synapse_count],
}
}
pub fn resize(&mut self, synapse_count: usize) {
self.dormancy_counts.resize(synapse_count, 0);
}
pub fn get(&self, idx: usize) -> u8 {
self.dormancy_counts.get(idx).copied().unwrap_or(0)
}
pub fn increment(&mut self, idx: usize) {
if let Some(count) = self.dormancy_counts.get_mut(idx) {
*count = count.saturating_add(1);
}
}
pub fn reset(&mut self, idx: usize) {
if let Some(count) = self.dormancy_counts.get_mut(idx) {
*count = 0;
}
}
pub fn clear(&mut self) {
for count in &mut self.dormancy_counts {
*count = 0;
}
}
}
#[derive(Clone, Debug, Default)]
pub struct PruningResult {
pub synapses_pruned: usize,
pub axons_depleted: usize,
pub synapses_dormant: usize,
pub synapses_active: usize,
}
pub fn identify_prunable_synapses(
synapses: &SpatialSynapseStore,
dormancy: &mut DormancyTracker,
config: &PruningConfig,
) -> Vec<usize> {
let mut to_prune = Vec::new();
for (idx, syn) in synapses.iter().enumerate() {
if syn.signal.magnitude < config.activity_threshold {
dormancy.increment(idx);
if dormancy.get(idx) >= config.dormancy_threshold {
to_prune.push(idx);
}
} else {
dormancy.reset(idx);
}
}
to_prune
}
pub fn decay_axon_health(
neurons: &mut [SpatialNeuron],
synapses: &SpatialSynapseStore,
config: &PruningConfig,
) -> usize {
let mut depleted = 0;
for (idx, neuron) in neurons.iter_mut().enumerate() {
if neuron.nuclei.is_motor() || neuron.nuclei.is_sensory() {
continue;
}
let outgoing = synapses.outgoing(idx as u32);
if outgoing.is_empty() {
neuron.axon.decay(config.inactivity_decay);
} else {
let active_count = outgoing.iter()
.filter(|s| s.signal.magnitude >= config.activity_threshold)
.count();
if active_count == 0 {
neuron.axon.decay(config.inactivity_decay);
} else {
neuron.axon.boost(config.activity_boost);
}
let axon_length = neuron.axon.length(neuron.soma.position);
let distance_decay = (axon_length * config.distance_decay_per_unit) as u8;
neuron.axon.decay(distance_decay);
}
if !neuron.axon.is_alive() {
depleted += 1;
}
}
depleted
}
pub fn retract_dead_axons(neurons: &mut [SpatialNeuron], config: &PruningConfig) {
for neuron in neurons.iter_mut() {
if !neuron.axon.is_alive() {
neuron.axon.retract_toward(neuron.soma.position, config.retraction_rate);
}
}
}
pub fn pruning_cycle(
neurons: &mut [SpatialNeuron],
synapses: &SpatialSynapseStore,
dormancy: &mut DormancyTracker,
config: &PruningConfig,
) -> PruningResult {
let mut result = PruningResult::default();
for syn in synapses.iter() {
if syn.signal.magnitude >= config.activity_threshold {
result.synapses_active += 1;
} else {
result.synapses_dormant += 1;
}
}
let to_prune = identify_prunable_synapses(synapses, dormancy, config);
result.synapses_pruned = to_prune.len();
result.axons_depleted = decay_axon_health(neurons, synapses, config);
retract_dead_axons(neurons, config);
result
}
pub fn hard_prune(
synapses: &mut SpatialSynapseStore,
dormancy: &mut DormancyTracker,
neuron_count: usize,
) -> usize {
let before = synapses.len();
synapses.prune_dormant(neuron_count);
let after = synapses.len();
dormancy.resize(after);
before - after
}
#[cfg(test)]
mod tests {
use super::*;
use crate::spatial::SpatialSynapse;
use ternary_signal::Signal;
fn make_test_synapse(magnitude: u8) -> SpatialSynapse {
SpatialSynapse::with_signal(0, 1, Signal::positive(magnitude), 100)
}
#[test]
fn test_dormancy_tracker() {
let mut tracker = DormancyTracker::new(3);
tracker.increment(0);
tracker.increment(0);
assert_eq!(tracker.get(0), 2);
tracker.reset(0);
assert_eq!(tracker.get(0), 0);
}
#[test]
fn test_identify_prunable() {
let mut store = SpatialSynapseStore::new(2);
store.add(make_test_synapse(100)); store.add(make_test_synapse(0)); store.rebuild_index(2);
let mut dormancy = DormancyTracker::new(2);
let config = PruningConfig {
dormancy_threshold: 3,
activity_threshold: 5,
..Default::default()
};
let prunable = identify_prunable_synapses(&store, &mut dormancy, &config);
assert!(prunable.is_empty());
for _ in 0..3 {
identify_prunable_synapses(&store, &mut dormancy, &config);
}
let prunable = identify_prunable_synapses(&store, &mut dormancy, &config);
assert_eq!(prunable.len(), 1); }
#[test]
fn test_axon_decay_inactive() {
let mut neurons = vec![SpatialNeuron::pyramidal_at([0.0, 0.0, 0.0])];
neurons[0].axon.health = 100;
let store = SpatialSynapseStore::new(1); let config = PruningConfig {
inactivity_decay: 10,
..Default::default()
};
decay_axon_health(&mut neurons, &store, &config);
assert_eq!(neurons[0].axon.health, 90);
}
#[test]
fn test_axon_boost_active() {
let mut neurons = vec![SpatialNeuron::pyramidal_at([0.0, 0.0, 0.0])];
neurons[0].axon.health = 100;
let mut store = SpatialSynapseStore::new(2);
store.add(make_test_synapse(100)); store.rebuild_index(2);
let config = PruningConfig {
activity_boost: 5,
activity_threshold: 5,
distance_decay_per_unit: 0.0, ..Default::default()
};
decay_axon_health(&mut neurons, &store, &config);
assert!(neurons[0].axon.health > 100);
}
#[test]
fn test_retract_dead_axon() {
let mut neurons = vec![SpatialNeuron::pyramidal_at([0.0, 0.0, 0.0])];
neurons[0].axon.terminal = [10.0, 0.0, 0.0];
neurons[0].axon.health = 0;
let config = PruningConfig {
retraction_rate: 0.5,
..Default::default()
};
retract_dead_axons(&mut neurons, &config);
assert!((neurons[0].axon.terminal[0] - 5.0).abs() < 0.01);
}
#[test]
fn test_hard_prune() {
let mut store = SpatialSynapseStore::new(2);
store.add(make_test_synapse(100)); store.add(SpatialSynapse::dormant(0, 1, 100)); store.rebuild_index(2);
let mut dormancy = DormancyTracker::new(2);
let pruned = hard_prune(&mut store, &mut dormancy, 2);
assert_eq!(pruned, 1);
assert_eq!(store.len(), 1);
}
#[test]
fn test_pruning_cycle() {
let mut neurons = vec![
SpatialNeuron::pyramidal_at([0.0, 0.0, 0.0]),
SpatialNeuron::pyramidal_at([1.0, 0.0, 0.0]),
];
let mut store = SpatialSynapseStore::new(2);
store.add(make_test_synapse(100)); store.add(make_test_synapse(0)); store.rebuild_index(2);
let mut dormancy = DormancyTracker::new(2);
let config = PruningConfig::default();
let result = pruning_cycle(&mut neurons, &store, &mut dormancy, &config);
assert_eq!(result.synapses_active, 1);
assert_eq!(result.synapses_dormant, 1);
}
#[test]
fn test_distance_decay() {
let mut neurons = vec![SpatialNeuron::pyramidal_at([0.0, 0.0, 0.0])];
neurons[0].axon.terminal = [10.0, 0.0, 0.0]; neurons[0].axon.health = 100;
let mut store = SpatialSynapseStore::new(2);
store.add(make_test_synapse(100));
store.rebuild_index(2);
let config = PruningConfig {
distance_decay_per_unit: 1.0, activity_threshold: 5,
activity_boost: 0, ..Default::default()
};
decay_axon_health(&mut neurons, &store, &config);
assert_eq!(neurons[0].axon.health, 90);
}
}