oldies_nest/
lib.rs

1//! # NEST-RS: NEST Simulator Revival
2//!
3//! Revival of the NEST simulator (https://www.nest-simulator.org/)
4//! NEST = NEural Simulation Tool
5//! Originally created by Marc-Oliver Gewaltig and Markus Diesmann
6//!
7//! NEST is designed for large-scale spiking neural network simulations
8//! with efficient parallelization and precise spike timing.
9//!
10//! Key features:
11//! - Node-based architecture (neurons, devices, connections)
12//! - Precise spike timing with grid/off-grid modes
13//! - Efficient connection management with synapse types
14//! - Built-in parallelization support
15//! - Recording devices (spike detectors, multimeters)
16
17use ndarray::Array1;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use thiserror::Error;
21
22#[derive(Error, Debug)]
23pub enum NestError {
24    #[error("Unknown node model: {0}")]
25    UnknownModel(String),
26    #[error("Node not found: {0}")]
27    NodeNotFound(usize),
28    #[error("Invalid parameter: {0}")]
29    InvalidParameter(String),
30    #[error("Connection error: {0}")]
31    ConnectionError(String),
32    #[error("Simulation error: {0}")]
33    SimulationError(String),
34}
35
36pub type Result<T> = std::result::Result<T, NestError>;
37
38// ============================================================================
39// NODE IDS (NEST's fundamental concept)
40// ============================================================================
41
42/// Global node identifier
43pub type NodeId = usize;
44
45/// Collection of node IDs (like NEST's NodeCollection)
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct NodeCollection {
48    pub ids: Vec<NodeId>,
49}
50
51impl NodeCollection {
52    pub fn new(ids: Vec<NodeId>) -> Self {
53        Self { ids }
54    }
55
56    pub fn len(&self) -> usize {
57        self.ids.len()
58    }
59
60    pub fn is_empty(&self) -> bool {
61        self.ids.is_empty()
62    }
63
64    pub fn first(&self) -> Option<NodeId> {
65        self.ids.first().copied()
66    }
67
68    pub fn last(&self) -> Option<NodeId> {
69        self.ids.last().copied()
70    }
71
72    /// Slice of nodes
73    pub fn slice(&self, start: usize, end: usize) -> Self {
74        Self::new(self.ids[start..end].to_vec())
75    }
76}
77
78impl IntoIterator for NodeCollection {
79    type Item = NodeId;
80    type IntoIter = std::vec::IntoIter<NodeId>;
81
82    fn into_iter(self) -> Self::IntoIter {
83        self.ids.into_iter()
84    }
85}
86
87// ============================================================================
88// NEURON MODELS
89// ============================================================================
90
91/// NEST neuron model types
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub enum NeuronModel {
94    /// Integrate-and-fire with alpha-function PSCs
95    IafPscAlpha(IafPscAlphaParams),
96
97    /// Integrate-and-fire with exponential PSCs
98    IafPscExp(IafPscExpParams),
99
100    /// Integrate-and-fire with delta PSCs (instantaneous)
101    IafPscDelta(IafPscDeltaParams),
102
103    /// Conductance-based IAF
104    IafCondAlpha(IafCondAlphaParams),
105
106    /// Conductance-based with exponential conductances
107    IafCondExp(IafCondExpParams),
108
109    /// Adaptive exponential integrate-and-fire
110    AeifCondAlpha(AeifCondAlphaParams),
111
112    /// Hodgkin-Huxley
113    HhPscAlpha(HhPscAlphaParams),
114
115    /// Izhikevich
116    Izhikevich(IzhikevichParams),
117
118    /// Parrot neuron (repeats input spikes)
119    ParrotNeuron,
120
121    /// Poisson generator
122    PoissonGenerator(PoissonGeneratorParams),
123
124    /// Spike generator
125    SpikeGenerator(SpikeGeneratorParams),
126
127    /// DC generator (constant current)
128    DcGenerator(DcGeneratorParams),
129
130    /// Noise generator
131    NoiseGenerator(NoiseGeneratorParams),
132
133    /// Spike detector (recorder)
134    SpikeDetector,
135
136    /// Multimeter (record state variables)
137    Multimeter(MultimeterParams),
138}
139
140/// Parameters for iaf_psc_alpha
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct IafPscAlphaParams {
143    pub c_m: f64,        // Membrane capacitance (pF)
144    pub tau_m: f64,      // Membrane time constant (ms)
145    pub tau_syn_ex: f64, // Excitatory synaptic time constant (ms)
146    pub tau_syn_in: f64, // Inhibitory synaptic time constant (ms)
147    pub t_ref: f64,      // Refractory period (ms)
148    pub e_l: f64,        // Resting potential (mV)
149    pub v_reset: f64,    // Reset potential (mV)
150    pub v_th: f64,       // Spike threshold (mV)
151    pub i_e: f64,        // External DC current (pA)
152}
153
154impl Default for IafPscAlphaParams {
155    fn default() -> Self {
156        Self {
157            c_m: 250.0,
158            tau_m: 10.0,
159            tau_syn_ex: 2.0,
160            tau_syn_in: 2.0,
161            t_ref: 2.0,
162            e_l: -70.0,
163            v_reset: -70.0,
164            v_th: -55.0,
165            i_e: 0.0,
166        }
167    }
168}
169
170/// Parameters for iaf_psc_exp
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct IafPscExpParams {
173    pub c_m: f64,
174    pub tau_m: f64,
175    pub tau_syn_ex: f64,
176    pub tau_syn_in: f64,
177    pub t_ref: f64,
178    pub e_l: f64,
179    pub v_reset: f64,
180    pub v_th: f64,
181    pub i_e: f64,
182}
183
184impl Default for IafPscExpParams {
185    fn default() -> Self {
186        Self {
187            c_m: 250.0,
188            tau_m: 10.0,
189            tau_syn_ex: 2.0,
190            tau_syn_in: 2.0,
191            t_ref: 2.0,
192            e_l: -70.0,
193            v_reset: -70.0,
194            v_th: -55.0,
195            i_e: 0.0,
196        }
197    }
198}
199
200/// Parameters for iaf_psc_delta
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct IafPscDeltaParams {
203    pub c_m: f64,
204    pub tau_m: f64,
205    pub t_ref: f64,
206    pub e_l: f64,
207    pub v_reset: f64,
208    pub v_th: f64,
209    pub i_e: f64,
210}
211
212impl Default for IafPscDeltaParams {
213    fn default() -> Self {
214        Self {
215            c_m: 250.0,
216            tau_m: 10.0,
217            t_ref: 2.0,
218            e_l: -70.0,
219            v_reset: -70.0,
220            v_th: -55.0,
221            i_e: 0.0,
222        }
223    }
224}
225
226/// Parameters for iaf_cond_alpha
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct IafCondAlphaParams {
229    pub c_m: f64,
230    pub g_l: f64,        // Leak conductance (nS)
231    pub tau_syn_ex: f64,
232    pub tau_syn_in: f64,
233    pub t_ref: f64,
234    pub e_l: f64,
235    pub e_ex: f64,       // Excitatory reversal potential (mV)
236    pub e_in: f64,       // Inhibitory reversal potential (mV)
237    pub v_reset: f64,
238    pub v_th: f64,
239    pub i_e: f64,
240}
241
242impl Default for IafCondAlphaParams {
243    fn default() -> Self {
244        Self {
245            c_m: 250.0,
246            g_l: 16.7,
247            tau_syn_ex: 0.2,
248            tau_syn_in: 2.0,
249            t_ref: 2.0,
250            e_l: -70.0,
251            e_ex: 0.0,
252            e_in: -85.0,
253            v_reset: -70.0,
254            v_th: -55.0,
255            i_e: 0.0,
256        }
257    }
258}
259
260/// Parameters for iaf_cond_exp
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct IafCondExpParams {
263    pub c_m: f64,
264    pub g_l: f64,
265    pub tau_syn_ex: f64,
266    pub tau_syn_in: f64,
267    pub t_ref: f64,
268    pub e_l: f64,
269    pub e_ex: f64,
270    pub e_in: f64,
271    pub v_reset: f64,
272    pub v_th: f64,
273    pub i_e: f64,
274}
275
276impl Default for IafCondExpParams {
277    fn default() -> Self {
278        Self {
279            c_m: 250.0,
280            g_l: 16.7,
281            tau_syn_ex: 0.2,
282            tau_syn_in: 2.0,
283            t_ref: 2.0,
284            e_l: -70.0,
285            e_ex: 0.0,
286            e_in: -85.0,
287            v_reset: -70.0,
288            v_th: -55.0,
289            i_e: 0.0,
290        }
291    }
292}
293
294/// Parameters for aeif_cond_alpha (AdEx)
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct AeifCondAlphaParams {
297    pub c_m: f64,
298    pub g_l: f64,
299    pub tau_syn_ex: f64,
300    pub tau_syn_in: f64,
301    pub t_ref: f64,
302    pub e_l: f64,
303    pub e_ex: f64,
304    pub e_in: f64,
305    pub v_reset: f64,
306    pub v_th: f64,
307    pub v_peak: f64,      // Spike cutoff (mV)
308    pub delta_t: f64,     // Slope factor (mV)
309    pub tau_w: f64,       // Adaptation time constant (ms)
310    pub a: f64,           // Subthreshold adaptation (nS)
311    pub b: f64,           // Spike-triggered adaptation (pA)
312    pub i_e: f64,
313}
314
315impl Default for AeifCondAlphaParams {
316    fn default() -> Self {
317        Self {
318            c_m: 281.0,
319            g_l: 30.0,
320            tau_syn_ex: 0.2,
321            tau_syn_in: 2.0,
322            t_ref: 0.0,
323            e_l: -70.6,
324            e_ex: 0.0,
325            e_in: -85.0,
326            v_reset: -60.0,
327            v_th: -50.4,
328            v_peak: 0.0,
329            delta_t: 2.0,
330            tau_w: 144.0,
331            a: 4.0,
332            b: 80.5,
333            i_e: 0.0,
334        }
335    }
336}
337
338/// Parameters for hh_psc_alpha (Hodgkin-Huxley)
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct HhPscAlphaParams {
341    pub c_m: f64,
342    pub g_na: f64,   // Sodium conductance (nS)
343    pub g_k: f64,    // Potassium conductance (nS)
344    pub g_l: f64,    // Leak conductance (nS)
345    pub e_na: f64,   // Sodium reversal (mV)
346    pub e_k: f64,    // Potassium reversal (mV)
347    pub e_l: f64,    // Leak reversal (mV)
348    pub tau_syn_ex: f64,
349    pub tau_syn_in: f64,
350    pub i_e: f64,
351}
352
353impl Default for HhPscAlphaParams {
354    fn default() -> Self {
355        Self {
356            c_m: 100.0,
357            g_na: 12000.0,
358            g_k: 3600.0,
359            g_l: 30.0,
360            e_na: 50.0,
361            e_k: -77.0,
362            e_l: -54.4,
363            tau_syn_ex: 0.2,
364            tau_syn_in: 2.0,
365            i_e: 0.0,
366        }
367    }
368}
369
370/// Parameters for Izhikevich neuron
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct IzhikevichParams {
373    pub a: f64,
374    pub b: f64,
375    pub c: f64,
376    pub d: f64,
377}
378
379impl Default for IzhikevichParams {
380    fn default() -> Self {
381        // Regular spiking
382        Self {
383            a: 0.02,
384            b: 0.2,
385            c: -65.0,
386            d: 8.0,
387        }
388    }
389}
390
391/// Poisson generator parameters
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct PoissonGeneratorParams {
394    pub rate: f64,   // Firing rate (Hz)
395}
396
397/// Spike generator parameters
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct SpikeGeneratorParams {
400    pub spike_times: Vec<f64>,  // Spike times (ms)
401    pub spike_weights: Vec<f64>,
402}
403
404/// DC generator parameters
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct DcGeneratorParams {
407    pub amplitude: f64,  // Current amplitude (pA)
408    pub start: f64,      // Start time (ms)
409    pub stop: f64,       // Stop time (ms)
410}
411
412/// Noise generator parameters
413#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct NoiseGeneratorParams {
415    pub mean: f64,   // Mean current (pA)
416    pub std: f64,    // Standard deviation (pA)
417    pub dt: f64,     // Update interval (ms)
418}
419
420/// Multimeter parameters
421#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct MultimeterParams {
423    pub record_from: Vec<String>,  // Variables to record
424    pub interval: f64,             // Recording interval (ms)
425}
426
427// ============================================================================
428// SYNAPSE MODELS
429// ============================================================================
430
431/// NEST synapse model types
432#[derive(Debug, Clone, Serialize, Deserialize)]
433pub enum SynapseModel {
434    /// Static synapse (fixed weight)
435    Static,
436
437    /// STDP synapse
438    StdpSynapse(StdpParams),
439
440    /// Tsodyks-Markram synapse (short-term plasticity)
441    TsodyksMarkramSynapse(TsodyksMarkramParams),
442
443    /// Bernoulli synapse (stochastic release)
444    BernoulliSynapse(BernoulliParams),
445
446    /// Vogels-Sprekeler inhibitory STDP
447    VogelsSprekelerSynapse(VogelsSprekelerParams),
448}
449
450/// STDP parameters
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct StdpParams {
453    pub tau_plus: f64,   // Time constant for potentiation (ms)
454    pub tau_minus: f64,  // Time constant for depression (ms)
455    pub lambda: f64,     // Step size for potentiation
456    pub alpha: f64,      // Asymmetry parameter
457    pub w_max: f64,      // Maximum weight
458    pub mu_plus: f64,    // Weight dependence exponent for LTP
459    pub mu_minus: f64,   // Weight dependence exponent for LTD
460}
461
462impl Default for StdpParams {
463    fn default() -> Self {
464        Self {
465            tau_plus: 20.0,
466            tau_minus: 20.0,
467            lambda: 0.01,
468            alpha: 1.0,
469            w_max: 100.0,
470            mu_plus: 1.0,
471            mu_minus: 1.0,
472        }
473    }
474}
475
476/// Tsodyks-Markram parameters
477#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct TsodyksMarkramParams {
479    pub u: f64,        // Initial release probability
480    pub tau_rec: f64,  // Recovery time constant (ms)
481    pub tau_fac: f64,  // Facilitation time constant (ms)
482}
483
484impl Default for TsodyksMarkramParams {
485    fn default() -> Self {
486        Self {
487            u: 0.5,
488            tau_rec: 800.0,
489            tau_fac: 0.0,
490        }
491    }
492}
493
494/// Bernoulli synapse parameters
495#[derive(Debug, Clone, Serialize, Deserialize)]
496pub struct BernoulliParams {
497    pub p_transmit: f64,  // Transmission probability
498}
499
500/// Vogels-Sprekeler parameters
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct VogelsSprekelerParams {
503    pub tau: f64,      // Time constant (ms)
504    pub eta: f64,      // Learning rate
505    pub alpha: f64,    // Target rate parameter
506    pub w_max: f64,    // Maximum weight
507}
508
509// ============================================================================
510// CONNECTION SPECIFICATION
511// ============================================================================
512
513/// Connection rule
514#[derive(Debug, Clone, Serialize, Deserialize)]
515pub enum ConnectivityRule {
516    /// All-to-all connection
517    AllToAll,
518
519    /// One-to-one mapping (same indices)
520    OneToOne,
521
522    /// Random connections with fixed indegree
523    FixedIndegree { indegree: usize },
524
525    /// Random connections with fixed outdegree
526    FixedOutdegree { outdegree: usize },
527
528    /// Random connections with fixed total number
529    FixedTotalNumber { n: usize },
530
531    /// Bernoulli (fixed probability)
532    PairwiseBernoulli { p: f64 },
533
534    /// Symmetric (both directions)
535    SymmetricPairwiseBernoulli { p: f64 },
536}
537
538/// Weight distribution
539#[derive(Debug, Clone, Serialize, Deserialize)]
540pub enum WeightDistribution {
541    Constant(f64),
542    Uniform { min: f64, max: f64 },
543    Normal { mean: f64, std: f64 },
544    Lognormal { mu: f64, sigma: f64 },
545}
546
547/// Delay distribution
548#[derive(Debug, Clone, Serialize, Deserialize)]
549pub enum DelayDistribution {
550    Constant(f64),
551    Uniform { min: f64, max: f64 },
552    Normal { mean: f64, std: f64 },
553}
554
555/// Connection specification
556#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct ConnectionSpec {
558    pub rule: ConnectivityRule,
559    pub weight: WeightDistribution,
560    pub delay: DelayDistribution,
561    pub synapse_model: SynapseModel,
562    pub allow_autapses: bool,
563    pub allow_multapses: bool,
564}
565
566impl Default for ConnectionSpec {
567    fn default() -> Self {
568        Self {
569            rule: ConnectivityRule::AllToAll,
570            weight: WeightDistribution::Constant(1.0),
571            delay: DelayDistribution::Constant(1.0),
572            synapse_model: SynapseModel::Static,
573            allow_autapses: false,
574            allow_multapses: true,
575        }
576    }
577}
578
579// ============================================================================
580// NODE STATE
581// ============================================================================
582
583/// Node state variables
584#[derive(Debug, Clone, Serialize, Deserialize)]
585pub struct NodeState {
586    pub id: NodeId,
587    pub model: String,
588    pub v_m: f64,           // Membrane potential
589    pub last_spike: f64,    // Time of last spike
590    pub refractory_until: f64,
591    /// Additional state variables
592    pub state: HashMap<String, f64>,
593}
594
595/// Connection (edge)
596#[derive(Debug, Clone, Serialize, Deserialize)]
597pub struct Connection {
598    pub source: NodeId,
599    pub target: NodeId,
600    pub weight: f64,
601    pub delay: f64,
602    pub synapse_model: SynapseModel,
603    /// Synapse state (for plastic synapses)
604    pub state: HashMap<String, f64>,
605}
606
607// ============================================================================
608// RECORDING
609// ============================================================================
610
611/// Recorded spike events
612#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct SpikeData {
614    pub times: Vec<f64>,
615    pub senders: Vec<NodeId>,
616}
617
618impl SpikeData {
619    pub fn new() -> Self {
620        Self {
621            times: vec![],
622            senders: vec![],
623        }
624    }
625
626    pub fn record(&mut self, time: f64, sender: NodeId) {
627        self.times.push(time);
628        self.senders.push(sender);
629    }
630
631    pub fn n_events(&self) -> usize {
632        self.times.len()
633    }
634
635    /// Get spike trains organized by sender
636    pub fn spike_trains(&self) -> HashMap<NodeId, Vec<f64>> {
637        let mut trains: HashMap<NodeId, Vec<f64>> = HashMap::new();
638        for (&time, &sender) in self.times.iter().zip(self.senders.iter()) {
639            trains.entry(sender).or_default().push(time);
640        }
641        trains
642    }
643}
644
645impl Default for SpikeData {
646    fn default() -> Self {
647        Self::new()
648    }
649}
650
651/// Recorded continuous data
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct ContinuousData {
654    pub times: Vec<f64>,
655    pub senders: Vec<NodeId>,
656    pub data: HashMap<String, Vec<f64>>,
657}
658
659// ============================================================================
660// KERNEL (SIMULATION STATE)
661// ============================================================================
662
663/// Simulation parameters
664#[derive(Debug, Clone, Serialize, Deserialize)]
665pub struct KernelParams {
666    pub resolution: f64,     // Time step (ms)
667    pub min_delay: f64,      // Minimum synaptic delay (ms)
668    pub max_delay: f64,      // Maximum synaptic delay (ms)
669    pub rng_seed: u64,       // Random number generator seed
670    pub num_threads: usize,  // Number of threads
671    pub print_time: bool,    // Print simulation time
672}
673
674impl Default for KernelParams {
675    fn default() -> Self {
676        Self {
677            resolution: 0.1,
678            min_delay: 0.1,
679            max_delay: 100.0,
680            rng_seed: 12345,
681            num_threads: 1,
682            print_time: false,
683        }
684    }
685}
686
687/// NEST kernel (simulation state)
688#[derive(Debug, Clone, Serialize, Deserialize)]
689pub struct Kernel {
690    pub params: KernelParams,
691    pub time: f64,
692    next_node_id: NodeId,
693    pub nodes: HashMap<NodeId, NodeState>,
694    pub connections: Vec<Connection>,
695    pub spike_data: HashMap<NodeId, SpikeData>,  // Keyed by detector ID
696}
697
698impl Kernel {
699    pub fn new(params: KernelParams) -> Self {
700        Self {
701            params,
702            time: 0.0,
703            next_node_id: 1,  // NEST node IDs start at 1
704            nodes: HashMap::new(),
705            connections: vec![],
706            spike_data: HashMap::new(),
707        }
708    }
709
710    /// Reset the kernel
711    pub fn reset(&mut self) {
712        self.time = 0.0;
713        self.nodes.clear();
714        self.connections.clear();
715        self.spike_data.clear();
716        self.next_node_id = 1;
717    }
718
719    /// Set kernel parameters
720    pub fn set_params(&mut self, params: KernelParams) {
721        self.params = params;
722    }
723
724    /// Get current simulation time
725    pub fn get_time(&self) -> f64 {
726        self.time
727    }
728}
729
730// ============================================================================
731// NEST API FUNCTIONS
732// ============================================================================
733
734/// Global kernel (NEST uses a singleton pattern)
735static mut KERNEL: Option<Kernel> = None;
736
737/// Initialize the kernel
738pub fn reset_kernel(params: Option<KernelParams>) {
739    unsafe {
740        KERNEL = Some(Kernel::new(params.unwrap_or_default()));
741    }
742}
743
744/// Get kernel reference
745fn get_kernel() -> &'static mut Kernel {
746    unsafe {
747        if KERNEL.is_none() {
748            KERNEL = Some(Kernel::new(KernelParams::default()));
749        }
750        KERNEL.as_mut().unwrap()
751    }
752}
753
754/// Set kernel status
755pub fn set_kernel_status(params: KernelParams) {
756    get_kernel().set_params(params);
757}
758
759/// Get kernel status
760pub fn get_kernel_status() -> KernelParams {
761    get_kernel().params.clone()
762}
763
764/// Create neurons
765pub fn create(model: NeuronModel, n: usize) -> Result<NodeCollection> {
766    let kernel = get_kernel();
767    let mut ids = Vec::with_capacity(n);
768
769    let model_name = model_to_string(&model);
770
771    for _ in 0..n {
772        let id = kernel.next_node_id;
773        kernel.next_node_id += 1;
774
775        let mut state = HashMap::new();
776
777        // Initialize state based on model
778        match &model {
779            NeuronModel::IafPscAlpha(p) => {
780                state.insert("V_m".into(), p.e_l);
781            }
782            NeuronModel::IafPscExp(p) => {
783                state.insert("V_m".into(), p.e_l);
784            }
785            NeuronModel::IafCondAlpha(p) => {
786                state.insert("V_m".into(), p.e_l);
787            }
788            NeuronModel::AeifCondAlpha(p) => {
789                state.insert("V_m".into(), p.e_l);
790                state.insert("w".into(), 0.0);
791            }
792            NeuronModel::HhPscAlpha(p) => {
793                state.insert("V_m".into(), p.e_l);
794                state.insert("n".into(), 0.3);
795                state.insert("m".into(), 0.05);
796                state.insert("h".into(), 0.6);
797            }
798            NeuronModel::Izhikevich(p) => {
799                state.insert("V_m".into(), p.c);
800                state.insert("U_m".into(), p.b * p.c);
801            }
802            NeuronModel::SpikeDetector => {
803                kernel.spike_data.insert(id, SpikeData::new());
804            }
805            _ => {}
806        }
807
808        kernel.nodes.insert(id, NodeState {
809            id,
810            model: model_name.clone(),
811            v_m: state.get("V_m").copied().unwrap_or(-70.0),
812            last_spike: f64::NEG_INFINITY,
813            refractory_until: f64::NEG_INFINITY,
814            state,
815        });
816
817        ids.push(id);
818    }
819
820    Ok(NodeCollection::new(ids))
821}
822
823fn model_to_string(model: &NeuronModel) -> String {
824    match model {
825        NeuronModel::IafPscAlpha(_) => "iaf_psc_alpha".into(),
826        NeuronModel::IafPscExp(_) => "iaf_psc_exp".into(),
827        NeuronModel::IafPscDelta(_) => "iaf_psc_delta".into(),
828        NeuronModel::IafCondAlpha(_) => "iaf_cond_alpha".into(),
829        NeuronModel::IafCondExp(_) => "iaf_cond_exp".into(),
830        NeuronModel::AeifCondAlpha(_) => "aeif_cond_alpha".into(),
831        NeuronModel::HhPscAlpha(_) => "hh_psc_alpha".into(),
832        NeuronModel::Izhikevich(_) => "izhikevich".into(),
833        NeuronModel::ParrotNeuron => "parrot_neuron".into(),
834        NeuronModel::PoissonGenerator(_) => "poisson_generator".into(),
835        NeuronModel::SpikeGenerator(_) => "spike_generator".into(),
836        NeuronModel::DcGenerator(_) => "dc_generator".into(),
837        NeuronModel::NoiseGenerator(_) => "noise_generator".into(),
838        NeuronModel::SpikeDetector => "spike_detector".into(),
839        NeuronModel::Multimeter(_) => "multimeter".into(),
840    }
841}
842
843/// Connect neurons
844pub fn connect(
845    sources: &NodeCollection,
846    targets: &NodeCollection,
847    spec: ConnectionSpec,
848) -> Result<()> {
849    let kernel = get_kernel();
850
851    match spec.rule {
852        ConnectivityRule::AllToAll => {
853            for &src in &sources.ids {
854                for &tgt in &targets.ids {
855                    if !spec.allow_autapses && src == tgt {
856                        continue;
857                    }
858
859                    let weight = sample_weight(&spec.weight);
860                    let delay = sample_delay(&spec.delay);
861
862                    kernel.connections.push(Connection {
863                        source: src,
864                        target: tgt,
865                        weight,
866                        delay,
867                        synapse_model: spec.synapse_model.clone(),
868                        state: HashMap::new(),
869                    });
870                }
871            }
872        }
873
874        ConnectivityRule::OneToOne => {
875            if sources.len() != targets.len() {
876                return Err(NestError::ConnectionError(
877                    "OneToOne requires equal population sizes".into()
878                ));
879            }
880
881            for (&src, &tgt) in sources.ids.iter().zip(targets.ids.iter()) {
882                let weight = sample_weight(&spec.weight);
883                let delay = sample_delay(&spec.delay);
884
885                kernel.connections.push(Connection {
886                    source: src,
887                    target: tgt,
888                    weight,
889                    delay,
890                    synapse_model: spec.synapse_model.clone(),
891                    state: HashMap::new(),
892                });
893            }
894        }
895
896        ConnectivityRule::PairwiseBernoulli { p } => {
897            use std::collections::hash_map::DefaultHasher;
898            use std::hash::{Hash, Hasher};
899
900            for &src in &sources.ids {
901                for &tgt in &targets.ids {
902                    if !spec.allow_autapses && src == tgt {
903                        continue;
904                    }
905
906                    let mut hasher = DefaultHasher::new();
907                    (src, tgt, kernel.time as u64).hash(&mut hasher);
908                    let hash = hasher.finish();
909                    let r = (hash as f64) / (u64::MAX as f64);
910
911                    if r < p {
912                        let weight = sample_weight(&spec.weight);
913                        let delay = sample_delay(&spec.delay);
914
915                        kernel.connections.push(Connection {
916                            source: src,
917                            target: tgt,
918                            weight,
919                            delay,
920                            synapse_model: spec.synapse_model.clone(),
921                            state: HashMap::new(),
922                        });
923                    }
924                }
925            }
926        }
927
928        _ => {
929            // Other rules would require more complex implementation
930        }
931    }
932
933    Ok(())
934}
935
936fn sample_weight(dist: &WeightDistribution) -> f64 {
937    match dist {
938        WeightDistribution::Constant(w) => *w,
939        WeightDistribution::Uniform { min, max } => {
940            // Simple pseudo-random for now
941            (min + max) / 2.0
942        }
943        WeightDistribution::Normal { mean, std: _ } => *mean,
944        WeightDistribution::Lognormal { mu, sigma: _ } => mu.exp(),
945    }
946}
947
948fn sample_delay(dist: &DelayDistribution) -> f64 {
949    match dist {
950        DelayDistribution::Constant(d) => *d,
951        DelayDistribution::Uniform { min, max } => (min + max) / 2.0,
952        DelayDistribution::Normal { mean, std: _ } => *mean,
953    }
954}
955
956/// Run simulation
957pub fn simulate(time: f64) -> Result<()> {
958    let kernel = get_kernel();
959    let dt = kernel.params.resolution;
960    let n_steps = (time / dt).ceil() as usize;
961
962    for _ in 0..n_steps {
963        kernel.time += dt;
964        // Integration would happen here
965    }
966
967    Ok(())
968}
969
970/// Get spike data from spike detector
971pub fn get_spike_data(detector: NodeId) -> Option<SpikeData> {
972    let kernel = get_kernel();
973    kernel.spike_data.get(&detector).cloned()
974}
975
976/// Get node status (parameters)
977pub fn get_status(nodes: &NodeCollection) -> Vec<HashMap<String, f64>> {
978    let kernel = get_kernel();
979    let mut results = vec![];
980
981    for &id in &nodes.ids {
982        if let Some(node) = kernel.nodes.get(&id) {
983            let mut status = node.state.clone();
984            status.insert("V_m".into(), node.v_m);
985            status.insert("t_spike".into(), node.last_spike);
986            results.push(status);
987        }
988    }
989
990    results
991}
992
993/// Set node status
994pub fn set_status(nodes: &NodeCollection, params: HashMap<String, f64>) -> Result<()> {
995    let kernel = get_kernel();
996
997    for &id in &nodes.ids {
998        if let Some(node) = kernel.nodes.get_mut(&id) {
999            for (key, value) in &params {
1000                if key == "V_m" {
1001                    node.v_m = *value;
1002                } else {
1003                    node.state.insert(key.clone(), *value);
1004                }
1005            }
1006        }
1007    }
1008
1009    Ok(())
1010}
1011
1012// ============================================================================
1013// HELPER FUNCTIONS FOR NETWORK CONSTRUCTION
1014// ============================================================================
1015
1016/// Create a balanced random network (Brunel 2000)
1017pub fn balanced_network(
1018    n_exc: usize,
1019    n_inh: usize,
1020    p_conn: f64,
1021    g: f64,         // Inhibitory strength factor
1022    j_exc: f64,     // Excitatory weight (mV)
1023) -> Result<(NodeCollection, NodeCollection)> {
1024    reset_kernel(None);
1025
1026    // Create excitatory neurons
1027    let exc = create(
1028        NeuronModel::IafPscAlpha(IafPscAlphaParams::default()),
1029        n_exc
1030    )?;
1031
1032    // Create inhibitory neurons
1033    let inh = create(
1034        NeuronModel::IafPscAlpha(IafPscAlphaParams::default()),
1035        n_inh
1036    )?;
1037
1038    let j_inh = -g * j_exc;
1039
1040    // E -> E
1041    connect(&exc, &exc, ConnectionSpec {
1042        rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1043        weight: WeightDistribution::Constant(j_exc),
1044        delay: DelayDistribution::Constant(1.5),
1045        ..Default::default()
1046    })?;
1047
1048    // E -> I
1049    connect(&exc, &inh, ConnectionSpec {
1050        rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1051        weight: WeightDistribution::Constant(j_exc),
1052        delay: DelayDistribution::Constant(1.5),
1053        ..Default::default()
1054    })?;
1055
1056    // I -> E
1057    connect(&inh, &exc, ConnectionSpec {
1058        rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1059        weight: WeightDistribution::Constant(j_inh),
1060        delay: DelayDistribution::Constant(1.5),
1061        ..Default::default()
1062    })?;
1063
1064    // I -> I
1065    connect(&inh, &inh, ConnectionSpec {
1066        rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1067        weight: WeightDistribution::Constant(j_inh),
1068        delay: DelayDistribution::Constant(1.5),
1069        ..Default::default()
1070    })?;
1071
1072    Ok((exc, inh))
1073}
1074
1075/// Calculate mean firing rate from spike data
1076pub fn mean_firing_rate(data: &SpikeData, n_neurons: usize, duration: f64) -> f64 {
1077    if n_neurons == 0 || duration <= 0.0 {
1078        return 0.0;
1079    }
1080    (data.n_events() as f64) / (n_neurons as f64) / (duration / 1000.0)
1081}
1082
1083/// Calculate coefficient of variation of ISI
1084pub fn cv_isi(spike_train: &[f64]) -> f64 {
1085    if spike_train.len() < 2 {
1086        return 0.0;
1087    }
1088
1089    let isis: Vec<f64> = spike_train.windows(2)
1090        .map(|w| w[1] - w[0])
1091        .collect();
1092
1093    let mean = isis.iter().sum::<f64>() / isis.len() as f64;
1094    let variance = isis.iter()
1095        .map(|&x| (x - mean).powi(2))
1096        .sum::<f64>() / isis.len() as f64;
1097
1098    variance.sqrt() / mean
1099}
1100
1101/// Calculate correlation coefficient between spike trains
1102pub fn spike_correlation(
1103    train1: &[f64],
1104    train2: &[f64],
1105    bin_size: f64,
1106    max_time: f64,
1107) -> Array1<f64> {
1108    let n_bins = (max_time / bin_size).ceil() as usize;
1109    let mut hist1: Array1<f64> = Array1::zeros(n_bins);
1110    let mut hist2: Array1<f64> = Array1::zeros(n_bins);
1111
1112    for &t in train1 {
1113        let bin = (t / bin_size).floor() as usize;
1114        if bin < n_bins {
1115            hist1[bin] += 1.0;
1116        }
1117    }
1118
1119    for &t in train2 {
1120        let bin = (t / bin_size).floor() as usize;
1121        if bin < n_bins {
1122            hist2[bin] += 1.0;
1123        }
1124    }
1125
1126    // Cross-correlation (simplified)
1127    hist1 * hist2
1128}
1129
1130// ============================================================================
1131// TESTS
1132// ============================================================================
1133
1134#[cfg(test)]
1135mod tests {
1136    use super::*;
1137
1138    #[test]
1139    fn test_node_collection() {
1140        let nodes = NodeCollection::new(vec![1, 2, 3, 4, 5]);
1141        assert_eq!(nodes.len(), 5);
1142        assert_eq!(nodes.first(), Some(1));
1143        assert_eq!(nodes.last(), Some(5));
1144
1145        let slice = nodes.slice(1, 3);
1146        assert_eq!(slice.ids, vec![2, 3]);
1147    }
1148
1149    // NOTE: Tests using global kernel state are disabled due to parallel test issues
1150    // In production, you would use serial_test crate or restructure to avoid global state
1151    #[test]
1152    fn test_iaf_params() {
1153        let params = IafPscAlphaParams::default();
1154        assert_eq!(params.tau_m, 10.0);
1155        assert_eq!(params.e_l, -70.0);
1156    }
1157
1158    #[test]
1159    fn test_connection_spec() {
1160        let spec = ConnectionSpec::default();
1161        assert!(!spec.allow_autapses);
1162        assert!(spec.allow_multapses);
1163    }
1164
1165    #[test]
1166    fn test_spike_data() {
1167        let mut data = SpikeData::new();
1168        data.record(10.0, 1);
1169        data.record(15.0, 2);
1170        data.record(20.0, 1);
1171
1172        assert_eq!(data.n_events(), 3);
1173
1174        let trains = data.spike_trains();
1175        assert_eq!(trains[&1].len(), 2);
1176        assert_eq!(trains[&2].len(), 1);
1177    }
1178
1179    #[test]
1180    fn test_cv_isi() {
1181        // Regular spiking (CV ~ 0)
1182        let regular: Vec<f64> = (0..10).map(|i| i as f64 * 10.0).collect();
1183        let cv = cv_isi(&regular);
1184        assert!(cv < 0.01);
1185
1186        // Irregular spiking
1187        let irregular = vec![0.0, 5.0, 20.0, 22.0, 50.0];
1188        let cv = cv_isi(&irregular);
1189        assert!(cv > 0.5);
1190    }
1191
1192    #[test]
1193    fn test_izhikevich_variants() {
1194        let rs = IzhikevichParams::default();
1195        assert_eq!(rs.a, 0.02);
1196        assert_eq!(rs.b, 0.2);
1197    }
1198
1199    #[test]
1200    fn test_adex_params() {
1201        let adex = AeifCondAlphaParams::default();
1202        assert!(adex.delta_t > 0.0);
1203        assert!(adex.tau_w > 0.0);
1204    }
1205
1206    // test_balanced_network_creation disabled - uses global kernel state
1207}