Skip to main content

ncps/wirings/
base.rs

1use rand::prelude::*;
2use super::WiringConfig;
3
4/// Core trait defining connectivity patterns for Neural Circuit Policy networks.
5///
6/// The `Wiring` trait specifies how neurons connect to each other and to external inputs.
7/// It defines the sparse connectivity structure that makes NCPs parameter-efficient and
8/// interpretable compared to fully-connected RNNs.
9///
10/// # Lifecycle
11///
12/// 1. **Create** the wiring with desired structure (e.g., `AutoNCP::new(...)`)
13/// 2. **Build** with input dimension: `wiring.build(input_dim)` - this creates the sensory connections
14/// 3. **Use** with an RNN layer: `CfC::with_wiring(input_size, wiring, device)`
15///
16/// # Key Matrices
17///
18/// - **Adjacency Matrix** `[units × units]`: Internal neuron-to-neuron connections
19///   - Values: +1 (excitatory), -1 (inhibitory), 0 (no connection)
20/// - **Sensory Adjacency Matrix** `[input_dim × units]`: Input-to-neuron connections
21///   - Only available after calling `.build()`
22///
23/// # Example
24///
25/// ```rust
26/// use ncps::wirings::{AutoNCP, Wiring};
27///
28/// let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
29///
30/// // Check state before building
31/// assert!(!wiring.is_built());
32/// assert_eq!(wiring.input_dim(), None);
33///
34/// // Build with input dimension
35/// wiring.build(16);
36///
37/// // Now fully configured
38/// assert!(wiring.is_built());
39/// assert_eq!(wiring.input_dim(), Some(16));
40/// assert_eq!(wiring.units(), 32);
41/// assert_eq!(wiring.output_dim(), Some(8));
42/// ```
43///
44/// # Implementors
45///
46/// - [`super::AutoNCP`]: Automatic NCP configuration (recommended)
47/// - [`super::NCP`]: Manual NCP configuration with full control
48/// - [`FullyConnected`]: Dense connectivity (no sparsity)
49/// - [`super::Random`]: Random sparse connectivity
50pub trait Wiring: Send + Sync {
51    /// Returns the total number of neurons (hidden units) in this wiring.
52    ///
53    /// This is the size of the hidden state tensor: `[batch, units]`.
54    /// For NCP wirings, this equals `inter_neurons + command_neurons + motor_neurons`.
55    fn units(&self) -> usize;
56
57    /// Returns the input dimension (number of input features), or `None` if not yet built.
58    ///
59    /// This is only available after calling [`.build(input_dim)`](Wiring::build).
60    /// The sensory adjacency matrix will have shape `[input_dim, units]`.
61    fn input_dim(&self) -> Option<usize>;
62
63    /// Returns the output dimension (number of motor neurons).
64    ///
65    /// For NCP wirings, this is the number of motor neurons. The RNN output
66    /// will be projected to this size if it differs from `units()`.
67    ///
68    /// Returns `None` if output_dim equals units (no projection needed).
69    fn output_dim(&self) -> Option<usize>;
70
71    /// Returns the number of logical layers in this wiring.
72    ///
73    /// - `FullyConnected`: 1 layer
74    /// - `NCP`/`AutoNCP`: 3 layers (inter, command, motor)
75    fn num_layers(&self) -> usize {
76        1
77    }
78
79    /// Returns the neuron IDs belonging to a specific layer.
80    ///
81    /// For NCP wirings:
82    /// - Layer 0: Inter neurons (feature processing)
83    /// - Layer 1: Command neurons (integration/decision)
84    /// - Layer 2: Motor neurons (output)
85    ///
86    /// # Panics
87    ///
88    /// Panics if `layer_id >= num_layers()`.
89    fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
90        if layer_id == 0 {
91            (0..self.units()).collect()
92        } else {
93            vec![]
94        }
95    }
96
97    /// Returns `true` if the wiring has been built (input dimension is set).
98    ///
99    /// A wiring must be built before it can be used with an RNN layer.
100    /// Call [`.build(input_dim)`](Wiring::build) to build the wiring.
101    fn is_built(&self) -> bool {
102        self.input_dim().is_some()
103    }
104
105    /// Builds the wiring by setting the input dimension and creating sensory connections.
106    ///
107    /// **This method must be called before using the wiring with an RNN layer.**
108    ///
109    /// The build process:
110    /// 1. Sets the input dimension
111    /// 2. Creates the sensory adjacency matrix `[input_dim × units]`
112    /// 3. Establishes connections from inputs to the first layer of neurons
113    ///
114    /// # Arguments
115    ///
116    /// * `input_dim` - Number of input features per timestep
117    ///
118    /// # Panics
119    ///
120    /// Panics if called twice with different input dimensions:
121    /// ```should_panic
122    /// use ncps::wirings::{AutoNCP, Wiring};
123    ///
124    /// let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
125    /// wiring.build(16);
126    /// wiring.build(32);  // Panics! Different input_dim
127    /// ```
128    ///
129    /// Calling with the same dimension twice is safe (no-op).
130    fn build(&mut self, input_dim: usize);
131
132    /// Returns the type of a neuron by its ID.
133    ///
134    /// # Returns
135    ///
136    /// - `"motor"`: Output neurons (IDs 0..motor_neurons)
137    /// - `"command"`: Integration neurons (NCP only)
138    /// - `"inter"`: Feature processing neurons
139    ///
140    /// # Example
141    ///
142    /// ```rust
143    /// use ncps::wirings::{AutoNCP, Wiring};
144    ///
145    /// let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
146    /// wiring.build(16);
147    ///
148    /// // First 8 neurons are motor neurons (output)
149    /// assert_eq!(wiring.get_type_of_neuron(0), "motor");
150    /// assert_eq!(wiring.get_type_of_neuron(7), "motor");
151    ///
152    /// // Higher IDs are command or inter neurons
153    /// let neuron_type = wiring.get_type_of_neuron(15);
154    /// assert!(neuron_type == "command" || neuron_type == "inter");
155    /// ```
156    fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
157        let output_dim = self.output_dim().unwrap_or(0);
158        if neuron_id < output_dim {
159            "motor"
160        } else {
161            "inter"
162        }
163    }
164
165    /// Returns the internal adjacency matrix representing neuron-to-neuron synapses.
166    ///
167    /// Shape: `[units × units]`
168    ///
169    /// Values:
170    /// - `+1`: Excitatory synapse (increases activation)
171    /// - `-1`: Inhibitory synapse (decreases activation)
172    /// - `0`: No connection
173    ///
174    /// The matrix is indexed as `[source, destination]`, meaning `matrix[[i, j]]`
175    /// represents a synapse from neuron `i` to neuron `j`.
176    fn adjacency_matrix(&self) -> &ndarray::Array2<i32>;
177
178    /// Returns the sensory adjacency matrix (input-to-neuron connections).
179    ///
180    /// Shape: `[input_dim × units]`
181    ///
182    /// Returns `None` before [`.build()`](Wiring::build) is called.
183    ///
184    /// The matrix is indexed as `[input_feature, neuron]`, meaning `matrix[[i, j]]`
185    /// represents a synapse from input feature `i` to neuron `j`.
186    fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>>;
187
188    /// Returns the reversal potential initializer (same as adjacency matrix).
189    ///
190    /// Used internally by LTC cells for biologically-plausible dynamics.
191    fn erev_initializer(&self) -> ndarray::Array2<i32> {
192        self.adjacency_matrix().clone()
193    }
194
195    /// Returns the sensory reversal potential initializer.
196    ///
197    /// Used internally by LTC cells for biologically-plausible dynamics.
198    fn sensory_erev_initializer(&self) -> Option<ndarray::Array2<i32>> {
199        self.sensory_adjacency_matrix().map(|m| m.clone())
200    }
201
202    /// Adds or modifies an internal synapse between two neurons.
203    ///
204    /// # Arguments
205    ///
206    /// * `src` - Source neuron ID (0..units)
207    /// * `dest` - Destination neuron ID (0..units)
208    /// * `polarity` - Synapse type: +1 (excitatory) or -1 (inhibitory)
209    ///
210    /// # Panics
211    ///
212    /// - Panics if `src >= units` or `dest >= units`
213    /// - Panics if `polarity` is not +1 or -1
214    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32);
215
216    /// Adds or modifies a sensory synapse from an input feature to a neuron.
217    ///
218    /// # Arguments
219    ///
220    /// * `src` - Input feature index (0..input_dim)
221    /// * `dest` - Destination neuron ID (0..units)
222    /// * `polarity` - Synapse type: +1 (excitatory) or -1 (inhibitory)
223    ///
224    /// # Panics
225    ///
226    /// - Panics if wiring is not built (call `.build()` first)
227    /// - Panics if `src >= input_dim` or `dest >= units`
228    /// - Panics if `polarity` is not +1 or -1
229    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32);
230
231    /// Returns the total number of internal synapses (non-zero entries in adjacency matrix).
232    ///
233    /// This is a measure of model complexity/sparsity. Lower values indicate sparser networks.
234    fn synapse_count(&self) -> usize {
235        self.adjacency_matrix().mapv(|x| x.abs() as usize).sum()
236    }
237
238    /// Returns the total number of sensory synapses (input-to-neuron connections).
239    ///
240    /// Returns 0 if the wiring hasn't been built yet.
241    fn sensory_synapse_count(&self) -> usize {
242        self.sensory_adjacency_matrix()
243            .map(|m| m.mapv(|x| x.abs() as usize).sum())
244            .unwrap_or(0)
245    }
246
247    /// Returns `true` if this wiring requires external input (has sensory connections).
248    ///
249    /// Always returns `true` after `.build()` is called.
250    fn input_required(&self) -> bool {
251        self.sensory_adjacency_matrix().is_some()
252    }
253
254    /// Creates a serializable configuration for this wiring.
255    ///
256    /// Used for saving/loading models. See [`WiringConfig`] for details.
257    fn get_config(&self) -> WiringConfig;
258}
259
260/// Fully connected (dense) wiring structure.
261///
262/// Every neuron connects to every other neuron (and optionally to itself).
263/// This provides a baseline comparison for NCP's sparse connectivity.
264///
265/// # When to Use
266///
267/// - **Baseline comparison**: Compare NCP performance against dense networks
268/// - **Maximum expressiveness**: When sparsity is not a concern
269/// - **Debugging**: Simpler structure for testing
270///
271/// # Sparsity
272///
273/// `FullyConnected` has **no sparsity** - the adjacency matrix is fully populated.
274/// For a network with `N` units, this means `N²` internal synapses (or `N²-N` without self-connections).
275///
276/// # Example
277///
278/// ```rust
279/// use ncps::wirings::{FullyConnected, Wiring};
280///
281/// // Create a fully-connected wiring with 32 neurons, 8 outputs
282/// let mut wiring = FullyConnected::new(
283///     32,        // units (total neurons)
284///     Some(8),   // output_dim (motor neurons)
285///     42,        // seed for reproducibility
286///     true,      // self_connections allowed
287/// );
288///
289/// // Build with input dimension
290/// wiring.build(16);
291///
292/// // Check connectivity
293/// println!("Total synapses: {}", wiring.synapse_count());  // 32*32 = 1024
294/// println!("Sensory synapses: {}", wiring.sensory_synapse_count());  // 16*32 = 512
295/// ```
296///
297/// # Comparison with NCP
298///
299/// | Aspect | FullyConnected | NCP |
300/// |--------|----------------|-----|
301/// | Synapses | O(N²) | O(N) to O(N log N) |
302/// | Interpretability | Low | High |
303/// | Parameters | More | Fewer |
304/// | Structure | Single layer | 4-layer biological |
305#[derive(Clone, Debug)]
306pub struct FullyConnected {
307    units: usize,
308    output_dim: usize,
309    adjacency_matrix: ndarray::Array2<i32>,
310    sensory_adjacency_matrix: Option<ndarray::Array2<i32>>,
311    input_dim: Option<usize>,
312    self_connections: bool,
313    erev_init_seed: u64,
314}
315
316impl FullyConnected {
317    pub fn new(
318        units: usize,
319        output_dim: Option<usize>,
320        erev_init_seed: u64,
321        self_connections: bool,
322    ) -> Self {
323        let output_dim = output_dim.unwrap_or(units);
324        let mut adjacency_matrix = ndarray::Array2::zeros((units, units));
325        let mut rng = StdRng::seed_from_u64(erev_init_seed);
326
327        // Initialize synapses
328        for src in 0..units {
329            for dest in 0..units {
330                if src == dest && !self_connections {
331                    continue;
332                }
333                // 2/3 chance of excitatory, 1/3 inhibitory
334                let polarity: i32 = if rand::random::<f64>() < 0.33 { -1 } else { 1 };
335                adjacency_matrix[[src, dest]] = polarity;
336            }
337        }
338
339        Self {
340            units,
341            output_dim,
342            adjacency_matrix,
343            sensory_adjacency_matrix: None,
344            input_dim: None,
345            self_connections,
346            erev_init_seed,
347        }
348    }
349
350    /// Get configuration for serialization
351    pub fn get_full_config(&self) -> WiringConfig {
352        WiringConfig {
353            units: self.units,
354            adjacency_matrix: Some(
355                self.adjacency_matrix
356                    .outer_iter()
357                    .map(|v| v.to_vec())
358                    .collect(),
359            ),
360            sensory_adjacency_matrix: self
361                .sensory_adjacency_matrix
362                .as_ref()
363                .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
364            input_dim: self.input_dim,
365            output_dim: Some(self.output_dim),
366            // FullyConnected-specific fields
367            erev_init_seed: Some(self.erev_init_seed),
368            self_connections: Some(self.self_connections),
369            // Other fields not used by FullyConnected
370            num_inter_neurons: None,
371            num_command_neurons: None,
372            num_motor_neurons: None,
373            sensory_fanout: None,
374            inter_fanout: None,
375            recurrent_command_synapses: None,
376            motor_fanin: None,
377            seed: None,
378            sparsity_level: None,
379            random_seed: None,
380        }
381    }
382
383    pub fn from_config(config: WiringConfig) -> Self {
384        let units = config.units;
385        let adjacency_matrix = if let Some(matrix) = config.adjacency_matrix {
386            ndarray::Array2::from_shape_vec((units, units), matrix.into_iter().flatten().collect())
387                .expect("Invalid adjacency matrix shape")
388        } else {
389            ndarray::Array2::zeros((units, units))
390        };
391
392        let sensory_adjacency_matrix = config.sensory_adjacency_matrix.map(|matrix| {
393            let input_dim = config
394                .input_dim
395                .expect("Input dimension required when sensory matrix exists");
396            ndarray::Array2::from_shape_vec(
397                (input_dim, units),
398                matrix.into_iter().flatten().collect(),
399            )
400            .expect("Invalid sensory adjacency matrix shape")
401        });
402
403        Self {
404            units,
405            output_dim: config.output_dim.unwrap_or(units),
406            adjacency_matrix,
407            sensory_adjacency_matrix,
408            input_dim: config.input_dim,
409            self_connections: true,
410            erev_init_seed: 1111,
411        }
412    }
413}
414
415impl Wiring for FullyConnected {
416    fn units(&self) -> usize {
417        self.units
418    }
419
420    fn input_dim(&self) -> Option<usize> {
421        self.input_dim
422    }
423
424    fn output_dim(&self) -> Option<usize> {
425        Some(self.output_dim)
426    }
427
428    fn build(&mut self, input_dim: usize) {
429        if let Some(existing) = self.input_dim {
430            if existing != input_dim {
431                panic!(
432                    "Conflicting input dimensions: expected {}, got {}",
433                    existing, input_dim
434                );
435            }
436            return;
437        }
438
439        self.input_dim = Some(input_dim);
440        let mut sensory_matrix = ndarray::Array2::zeros((input_dim, self.units));
441        let mut rng = StdRng::seed_from_u64(self.erev_init_seed);
442
443        for src in 0..input_dim {
444            for dest in 0..self.units {
445                let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
446                sensory_matrix[[src, dest]] = polarity;
447            }
448        }
449        self.sensory_adjacency_matrix = Some(sensory_matrix);
450    }
451
452    fn adjacency_matrix(&self) -> &ndarray::Array2<i32> {
453        &self.adjacency_matrix
454    }
455
456    fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>> {
457        self.sensory_adjacency_matrix.as_ref()
458    }
459
460    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
461        if src >= self.units || dest >= self.units {
462            panic!(
463                "Invalid synapse: src={}, dest={}, units={}",
464                src, dest, self.units
465            );
466        }
467        if ![-1, 1].contains(&polarity) {
468            panic!("Polarity must be -1 or 1, got {}", polarity);
469        }
470        self.adjacency_matrix[[src, dest]] = polarity;
471    }
472
473    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
474        let input_dim = self
475            .input_dim
476            .expect("Must build wiring before adding sensory synapses");
477        if src >= input_dim || dest >= self.units {
478            panic!(
479                "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
480                src, dest, input_dim, self.units
481            );
482        }
483        if ![-1, 1].contains(&polarity) {
484            panic!("Polarity must be -1 or 1, got {}", polarity);
485        }
486        self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
487    }
488
489    fn get_config(&self) -> WiringConfig {
490        WiringConfig {
491            units: self.units,
492            adjacency_matrix: Some(
493                self.adjacency_matrix
494                    .outer_iter()
495                    .map(|v| v.to_vec())
496                    .collect(),
497            ),
498            sensory_adjacency_matrix: self
499                .sensory_adjacency_matrix
500                .as_ref()
501                .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
502            input_dim: self.input_dim,
503            output_dim: Some(self.output_dim),
504            // FullyConnected-specific fields
505            erev_init_seed: Some(self.erev_init_seed),
506            self_connections: Some(self.self_connections),
507            // Other fields not used by FullyConnected
508            num_inter_neurons: None,
509            num_command_neurons: None,
510            num_motor_neurons: None,
511            sensory_fanout: None,
512            inter_fanout: None,
513            recurrent_command_synapses: None,
514            motor_fanin: None,
515            seed: None,
516            sparsity_level: None,
517            random_seed: None,
518        }
519    }
520}