Skip to main content

rsnn_eta/
config.rs

1use serde::{Deserialize, Serialize};
2
3// ── Network defaults ──
4pub const DEFAULT_NUM_NEURONS: usize = 10;
5pub const DEFAULT_STEPS_PER_TICK: u32 = 5;
6pub const DEFAULT_INPUT_SPARSITY: f64 = 0.3;
7pub const DEFAULT_RECURRENT_SPARSITY: f64 = 0.1;
8pub const DEFAULT_EXCITATORY_FRAC: f64 = 0.8;
9pub const DEFAULT_INIT_INPUT_SCALE: f64 = 1.0;
10pub const DEFAULT_INIT_RECURRENT_SCALE: f64 = 1.0;
11pub const DEFAULT_TEMPORAL_CODING_FRAC: f64 = 0.2;
12
13// ── Neuron defaults ──
14pub const DEFAULT_V_THRESHOLD: f64 = 1.0;
15pub const DEFAULT_V_RESET: f64 = 0.0;
16pub const DEFAULT_REFRACTORY_STEPS: u8 = 1;
17pub const DEFAULT_TAU_MIN: f64 = 3.0;
18pub const DEFAULT_TAU_MAX: f64 = 120.0;
19
20// ── STDP defaults ──
21pub const DEFAULT_ETA_STDP: f64 = 0.05;
22pub const DEFAULT_A_PLUS: f64 = 0.10;
23pub const DEFAULT_A_MINUS: f64 = 0.12;
24pub const DEFAULT_W_MAX: f64 = 1.0;
25pub const DEFAULT_W_MIN: f64 = -1.0;
26pub const DEFAULT_ELIGIBILITY_DECAY: f64 = 0.95;
27pub const DEFAULT_TAU_STDP_FRAC: f64 = 0.2;
28pub const DEFAULT_SOFT_BOUND_POWER: f64 = 1.0;
29pub const DEFAULT_ETA_ERROR: f64 = 0.01;
30
31// ── Decoder defaults ──
32pub const DEFAULT_DECODER_SCALE: f64 = 2.0;
33pub const DEFAULT_DECODER_EPSILON: f64 = 0.01;
34pub const DEFAULT_DECODER_SCALE_LR: f64 = 0.01;
35
36// ── Tracker defaults ──
37pub const DEFAULT_BURN_IN_TICKS: u64 = 10;
38pub const DEFAULT_CONFIDENCE_ALPHA: f64 = 0.1;
39
40// ── EMA defaults ──
41pub const DEFAULT_EMA_ALPHA: f64 = 0.05;
42pub const DEFAULT_EMA_WARMUP: u64 = 2;
43
44/// RSNN reservoir topology configuration.
45///
46/// Controls neuron count, connectivity sparsity, excitatory/inhibitory ratio,
47/// LIF neuron parameters, and temporal coding injection fraction.
48#[derive(Clone, Debug, Serialize, Deserialize)]
49pub struct NetworkConfig {
50    /// Number of LIF neurons in the reservoir (default: 50).
51    pub num_neurons: usize,
52    pub steps_per_tick: u32,
53    pub input_sparsity: f64,
54    pub recurrent_sparsity: f64,
55    pub excitatory_frac: f64,
56    pub init_input_scale: f64,
57    pub init_recurrent_scale: f64,
58    pub tau_min: f64,
59    pub tau_max: f64,
60    pub v_threshold: f64,
61    pub v_reset: f64,
62    pub refractory_steps: u8,
63    pub temporal_coding_frac: f64,
64}
65
66impl Default for NetworkConfig {
67    fn default() -> Self {
68        Self {
69            num_neurons: DEFAULT_NUM_NEURONS,
70            steps_per_tick: DEFAULT_STEPS_PER_TICK,
71            input_sparsity: DEFAULT_INPUT_SPARSITY,
72            recurrent_sparsity: DEFAULT_RECURRENT_SPARSITY,
73            excitatory_frac: DEFAULT_EXCITATORY_FRAC,
74            init_input_scale: DEFAULT_INIT_INPUT_SCALE,
75            init_recurrent_scale: DEFAULT_INIT_RECURRENT_SCALE,
76            tau_min: DEFAULT_TAU_MIN,
77            tau_max: DEFAULT_TAU_MAX,
78            v_threshold: DEFAULT_V_THRESHOLD,
79            v_reset: DEFAULT_V_RESET,
80            refractory_steps: DEFAULT_REFRACTORY_STEPS,
81            temporal_coding_frac: DEFAULT_TEMPORAL_CODING_FRAC,
82        }
83    }
84}
85
86/// Spike-Timing-Dependent Plasticity (STDP) learning configuration.
87///
88/// Controls the learning rate, LTP/LTD amplitudes, eligibility trace decay,
89/// weight bounds, and error modulation strength.
90#[derive(Clone, Debug, Serialize, Deserialize)]
91pub struct StdpConfig {
92    pub eta_stdp: f64,
93    pub a_plus: f64,
94    pub a_minus: f64,
95    pub w_max: f64,
96    pub w_min: f64,
97    pub eligibility_decay: f64,
98    pub tau_stdp_frac: f64,
99    pub soft_bound_power: f64,
100    pub eta_error: f64,
101}
102
103impl StdpConfig {
104    pub fn tau_stdp(&self, steps_per_tick: u32) -> f64 {
105        (steps_per_tick as f64 * self.tau_stdp_frac).max(2.0)
106    }
107
108    pub fn soft_bound(&self, w: f64) -> f64 {
109        if w >= 0.0 {
110            (self.w_max - w).max(0.0).powf(self.soft_bound_power)
111        } else {
112            (w - self.w_min).abs().max(0.0).powf(self.soft_bound_power)
113        }
114    }
115}
116
117impl Default for StdpConfig {
118    fn default() -> Self {
119        Self {
120            eta_stdp: DEFAULT_ETA_STDP,
121            a_plus: DEFAULT_A_PLUS,
122            a_minus: DEFAULT_A_MINUS,
123            w_max: DEFAULT_W_MAX,
124            w_min: DEFAULT_W_MIN,
125            eligibility_decay: DEFAULT_ELIGIBILITY_DECAY,
126            tau_stdp_frac: DEFAULT_TAU_STDP_FRAC,
127            soft_bound_power: DEFAULT_SOFT_BOUND_POWER,
128            eta_error: DEFAULT_ETA_ERROR,
129        }
130    }
131}
132
133/// Correction factor decoder configuration.
134///
135/// The decoder maps the output neuron's firing rate to a correction factor via
136/// `exp(scale * (rate - 0.5))`. `scale` is learnable and adjusted by the ratio
137/// error signal.
138#[derive(Clone, Debug, Serialize, Deserialize)]
139pub struct DecoderConfig {
140    pub initial_scale: f64,
141    pub epsilon: f64,
142    pub scale_lr: f64,
143}
144
145impl Default for DecoderConfig {
146    fn default() -> Self {
147        Self {
148            initial_scale: DEFAULT_DECODER_SCALE,
149            epsilon: DEFAULT_DECODER_EPSILON,
150            scale_lr: DEFAULT_DECODER_SCALE_LR,
151        }
152    }
153}