Skip to main content

ncps/wirings/
ncp.rs

1use super::base::Wiring;
2use super::WiringConfig;
3use ndarray::Array2;
4use rand::prelude::*;
5
6/// Neural Circuit Policy (NCP) wiring with biologically-inspired 4-layer architecture.
7///
8/// NCPs implement sparse, structured connectivity patterns inspired by the nervous system
9/// of *C. elegans*. This architecture provides:
10///
11/// - **Parameter efficiency**: Fewer synapses than fully-connected networks
12/// - **Interpretability**: Clear information flow through defined layers
13/// - **Biological plausibility**: Excitatory/inhibitory synapse types
14///
15/// # Architecture
16///
17/// NCPs organize neurons into 4 functional layers:
18///
19/// ```text
20/// Sensory Inputs ──► Inter Neurons ──► Command Neurons ──► Motor Neurons
21///    (input)         (processing)      (integration)       (output)
22///                                           │
23///                                           └──► Recurrent connections
24/// ```
25///
26/// ## Layer Descriptions
27///
28/// | Layer | Role | Connectivity |
29/// |-------|------|--------------|
30/// | **Sensory** | External inputs | → Inter (via `sensory_fanout`) |
31/// | **Inter** | Feature extraction | → Command (via `inter_fanout`) |
32/// | **Command** | Decision/integration | → Motor + self (recurrent) |
33/// | **Motor** | Output neurons | ← Command (via `motor_fanin`) |
34///
35/// # Connectivity Parameters
36///
37/// - `sensory_fanout`: How many inter neurons each input connects to
38/// - `inter_fanout`: How many command neurons each inter neuron connects to
39/// - `recurrent_command_synapses`: Number of command→command recurrent connections
40/// - `motor_fanin`: How many command neurons connect to each motor neuron
41///
42/// # Example
43///
44/// ```rust
45/// use ncps::wirings::{NCP, Wiring};
46///
47/// // Create NCP with explicit layer sizes
48/// let mut wiring = NCP::new(
49///     10,  // inter_neurons: feature processing
50///     8,   // command_neurons: integration layer
51///     4,   // motor_neurons: output size
52///     4,   // sensory_fanout: each input → 4 inter neurons
53///     4,   // inter_fanout: each inter → 4 command neurons
54///     6,   // recurrent_command_synapses
55///     4,   // motor_fanin: each motor ← 4 command neurons
56///     42,  // seed for reproducibility
57/// );
58///
59/// // Must build before use
60/// wiring.build(16);  // 16 input features
61///
62/// // Total neurons = inter + command + motor = 22
63/// assert_eq!(wiring.units(), 22);
64/// assert_eq!(wiring.output_dim(), Some(4));
65/// ```
66///
67/// # Neuron ID Layout
68///
69/// Neurons are assigned IDs in this order:
70/// ```text
71/// [0..motor) [motor..motor+command) [motor+command..units)
72///   Motor        Command                Inter
73/// ```
74///
75/// # When to Use
76///
77/// Use `NCP` directly when you need fine-grained control over:
78/// - Exact layer sizes
79/// - Connectivity density (fanout/fanin parameters)
80/// - Recurrent connection count
81///
82/// For automatic parameter selection, use [`AutoNCP`] instead.
83///
84/// # Panics
85///
86/// The constructor panics if constraints are violated:
87/// - `motor_fanin > command_neurons`
88/// - `sensory_fanout > inter_neurons`
89/// - `inter_fanout > command_neurons`
90#[derive(Clone, Debug)]
91pub struct NCP {
92    units: usize,
93    adjacency_matrix: Array2<i32>,
94    sensory_adjacency_matrix: Option<Array2<i32>>,
95    input_dim: Option<usize>,
96    num_inter_neurons: usize,
97    num_command_neurons: usize,
98    num_motor_neurons: usize,
99    sensory_fanout: usize,
100    inter_fanout: usize,
101    recurrent_command_synapses: usize,
102    motor_fanin: usize,
103    motor_neurons: Vec<usize>,
104    command_neurons: Vec<usize>,
105    inter_neurons: Vec<usize>,
106    sensory_neurons: Vec<usize>,
107    rng: StdRng,
108}
109
110impl NCP {
111    pub fn new(
112        inter_neurons: usize,
113        command_neurons: usize,
114        motor_neurons: usize,
115        sensory_fanout: usize,
116        inter_fanout: usize,
117        recurrent_command_synapses: usize,
118        motor_fanin: usize,
119        seed: u64,
120    ) -> Self {
121        let units = inter_neurons + command_neurons + motor_neurons;
122
123        // Validate parameters
124        if motor_fanin > command_neurons {
125            panic!(
126                "Motor fanin {} exceeds number of command neurons {}",
127                motor_fanin, command_neurons
128            );
129        }
130        if sensory_fanout > inter_neurons {
131            panic!(
132                "Sensory fanout {} exceeds number of inter neurons {}",
133                sensory_fanout, inter_neurons
134            );
135        }
136        if inter_fanout > command_neurons {
137            panic!(
138                "Inter fanout {} exceeds number of command neurons {}",
139                inter_fanout, command_neurons
140            );
141        }
142
143        // Neuron IDs: [0..motor ... command ... inter]
144        let motor_neuron_ids: Vec<usize> = (0..motor_neurons).collect();
145        let command_neuron_ids: Vec<usize> =
146            (motor_neurons..motor_neurons + command_neurons).collect();
147        let inter_neuron_ids: Vec<usize> = (motor_neurons + command_neurons..units).collect();
148
149        let adjacency_matrix = Array2::zeros((units, units));
150        let rng = StdRng::seed_from_u64(seed);
151
152        Self {
153            units,
154            adjacency_matrix,
155            sensory_adjacency_matrix: None,
156            input_dim: None,
157            num_inter_neurons: inter_neurons,
158            num_command_neurons: command_neurons,
159            num_motor_neurons: motor_neurons,
160            sensory_fanout,
161            inter_fanout,
162            recurrent_command_synapses,
163            motor_fanin,
164            motor_neurons: motor_neuron_ids,
165            command_neurons: command_neuron_ids,
166            inter_neurons: inter_neuron_ids,
167            sensory_neurons: vec![],
168            rng,
169        }
170    }
171
172    fn build_sensory_to_inter_layer(&mut self) {
173        let input_dim = self.input_dim.unwrap();
174        self.sensory_neurons = (0..input_dim).collect();
175
176        // Clone the neuron list to avoid borrow issues
177        let inter_neurons = self.inter_neurons.clone();
178        let sensory_neurons = self.sensory_neurons.clone();
179        let mut unreachable_inter: Vec<usize> = inter_neurons.clone();
180
181        // Connect each sensory neuron to exactly sensory_fanout inter neurons
182        for &src in &sensory_neurons {
183            let selected: Vec<_> = inter_neurons
184                .choose_multiple(&mut self.rng, self.sensory_fanout)
185                .cloned()
186                .collect();
187
188            for &dest in &selected {
189                if let Some(pos) = unreachable_inter.iter().position(|&x| x == dest) {
190                    unreachable_inter.remove(pos);
191                }
192                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
193                self.add_sensory_synapse(src, dest, polarity);
194            }
195        }
196
197        // Connect any unreachable inter neurons
198        let mean_inter_fanin = (input_dim * self.sensory_fanout / self.num_inter_neurons)
199            .max(1)
200            .min(input_dim);
201
202        for &dest in &unreachable_inter {
203            let selected: Vec<_> = sensory_neurons
204                .choose_multiple(&mut self.rng, mean_inter_fanin)
205                .cloned()
206                .collect();
207
208            for &src in &selected {
209                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
210                self.add_sensory_synapse(src, dest, polarity);
211            }
212        }
213    }
214
215    fn build_inter_to_command_layer(&mut self) {
216        // Clone the neuron lists to avoid borrow issues
217        let inter_neurons = self.inter_neurons.clone();
218        let command_neurons = self.command_neurons.clone();
219        let mut unreachable_command: Vec<usize> = command_neurons.clone();
220
221        // Connect inter neurons to command neurons
222        for &src in &inter_neurons {
223            let selected: Vec<_> = command_neurons
224                .choose_multiple(&mut self.rng, self.inter_fanout)
225                .cloned()
226                .collect();
227
228            for &dest in &selected {
229                if let Some(pos) = unreachable_command.iter().position(|&x| x == dest) {
230                    unreachable_command.remove(pos);
231                }
232                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
233                self.add_synapse(src, dest, polarity);
234            }
235        }
236
237        // Connect any unreachable command neurons
238        let mean_command_fanin = (self.num_inter_neurons * self.inter_fanout
239            / self.num_command_neurons)
240            .max(1)
241            .min(self.num_inter_neurons);
242
243        for &dest in &unreachable_command {
244            let selected: Vec<_> = inter_neurons
245                .choose_multiple(&mut self.rng, mean_command_fanin)
246                .cloned()
247                .collect();
248
249            for &src in &selected {
250                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
251                self.add_synapse(src, dest, polarity);
252            }
253        }
254    }
255
256    fn build_recurrent_command_layer(&mut self) {
257        for _ in 0..self.recurrent_command_synapses {
258            let src = *self.command_neurons.choose(&mut self.rng).unwrap();
259            let dest = *self.command_neurons.choose(&mut self.rng).unwrap();
260            let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
261            self.add_synapse(src, dest, polarity);
262        }
263    }
264
265    fn build_command_to_motor_layer(&mut self) {
266        // Clone the neuron lists to avoid borrow issues
267        let motor_neurons = self.motor_neurons.clone();
268        let command_neurons = self.command_neurons.clone();
269        let mut unreachable_command: Vec<usize> = command_neurons.clone();
270
271        // Connect command neurons to motor neurons
272        for &dest in &motor_neurons {
273            let selected: Vec<_> = command_neurons
274                .choose_multiple(&mut self.rng, self.motor_fanin)
275                .cloned()
276                .collect();
277
278            for &src in &selected {
279                if let Some(pos) = unreachable_command.iter().position(|&x| x == src) {
280                    unreachable_command.remove(pos);
281                }
282                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
283                self.add_synapse(src, dest, polarity);
284            }
285        }
286
287        // Connect any unreachable command neurons
288        let mean_command_fanout = (self.num_motor_neurons * self.motor_fanin
289            / self.num_command_neurons)
290            .max(1)
291            .min(self.num_motor_neurons);
292
293        for &src in &unreachable_command {
294            let selected: Vec<_> = motor_neurons
295                .choose_multiple(&mut self.rng, mean_command_fanout)
296                .cloned()
297                .collect();
298
299            for &dest in &selected {
300                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
301                self.add_synapse(src, dest, polarity);
302            }
303        }
304    }
305
306    pub fn from_config(config: WiringConfig) -> Self {
307        // Parse config to reconstruct NCP
308        let units = config.units;
309        let adjacency_matrix = if let Some(matrix) = config.adjacency_matrix {
310            Array2::from_shape_vec((units, units), matrix.into_iter().flatten().collect())
311                .expect("Invalid adjacency matrix shape")
312        } else {
313            Array2::zeros((units, units))
314        };
315
316        let sensory_adjacency_matrix = config.sensory_adjacency_matrix.map(|matrix| {
317            let input_dim = config
318                .input_dim
319                .expect("Input dimension required when sensory matrix exists");
320            Array2::from_shape_vec((input_dim, units), matrix.into_iter().flatten().collect())
321                .expect("Invalid sensory adjacency matrix shape")
322        });
323
324        // This would need additional info stored in config to reconstruct properly
325        // For now, create a basic NCP structure
326        let output_dim = config.output_dim.unwrap_or(1);
327        let inter_and_command = units - output_dim;
328        let command_neurons = (inter_and_command as f64 * 0.4).ceil() as usize;
329        let inter_neurons = inter_and_command - command_neurons;
330
331        NCP::new(
332            inter_neurons,
333            command_neurons,
334            output_dim,
335            6,     // Default sensory_fanout
336            6,     // Default inter_fanout
337            4,     // Default recurrent_command_synapses
338            6,     // Default motor_fanin
339            22222, // Default seed
340        )
341    }
342}
343
344impl Wiring for NCP {
345    fn units(&self) -> usize {
346        self.units
347    }
348
349    fn input_dim(&self) -> Option<usize> {
350        self.input_dim
351    }
352
353    fn output_dim(&self) -> Option<usize> {
354        Some(self.num_motor_neurons)
355    }
356
357    fn num_layers(&self) -> usize {
358        3
359    }
360
361    fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
362        match layer_id {
363            0 => self.inter_neurons.clone(),
364            1 => self.command_neurons.clone(),
365            2 => self.motor_neurons.clone(),
366            _ => panic!("Unknown layer {}", layer_id),
367        }
368    }
369
370    fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
371        if neuron_id < self.num_motor_neurons {
372            "motor"
373        } else if neuron_id < self.num_motor_neurons + self.num_command_neurons {
374            "command"
375        } else {
376            "inter"
377        }
378    }
379
380    fn build(&mut self, input_dim: usize) {
381        if let Some(existing) = self.input_dim {
382            if existing != input_dim {
383                panic!(
384                    "Conflicting input dimensions: expected {}, got {}",
385                    existing, input_dim
386                );
387            }
388            return;
389        }
390
391        self.input_dim = Some(input_dim);
392        self.sensory_adjacency_matrix = Some(Array2::zeros((input_dim, self.units)));
393
394        self.build_sensory_to_inter_layer();
395        self.build_inter_to_command_layer();
396        self.build_recurrent_command_layer();
397        self.build_command_to_motor_layer();
398    }
399
400    fn adjacency_matrix(&self) -> &Array2<i32> {
401        &self.adjacency_matrix
402    }
403
404    fn sensory_adjacency_matrix(&self) -> Option<&Array2<i32>> {
405        self.sensory_adjacency_matrix.as_ref()
406    }
407
408    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
409        if src >= self.units || dest >= self.units {
410            panic!(
411                "Invalid synapse: src={}, dest={}, units={}",
412                src, dest, self.units
413            );
414        }
415        if ![-1, 1].contains(&polarity) {
416            panic!("Polarity must be -1 or 1, got {}", polarity);
417        }
418        self.adjacency_matrix[[src, dest]] = polarity;
419    }
420
421    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
422        let input_dim = self
423            .input_dim
424            .expect("Must build wiring before adding sensory synapses");
425        if src >= input_dim || dest >= self.units {
426            panic!(
427                "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
428                src, dest, input_dim, self.units
429            );
430        }
431        if ![-1, 1].contains(&polarity) {
432            panic!("Polarity must be -1 or 1, got {}", polarity);
433        }
434        self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
435    }
436
437    fn get_config(&self) -> WiringConfig {
438        WiringConfig {
439            units: self.units,
440            adjacency_matrix: Some(
441                self.adjacency_matrix
442                    .outer_iter()
443                    .map(|v| v.to_vec())
444                    .collect(),
445            ),
446            sensory_adjacency_matrix: self
447                .sensory_adjacency_matrix
448                .as_ref()
449                .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
450            input_dim: self.input_dim,
451            output_dim: Some(self.num_motor_neurons),
452            // NCP-specific fields
453            num_inter_neurons: Some(self.num_inter_neurons),
454            num_command_neurons: Some(self.num_command_neurons),
455            num_motor_neurons: Some(self.num_motor_neurons),
456            sensory_fanout: Some(self.sensory_fanout),
457            inter_fanout: Some(self.inter_fanout),
458            recurrent_command_synapses: Some(self.recurrent_command_synapses),
459            motor_fanin: Some(self.motor_fanin),
460            seed: None, // NCP uses rng internally
461            // Other fields not used by NCP
462            erev_init_seed: None,
463            self_connections: None,
464            sparsity_level: None,
465            random_seed: None,
466        }
467    }
468}
469
470/// Automatic NCP configuration with simplified parameters.
471///
472/// `AutoNCP` is the **recommended way** to create NCP wirings. It automatically
473/// calculates layer sizes and connectivity based on just a few high-level parameters.
474///
475/// # Simplified Interface
476///
477/// Instead of specifying 7 parameters like [`NCP`], you only need 4:
478///
479/// | Parameter | Description |
480/// |-----------|-------------|
481/// | `units` | Total number of neurons (hidden state size) |
482/// | `output_size` | Number of motor neurons (output dimension) |
483/// | `sparsity_level` | Fraction of connections to remove (0.0 - 0.9) |
484/// | `seed` | Random seed for reproducibility |
485///
486/// # How Auto-Configuration Works
487///
488/// Given your parameters, AutoNCP:
489///
490/// 1. **Allocates neurons**: `units - output_size` split 60/40 between inter/command
491/// 2. **Sets connectivity**: Based on `density = 1.0 - sparsity_level`
492///    - `sensory_fanout = inter_neurons × density`
493///    - `inter_fanout = command_neurons × density`
494///    - `motor_fanin = command_neurons × density`
495///    - `recurrent_command_synapses = command_neurons × density × 2`
496///
497/// # Example
498///
499/// ```rust
500/// use ncps::wirings::{AutoNCP, Wiring};
501///
502/// // Create with automatic configuration
503/// let mut wiring = AutoNCP::new(
504///     32,    // units: total neurons
505///     8,     // output_size: motor neurons
506///     0.5,   // sparsity_level: 50% connections removed
507///     42,    // seed
508/// );
509///
510/// wiring.build(16);  // 16 input features
511///
512/// // Check auto-calculated structure
513/// assert_eq!(wiring.units(), 32);
514/// assert_eq!(wiring.output_dim(), Some(8));
515/// assert_eq!(wiring.num_layers(), 3);  // inter, command, motor
516/// ```
517///
518/// # Sparsity Level Guide
519///
520/// | Sparsity | Effect | Use Case |
521/// |----------|--------|----------|
522/// | 0.0 | Dense connections | Maximum expressiveness |
523/// | 0.3-0.5 | Moderate sparsity | **Recommended starting point** |
524/// | 0.7-0.9 | Very sparse | Edge deployment, interpretability |
525///
526/// # Constraints
527///
528/// - `output_size < units - 2` (need at least 2 neurons for inter + command)
529/// - `sparsity_level` must be in `[0.0, 0.9]`
530///
531/// # Panics
532///
533/// ```should_panic
534/// use ncps::wirings::AutoNCP;
535///
536/// // Panics: output_size too large
537/// let wiring = AutoNCP::new(10, 9, 0.5, 42);
538/// ```
539///
540/// ```should_panic
541/// use ncps::wirings::AutoNCP;
542///
543/// // Panics: sparsity_level out of range
544/// let wiring = AutoNCP::new(32, 8, 0.95, 42);
545/// ```
546#[derive(Clone, Debug)]
547pub struct AutoNCP {
548    ncp: NCP,
549    output_size: usize,
550    sparsity_level: f64,
551    seed: u64,
552}
553
554impl AutoNCP {
555    pub fn new(units: usize, output_size: usize, sparsity_level: f64, seed: u64) -> Self {
556        if output_size >= units - 2 {
557            panic!(
558                "Output size {} must be less than units-2 ({})",
559                output_size,
560                units - 2
561            );
562        }
563        if sparsity_level < 0.0 || sparsity_level > 0.9 {
564            panic!(
565                "Sparsity level must be between 0.0 and 0.9, got {}",
566                sparsity_level
567            );
568        }
569
570        let density_level = 1.0 - sparsity_level;
571        let inter_and_command_neurons = units - output_size;
572        let command_neurons = ((inter_and_command_neurons as f64 * 0.4).ceil() as usize).max(1);
573        let inter_neurons = inter_and_command_neurons - command_neurons;
574
575        let sensory_fanout = ((inter_neurons as f64 * density_level).ceil() as usize).max(1);
576        let inter_fanout = ((command_neurons as f64 * density_level).ceil() as usize).max(1);
577        let recurrent_command_synapses =
578            ((command_neurons as f64 * density_level * 2.0).ceil() as usize).max(1);
579        let motor_fanin = ((command_neurons as f64 * density_level).ceil() as usize).max(1);
580
581        let ncp = NCP::new(
582            inter_neurons,
583            command_neurons,
584            output_size,
585            sensory_fanout,
586            inter_fanout,
587            recurrent_command_synapses,
588            motor_fanin,
589            seed,
590        );
591
592        Self {
593            ncp,
594            output_size,
595            sparsity_level,
596            seed,
597        }
598    }
599}
600
601impl Wiring for AutoNCP {
602    fn units(&self) -> usize {
603        self.ncp.units()
604    }
605
606    fn input_dim(&self) -> Option<usize> {
607        self.ncp.input_dim()
608    }
609
610    fn output_dim(&self) -> Option<usize> {
611        Some(self.output_size)
612    }
613
614    fn num_layers(&self) -> usize {
615        self.ncp.num_layers()
616    }
617
618    fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
619        self.ncp.get_neurons_of_layer(layer_id)
620    }
621
622    fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
623        self.ncp.get_type_of_neuron(neuron_id)
624    }
625
626    fn build(&mut self, input_dim: usize) {
627        self.ncp.build(input_dim)
628    }
629
630    fn is_built(&self) -> bool {
631        self.ncp.is_built()
632    }
633
634    fn adjacency_matrix(&self) -> &Array2<i32> {
635        self.ncp.adjacency_matrix()
636    }
637
638    fn sensory_adjacency_matrix(&self) -> Option<&Array2<i32>> {
639        self.ncp.sensory_adjacency_matrix()
640    }
641
642    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
643        self.ncp.add_synapse(src, dest, polarity)
644    }
645
646    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
647        self.ncp.add_sensory_synapse(src, dest, polarity)
648    }
649
650    fn get_config(&self) -> WiringConfig {
651        // Get the underlying NCP config and add AutoNCP-specific fields
652        let mut config = self.ncp.get_config();
653        config.output_dim = Some(self.output_size);
654        config.sparsity_level = Some(self.sparsity_level);
655        config.seed = Some(self.seed);
656        config
657    }
658}