Skip to main content

neuropool/
synapse.rs

1//! Synapse data structures — 8-byte compact synapses with CSR storage.
2//!
3//! Each synapse is 8 bytes: target(2) + weight(1) + delay(1) + eligibility(1)
4//! + maturity(1) + reserved(2). Stored in CSR (Compressed Sparse Row) format
5//! for cache-friendly iteration during spike propagation.
6
7use crate::neuron::flags;
8
9/// Thermal state of a synapse, encoded in bits 0-1 of the maturity byte.
10///
11/// Mirrors the thermogram thermal lifecycle but at the individual synapse level.
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13#[repr(u8)]
14pub enum ThermalState {
15    /// Newly formed, high plasticity, vulnerable to pruning
16    Hot = 0b00,
17    /// Established, moderate plasticity
18    Warm = 0b01,
19    /// Proven, low plasticity, protected from weakening
20    Cool = 0b10,
21    /// Frozen, genome-level, immutable
22    Cold = 0b11,
23}
24
25impl ThermalState {
26    #[inline]
27    pub fn from_maturity(maturity: u8) -> Self {
28        match maturity & 0b11 {
29            0b00 => Self::Hot,
30            0b01 => Self::Warm,
31            0b10 => Self::Cool,
32            0b11 => Self::Cold,
33            _ => unreachable!(),
34        }
35    }
36}
37
38/// Maturity byte encoding utilities.
39///
40/// Layout: `[counter:6][state:2]`
41/// - Bits 0-1: ThermalState
42/// - Bits 2-7: Reinforcement counter (0-63)
43pub mod maturity {
44    use super::ThermalState;
45
46    /// Promotion thresholds (counter value at which state advances)
47    pub const HOT_TO_WARM: u8 = 8;
48    pub const WARM_TO_COOL: u8 = 24;
49    pub const COOL_TO_COLD: u8 = 63;
50
51    /// Encode maturity byte from state and counter
52    #[inline]
53    pub fn encode(state: ThermalState, counter: u8) -> u8 {
54        (counter.min(63) << 2) | (state as u8)
55    }
56
57    /// Decode state from maturity byte
58    #[inline]
59    pub fn state(m: u8) -> ThermalState {
60        ThermalState::from_maturity(m)
61    }
62
63    /// Decode counter from maturity byte
64    #[inline]
65    pub fn counter(m: u8) -> u8 {
66        m >> 2
67    }
68
69    /// Increment counter, promoting state if threshold reached.
70    /// Returns new maturity byte.
71    #[inline]
72    pub fn increment(m: u8) -> u8 {
73        let s = state(m);
74        let c = counter(m);
75
76        if s == ThermalState::Cold {
77            return m; // Frozen, no change
78        }
79
80        let new_c = c.saturating_add(1).min(63);
81        let new_state = match s {
82            ThermalState::Hot if new_c >= HOT_TO_WARM => ThermalState::Warm,
83            ThermalState::Warm if new_c >= WARM_TO_COOL => ThermalState::Cool,
84            ThermalState::Cool if new_c >= COOL_TO_COLD => ThermalState::Cold,
85            other => other,
86        };
87
88        // Reset counter on promotion
89        if new_state != s {
90            encode(new_state, 0)
91        } else {
92            encode(s, new_c)
93        }
94    }
95
96    /// Decrement counter. If counter reaches 0 in HOT state, synapse is dead.
97    /// Returns new maturity byte. Dead synapses have maturity == 0x00.
98    #[inline]
99    pub fn decrement(m: u8) -> u8 {
100        let s = state(m);
101        let c = counter(m);
102
103        if s == ThermalState::Cold {
104            return m; // Frozen, no change
105        }
106
107        if c == 0 {
108            // Demote state
109            let new_state = match s {
110                ThermalState::Cool => ThermalState::Warm,
111                ThermalState::Warm => ThermalState::Hot,
112                ThermalState::Hot => return 0x00, // DEAD — maturity 0 means prunable
113                ThermalState::Cold => return m,
114            };
115            // Start counter at half the promotion threshold for the new (lower) state
116            let new_c = match new_state {
117                ThermalState::Warm => WARM_TO_COOL / 2,
118                ThermalState::Hot => HOT_TO_WARM / 2,
119                _ => 0,
120            };
121            encode(new_state, new_c)
122        } else {
123            encode(s, c - 1)
124        }
125    }
126
127    /// Check if a synapse is dead (HOT state with counter 0)
128    #[inline]
129    pub fn is_dead(m: u8) -> bool {
130        m == 0x00
131    }
132}
133
134/// A single synapse — 8 bytes, repr(C) for binary persistence.
135///
136/// Weight sign is constrained by Dale's Law: excitatory source neurons produce
137/// positive weights only, inhibitory neurons produce negative weights only.
138/// This is enforced at creation time, not checked in the hot path.
139#[derive(Clone, Copy, Debug)]
140#[repr(C)]
141pub struct Synapse {
142    /// Post-synaptic neuron index (max 65535)
143    pub target: u16,
144    /// Signed weight (-127..+127). Sign constrained by Dale's Law.
145    pub weight: i8,
146    /// Axonal delay in ticks (1-8). 0 is invalid.
147    pub delay: u8,
148    /// Eligibility trace: STDP tag, decays toward 0 each tick.
149    /// Positive = causal (pre before post), negative = anti-causal.
150    pub eligibility: i8,
151    /// Thermal lifecycle: bits 0-1 = state, bits 2-7 = reinforcement counter
152    pub maturity: u8,
153    /// Reserved for future use (neuromodulator sensitivity, branch tag)
154    pub _reserved: [u8; 2],
155}
156
157impl Synapse {
158    /// Create a new synapse respecting Dale's Law.
159    ///
160    /// `source_flags` is the flags byte of the pre-synaptic neuron.
161    /// Weight magnitude is provided; sign is determined by neuron type.
162    pub fn new(target: u16, weight_magnitude: u8, delay: u8, source_flags: u8) -> Self {
163        let signed_weight = if flags::is_inhibitory(source_flags) {
164            -(weight_magnitude.min(127) as i8)
165        } else {
166            weight_magnitude.min(127) as i8
167        };
168
169        Self {
170            target,
171            weight: signed_weight,
172            delay: delay.max(1).min(8),
173            eligibility: 0,
174            maturity: maturity::encode(ThermalState::Hot, 4), // Start at HOT with some initial counter
175            _reserved: [0; 2],
176        }
177    }
178
179    /// Create a frozen (COLD) synapse for genome-level connectivity.
180    pub fn frozen(target: u16, weight: i8, delay: u8) -> Self {
181        Self {
182            target,
183            weight,
184            delay: delay.max(1).min(8),
185            eligibility: 0,
186            maturity: maturity::encode(ThermalState::Cold, 63),
187            _reserved: [0; 2],
188        }
189    }
190
191    /// Current thermal state
192    #[inline]
193    pub fn thermal_state(&self) -> ThermalState {
194        maturity::state(self.maturity)
195    }
196
197    /// Whether this synapse is dead and should be pruned
198    #[inline]
199    pub fn is_dead(&self) -> bool {
200        maturity::is_dead(self.maturity)
201    }
202
203    /// Increment maturity (toward promotion)
204    #[inline]
205    pub fn increment_maturity(&mut self) {
206        self.maturity = maturity::increment(self.maturity);
207    }
208
209    /// Decrement maturity (toward demotion/death)
210    #[inline]
211    pub fn decrement_maturity(&mut self) {
212        self.maturity = maturity::decrement(self.maturity);
213    }
214}
215
216/// Compressed Sparse Row synapse storage.
217///
218/// All outgoing synapses for neuron `i` are at indices `row_ptr[i]..row_ptr[i+1]`
219/// in the `synapses` array. This gives cache-friendly iteration during spike
220/// propagation — all targets of a spiking neuron are contiguous in memory.
221pub struct SynapseStore {
222    /// Index into `synapses` for each neuron. Length = n_neurons + 1.
223    /// `row_ptr[i]` = start index, `row_ptr[i+1]` = end index (exclusive).
224    pub row_ptr: Vec<u32>,
225    /// All synapses, grouped contiguously by source neuron.
226    pub synapses: Vec<Synapse>,
227}
228
229impl SynapseStore {
230    /// Create empty store for `n` neurons (no connections).
231    pub fn empty(n_neurons: u32) -> Self {
232        Self {
233            row_ptr: vec![0; (n_neurons + 1) as usize],
234            synapses: Vec::new(),
235        }
236    }
237
238    /// Build CSR from a list of (source_neuron, Synapse) pairs.
239    ///
240    /// The pairs do NOT need to be sorted — this function sorts them internally.
241    pub fn from_edges(n_neurons: u32, mut edges: Vec<(u32, Synapse)>) -> Self {
242        edges.sort_unstable_by_key(|(src, _)| *src);
243
244        let n = n_neurons as usize;
245        let mut row_ptr = vec![0u32; n + 1];
246
247        // Count synapses per neuron
248        for (src, _) in &edges {
249            let idx = (*src as usize).min(n - 1);
250            row_ptr[idx + 1] += 1;
251        }
252
253        // Prefix sum
254        for i in 1..=n {
255            row_ptr[i] += row_ptr[i - 1];
256        }
257
258        let synapses: Vec<Synapse> = edges.into_iter().map(|(_, syn)| syn).collect();
259
260        Self { row_ptr, synapses }
261    }
262
263    /// Get outgoing synapses for a given source neuron.
264    #[inline]
265    pub fn outgoing(&self, neuron: u32) -> &[Synapse] {
266        let start = self.row_ptr[neuron as usize] as usize;
267        let end = self.row_ptr[neuron as usize + 1] as usize;
268        &self.synapses[start..end]
269    }
270
271    /// Get mutable outgoing synapses for a given source neuron.
272    #[inline]
273    pub fn outgoing_mut(&mut self, neuron: u32) -> &mut [Synapse] {
274        let start = self.row_ptr[neuron as usize] as usize;
275        let end = self.row_ptr[neuron as usize + 1] as usize;
276        &mut self.synapses[start..end]
277    }
278
279    /// Total number of synapses across all neurons.
280    #[inline]
281    pub fn total_synapses(&self) -> usize {
282        self.synapses.len()
283    }
284
285    /// Number of neurons this store covers.
286    #[inline]
287    pub fn n_neurons(&self) -> u32 {
288        (self.row_ptr.len().saturating_sub(1)) as u32
289    }
290
291    /// Remove dead synapses (maturity == 0x00) and rebuild CSR.
292    /// Returns count of pruned synapses.
293    pub fn prune_dead(&mut self) -> usize {
294        let n = self.n_neurons() as usize;
295        let mut new_synapses = Vec::with_capacity(self.synapses.len());
296        let mut new_row_ptr = vec![0u32; n + 1];
297        let mut pruned = 0usize;
298
299        for i in 0..n {
300            let start = self.row_ptr[i] as usize;
301            let end = self.row_ptr[i + 1] as usize;
302
303            for syn in &self.synapses[start..end] {
304                if syn.is_dead() {
305                    pruned += 1;
306                } else {
307                    new_synapses.push(*syn);
308                }
309            }
310            new_row_ptr[i + 1] = new_synapses.len() as u32;
311        }
312
313        self.synapses = new_synapses;
314        self.row_ptr = new_row_ptr;
315        pruned
316    }
317
318    /// Add a synapse from `source` to the given target. Rebuilds the CSR row.
319    /// This is expensive — batch additions and rebuild when possible.
320    pub fn add_synapse(&mut self, source: u32, syn: Synapse) {
321        let idx = source as usize;
322        let insert_pos = self.row_ptr[idx + 1] as usize;
323
324        self.synapses.insert(insert_pos, syn);
325
326        // Update row_ptr for all neurons after source
327        for ptr in &mut self.row_ptr[(idx + 1)..] {
328            *ptr += 1;
329        }
330    }
331
332    /// Extend the store to accommodate additional neurons (with no synapses).
333    ///
334    /// Used when dynamically spawning neurons from templates.
335    pub fn extend(&mut self, count: usize) {
336        let last_ptr = *self.row_ptr.last().unwrap_or(&0);
337        for _ in 0..count {
338            self.row_ptr.push(last_ptr);
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn synapse_size() {
349        assert_eq!(std::mem::size_of::<Synapse>(), 8);
350    }
351
352    #[test]
353    fn maturity_lifecycle() {
354        // Start at HOT with counter 0
355        let mut m = maturity::encode(ThermalState::Hot, 0);
356        assert_eq!(maturity::state(m), ThermalState::Hot);
357        assert_eq!(maturity::counter(m), 0);
358
359        // Increment to HOT->WARM promotion
360        for _ in 0..maturity::HOT_TO_WARM {
361            m = maturity::increment(m);
362        }
363        assert_eq!(maturity::state(m), ThermalState::Warm);
364        assert_eq!(maturity::counter(m), 0); // Reset on promotion
365
366        // Increment to WARM->COOL promotion
367        for _ in 0..maturity::WARM_TO_COOL {
368            m = maturity::increment(m);
369        }
370        assert_eq!(maturity::state(m), ThermalState::Cool);
371
372        // Increment to COOL->COLD promotion
373        for _ in 0..maturity::COOL_TO_COLD {
374            m = maturity::increment(m);
375        }
376        assert_eq!(maturity::state(m), ThermalState::Cold);
377
378        // Cold is frozen — increment has no effect
379        let m2 = maturity::increment(m);
380        assert_eq!(m2, m);
381    }
382
383    #[test]
384    fn maturity_death() {
385        let m = maturity::encode(ThermalState::Hot, 0);
386        let dead = maturity::decrement(m);
387        assert!(maturity::is_dead(dead));
388    }
389
390    #[test]
391    fn maturity_demotion() {
392        // Start at COOL with counter 0
393        let m = maturity::encode(ThermalState::Cool, 0);
394        let demoted = maturity::decrement(m);
395        assert_eq!(maturity::state(demoted), ThermalState::Warm);
396        // Should start with some counter to avoid immediate further demotion
397        assert!(maturity::counter(demoted) > 0);
398    }
399
400    #[test]
401    fn dale_law_excitatory() {
402        let exc_flags = crate::neuron::flags::encode(false, crate::neuron::NeuronProfile::RegularSpiking);
403        let syn = Synapse::new(42, 100, 2, exc_flags);
404        assert!(syn.weight > 0);
405    }
406
407    #[test]
408    fn dale_law_inhibitory() {
409        let inh_flags = crate::neuron::flags::encode(true, crate::neuron::NeuronProfile::FastSpiking);
410        let syn = Synapse::new(42, 100, 2, inh_flags);
411        assert!(syn.weight < 0);
412    }
413
414    #[test]
415    fn csr_basic() {
416        let exc_flags = crate::neuron::flags::encode(false, crate::neuron::NeuronProfile::RegularSpiking);
417        let edges = vec![
418            (0, Synapse::new(1, 50, 1, exc_flags)),
419            (0, Synapse::new(2, 30, 1, exc_flags)),
420            (1, Synapse::new(0, 40, 1, exc_flags)),
421        ];
422        let store = SynapseStore::from_edges(3, edges);
423
424        assert_eq!(store.outgoing(0).len(), 2);
425        assert_eq!(store.outgoing(1).len(), 1);
426        assert_eq!(store.outgoing(2).len(), 0);
427        assert_eq!(store.total_synapses(), 3);
428    }
429
430    #[test]
431    fn csr_prune_dead() {
432        let mut store = SynapseStore::empty(2);
433        let exc_flags = crate::neuron::flags::encode(false, crate::neuron::NeuronProfile::RegularSpiking);
434
435        // Add some synapses
436        store.add_synapse(0, Synapse::new(1, 50, 1, exc_flags));
437        let mut dead_syn = Synapse::new(1, 30, 1, exc_flags);
438        dead_syn.maturity = 0x00; // Dead
439        store.add_synapse(0, dead_syn);
440
441        assert_eq!(store.total_synapses(), 2);
442        let pruned = store.prune_dead();
443        assert_eq!(pruned, 1);
444        assert_eq!(store.total_synapses(), 1);
445    }
446}