1use crate::neuron::flags;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13#[repr(u8)]
14pub enum ThermalState {
15 Hot = 0b00,
17 Warm = 0b01,
19 Cool = 0b10,
21 Cold = 0b11,
23}
24
25impl ThermalState {
26 #[inline]
27 pub fn from_maturity(maturity: u8) -> Self {
28 match maturity & 0b11 {
29 0b00 => Self::Hot,
30 0b01 => Self::Warm,
31 0b10 => Self::Cool,
32 0b11 => Self::Cold,
33 _ => unreachable!(),
34 }
35 }
36}
37
38pub mod maturity {
44 use super::ThermalState;
45
46 pub const HOT_TO_WARM: u8 = 8;
48 pub const WARM_TO_COOL: u8 = 24;
49 pub const COOL_TO_COLD: u8 = 63;
50
51 #[inline]
53 pub fn encode(state: ThermalState, counter: u8) -> u8 {
54 (counter.min(63) << 2) | (state as u8)
55 }
56
57 #[inline]
59 pub fn state(m: u8) -> ThermalState {
60 ThermalState::from_maturity(m)
61 }
62
63 #[inline]
65 pub fn counter(m: u8) -> u8 {
66 m >> 2
67 }
68
69 #[inline]
72 pub fn increment(m: u8) -> u8 {
73 let s = state(m);
74 let c = counter(m);
75
76 if s == ThermalState::Cold {
77 return m; }
79
80 let new_c = c.saturating_add(1).min(63);
81 let new_state = match s {
82 ThermalState::Hot if new_c >= HOT_TO_WARM => ThermalState::Warm,
83 ThermalState::Warm if new_c >= WARM_TO_COOL => ThermalState::Cool,
84 ThermalState::Cool if new_c >= COOL_TO_COLD => ThermalState::Cold,
85 other => other,
86 };
87
88 if new_state != s {
90 encode(new_state, 0)
91 } else {
92 encode(s, new_c)
93 }
94 }
95
96 #[inline]
99 pub fn decrement(m: u8) -> u8 {
100 let s = state(m);
101 let c = counter(m);
102
103 if s == ThermalState::Cold {
104 return m; }
106
107 if c == 0 {
108 let new_state = match s {
110 ThermalState::Cool => ThermalState::Warm,
111 ThermalState::Warm => ThermalState::Hot,
112 ThermalState::Hot => return 0x00, ThermalState::Cold => return m,
114 };
115 let new_c = match new_state {
117 ThermalState::Warm => WARM_TO_COOL / 2,
118 ThermalState::Hot => HOT_TO_WARM / 2,
119 _ => 0,
120 };
121 encode(new_state, new_c)
122 } else {
123 encode(s, c - 1)
124 }
125 }
126
127 #[inline]
129 pub fn is_dead(m: u8) -> bool {
130 m == 0x00
131 }
132}
133
134#[derive(Clone, Copy, Debug)]
140#[repr(C)]
141pub struct Synapse {
142 pub target: u16,
144 pub weight: i8,
146 pub delay: u8,
148 pub eligibility: i8,
151 pub maturity: u8,
153 pub _reserved: [u8; 2],
155}
156
157impl Synapse {
158 pub fn new(target: u16, weight_magnitude: u8, delay: u8, source_flags: u8) -> Self {
163 let signed_weight = if flags::is_inhibitory(source_flags) {
164 -(weight_magnitude.min(127) as i8)
165 } else {
166 weight_magnitude.min(127) as i8
167 };
168
169 Self {
170 target,
171 weight: signed_weight,
172 delay: delay.max(1).min(8),
173 eligibility: 0,
174 maturity: maturity::encode(ThermalState::Hot, 4), _reserved: [0; 2],
176 }
177 }
178
179 pub fn frozen(target: u16, weight: i8, delay: u8) -> Self {
181 Self {
182 target,
183 weight,
184 delay: delay.max(1).min(8),
185 eligibility: 0,
186 maturity: maturity::encode(ThermalState::Cold, 63),
187 _reserved: [0; 2],
188 }
189 }
190
191 #[inline]
193 pub fn thermal_state(&self) -> ThermalState {
194 maturity::state(self.maturity)
195 }
196
197 #[inline]
199 pub fn is_dead(&self) -> bool {
200 maturity::is_dead(self.maturity)
201 }
202
203 #[inline]
205 pub fn increment_maturity(&mut self) {
206 self.maturity = maturity::increment(self.maturity);
207 }
208
209 #[inline]
211 pub fn decrement_maturity(&mut self) {
212 self.maturity = maturity::decrement(self.maturity);
213 }
214}
215
216pub struct SynapseStore {
222 pub row_ptr: Vec<u32>,
225 pub synapses: Vec<Synapse>,
227}
228
229impl SynapseStore {
230 pub fn empty(n_neurons: u32) -> Self {
232 Self {
233 row_ptr: vec![0; (n_neurons + 1) as usize],
234 synapses: Vec::new(),
235 }
236 }
237
238 pub fn from_edges(n_neurons: u32, mut edges: Vec<(u32, Synapse)>) -> Self {
242 edges.sort_unstable_by_key(|(src, _)| *src);
243
244 let n = n_neurons as usize;
245 let mut row_ptr = vec![0u32; n + 1];
246
247 for (src, _) in &edges {
249 let idx = (*src as usize).min(n - 1);
250 row_ptr[idx + 1] += 1;
251 }
252
253 for i in 1..=n {
255 row_ptr[i] += row_ptr[i - 1];
256 }
257
258 let synapses: Vec<Synapse> = edges.into_iter().map(|(_, syn)| syn).collect();
259
260 Self { row_ptr, synapses }
261 }
262
263 #[inline]
265 pub fn outgoing(&self, neuron: u32) -> &[Synapse] {
266 let start = self.row_ptr[neuron as usize] as usize;
267 let end = self.row_ptr[neuron as usize + 1] as usize;
268 &self.synapses[start..end]
269 }
270
271 #[inline]
273 pub fn outgoing_mut(&mut self, neuron: u32) -> &mut [Synapse] {
274 let start = self.row_ptr[neuron as usize] as usize;
275 let end = self.row_ptr[neuron as usize + 1] as usize;
276 &mut self.synapses[start..end]
277 }
278
279 #[inline]
281 pub fn total_synapses(&self) -> usize {
282 self.synapses.len()
283 }
284
285 #[inline]
287 pub fn n_neurons(&self) -> u32 {
288 (self.row_ptr.len().saturating_sub(1)) as u32
289 }
290
291 pub fn prune_dead(&mut self) -> usize {
294 let n = self.n_neurons() as usize;
295 let mut new_synapses = Vec::with_capacity(self.synapses.len());
296 let mut new_row_ptr = vec![0u32; n + 1];
297 let mut pruned = 0usize;
298
299 for i in 0..n {
300 let start = self.row_ptr[i] as usize;
301 let end = self.row_ptr[i + 1] as usize;
302
303 for syn in &self.synapses[start..end] {
304 if syn.is_dead() {
305 pruned += 1;
306 } else {
307 new_synapses.push(*syn);
308 }
309 }
310 new_row_ptr[i + 1] = new_synapses.len() as u32;
311 }
312
313 self.synapses = new_synapses;
314 self.row_ptr = new_row_ptr;
315 pruned
316 }
317
318 pub fn add_synapse(&mut self, source: u32, syn: Synapse) {
321 let idx = source as usize;
322 let insert_pos = self.row_ptr[idx + 1] as usize;
323
324 self.synapses.insert(insert_pos, syn);
325
326 for ptr in &mut self.row_ptr[(idx + 1)..] {
328 *ptr += 1;
329 }
330 }
331
332 pub fn extend(&mut self, count: usize) {
336 let last_ptr = *self.row_ptr.last().unwrap_or(&0);
337 for _ in 0..count {
338 self.row_ptr.push(last_ptr);
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn synapse_size() {
349 assert_eq!(std::mem::size_of::<Synapse>(), 8);
350 }
351
352 #[test]
353 fn maturity_lifecycle() {
354 let mut m = maturity::encode(ThermalState::Hot, 0);
356 assert_eq!(maturity::state(m), ThermalState::Hot);
357 assert_eq!(maturity::counter(m), 0);
358
359 for _ in 0..maturity::HOT_TO_WARM {
361 m = maturity::increment(m);
362 }
363 assert_eq!(maturity::state(m), ThermalState::Warm);
364 assert_eq!(maturity::counter(m), 0); for _ in 0..maturity::WARM_TO_COOL {
368 m = maturity::increment(m);
369 }
370 assert_eq!(maturity::state(m), ThermalState::Cool);
371
372 for _ in 0..maturity::COOL_TO_COLD {
374 m = maturity::increment(m);
375 }
376 assert_eq!(maturity::state(m), ThermalState::Cold);
377
378 let m2 = maturity::increment(m);
380 assert_eq!(m2, m);
381 }
382
383 #[test]
384 fn maturity_death() {
385 let m = maturity::encode(ThermalState::Hot, 0);
386 let dead = maturity::decrement(m);
387 assert!(maturity::is_dead(dead));
388 }
389
390 #[test]
391 fn maturity_demotion() {
392 let m = maturity::encode(ThermalState::Cool, 0);
394 let demoted = maturity::decrement(m);
395 assert_eq!(maturity::state(demoted), ThermalState::Warm);
396 assert!(maturity::counter(demoted) > 0);
398 }
399
400 #[test]
401 fn dale_law_excitatory() {
402 let exc_flags = crate::neuron::flags::encode(false, crate::neuron::NeuronProfile::RegularSpiking);
403 let syn = Synapse::new(42, 100, 2, exc_flags);
404 assert!(syn.weight > 0);
405 }
406
407 #[test]
408 fn dale_law_inhibitory() {
409 let inh_flags = crate::neuron::flags::encode(true, crate::neuron::NeuronProfile::FastSpiking);
410 let syn = Synapse::new(42, 100, 2, inh_flags);
411 assert!(syn.weight < 0);
412 }
413
414 #[test]
415 fn csr_basic() {
416 let exc_flags = crate::neuron::flags::encode(false, crate::neuron::NeuronProfile::RegularSpiking);
417 let edges = vec![
418 (0, Synapse::new(1, 50, 1, exc_flags)),
419 (0, Synapse::new(2, 30, 1, exc_flags)),
420 (1, Synapse::new(0, 40, 1, exc_flags)),
421 ];
422 let store = SynapseStore::from_edges(3, edges);
423
424 assert_eq!(store.outgoing(0).len(), 2);
425 assert_eq!(store.outgoing(1).len(), 1);
426 assert_eq!(store.outgoing(2).len(), 0);
427 assert_eq!(store.total_synapses(), 3);
428 }
429
430 #[test]
431 fn csr_prune_dead() {
432 let mut store = SynapseStore::empty(2);
433 let exc_flags = crate::neuron::flags::encode(false, crate::neuron::NeuronProfile::RegularSpiking);
434
435 store.add_synapse(0, Synapse::new(1, 50, 1, exc_flags));
437 let mut dead_syn = Synapse::new(1, 30, 1, exc_flags);
438 dead_syn.maturity = 0x00; store.add_synapse(0, dead_syn);
440
441 assert_eq!(store.total_synapses(), 2);
442 let pruned = store.prune_dead();
443 assert_eq!(pruned, 1);
444 assert_eq!(store.total_synapses(), 1);
445 }
446}