oldies_brian/
lib.rs

1//! # Brian-RS: Brian Spiking Neural Network Simulator Revival
2//!
3//! Revival of the Brian simulator (http://briansimulator.org/)
4//! Originally created by Romain Brette and Dan Goodman (2007)
5//!
6//! Brian uses equation-based model definitions with natural mathematical syntax.
7//! This crate provides:
8//! - Equation parser for differential equations
9//! - Multiple neuron models (LIF, AdEx, Izhikevich, HH)
10//! - Synapse models (exponential, alpha, STDP)
11//! - Network topology and connectivity
12//! - Spike monitors and state monitors
13
14use ndarray::{Array1, Array2};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use thiserror::Error;
18
19#[derive(Error, Debug)]
20pub enum BrianError {
21    #[error("Parse error: {0}")]
22    ParseError(String),
23    #[error("Simulation error: {0}")]
24    SimulationError(String),
25    #[error("Invalid equation: {0}")]
26    EquationError(String),
27    #[error("Unit mismatch: expected {expected}, got {got}")]
28    UnitError { expected: String, got: String },
29}
30
31pub type Result<T> = std::result::Result<T, BrianError>;
32
33// ============================================================================
34// UNITS SYSTEM (Brian's signature feature)
35// ============================================================================
36
37/// Physical units with SI prefixes
38#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
39pub enum Unit {
40    // Time
41    Second,
42    Millisecond,  // ms
43    Microsecond,  // us
44
45    // Voltage
46    Volt,
47    Millivolt,    // mV
48
49    // Current
50    Ampere,
51    Nanoampere,   // nA
52    Picoampere,   // pA
53
54    // Conductance
55    Siemens,
56    Nanosiemens,  // nS
57    Microsiemens, // uS
58
59    // Capacitance
60    Farad,
61    Picofarad,    // pF
62
63    // Resistance
64    Ohm,
65    Megaohm,      // MOhm
66    Gigaohm,      // GOhm
67
68    // Frequency
69    Hertz,
70
71    // Dimensionless
72    Dimensionless,
73}
74
75impl Unit {
76    /// Convert to SI base units
77    pub fn to_si_factor(&self) -> f64 {
78        match self {
79            Unit::Second => 1.0,
80            Unit::Millisecond => 1e-3,
81            Unit::Microsecond => 1e-6,
82            Unit::Volt => 1.0,
83            Unit::Millivolt => 1e-3,
84            Unit::Ampere => 1.0,
85            Unit::Nanoampere => 1e-9,
86            Unit::Picoampere => 1e-12,
87            Unit::Siemens => 1.0,
88            Unit::Nanosiemens => 1e-9,
89            Unit::Microsiemens => 1e-6,
90            Unit::Farad => 1.0,
91            Unit::Picofarad => 1e-12,
92            Unit::Ohm => 1.0,
93            Unit::Megaohm => 1e6,
94            Unit::Gigaohm => 1e9,
95            Unit::Hertz => 1.0,
96            Unit::Dimensionless => 1.0,
97        }
98    }
99}
100
101/// Quantity with value and unit
102#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
103pub struct Quantity {
104    pub value: f64,
105    pub unit: Unit,
106}
107
108impl Quantity {
109    pub fn new(value: f64, unit: Unit) -> Self {
110        Self { value, unit }
111    }
112
113    /// Convert to SI base units
114    pub fn to_si(&self) -> f64 {
115        self.value * self.unit.to_si_factor()
116    }
117}
118
119// ============================================================================
120// EQUATION SYSTEM
121// ============================================================================
122
123/// Differential equation: dv/dt = expr
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct DifferentialEquation {
126    pub variable: String,
127    pub expression: String,
128    pub unit: Unit,
129    /// Method: euler, rk2, rk4, exponential_euler
130    pub method: IntegrationMethod,
131}
132
133/// Algebraic equation: v = expr (computed each timestep)
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct AlgebraicEquation {
136    pub variable: String,
137    pub expression: String,
138    pub unit: Unit,
139}
140
141/// Threshold condition for spike generation
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct ThresholdCondition {
144    pub condition: String,  // e.g., "v > v_thresh"
145}
146
147/// Reset equations after spike
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct ResetEquations {
150    pub equations: Vec<String>,  // e.g., ["v = v_reset", "w += b"]
151}
152
153/// Refractory period specification
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub enum RefractorySpec {
156    Duration(Quantity),           // Fixed duration
157    Condition(String),            // Until condition is met
158}
159
160/// Integration methods
161#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
162pub enum IntegrationMethod {
163    Euler,
164    ExponentialEuler,  // For linear ODEs
165    RungeKutta2,
166    RungeKutta4,
167    Heun,
168    Milstein,  // For SDEs
169    ExactSolution,  // For analytically solvable equations
170}
171
172/// Complete neuron equations
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct NeuronEquations {
175    pub differential: Vec<DifferentialEquation>,
176    pub algebraic: Vec<AlgebraicEquation>,
177    pub threshold: Option<ThresholdCondition>,
178    pub reset: Option<ResetEquations>,
179    pub refractory: Option<RefractorySpec>,
180    pub parameters: HashMap<String, Quantity>,
181}
182
183// ============================================================================
184// NEURON MODELS
185// ============================================================================
186
187/// Leaky Integrate-and-Fire neuron
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct LIFNeuron {
190    pub tau_m: f64,      // Membrane time constant (ms)
191    pub v_rest: f64,     // Resting potential (mV)
192    pub v_reset: f64,    // Reset potential (mV)
193    pub v_thresh: f64,   // Spike threshold (mV)
194    pub r_m: f64,        // Membrane resistance (MOhm)
195    pub tau_ref: f64,    // Refractory period (ms)
196}
197
198impl Default for LIFNeuron {
199    fn default() -> Self {
200        Self {
201            tau_m: 10.0,
202            v_rest: -65.0,
203            v_reset: -65.0,
204            v_thresh: -50.0,
205            r_m: 10.0,
206            tau_ref: 2.0,
207        }
208    }
209}
210
211impl LIFNeuron {
212    pub fn to_equations(&self) -> NeuronEquations {
213        NeuronEquations {
214            differential: vec![
215                DifferentialEquation {
216                    variable: "v".into(),
217                    expression: format!(
218                        "(({} - v) + {} * I) / {}",
219                        self.v_rest, self.r_m, self.tau_m
220                    ),
221                    unit: Unit::Millivolt,
222                    method: IntegrationMethod::ExponentialEuler,
223                },
224            ],
225            algebraic: vec![],
226            threshold: Some(ThresholdCondition {
227                condition: format!("v > {}", self.v_thresh),
228            }),
229            reset: Some(ResetEquations {
230                equations: vec![format!("v = {}", self.v_reset)],
231            }),
232            refractory: Some(RefractorySpec::Duration(
233                Quantity::new(self.tau_ref, Unit::Millisecond)
234            )),
235            parameters: HashMap::new(),
236        }
237    }
238}
239
240/// Adaptive Exponential Integrate-and-Fire (AdEx)
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct AdExNeuron {
243    pub c_m: f64,        // Membrane capacitance (pF)
244    pub g_l: f64,        // Leak conductance (nS)
245    pub e_l: f64,        // Leak reversal (mV)
246    pub v_t: f64,        // Spike initiation threshold (mV)
247    pub delta_t: f64,    // Slope factor (mV)
248    pub tau_w: f64,      // Adaptation time constant (ms)
249    pub a: f64,          // Subthreshold adaptation (nS)
250    pub b: f64,          // Spike-triggered adaptation (pA)
251    pub v_reset: f64,    // Reset potential (mV)
252    pub v_peak: f64,     // Spike cutoff (mV)
253}
254
255impl Default for AdExNeuron {
256    fn default() -> Self {
257        Self {
258            c_m: 281.0,
259            g_l: 30.0,
260            e_l: -70.6,
261            v_t: -50.4,
262            delta_t: 2.0,
263            tau_w: 144.0,
264            a: 4.0,
265            b: 80.5,
266            v_reset: -70.6,
267            v_peak: 20.0,
268        }
269    }
270}
271
272impl AdExNeuron {
273    pub fn to_equations(&self) -> NeuronEquations {
274        NeuronEquations {
275            differential: vec![
276                DifferentialEquation {
277                    variable: "v".into(),
278                    expression: format!(
279                        "(-{} * (v - {}) + {} * {} * exp((v - {}) / {}) - w + I) / {}",
280                        self.g_l, self.e_l, self.g_l, self.delta_t,
281                        self.v_t, self.delta_t, self.c_m
282                    ),
283                    unit: Unit::Millivolt,
284                    method: IntegrationMethod::Euler,
285                },
286                DifferentialEquation {
287                    variable: "w".into(),
288                    expression: format!(
289                        "({} * (v - {}) - w) / {}",
290                        self.a, self.e_l, self.tau_w
291                    ),
292                    unit: Unit::Picoampere,
293                    method: IntegrationMethod::Euler,
294                },
295            ],
296            algebraic: vec![],
297            threshold: Some(ThresholdCondition {
298                condition: format!("v > {}", self.v_peak),
299            }),
300            reset: Some(ResetEquations {
301                equations: vec![
302                    format!("v = {}", self.v_reset),
303                    format!("w += {}", self.b),
304                ],
305            }),
306            refractory: None,
307            parameters: HashMap::new(),
308        }
309    }
310}
311
312/// Izhikevich simple model
313#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct IzhikevichNeuron {
315    pub a: f64,  // Recovery time scale
316    pub b: f64,  // Recovery sensitivity
317    pub c: f64,  // Reset potential (mV)
318    pub d: f64,  // Recovery reset
319}
320
321impl IzhikevichNeuron {
322    /// Regular spiking (RS) - typical excitatory cortical neuron
323    pub fn regular_spiking() -> Self {
324        Self { a: 0.02, b: 0.2, c: -65.0, d: 8.0 }
325    }
326
327    /// Intrinsically bursting (IB)
328    pub fn intrinsically_bursting() -> Self {
329        Self { a: 0.02, b: 0.2, c: -55.0, d: 4.0 }
330    }
331
332    /// Chattering (CH)
333    pub fn chattering() -> Self {
334        Self { a: 0.02, b: 0.2, c: -50.0, d: 2.0 }
335    }
336
337    /// Fast spiking (FS) - inhibitory interneuron
338    pub fn fast_spiking() -> Self {
339        Self { a: 0.1, b: 0.2, c: -65.0, d: 2.0 }
340    }
341
342    /// Low-threshold spiking (LTS)
343    pub fn low_threshold_spiking() -> Self {
344        Self { a: 0.02, b: 0.25, c: -65.0, d: 2.0 }
345    }
346
347    pub fn to_equations(&self) -> NeuronEquations {
348        NeuronEquations {
349            differential: vec![
350                DifferentialEquation {
351                    variable: "v".into(),
352                    expression: "0.04 * v * v + 5.0 * v + 140.0 - u + I".into(),
353                    unit: Unit::Millivolt,
354                    method: IntegrationMethod::Euler,
355                },
356                DifferentialEquation {
357                    variable: "u".into(),
358                    expression: format!("{} * ({} * v - u)", self.a, self.b),
359                    unit: Unit::Dimensionless,
360                    method: IntegrationMethod::Euler,
361                },
362            ],
363            algebraic: vec![],
364            threshold: Some(ThresholdCondition {
365                condition: "v >= 30.0".into(),
366            }),
367            reset: Some(ResetEquations {
368                equations: vec![
369                    format!("v = {}", self.c),
370                    format!("u += {}", self.d),
371                ],
372            }),
373            refractory: None,
374            parameters: HashMap::new(),
375        }
376    }
377}
378
379// ============================================================================
380// SYNAPSE MODELS
381// ============================================================================
382
383/// Synapse model types
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub enum SynapseModel {
386    /// Instantaneous (delta function)
387    Delta { weight: f64 },
388
389    /// Exponential decay: g(t) = w * exp(-t/tau)
390    Exponential {
391        weight: f64,
392        tau: f64,  // ms
393    },
394
395    /// Alpha function: g(t) = w * (t/tau) * exp(1 - t/tau)
396    Alpha {
397        weight: f64,
398        tau: f64,  // ms
399    },
400
401    /// Difference of exponentials
402    DualExponential {
403        weight: f64,
404        tau_rise: f64,  // ms
405        tau_decay: f64, // ms
406    },
407
408    /// NMDA with voltage-dependent Mg block
409    NMDA {
410        weight: f64,
411        tau_rise: f64,
412        tau_decay: f64,
413        mg_concentration: f64,  // mM
414    },
415
416    /// Short-term plasticity (Tsodyks-Markram)
417    STP {
418        weight: f64,
419        u_se: f64,     // Initial utilization
420        tau_rec: f64,  // Recovery time constant (ms)
421        tau_fac: f64,  // Facilitation time constant (ms)
422    },
423}
424
425/// Spike-Timing-Dependent Plasticity
426#[derive(Debug, Clone, Serialize, Deserialize)]
427pub struct STDPRule {
428    pub tau_pre: f64,   // Pre-synaptic trace time constant (ms)
429    pub tau_post: f64,  // Post-synaptic trace time constant (ms)
430    pub a_plus: f64,    // LTP amplitude
431    pub a_minus: f64,   // LTD amplitude
432    pub w_max: f64,     // Maximum weight
433    pub w_min: f64,     // Minimum weight
434}
435
436impl Default for STDPRule {
437    fn default() -> Self {
438        Self {
439            tau_pre: 20.0,
440            tau_post: 20.0,
441            a_plus: 0.01,
442            a_minus: 0.012,  // Slightly stronger LTD
443            w_max: 1.0,
444            w_min: 0.0,
445        }
446    }
447}
448
449// ============================================================================
450// NEURON GROUP
451// ============================================================================
452
453/// A group of neurons sharing the same equations
454#[derive(Debug, Clone, Serialize, Deserialize)]
455pub struct NeuronGroup {
456    pub name: String,
457    pub n: usize,
458    pub equations: NeuronEquations,
459    pub method: IntegrationMethod,
460    /// State variables for all neurons
461    pub state: HashMap<String, Array1<f64>>,
462    /// Last spike time for each neuron (-inf if never spiked)
463    pub last_spike: Array1<f64>,
464    /// Is neuron currently in refractory period?
465    pub refractory_until: Array1<f64>,
466}
467
468impl NeuronGroup {
469    pub fn new(name: &str, n: usize, equations: NeuronEquations) -> Self {
470        let mut state = HashMap::new();
471
472        // Initialize state variables
473        for eq in &equations.differential {
474            state.insert(eq.variable.clone(), Array1::zeros(n));
475        }
476
477        Self {
478            name: name.to_string(),
479            n,
480            equations,
481            method: IntegrationMethod::Euler,
482            state,
483            last_spike: Array1::from_elem(n, f64::NEG_INFINITY),
484            refractory_until: Array1::from_elem(n, f64::NEG_INFINITY),
485        }
486    }
487
488    pub fn set_initial(&mut self, variable: &str, values: Array1<f64>) -> Result<()> {
489        if let Some(state) = self.state.get_mut(variable) {
490            if values.len() != self.n {
491                return Err(BrianError::SimulationError(
492                    format!("Expected {} values, got {}", self.n, values.len())
493                ));
494            }
495            *state = values;
496            Ok(())
497        } else {
498            Err(BrianError::SimulationError(
499                format!("Unknown variable: {}", variable)
500            ))
501        }
502    }
503}
504
505// ============================================================================
506// SYNAPSES
507// ============================================================================
508
509/// Synapse connections between neuron groups
510#[derive(Debug, Clone, Serialize, Deserialize)]
511pub struct Synapses {
512    pub name: String,
513    pub source: String,      // Source NeuronGroup name
514    pub target: String,      // Target NeuronGroup name
515    pub model: SynapseModel,
516    pub plasticity: Option<STDPRule>,
517    /// Sparse connectivity: (source_idx, target_idx)
518    pub connections: Vec<(usize, usize)>,
519    /// Weights (same length as connections)
520    pub weights: Vec<f64>,
521    /// Delays in ms (same length as connections)
522    pub delays: Vec<f64>,
523}
524
525impl Synapses {
526    pub fn new(name: &str, source: &str, target: &str, model: SynapseModel) -> Self {
527        Self {
528            name: name.to_string(),
529            source: source.to_string(),
530            target: target.to_string(),
531            model,
532            plasticity: None,
533            connections: vec![],
534            weights: vec![],
535            delays: vec![],
536        }
537    }
538
539    /// Connect all-to-all
540    pub fn connect_all_to_all(&mut self, n_source: usize, n_target: usize, weight: f64, delay: f64) {
541        for i in 0..n_source {
542            for j in 0..n_target {
543                self.connections.push((i, j));
544                self.weights.push(weight);
545                self.delays.push(delay);
546            }
547        }
548    }
549
550    /// Connect with probability p
551    pub fn connect_random(&mut self, n_source: usize, n_target: usize, p: f64, weight: f64, delay: f64) {
552        use std::collections::hash_map::DefaultHasher;
553        use std::hash::{Hash, Hasher};
554
555        for i in 0..n_source {
556            for j in 0..n_target {
557                let mut hasher = DefaultHasher::new();
558                (i, j).hash(&mut hasher);
559                let hash = hasher.finish();
560                let r = (hash as f64) / (u64::MAX as f64);
561
562                if r < p {
563                    self.connections.push((i, j));
564                    self.weights.push(weight);
565                    self.delays.push(delay);
566                }
567            }
568        }
569    }
570
571    /// One-to-one mapping
572    pub fn connect_one_to_one(&mut self, n: usize, weight: f64, delay: f64) {
573        for i in 0..n {
574            self.connections.push((i, i));
575            self.weights.push(weight);
576            self.delays.push(delay);
577        }
578    }
579}
580
581// ============================================================================
582// INPUT DEVICES
583// ============================================================================
584
585/// Poisson spike generator
586#[derive(Debug, Clone, Serialize, Deserialize)]
587pub struct PoissonGroup {
588    pub name: String,
589    pub n: usize,
590    pub rates: Array1<f64>,  // Hz
591}
592
593impl PoissonGroup {
594    pub fn new(name: &str, n: usize, rate: f64) -> Self {
595        Self {
596            name: name.to_string(),
597            n,
598            rates: Array1::from_elem(n, rate),
599        }
600    }
601
602    pub fn new_heterogeneous(name: &str, rates: Array1<f64>) -> Self {
603        let n = rates.len();
604        Self {
605            name: name.to_string(),
606            n,
607            rates,
608        }
609    }
610}
611
612/// Spike generator from predetermined spike times
613#[derive(Debug, Clone, Serialize, Deserialize)]
614pub struct SpikeGeneratorGroup {
615    pub name: String,
616    pub n: usize,
617    /// Spike times: (neuron_idx, time_ms)
618    pub spike_times: Vec<(usize, f64)>,
619}
620
621impl SpikeGeneratorGroup {
622    pub fn new(name: &str, n: usize) -> Self {
623        Self {
624            name: name.to_string(),
625            n,
626            spike_times: vec![],
627        }
628    }
629
630    pub fn add_spikes(&mut self, indices: &[usize], times: &[f64]) {
631        for (&i, &t) in indices.iter().zip(times.iter()) {
632            self.spike_times.push((i, t));
633        }
634        // Sort by time
635        self.spike_times.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
636    }
637}
638
639/// Timed array for time-varying input
640#[derive(Debug, Clone, Serialize, Deserialize)]
641pub struct TimedArray {
642    pub name: String,
643    pub times: Array1<f64>,   // ms
644    pub values: Array2<f64>,  // (time_points, neurons)
645}
646
647// ============================================================================
648// MONITORS
649// ============================================================================
650
651/// Record spike times
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct SpikeMonitor {
654    pub source: String,
655    /// Recorded spikes: (neuron_idx, time_ms)
656    pub spikes: Vec<(usize, f64)>,
657    /// Spike counts per neuron
658    pub counts: Vec<usize>,
659}
660
661impl SpikeMonitor {
662    pub fn new(source: &str, n: usize) -> Self {
663        Self {
664            source: source.to_string(),
665            spikes: vec![],
666            counts: vec![0; n],
667        }
668    }
669
670    pub fn record_spike(&mut self, idx: usize, time: f64) {
671        self.spikes.push((idx, time));
672        if idx < self.counts.len() {
673            self.counts[idx] += 1;
674        }
675    }
676
677    /// Get spike trains for each neuron
678    pub fn spike_trains(&self) -> HashMap<usize, Vec<f64>> {
679        let mut trains: HashMap<usize, Vec<f64>> = HashMap::new();
680        for &(idx, time) in &self.spikes {
681            trains.entry(idx).or_default().push(time);
682        }
683        trains
684    }
685
686    /// Calculate firing rate in Hz
687    pub fn mean_rate(&self, duration_ms: f64) -> f64 {
688        if self.counts.is_empty() || duration_ms <= 0.0 {
689            return 0.0;
690        }
691        let total_spikes: usize = self.counts.iter().sum();
692        (total_spikes as f64) / (self.counts.len() as f64) / (duration_ms / 1000.0)
693    }
694}
695
696/// Record state variable over time
697#[derive(Debug, Clone, Serialize, Deserialize)]
698pub struct StateMonitor {
699    pub source: String,
700    pub variables: Vec<String>,
701    pub record_indices: Vec<usize>,  // Which neurons to record
702    pub dt: f64,                     // Recording timestep (ms)
703    /// Recorded values: variable -> (times, values[neuron][time])
704    pub data: HashMap<String, (Vec<f64>, Vec<Vec<f64>>)>,
705}
706
707impl StateMonitor {
708    pub fn new(source: &str, variables: &[&str], indices: &[usize], dt: f64) -> Self {
709        let mut data = HashMap::new();
710        for var in variables {
711            data.insert(var.to_string(), (vec![], vec![vec![]; indices.len()]));
712        }
713
714        Self {
715            source: source.to_string(),
716            variables: variables.iter().map(|s| s.to_string()).collect(),
717            record_indices: indices.to_vec(),
718            dt,
719            data,
720        }
721    }
722
723    pub fn record(&mut self, variable: &str, time: f64, values: &Array1<f64>) {
724        if let Some((times, data)) = self.data.get_mut(variable) {
725            if times.is_empty() || time >= times.last().unwrap() + self.dt {
726                times.push(time);
727                for (i, &idx) in self.record_indices.iter().enumerate() {
728                    if idx < values.len() {
729                        data[i].push(values[idx]);
730                    }
731                }
732            }
733        }
734    }
735}
736
737/// Population rate monitor
738#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct PopulationRateMonitor {
740    pub source: String,
741    pub bin_size: f64,  // ms
742    pub times: Vec<f64>,
743    pub rates: Vec<f64>,  // Hz
744}
745
746// ============================================================================
747// NETWORK
748// ============================================================================
749
750/// Complete Brian network
751#[derive(Debug, Clone, Serialize, Deserialize)]
752pub struct Network {
753    pub neuron_groups: HashMap<String, NeuronGroup>,
754    pub synapses: HashMap<String, Synapses>,
755    pub poisson_groups: HashMap<String, PoissonGroup>,
756    pub spike_generators: HashMap<String, SpikeGeneratorGroup>,
757    pub spike_monitors: HashMap<String, SpikeMonitor>,
758    pub state_monitors: HashMap<String, StateMonitor>,
759    pub dt: f64,  // Timestep in ms
760    pub t: f64,   // Current time in ms
761}
762
763impl Network {
764    pub fn new(dt: f64) -> Self {
765        Self {
766            neuron_groups: HashMap::new(),
767            synapses: HashMap::new(),
768            poisson_groups: HashMap::new(),
769            spike_generators: HashMap::new(),
770            spike_monitors: HashMap::new(),
771            state_monitors: HashMap::new(),
772            dt,
773            t: 0.0,
774        }
775    }
776
777    pub fn add_neuron_group(&mut self, group: NeuronGroup) {
778        self.neuron_groups.insert(group.name.clone(), group);
779    }
780
781    pub fn add_synapses(&mut self, synapses: Synapses) {
782        self.synapses.insert(synapses.name.clone(), synapses);
783    }
784
785    pub fn add_poisson_group(&mut self, group: PoissonGroup) {
786        self.poisson_groups.insert(group.name.clone(), group);
787    }
788
789    pub fn add_spike_monitor(&mut self, monitor: SpikeMonitor) {
790        self.spike_monitors.insert(monitor.source.clone(), monitor);
791    }
792
793    pub fn add_state_monitor(&mut self, monitor: StateMonitor) {
794        self.state_monitors.insert(
795            format!("{}_state", monitor.source),
796            monitor
797        );
798    }
799
800    /// Run simulation for given duration
801    pub fn run(&mut self, duration: f64) -> Result<()> {
802        let n_steps = (duration / self.dt).ceil() as usize;
803
804        for _ in 0..n_steps {
805            self.step()?;
806        }
807
808        Ok(())
809    }
810
811    /// Single simulation step
812    fn step(&mut self) -> Result<()> {
813        // Update time
814        self.t += self.dt;
815
816        // For now, basic Euler integration (placeholder for full implementation)
817        for (_name, group) in &mut self.neuron_groups {
818            // Simple integration of state variables would go here
819            // This is a skeleton - full implementation would parse and evaluate equations
820            let _n = group.n;
821        }
822
823        Ok(())
824    }
825}
826
827// ============================================================================
828// BRIAN SCRIPT PARSER (simplified)
829// ============================================================================
830
831/// Parse Brian-style equations
832pub fn parse_equations(text: &str) -> Result<NeuronEquations> {
833    let mut differential = vec![];
834    let mut algebraic = vec![];
835
836    for line in text.lines() {
837        let line = line.trim();
838        if line.is_empty() || line.starts_with('#') {
839            continue;
840        }
841
842        // Differential equation: dv/dt = expr : unit
843        if line.starts_with('d') && line.contains("/dt") {
844            let parts: Vec<&str> = line.split('=').collect();
845            if parts.len() >= 2 {
846                let var_part = parts[0].trim();
847                let var = var_part
848                    .trim_start_matches('d')
849                    .split("/dt")
850                    .next()
851                    .unwrap_or("")
852                    .trim();
853
854                let expr_parts: Vec<&str> = parts[1].split(':').collect();
855                let expr = expr_parts[0].trim();
856
857                differential.push(DifferentialEquation {
858                    variable: var.to_string(),
859                    expression: expr.to_string(),
860                    unit: Unit::Dimensionless,
861                    method: IntegrationMethod::Euler,
862                });
863            }
864        }
865        // Algebraic equation: v = expr : unit
866        else if line.contains('=') && !line.contains("/dt") {
867            let parts: Vec<&str> = line.split('=').collect();
868            if parts.len() >= 2 {
869                let var = parts[0].trim();
870                let expr_parts: Vec<&str> = parts[1].split(':').collect();
871                let expr = expr_parts[0].trim();
872
873                algebraic.push(AlgebraicEquation {
874                    variable: var.to_string(),
875                    expression: expr.to_string(),
876                    unit: Unit::Dimensionless,
877                });
878            }
879        }
880    }
881
882    Ok(NeuronEquations {
883        differential,
884        algebraic,
885        threshold: None,
886        reset: None,
887        refractory: None,
888        parameters: HashMap::new(),
889    })
890}
891
892// ============================================================================
893// STANDARD MODELS
894// ============================================================================
895
896/// Create a balanced E/I network (Brunel 2000)
897pub fn brunel_network(
898    n_exc: usize,
899    n_inh: usize,
900    g: f64,      // Relative inhibitory strength
901    eta: f64,    // External rate relative to threshold
902    dt: f64,
903) -> Network {
904    let mut network = Network::new(dt);
905
906    // LIF parameters
907    let lif = LIFNeuron::default();
908
909    // Excitatory neurons
910    let mut exc = NeuronGroup::new("E", n_exc, lif.to_equations());
911    exc.set_initial("v", Array1::from_elem(n_exc, -65.0)).ok();
912    network.add_neuron_group(exc);
913
914    // Inhibitory neurons
915    let mut inh = NeuronGroup::new("I", n_inh, lif.to_equations());
916    inh.set_initial("v", Array1::from_elem(n_inh, -65.0)).ok();
917    network.add_neuron_group(inh);
918
919    // Synapses
920    let w_exc = 0.1;  // mV
921    let w_inh = -g * w_exc;
922    let _delay = 1.5;  // ms
923
924    let p_conn = 0.1;  // Connection probability
925
926    // E -> E
927    let mut ee = Synapses::new("EE", "E", "E", SynapseModel::Delta { weight: w_exc });
928    ee.connect_random(n_exc, n_exc, p_conn, w_exc, 1.5);
929    network.add_synapses(ee);
930
931    // E -> I
932    let mut ei = Synapses::new("EI", "E", "I", SynapseModel::Delta { weight: w_exc });
933    ei.connect_random(n_exc, n_inh, p_conn, w_exc, 1.5);
934    network.add_synapses(ei);
935
936    // I -> E
937    let mut ie = Synapses::new("IE", "I", "E", SynapseModel::Delta { weight: w_inh });
938    ie.connect_random(n_inh, n_exc, p_conn, w_inh, 1.5);
939    network.add_synapses(ie);
940
941    // I -> I
942    let mut ii = Synapses::new("II", "I", "I", SynapseModel::Delta { weight: w_inh });
943    ii.connect_random(n_inh, n_inh, p_conn, w_inh, 1.5);
944    network.add_synapses(ii);
945
946    // External Poisson input
947    let nu_thresh = lif.v_thresh / (lif.r_m * lif.tau_m);  // Threshold rate
948    let nu_ext = eta * nu_thresh * 1000.0;  // Hz
949
950    network.add_poisson_group(PoissonGroup::new("ext_E", n_exc, nu_ext));
951    network.add_poisson_group(PoissonGroup::new("ext_I", n_inh, nu_ext));
952
953    // Monitors
954    network.add_spike_monitor(SpikeMonitor::new("E", n_exc));
955    network.add_spike_monitor(SpikeMonitor::new("I", n_inh));
956
957    network
958}
959
960/// CUBA (Current-based) network from Brian examples
961pub fn cuba_network(n: usize, dt: f64) -> Network {
962    let n_exc = (0.8 * n as f64) as usize;
963    let n_inh = n - n_exc;
964
965    brunel_network(n_exc, n_inh, 5.0, 2.0, dt)
966}
967
968/// COBA (Conductance-based) LIF network
969pub fn coba_network(_n: usize, dt: f64) -> Network {
970    // Simplified implementation
971    Network::new(dt)
972}
973
974// ============================================================================
975// TESTS
976// ============================================================================
977
978#[cfg(test)]
979mod tests {
980    use super::*;
981
982    #[test]
983    fn test_lif_equations() {
984        let lif = LIFNeuron::default();
985        let eqs = lif.to_equations();
986
987        assert_eq!(eqs.differential.len(), 1);
988        assert_eq!(eqs.differential[0].variable, "v");
989        assert!(eqs.threshold.is_some());
990        assert!(eqs.reset.is_some());
991    }
992
993    #[test]
994    fn test_adex_equations() {
995        let adex = AdExNeuron::default();
996        let eqs = adex.to_equations();
997
998        assert_eq!(eqs.differential.len(), 2);
999        assert_eq!(eqs.differential[0].variable, "v");
1000        assert_eq!(eqs.differential[1].variable, "w");
1001    }
1002
1003    #[test]
1004    fn test_izhikevich_types() {
1005        let rs = IzhikevichNeuron::regular_spiking();
1006        let fs = IzhikevichNeuron::fast_spiking();
1007
1008        assert!(rs.a < fs.a);  // FS has faster recovery
1009    }
1010
1011    #[test]
1012    fn test_synapse_connectivity() {
1013        let mut syn = Synapses::new("test", "A", "B", SynapseModel::Delta { weight: 1.0 });
1014        syn.connect_all_to_all(3, 4, 1.0, 1.0);
1015
1016        assert_eq!(syn.connections.len(), 12);  // 3 * 4
1017    }
1018
1019    #[test]
1020    fn test_neuron_group() {
1021        let lif = LIFNeuron::default();
1022        let mut group = NeuronGroup::new("test", 100, lif.to_equations());
1023
1024        assert_eq!(group.n, 100);
1025        assert!(group.state.contains_key("v"));
1026
1027        group.set_initial("v", Array1::from_elem(100, -70.0)).unwrap();
1028        assert_eq!(group.state["v"][0], -70.0);
1029    }
1030
1031    #[test]
1032    fn test_spike_monitor() {
1033        let mut monitor = SpikeMonitor::new("test", 10);
1034        monitor.record_spike(0, 10.0);
1035        monitor.record_spike(0, 20.0);
1036        monitor.record_spike(1, 15.0);
1037
1038        assert_eq!(monitor.counts[0], 2);
1039        assert_eq!(monitor.counts[1], 1);
1040        assert_eq!(monitor.spikes.len(), 3);
1041    }
1042
1043    #[test]
1044    fn test_parse_equations() {
1045        let text = r#"
1046            dv/dt = (v_rest - v) / tau : volt
1047            dw/dt = a * (v - v_rest) : amp
1048        "#;
1049
1050        let eqs = parse_equations(text).unwrap();
1051        assert_eq!(eqs.differential.len(), 2);
1052    }
1053
1054    #[test]
1055    fn test_brunel_network() {
1056        let net = brunel_network(80, 20, 5.0, 2.0, 0.1);
1057
1058        assert!(net.neuron_groups.contains_key("E"));
1059        assert!(net.neuron_groups.contains_key("I"));
1060        assert_eq!(net.neuron_groups["E"].n, 80);
1061        assert_eq!(net.neuron_groups["I"].n, 20);
1062    }
1063
1064    #[test]
1065    fn test_stdp_rule() {
1066        let stdp = STDPRule::default();
1067
1068        assert!(stdp.a_minus > stdp.a_plus);  // Slight LTD dominance
1069        assert_eq!(stdp.tau_pre, stdp.tau_post);
1070    }
1071}