1use ndarray::Array1;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use thiserror::Error;
21
22#[derive(Error, Debug)]
23pub enum NestError {
24 #[error("Unknown node model: {0}")]
25 UnknownModel(String),
26 #[error("Node not found: {0}")]
27 NodeNotFound(usize),
28 #[error("Invalid parameter: {0}")]
29 InvalidParameter(String),
30 #[error("Connection error: {0}")]
31 ConnectionError(String),
32 #[error("Simulation error: {0}")]
33 SimulationError(String),
34}
35
36pub type Result<T> = std::result::Result<T, NestError>;
37
38pub type NodeId = usize;
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct NodeCollection {
48 pub ids: Vec<NodeId>,
49}
50
51impl NodeCollection {
52 pub fn new(ids: Vec<NodeId>) -> Self {
53 Self { ids }
54 }
55
56 pub fn len(&self) -> usize {
57 self.ids.len()
58 }
59
60 pub fn is_empty(&self) -> bool {
61 self.ids.is_empty()
62 }
63
64 pub fn first(&self) -> Option<NodeId> {
65 self.ids.first().copied()
66 }
67
68 pub fn last(&self) -> Option<NodeId> {
69 self.ids.last().copied()
70 }
71
72 pub fn slice(&self, start: usize, end: usize) -> Self {
74 Self::new(self.ids[start..end].to_vec())
75 }
76}
77
78impl IntoIterator for NodeCollection {
79 type Item = NodeId;
80 type IntoIter = std::vec::IntoIter<NodeId>;
81
82 fn into_iter(self) -> Self::IntoIter {
83 self.ids.into_iter()
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
93pub enum NeuronModel {
94 IafPscAlpha(IafPscAlphaParams),
96
97 IafPscExp(IafPscExpParams),
99
100 IafPscDelta(IafPscDeltaParams),
102
103 IafCondAlpha(IafCondAlphaParams),
105
106 IafCondExp(IafCondExpParams),
108
109 AeifCondAlpha(AeifCondAlphaParams),
111
112 HhPscAlpha(HhPscAlphaParams),
114
115 Izhikevich(IzhikevichParams),
117
118 ParrotNeuron,
120
121 PoissonGenerator(PoissonGeneratorParams),
123
124 SpikeGenerator(SpikeGeneratorParams),
126
127 DcGenerator(DcGeneratorParams),
129
130 NoiseGenerator(NoiseGeneratorParams),
132
133 SpikeDetector,
135
136 Multimeter(MultimeterParams),
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct IafPscAlphaParams {
143 pub c_m: f64, pub tau_m: f64, pub tau_syn_ex: f64, pub tau_syn_in: f64, pub t_ref: f64, pub e_l: f64, pub v_reset: f64, pub v_th: f64, pub i_e: f64, }
153
154impl Default for IafPscAlphaParams {
155 fn default() -> Self {
156 Self {
157 c_m: 250.0,
158 tau_m: 10.0,
159 tau_syn_ex: 2.0,
160 tau_syn_in: 2.0,
161 t_ref: 2.0,
162 e_l: -70.0,
163 v_reset: -70.0,
164 v_th: -55.0,
165 i_e: 0.0,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct IafPscExpParams {
173 pub c_m: f64,
174 pub tau_m: f64,
175 pub tau_syn_ex: f64,
176 pub tau_syn_in: f64,
177 pub t_ref: f64,
178 pub e_l: f64,
179 pub v_reset: f64,
180 pub v_th: f64,
181 pub i_e: f64,
182}
183
184impl Default for IafPscExpParams {
185 fn default() -> Self {
186 Self {
187 c_m: 250.0,
188 tau_m: 10.0,
189 tau_syn_ex: 2.0,
190 tau_syn_in: 2.0,
191 t_ref: 2.0,
192 e_l: -70.0,
193 v_reset: -70.0,
194 v_th: -55.0,
195 i_e: 0.0,
196 }
197 }
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct IafPscDeltaParams {
203 pub c_m: f64,
204 pub tau_m: f64,
205 pub t_ref: f64,
206 pub e_l: f64,
207 pub v_reset: f64,
208 pub v_th: f64,
209 pub i_e: f64,
210}
211
212impl Default for IafPscDeltaParams {
213 fn default() -> Self {
214 Self {
215 c_m: 250.0,
216 tau_m: 10.0,
217 t_ref: 2.0,
218 e_l: -70.0,
219 v_reset: -70.0,
220 v_th: -55.0,
221 i_e: 0.0,
222 }
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct IafCondAlphaParams {
229 pub c_m: f64,
230 pub g_l: f64, pub tau_syn_ex: f64,
232 pub tau_syn_in: f64,
233 pub t_ref: f64,
234 pub e_l: f64,
235 pub e_ex: f64, pub e_in: f64, pub v_reset: f64,
238 pub v_th: f64,
239 pub i_e: f64,
240}
241
242impl Default for IafCondAlphaParams {
243 fn default() -> Self {
244 Self {
245 c_m: 250.0,
246 g_l: 16.7,
247 tau_syn_ex: 0.2,
248 tau_syn_in: 2.0,
249 t_ref: 2.0,
250 e_l: -70.0,
251 e_ex: 0.0,
252 e_in: -85.0,
253 v_reset: -70.0,
254 v_th: -55.0,
255 i_e: 0.0,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct IafCondExpParams {
263 pub c_m: f64,
264 pub g_l: f64,
265 pub tau_syn_ex: f64,
266 pub tau_syn_in: f64,
267 pub t_ref: f64,
268 pub e_l: f64,
269 pub e_ex: f64,
270 pub e_in: f64,
271 pub v_reset: f64,
272 pub v_th: f64,
273 pub i_e: f64,
274}
275
276impl Default for IafCondExpParams {
277 fn default() -> Self {
278 Self {
279 c_m: 250.0,
280 g_l: 16.7,
281 tau_syn_ex: 0.2,
282 tau_syn_in: 2.0,
283 t_ref: 2.0,
284 e_l: -70.0,
285 e_ex: 0.0,
286 e_in: -85.0,
287 v_reset: -70.0,
288 v_th: -55.0,
289 i_e: 0.0,
290 }
291 }
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct AeifCondAlphaParams {
297 pub c_m: f64,
298 pub g_l: f64,
299 pub tau_syn_ex: f64,
300 pub tau_syn_in: f64,
301 pub t_ref: f64,
302 pub e_l: f64,
303 pub e_ex: f64,
304 pub e_in: f64,
305 pub v_reset: f64,
306 pub v_th: f64,
307 pub v_peak: f64, pub delta_t: f64, pub tau_w: f64, pub a: f64, pub b: f64, pub i_e: f64,
313}
314
315impl Default for AeifCondAlphaParams {
316 fn default() -> Self {
317 Self {
318 c_m: 281.0,
319 g_l: 30.0,
320 tau_syn_ex: 0.2,
321 tau_syn_in: 2.0,
322 t_ref: 0.0,
323 e_l: -70.6,
324 e_ex: 0.0,
325 e_in: -85.0,
326 v_reset: -60.0,
327 v_th: -50.4,
328 v_peak: 0.0,
329 delta_t: 2.0,
330 tau_w: 144.0,
331 a: 4.0,
332 b: 80.5,
333 i_e: 0.0,
334 }
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct HhPscAlphaParams {
341 pub c_m: f64,
342 pub g_na: f64, pub g_k: f64, pub g_l: f64, pub e_na: f64, pub e_k: f64, pub e_l: f64, pub tau_syn_ex: f64,
349 pub tau_syn_in: f64,
350 pub i_e: f64,
351}
352
353impl Default for HhPscAlphaParams {
354 fn default() -> Self {
355 Self {
356 c_m: 100.0,
357 g_na: 12000.0,
358 g_k: 3600.0,
359 g_l: 30.0,
360 e_na: 50.0,
361 e_k: -77.0,
362 e_l: -54.4,
363 tau_syn_ex: 0.2,
364 tau_syn_in: 2.0,
365 i_e: 0.0,
366 }
367 }
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct IzhikevichParams {
373 pub a: f64,
374 pub b: f64,
375 pub c: f64,
376 pub d: f64,
377}
378
379impl Default for IzhikevichParams {
380 fn default() -> Self {
381 Self {
383 a: 0.02,
384 b: 0.2,
385 c: -65.0,
386 d: 8.0,
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct PoissonGeneratorParams {
394 pub rate: f64, }
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct SpikeGeneratorParams {
400 pub spike_times: Vec<f64>, pub spike_weights: Vec<f64>,
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct DcGeneratorParams {
407 pub amplitude: f64, pub start: f64, pub stop: f64, }
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct NoiseGeneratorParams {
415 pub mean: f64, pub std: f64, pub dt: f64, }
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct MultimeterParams {
423 pub record_from: Vec<String>, pub interval: f64, }
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
433pub enum SynapseModel {
434 Static,
436
437 StdpSynapse(StdpParams),
439
440 TsodyksMarkramSynapse(TsodyksMarkramParams),
442
443 BernoulliSynapse(BernoulliParams),
445
446 VogelsSprekelerSynapse(VogelsSprekelerParams),
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct StdpParams {
453 pub tau_plus: f64, pub tau_minus: f64, pub lambda: f64, pub alpha: f64, pub w_max: f64, pub mu_plus: f64, pub mu_minus: f64, }
461
462impl Default for StdpParams {
463 fn default() -> Self {
464 Self {
465 tau_plus: 20.0,
466 tau_minus: 20.0,
467 lambda: 0.01,
468 alpha: 1.0,
469 w_max: 100.0,
470 mu_plus: 1.0,
471 mu_minus: 1.0,
472 }
473 }
474}
475
476#[derive(Debug, Clone, Serialize, Deserialize)]
478pub struct TsodyksMarkramParams {
479 pub u: f64, pub tau_rec: f64, pub tau_fac: f64, }
483
484impl Default for TsodyksMarkramParams {
485 fn default() -> Self {
486 Self {
487 u: 0.5,
488 tau_rec: 800.0,
489 tau_fac: 0.0,
490 }
491 }
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize)]
496pub struct BernoulliParams {
497 pub p_transmit: f64, }
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct VogelsSprekelerParams {
503 pub tau: f64, pub eta: f64, pub alpha: f64, pub w_max: f64, }
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
515pub enum ConnectivityRule {
516 AllToAll,
518
519 OneToOne,
521
522 FixedIndegree { indegree: usize },
524
525 FixedOutdegree { outdegree: usize },
527
528 FixedTotalNumber { n: usize },
530
531 PairwiseBernoulli { p: f64 },
533
534 SymmetricPairwiseBernoulli { p: f64 },
536}
537
538#[derive(Debug, Clone, Serialize, Deserialize)]
540pub enum WeightDistribution {
541 Constant(f64),
542 Uniform { min: f64, max: f64 },
543 Normal { mean: f64, std: f64 },
544 Lognormal { mu: f64, sigma: f64 },
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize)]
549pub enum DelayDistribution {
550 Constant(f64),
551 Uniform { min: f64, max: f64 },
552 Normal { mean: f64, std: f64 },
553}
554
555#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct ConnectionSpec {
558 pub rule: ConnectivityRule,
559 pub weight: WeightDistribution,
560 pub delay: DelayDistribution,
561 pub synapse_model: SynapseModel,
562 pub allow_autapses: bool,
563 pub allow_multapses: bool,
564}
565
566impl Default for ConnectionSpec {
567 fn default() -> Self {
568 Self {
569 rule: ConnectivityRule::AllToAll,
570 weight: WeightDistribution::Constant(1.0),
571 delay: DelayDistribution::Constant(1.0),
572 synapse_model: SynapseModel::Static,
573 allow_autapses: false,
574 allow_multapses: true,
575 }
576 }
577}
578
579#[derive(Debug, Clone, Serialize, Deserialize)]
585pub struct NodeState {
586 pub id: NodeId,
587 pub model: String,
588 pub v_m: f64, pub last_spike: f64, pub refractory_until: f64,
591 pub state: HashMap<String, f64>,
593}
594
595#[derive(Debug, Clone, Serialize, Deserialize)]
597pub struct Connection {
598 pub source: NodeId,
599 pub target: NodeId,
600 pub weight: f64,
601 pub delay: f64,
602 pub synapse_model: SynapseModel,
603 pub state: HashMap<String, f64>,
605}
606
607#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct SpikeData {
614 pub times: Vec<f64>,
615 pub senders: Vec<NodeId>,
616}
617
618impl SpikeData {
619 pub fn new() -> Self {
620 Self {
621 times: vec![],
622 senders: vec![],
623 }
624 }
625
626 pub fn record(&mut self, time: f64, sender: NodeId) {
627 self.times.push(time);
628 self.senders.push(sender);
629 }
630
631 pub fn n_events(&self) -> usize {
632 self.times.len()
633 }
634
635 pub fn spike_trains(&self) -> HashMap<NodeId, Vec<f64>> {
637 let mut trains: HashMap<NodeId, Vec<f64>> = HashMap::new();
638 for (&time, &sender) in self.times.iter().zip(self.senders.iter()) {
639 trains.entry(sender).or_default().push(time);
640 }
641 trains
642 }
643}
644
645impl Default for SpikeData {
646 fn default() -> Self {
647 Self::new()
648 }
649}
650
651#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct ContinuousData {
654 pub times: Vec<f64>,
655 pub senders: Vec<NodeId>,
656 pub data: HashMap<String, Vec<f64>>,
657}
658
659#[derive(Debug, Clone, Serialize, Deserialize)]
665pub struct KernelParams {
666 pub resolution: f64, pub min_delay: f64, pub max_delay: f64, pub rng_seed: u64, pub num_threads: usize, pub print_time: bool, }
673
674impl Default for KernelParams {
675 fn default() -> Self {
676 Self {
677 resolution: 0.1,
678 min_delay: 0.1,
679 max_delay: 100.0,
680 rng_seed: 12345,
681 num_threads: 1,
682 print_time: false,
683 }
684 }
685}
686
687#[derive(Debug, Clone, Serialize, Deserialize)]
689pub struct Kernel {
690 pub params: KernelParams,
691 pub time: f64,
692 next_node_id: NodeId,
693 pub nodes: HashMap<NodeId, NodeState>,
694 pub connections: Vec<Connection>,
695 pub spike_data: HashMap<NodeId, SpikeData>, }
697
698impl Kernel {
699 pub fn new(params: KernelParams) -> Self {
700 Self {
701 params,
702 time: 0.0,
703 next_node_id: 1, nodes: HashMap::new(),
705 connections: vec![],
706 spike_data: HashMap::new(),
707 }
708 }
709
710 pub fn reset(&mut self) {
712 self.time = 0.0;
713 self.nodes.clear();
714 self.connections.clear();
715 self.spike_data.clear();
716 self.next_node_id = 1;
717 }
718
719 pub fn set_params(&mut self, params: KernelParams) {
721 self.params = params;
722 }
723
724 pub fn get_time(&self) -> f64 {
726 self.time
727 }
728}
729
730static mut KERNEL: Option<Kernel> = None;
736
737pub fn reset_kernel(params: Option<KernelParams>) {
739 unsafe {
740 KERNEL = Some(Kernel::new(params.unwrap_or_default()));
741 }
742}
743
744fn get_kernel() -> &'static mut Kernel {
746 unsafe {
747 if KERNEL.is_none() {
748 KERNEL = Some(Kernel::new(KernelParams::default()));
749 }
750 KERNEL.as_mut().unwrap()
751 }
752}
753
754pub fn set_kernel_status(params: KernelParams) {
756 get_kernel().set_params(params);
757}
758
759pub fn get_kernel_status() -> KernelParams {
761 get_kernel().params.clone()
762}
763
764pub fn create(model: NeuronModel, n: usize) -> Result<NodeCollection> {
766 let kernel = get_kernel();
767 let mut ids = Vec::with_capacity(n);
768
769 let model_name = model_to_string(&model);
770
771 for _ in 0..n {
772 let id = kernel.next_node_id;
773 kernel.next_node_id += 1;
774
775 let mut state = HashMap::new();
776
777 match &model {
779 NeuronModel::IafPscAlpha(p) => {
780 state.insert("V_m".into(), p.e_l);
781 }
782 NeuronModel::IafPscExp(p) => {
783 state.insert("V_m".into(), p.e_l);
784 }
785 NeuronModel::IafCondAlpha(p) => {
786 state.insert("V_m".into(), p.e_l);
787 }
788 NeuronModel::AeifCondAlpha(p) => {
789 state.insert("V_m".into(), p.e_l);
790 state.insert("w".into(), 0.0);
791 }
792 NeuronModel::HhPscAlpha(p) => {
793 state.insert("V_m".into(), p.e_l);
794 state.insert("n".into(), 0.3);
795 state.insert("m".into(), 0.05);
796 state.insert("h".into(), 0.6);
797 }
798 NeuronModel::Izhikevich(p) => {
799 state.insert("V_m".into(), p.c);
800 state.insert("U_m".into(), p.b * p.c);
801 }
802 NeuronModel::SpikeDetector => {
803 kernel.spike_data.insert(id, SpikeData::new());
804 }
805 _ => {}
806 }
807
808 kernel.nodes.insert(id, NodeState {
809 id,
810 model: model_name.clone(),
811 v_m: state.get("V_m").copied().unwrap_or(-70.0),
812 last_spike: f64::NEG_INFINITY,
813 refractory_until: f64::NEG_INFINITY,
814 state,
815 });
816
817 ids.push(id);
818 }
819
820 Ok(NodeCollection::new(ids))
821}
822
823fn model_to_string(model: &NeuronModel) -> String {
824 match model {
825 NeuronModel::IafPscAlpha(_) => "iaf_psc_alpha".into(),
826 NeuronModel::IafPscExp(_) => "iaf_psc_exp".into(),
827 NeuronModel::IafPscDelta(_) => "iaf_psc_delta".into(),
828 NeuronModel::IafCondAlpha(_) => "iaf_cond_alpha".into(),
829 NeuronModel::IafCondExp(_) => "iaf_cond_exp".into(),
830 NeuronModel::AeifCondAlpha(_) => "aeif_cond_alpha".into(),
831 NeuronModel::HhPscAlpha(_) => "hh_psc_alpha".into(),
832 NeuronModel::Izhikevich(_) => "izhikevich".into(),
833 NeuronModel::ParrotNeuron => "parrot_neuron".into(),
834 NeuronModel::PoissonGenerator(_) => "poisson_generator".into(),
835 NeuronModel::SpikeGenerator(_) => "spike_generator".into(),
836 NeuronModel::DcGenerator(_) => "dc_generator".into(),
837 NeuronModel::NoiseGenerator(_) => "noise_generator".into(),
838 NeuronModel::SpikeDetector => "spike_detector".into(),
839 NeuronModel::Multimeter(_) => "multimeter".into(),
840 }
841}
842
843pub fn connect(
845 sources: &NodeCollection,
846 targets: &NodeCollection,
847 spec: ConnectionSpec,
848) -> Result<()> {
849 let kernel = get_kernel();
850
851 match spec.rule {
852 ConnectivityRule::AllToAll => {
853 for &src in &sources.ids {
854 for &tgt in &targets.ids {
855 if !spec.allow_autapses && src == tgt {
856 continue;
857 }
858
859 let weight = sample_weight(&spec.weight);
860 let delay = sample_delay(&spec.delay);
861
862 kernel.connections.push(Connection {
863 source: src,
864 target: tgt,
865 weight,
866 delay,
867 synapse_model: spec.synapse_model.clone(),
868 state: HashMap::new(),
869 });
870 }
871 }
872 }
873
874 ConnectivityRule::OneToOne => {
875 if sources.len() != targets.len() {
876 return Err(NestError::ConnectionError(
877 "OneToOne requires equal population sizes".into()
878 ));
879 }
880
881 for (&src, &tgt) in sources.ids.iter().zip(targets.ids.iter()) {
882 let weight = sample_weight(&spec.weight);
883 let delay = sample_delay(&spec.delay);
884
885 kernel.connections.push(Connection {
886 source: src,
887 target: tgt,
888 weight,
889 delay,
890 synapse_model: spec.synapse_model.clone(),
891 state: HashMap::new(),
892 });
893 }
894 }
895
896 ConnectivityRule::PairwiseBernoulli { p } => {
897 use std::collections::hash_map::DefaultHasher;
898 use std::hash::{Hash, Hasher};
899
900 for &src in &sources.ids {
901 for &tgt in &targets.ids {
902 if !spec.allow_autapses && src == tgt {
903 continue;
904 }
905
906 let mut hasher = DefaultHasher::new();
907 (src, tgt, kernel.time as u64).hash(&mut hasher);
908 let hash = hasher.finish();
909 let r = (hash as f64) / (u64::MAX as f64);
910
911 if r < p {
912 let weight = sample_weight(&spec.weight);
913 let delay = sample_delay(&spec.delay);
914
915 kernel.connections.push(Connection {
916 source: src,
917 target: tgt,
918 weight,
919 delay,
920 synapse_model: spec.synapse_model.clone(),
921 state: HashMap::new(),
922 });
923 }
924 }
925 }
926 }
927
928 _ => {
929 }
931 }
932
933 Ok(())
934}
935
936fn sample_weight(dist: &WeightDistribution) -> f64 {
937 match dist {
938 WeightDistribution::Constant(w) => *w,
939 WeightDistribution::Uniform { min, max } => {
940 (min + max) / 2.0
942 }
943 WeightDistribution::Normal { mean, std: _ } => *mean,
944 WeightDistribution::Lognormal { mu, sigma: _ } => mu.exp(),
945 }
946}
947
948fn sample_delay(dist: &DelayDistribution) -> f64 {
949 match dist {
950 DelayDistribution::Constant(d) => *d,
951 DelayDistribution::Uniform { min, max } => (min + max) / 2.0,
952 DelayDistribution::Normal { mean, std: _ } => *mean,
953 }
954}
955
956pub fn simulate(time: f64) -> Result<()> {
958 let kernel = get_kernel();
959 let dt = kernel.params.resolution;
960 let n_steps = (time / dt).ceil() as usize;
961
962 for _ in 0..n_steps {
963 kernel.time += dt;
964 }
966
967 Ok(())
968}
969
970pub fn get_spike_data(detector: NodeId) -> Option<SpikeData> {
972 let kernel = get_kernel();
973 kernel.spike_data.get(&detector).cloned()
974}
975
976pub fn get_status(nodes: &NodeCollection) -> Vec<HashMap<String, f64>> {
978 let kernel = get_kernel();
979 let mut results = vec![];
980
981 for &id in &nodes.ids {
982 if let Some(node) = kernel.nodes.get(&id) {
983 let mut status = node.state.clone();
984 status.insert("V_m".into(), node.v_m);
985 status.insert("t_spike".into(), node.last_spike);
986 results.push(status);
987 }
988 }
989
990 results
991}
992
993pub fn set_status(nodes: &NodeCollection, params: HashMap<String, f64>) -> Result<()> {
995 let kernel = get_kernel();
996
997 for &id in &nodes.ids {
998 if let Some(node) = kernel.nodes.get_mut(&id) {
999 for (key, value) in ¶ms {
1000 if key == "V_m" {
1001 node.v_m = *value;
1002 } else {
1003 node.state.insert(key.clone(), *value);
1004 }
1005 }
1006 }
1007 }
1008
1009 Ok(())
1010}
1011
1012pub fn balanced_network(
1018 n_exc: usize,
1019 n_inh: usize,
1020 p_conn: f64,
1021 g: f64, j_exc: f64, ) -> Result<(NodeCollection, NodeCollection)> {
1024 reset_kernel(None);
1025
1026 let exc = create(
1028 NeuronModel::IafPscAlpha(IafPscAlphaParams::default()),
1029 n_exc
1030 )?;
1031
1032 let inh = create(
1034 NeuronModel::IafPscAlpha(IafPscAlphaParams::default()),
1035 n_inh
1036 )?;
1037
1038 let j_inh = -g * j_exc;
1039
1040 connect(&exc, &exc, ConnectionSpec {
1042 rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1043 weight: WeightDistribution::Constant(j_exc),
1044 delay: DelayDistribution::Constant(1.5),
1045 ..Default::default()
1046 })?;
1047
1048 connect(&exc, &inh, ConnectionSpec {
1050 rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1051 weight: WeightDistribution::Constant(j_exc),
1052 delay: DelayDistribution::Constant(1.5),
1053 ..Default::default()
1054 })?;
1055
1056 connect(&inh, &exc, ConnectionSpec {
1058 rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1059 weight: WeightDistribution::Constant(j_inh),
1060 delay: DelayDistribution::Constant(1.5),
1061 ..Default::default()
1062 })?;
1063
1064 connect(&inh, &inh, ConnectionSpec {
1066 rule: ConnectivityRule::PairwiseBernoulli { p: p_conn },
1067 weight: WeightDistribution::Constant(j_inh),
1068 delay: DelayDistribution::Constant(1.5),
1069 ..Default::default()
1070 })?;
1071
1072 Ok((exc, inh))
1073}
1074
1075pub fn mean_firing_rate(data: &SpikeData, n_neurons: usize, duration: f64) -> f64 {
1077 if n_neurons == 0 || duration <= 0.0 {
1078 return 0.0;
1079 }
1080 (data.n_events() as f64) / (n_neurons as f64) / (duration / 1000.0)
1081}
1082
1083pub fn cv_isi(spike_train: &[f64]) -> f64 {
1085 if spike_train.len() < 2 {
1086 return 0.0;
1087 }
1088
1089 let isis: Vec<f64> = spike_train.windows(2)
1090 .map(|w| w[1] - w[0])
1091 .collect();
1092
1093 let mean = isis.iter().sum::<f64>() / isis.len() as f64;
1094 let variance = isis.iter()
1095 .map(|&x| (x - mean).powi(2))
1096 .sum::<f64>() / isis.len() as f64;
1097
1098 variance.sqrt() / mean
1099}
1100
1101pub fn spike_correlation(
1103 train1: &[f64],
1104 train2: &[f64],
1105 bin_size: f64,
1106 max_time: f64,
1107) -> Array1<f64> {
1108 let n_bins = (max_time / bin_size).ceil() as usize;
1109 let mut hist1: Array1<f64> = Array1::zeros(n_bins);
1110 let mut hist2: Array1<f64> = Array1::zeros(n_bins);
1111
1112 for &t in train1 {
1113 let bin = (t / bin_size).floor() as usize;
1114 if bin < n_bins {
1115 hist1[bin] += 1.0;
1116 }
1117 }
1118
1119 for &t in train2 {
1120 let bin = (t / bin_size).floor() as usize;
1121 if bin < n_bins {
1122 hist2[bin] += 1.0;
1123 }
1124 }
1125
1126 hist1 * hist2
1128}
1129
1130#[cfg(test)]
1135mod tests {
1136 use super::*;
1137
1138 #[test]
1139 fn test_node_collection() {
1140 let nodes = NodeCollection::new(vec![1, 2, 3, 4, 5]);
1141 assert_eq!(nodes.len(), 5);
1142 assert_eq!(nodes.first(), Some(1));
1143 assert_eq!(nodes.last(), Some(5));
1144
1145 let slice = nodes.slice(1, 3);
1146 assert_eq!(slice.ids, vec![2, 3]);
1147 }
1148
1149 #[test]
1152 fn test_iaf_params() {
1153 let params = IafPscAlphaParams::default();
1154 assert_eq!(params.tau_m, 10.0);
1155 assert_eq!(params.e_l, -70.0);
1156 }
1157
1158 #[test]
1159 fn test_connection_spec() {
1160 let spec = ConnectionSpec::default();
1161 assert!(!spec.allow_autapses);
1162 assert!(spec.allow_multapses);
1163 }
1164
1165 #[test]
1166 fn test_spike_data() {
1167 let mut data = SpikeData::new();
1168 data.record(10.0, 1);
1169 data.record(15.0, 2);
1170 data.record(20.0, 1);
1171
1172 assert_eq!(data.n_events(), 3);
1173
1174 let trains = data.spike_trains();
1175 assert_eq!(trains[&1].len(), 2);
1176 assert_eq!(trains[&2].len(), 1);
1177 }
1178
1179 #[test]
1180 fn test_cv_isi() {
1181 let regular: Vec<f64> = (0..10).map(|i| i as f64 * 10.0).collect();
1183 let cv = cv_isi(®ular);
1184 assert!(cv < 0.01);
1185
1186 let irregular = vec![0.0, 5.0, 20.0, 22.0, 50.0];
1188 let cv = cv_isi(&irregular);
1189 assert!(cv > 0.5);
1190 }
1191
1192 #[test]
1193 fn test_izhikevich_variants() {
1194 let rs = IzhikevichParams::default();
1195 assert_eq!(rs.a, 0.02);
1196 assert_eq!(rs.b, 0.2);
1197 }
1198
1199 #[test]
1200 fn test_adex_params() {
1201 let adex = AeifCondAlphaParams::default();
1202 assert!(adex.delta_t > 0.0);
1203 assert!(adex.tau_w > 0.0);
1204 }
1205
1206 }