1use serde::{Deserialize, Serialize};
2
3pub const DEFAULT_NUM_NEURONS: usize = 10;
5pub const DEFAULT_STEPS_PER_TICK: u32 = 5;
6pub const DEFAULT_INPUT_SPARSITY: f64 = 0.3;
7pub const DEFAULT_RECURRENT_SPARSITY: f64 = 0.1;
8pub const DEFAULT_EXCITATORY_FRAC: f64 = 0.8;
9pub const DEFAULT_INIT_INPUT_SCALE: f64 = 1.0;
10pub const DEFAULT_INIT_RECURRENT_SCALE: f64 = 1.0;
11pub const DEFAULT_TEMPORAL_CODING_FRAC: f64 = 0.2;
12
13pub const DEFAULT_V_THRESHOLD: f64 = 1.0;
15pub const DEFAULT_V_RESET: f64 = 0.0;
16pub const DEFAULT_REFRACTORY_STEPS: u8 = 1;
17pub const DEFAULT_TAU_MIN: f64 = 3.0;
18pub const DEFAULT_TAU_MAX: f64 = 120.0;
19
20pub const DEFAULT_ETA_STDP: f64 = 0.05;
22pub const DEFAULT_A_PLUS: f64 = 0.10;
23pub const DEFAULT_A_MINUS: f64 = 0.12;
24pub const DEFAULT_W_MAX: f64 = 1.0;
25pub const DEFAULT_W_MIN: f64 = -1.0;
26pub const DEFAULT_ELIGIBILITY_DECAY: f64 = 0.95;
27pub const DEFAULT_TAU_STDP_FRAC: f64 = 0.2;
28pub const DEFAULT_SOFT_BOUND_POWER: f64 = 1.0;
29pub const DEFAULT_ETA_ERROR: f64 = 0.01;
30
31pub const DEFAULT_DECODER_SCALE: f64 = 2.0;
33pub const DEFAULT_DECODER_EPSILON: f64 = 0.01;
34pub const DEFAULT_DECODER_SCALE_LR: f64 = 0.01;
35
36pub const DEFAULT_BURN_IN_TICKS: u64 = 10;
38pub const DEFAULT_CONFIDENCE_ALPHA: f64 = 0.1;
39
40pub const DEFAULT_EMA_ALPHA: f64 = 0.05;
42pub const DEFAULT_EMA_WARMUP: u64 = 2;
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
49pub struct NetworkConfig {
50 pub num_neurons: usize,
52 pub steps_per_tick: u32,
53 pub input_sparsity: f64,
54 pub recurrent_sparsity: f64,
55 pub excitatory_frac: f64,
56 pub init_input_scale: f64,
57 pub init_recurrent_scale: f64,
58 pub tau_min: f64,
59 pub tau_max: f64,
60 pub v_threshold: f64,
61 pub v_reset: f64,
62 pub refractory_steps: u8,
63 pub temporal_coding_frac: f64,
64}
65
66impl Default for NetworkConfig {
67 fn default() -> Self {
68 Self {
69 num_neurons: DEFAULT_NUM_NEURONS,
70 steps_per_tick: DEFAULT_STEPS_PER_TICK,
71 input_sparsity: DEFAULT_INPUT_SPARSITY,
72 recurrent_sparsity: DEFAULT_RECURRENT_SPARSITY,
73 excitatory_frac: DEFAULT_EXCITATORY_FRAC,
74 init_input_scale: DEFAULT_INIT_INPUT_SCALE,
75 init_recurrent_scale: DEFAULT_INIT_RECURRENT_SCALE,
76 tau_min: DEFAULT_TAU_MIN,
77 tau_max: DEFAULT_TAU_MAX,
78 v_threshold: DEFAULT_V_THRESHOLD,
79 v_reset: DEFAULT_V_RESET,
80 refractory_steps: DEFAULT_REFRACTORY_STEPS,
81 temporal_coding_frac: DEFAULT_TEMPORAL_CODING_FRAC,
82 }
83 }
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize)]
91pub struct StdpConfig {
92 pub eta_stdp: f64,
93 pub a_plus: f64,
94 pub a_minus: f64,
95 pub w_max: f64,
96 pub w_min: f64,
97 pub eligibility_decay: f64,
98 pub tau_stdp_frac: f64,
99 pub soft_bound_power: f64,
100 pub eta_error: f64,
101}
102
103impl StdpConfig {
104 pub fn tau_stdp(&self, steps_per_tick: u32) -> f64 {
105 (steps_per_tick as f64 * self.tau_stdp_frac).max(2.0)
106 }
107
108 pub fn soft_bound(&self, w: f64) -> f64 {
109 if w >= 0.0 {
110 (self.w_max - w).max(0.0).powf(self.soft_bound_power)
111 } else {
112 (w - self.w_min).abs().max(0.0).powf(self.soft_bound_power)
113 }
114 }
115}
116
117impl Default for StdpConfig {
118 fn default() -> Self {
119 Self {
120 eta_stdp: DEFAULT_ETA_STDP,
121 a_plus: DEFAULT_A_PLUS,
122 a_minus: DEFAULT_A_MINUS,
123 w_max: DEFAULT_W_MAX,
124 w_min: DEFAULT_W_MIN,
125 eligibility_decay: DEFAULT_ELIGIBILITY_DECAY,
126 tau_stdp_frac: DEFAULT_TAU_STDP_FRAC,
127 soft_bound_power: DEFAULT_SOFT_BOUND_POWER,
128 eta_error: DEFAULT_ETA_ERROR,
129 }
130 }
131}
132
133#[derive(Clone, Debug, Serialize, Deserialize)]
139pub struct DecoderConfig {
140 pub initial_scale: f64,
141 pub epsilon: f64,
142 pub scale_lr: f64,
143}
144
145impl Default for DecoderConfig {
146 fn default() -> Self {
147 Self {
148 initial_scale: DEFAULT_DECODER_SCALE,
149 epsilon: DEFAULT_DECODER_EPSILON,
150 scale_lr: DEFAULT_DECODER_SCALE_LR,
151 }
152 }
153}