1use serde::{Deserialize, Serialize};
9use std::time::Duration;
10
11pub type NeuronId = String;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum NeuronType {
17 Excitatory,
19 Inhibitory,
21 Sensory,
23 Motor,
25 Modulatory,
27}
28
29impl NeuronType {
30 pub fn is_excitatory(&self) -> bool {
32 matches!(self, NeuronType::Excitatory | NeuronType::Sensory | NeuronType::Motor)
33 }
34
35 pub fn cortical_ratio(&self) -> f64 {
37 match self {
38 NeuronType::Excitatory => 0.80,
39 NeuronType::Inhibitory => 0.20,
40 _ => 0.0,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct NeuronState {
48 pub membrane_potential: f64,
50 pub refractory: bool,
52 pub refractory_remaining: Duration,
54 pub adaptation: f64,
56 pub input_current: f64,
58 pub time_since_spike: Duration,
60}
61
62impl Default for NeuronState {
63 fn default() -> Self {
64 Self {
65 membrane_potential: -70.0, refractory: false,
67 refractory_remaining: Duration::ZERO,
68 adaptation: 0.0,
69 input_current: 0.0,
70 time_since_spike: Duration::from_secs(1000), }
72 }
73}
74
75pub trait SpikingNeuron: Send + Sync {
77 fn id(&self) -> &NeuronId;
79
80 fn neuron_type(&self) -> NeuronType;
82
83 fn state(&self) -> &NeuronState;
85
86 fn state_mut(&mut self) -> &mut NeuronState;
88
89 fn step(&mut self, dt: Duration) -> bool;
91
92 fn receive_input(&mut self, current: f64);
94
95 fn reset(&mut self);
97
98 fn can_spike(&self) -> bool {
100 !self.state().refractory
101 }
102
103 fn membrane_potential(&self) -> f64 {
105 self.state().membrane_potential
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct LIFParams {
112 pub tau_m: f64,
114 pub r_m: f64,
116 pub v_rest: f64,
118 pub v_thresh: f64,
120 pub v_reset: f64,
122 pub t_ref: f64,
124}
125
126impl Default for LIFParams {
127 fn default() -> Self {
128 Self {
129 tau_m: 20.0, r_m: 10.0, v_rest: -70.0, v_thresh: -55.0, v_reset: -75.0, t_ref: 2.0, }
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct LIFNeuron {
142 id: NeuronId,
143 neuron_type: NeuronType,
144 params: LIFParams,
145 state: NeuronState,
146}
147
148impl LIFNeuron {
149 pub fn new(id: NeuronId, neuron_type: NeuronType, params: LIFParams) -> Self {
151 Self {
152 id,
153 neuron_type,
154 state: NeuronState {
155 membrane_potential: params.v_rest,
156 ..Default::default()
157 },
158 params,
159 }
160 }
161
162 pub fn with_defaults(id: NeuronId, neuron_type: NeuronType) -> Self {
164 Self::new(id, neuron_type, LIFParams::default())
165 }
166
167 pub fn params(&self) -> &LIFParams {
169 &self.params
170 }
171}
172
173impl SpikingNeuron for LIFNeuron {
174 fn id(&self) -> &NeuronId {
175 &self.id
176 }
177
178 fn neuron_type(&self) -> NeuronType {
179 self.neuron_type
180 }
181
182 fn state(&self) -> &NeuronState {
183 &self.state
184 }
185
186 fn state_mut(&mut self) -> &mut NeuronState {
187 &mut self.state
188 }
189
190 fn step(&mut self, dt: Duration) -> bool {
191 let dt_ms = dt.as_secs_f64() * 1000.0;
192
193 if self.state.refractory {
195 if self.state.refractory_remaining > dt {
196 self.state.refractory_remaining -= dt;
197 self.state.time_since_spike += dt;
198 self.state.input_current = 0.0;
199 return false;
200 } else {
201 self.state.refractory = false;
202 self.state.refractory_remaining = Duration::ZERO;
203 }
204 }
205
206 let dv = dt_ms / self.params.tau_m
208 * (-(self.state.membrane_potential - self.params.v_rest)
209 + self.params.r_m * self.state.input_current);
210
211 self.state.membrane_potential += dv;
212 self.state.time_since_spike += dt;
213
214 self.state.input_current = 0.0;
216
217 if self.state.membrane_potential >= self.params.v_thresh {
219 self.reset();
220 return true;
221 }
222
223 false
224 }
225
226 fn receive_input(&mut self, current: f64) {
227 self.state.input_current += current;
228 }
229
230 fn reset(&mut self) {
231 self.state.membrane_potential = self.params.v_reset;
232 self.state.refractory = true;
233 self.state.refractory_remaining = Duration::from_secs_f64(self.params.t_ref / 1000.0);
234 self.state.time_since_spike = Duration::ZERO;
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct AdaptiveLIFNeuron {
241 base: LIFNeuron,
242 tau_w: f64,
244 a: f64,
246 b: f64,
248}
249
250impl AdaptiveLIFNeuron {
251 pub fn new(id: NeuronId, neuron_type: NeuronType, params: LIFParams) -> Self {
253 Self {
254 base: LIFNeuron::new(id, neuron_type, params),
255 tau_w: 100.0, a: 0.01, b: 0.5, }
259 }
260
261 pub fn with_adaptation(
263 id: NeuronId,
264 neuron_type: NeuronType,
265 params: LIFParams,
266 tau_w: f64,
267 a: f64,
268 b: f64,
269 ) -> Self {
270 Self {
271 base: LIFNeuron::new(id, neuron_type, params),
272 tau_w,
273 a,
274 b,
275 }
276 }
277}
278
279impl SpikingNeuron for AdaptiveLIFNeuron {
280 fn id(&self) -> &NeuronId {
281 self.base.id()
282 }
283
284 fn neuron_type(&self) -> NeuronType {
285 self.base.neuron_type()
286 }
287
288 fn state(&self) -> &NeuronState {
289 self.base.state()
290 }
291
292 fn state_mut(&mut self) -> &mut NeuronState {
293 self.base.state_mut()
294 }
295
296 fn step(&mut self, dt: Duration) -> bool {
297 let dt_ms = dt.as_secs_f64() * 1000.0;
298
299 let v = self.base.state.membrane_potential;
301 let v_rest = self.base.params.v_rest;
302 let w = self.base.state.adaptation;
303
304 let dw = dt_ms / self.tau_w * (self.a * (v - v_rest) - w);
306 self.base.state.adaptation += dw;
307
308 self.base.state.input_current -= self.base.state.adaptation;
310
311 let spiked = self.base.step(dt);
313
314 if spiked {
316 self.base.state.adaptation += self.b;
317 }
318
319 spiked
320 }
321
322 fn receive_input(&mut self, current: f64) {
323 self.base.receive_input(current);
324 }
325
326 fn reset(&mut self) {
327 self.base.reset();
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn test_lif_neuron_creation() {
337 let neuron = LIFNeuron::with_defaults("n1".to_string(), NeuronType::Excitatory);
338
339 assert_eq!(neuron.id(), "n1");
340 assert_eq!(neuron.neuron_type(), NeuronType::Excitatory);
341 assert!((neuron.membrane_potential() - (-70.0)).abs() < 0.01);
342 }
343
344 #[test]
345 fn test_lif_neuron_spike() {
346 let mut neuron = LIFNeuron::with_defaults("n1".to_string(), NeuronType::Excitatory);
347
348 neuron.receive_input(10.0);
350
351 let dt = Duration::from_millis(1);
353 let mut spiked = false;
354 for _ in 0..100 {
355 if neuron.step(dt) {
356 spiked = true;
357 break;
358 }
359 neuron.receive_input(10.0);
360 }
361
362 assert!(spiked);
363 assert!(neuron.state().refractory);
364 }
365
366 #[test]
367 fn test_lif_refractory_period() {
368 let mut neuron = LIFNeuron::with_defaults("n1".to_string(), NeuronType::Excitatory);
369
370 neuron.state_mut().membrane_potential = -50.0;
372 let spiked = neuron.step(Duration::from_millis(1));
373 assert!(spiked);
374
375 assert!(neuron.state().refractory);
377 assert!(!neuron.can_spike());
378
379 neuron.receive_input(100.0);
381 let spiked = neuron.step(Duration::from_millis(1));
382 assert!(!spiked);
383 }
384
385 #[test]
386 fn test_neuron_type_properties() {
387 assert!(NeuronType::Excitatory.is_excitatory());
388 assert!(!NeuronType::Inhibitory.is_excitatory());
389 assert!(NeuronType::Sensory.is_excitatory());
390 }
391
392 #[test]
393 fn test_adaptive_lif_adaptation() {
394 let mut neuron = AdaptiveLIFNeuron::new(
395 "n1".to_string(),
396 NeuronType::Excitatory,
397 LIFParams::default(),
398 );
399
400 assert!((neuron.state().adaptation - 0.0).abs() < 0.01);
402
403 let dt = Duration::from_millis(1);
405 for _ in 0..10 {
406 neuron.receive_input(5.0);
407 neuron.step(dt);
408 }
409
410 }
413}