ncps/wirings/base.rs
1use rand::prelude::*;
2use super::WiringConfig;
3
4/// Core trait defining connectivity patterns for Neural Circuit Policy networks.
5///
6/// The `Wiring` trait specifies how neurons connect to each other and to external inputs.
7/// It defines the sparse connectivity structure that makes NCPs parameter-efficient and
8/// interpretable compared to fully-connected RNNs.
9///
10/// # Lifecycle
11///
12/// 1. **Create** the wiring with desired structure (e.g., `AutoNCP::new(...)`)
13/// 2. **Build** with input dimension: `wiring.build(input_dim)` - this creates the sensory connections
14/// 3. **Use** with an RNN layer: `CfC::with_wiring(input_size, wiring, device)`
15///
16/// # Key Matrices
17///
18/// - **Adjacency Matrix** `[units × units]`: Internal neuron-to-neuron connections
19/// - Values: +1 (excitatory), -1 (inhibitory), 0 (no connection)
20/// - **Sensory Adjacency Matrix** `[input_dim × units]`: Input-to-neuron connections
21/// - Only available after calling `.build()`
22///
23/// # Example
24///
25/// ```rust
26/// use ncps::wirings::{AutoNCP, Wiring};
27///
28/// let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
29///
30/// // Check state before building
31/// assert!(!wiring.is_built());
32/// assert_eq!(wiring.input_dim(), None);
33///
34/// // Build with input dimension
35/// wiring.build(16);
36///
37/// // Now fully configured
38/// assert!(wiring.is_built());
39/// assert_eq!(wiring.input_dim(), Some(16));
40/// assert_eq!(wiring.units(), 32);
41/// assert_eq!(wiring.output_dim(), Some(8));
42/// ```
43///
44/// # Implementors
45///
46/// - [`super::AutoNCP`]: Automatic NCP configuration (recommended)
47/// - [`super::NCP`]: Manual NCP configuration with full control
48/// - [`FullyConnected`]: Dense connectivity (no sparsity)
49/// - [`super::Random`]: Random sparse connectivity
50pub trait Wiring: Send + Sync {
51 /// Returns the total number of neurons (hidden units) in this wiring.
52 ///
53 /// This is the size of the hidden state tensor: `[batch, units]`.
54 /// For NCP wirings, this equals `inter_neurons + command_neurons + motor_neurons`.
55 fn units(&self) -> usize;
56
57 /// Returns the input dimension (number of input features), or `None` if not yet built.
58 ///
59 /// This is only available after calling [`.build(input_dim)`](Wiring::build).
60 /// The sensory adjacency matrix will have shape `[input_dim, units]`.
61 fn input_dim(&self) -> Option<usize>;
62
63 /// Returns the output dimension (number of motor neurons).
64 ///
65 /// For NCP wirings, this is the number of motor neurons. The RNN output
66 /// will be projected to this size if it differs from `units()`.
67 ///
68 /// Returns `None` if output_dim equals units (no projection needed).
69 fn output_dim(&self) -> Option<usize>;
70
71 /// Returns the number of logical layers in this wiring.
72 ///
73 /// - `FullyConnected`: 1 layer
74 /// - `NCP`/`AutoNCP`: 3 layers (inter, command, motor)
75 fn num_layers(&self) -> usize {
76 1
77 }
78
79 /// Returns the neuron IDs belonging to a specific layer.
80 ///
81 /// For NCP wirings:
82 /// - Layer 0: Inter neurons (feature processing)
83 /// - Layer 1: Command neurons (integration/decision)
84 /// - Layer 2: Motor neurons (output)
85 ///
86 /// # Panics
87 ///
88 /// Panics if `layer_id >= num_layers()`.
89 fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
90 if layer_id == 0 {
91 (0..self.units()).collect()
92 } else {
93 vec![]
94 }
95 }
96
97 /// Returns `true` if the wiring has been built (input dimension is set).
98 ///
99 /// A wiring must be built before it can be used with an RNN layer.
100 /// Call [`.build(input_dim)`](Wiring::build) to build the wiring.
101 fn is_built(&self) -> bool {
102 self.input_dim().is_some()
103 }
104
105 /// Builds the wiring by setting the input dimension and creating sensory connections.
106 ///
107 /// **This method must be called before using the wiring with an RNN layer.**
108 ///
109 /// The build process:
110 /// 1. Sets the input dimension
111 /// 2. Creates the sensory adjacency matrix `[input_dim × units]`
112 /// 3. Establishes connections from inputs to the first layer of neurons
113 ///
114 /// # Arguments
115 ///
116 /// * `input_dim` - Number of input features per timestep
117 ///
118 /// # Panics
119 ///
120 /// Panics if called twice with different input dimensions:
121 /// ```should_panic
122 /// use ncps::wirings::{AutoNCP, Wiring};
123 ///
124 /// let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
125 /// wiring.build(16);
126 /// wiring.build(32); // Panics! Different input_dim
127 /// ```
128 ///
129 /// Calling with the same dimension twice is safe (no-op).
130 fn build(&mut self, input_dim: usize);
131
132 /// Returns the type of a neuron by its ID.
133 ///
134 /// # Returns
135 ///
136 /// - `"motor"`: Output neurons (IDs 0..motor_neurons)
137 /// - `"command"`: Integration neurons (NCP only)
138 /// - `"inter"`: Feature processing neurons
139 ///
140 /// # Example
141 ///
142 /// ```rust
143 /// use ncps::wirings::{AutoNCP, Wiring};
144 ///
145 /// let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
146 /// wiring.build(16);
147 ///
148 /// // First 8 neurons are motor neurons (output)
149 /// assert_eq!(wiring.get_type_of_neuron(0), "motor");
150 /// assert_eq!(wiring.get_type_of_neuron(7), "motor");
151 ///
152 /// // Higher IDs are command or inter neurons
153 /// let neuron_type = wiring.get_type_of_neuron(15);
154 /// assert!(neuron_type == "command" || neuron_type == "inter");
155 /// ```
156 fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
157 let output_dim = self.output_dim().unwrap_or(0);
158 if neuron_id < output_dim {
159 "motor"
160 } else {
161 "inter"
162 }
163 }
164
165 /// Returns the internal adjacency matrix representing neuron-to-neuron synapses.
166 ///
167 /// Shape: `[units × units]`
168 ///
169 /// Values:
170 /// - `+1`: Excitatory synapse (increases activation)
171 /// - `-1`: Inhibitory synapse (decreases activation)
172 /// - `0`: No connection
173 ///
174 /// The matrix is indexed as `[source, destination]`, meaning `matrix[[i, j]]`
175 /// represents a synapse from neuron `i` to neuron `j`.
176 fn adjacency_matrix(&self) -> &ndarray::Array2<i32>;
177
178 /// Returns the sensory adjacency matrix (input-to-neuron connections).
179 ///
180 /// Shape: `[input_dim × units]`
181 ///
182 /// Returns `None` before [`.build()`](Wiring::build) is called.
183 ///
184 /// The matrix is indexed as `[input_feature, neuron]`, meaning `matrix[[i, j]]`
185 /// represents a synapse from input feature `i` to neuron `j`.
186 fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>>;
187
188 /// Returns the reversal potential initializer (same as adjacency matrix).
189 ///
190 /// Used internally by LTC cells for biologically-plausible dynamics.
191 fn erev_initializer(&self) -> ndarray::Array2<i32> {
192 self.adjacency_matrix().clone()
193 }
194
195 /// Returns the sensory reversal potential initializer.
196 ///
197 /// Used internally by LTC cells for biologically-plausible dynamics.
198 fn sensory_erev_initializer(&self) -> Option<ndarray::Array2<i32>> {
199 self.sensory_adjacency_matrix().map(|m| m.clone())
200 }
201
202 /// Adds or modifies an internal synapse between two neurons.
203 ///
204 /// # Arguments
205 ///
206 /// * `src` - Source neuron ID (0..units)
207 /// * `dest` - Destination neuron ID (0..units)
208 /// * `polarity` - Synapse type: +1 (excitatory) or -1 (inhibitory)
209 ///
210 /// # Panics
211 ///
212 /// - Panics if `src >= units` or `dest >= units`
213 /// - Panics if `polarity` is not +1 or -1
214 fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32);
215
216 /// Adds or modifies a sensory synapse from an input feature to a neuron.
217 ///
218 /// # Arguments
219 ///
220 /// * `src` - Input feature index (0..input_dim)
221 /// * `dest` - Destination neuron ID (0..units)
222 /// * `polarity` - Synapse type: +1 (excitatory) or -1 (inhibitory)
223 ///
224 /// # Panics
225 ///
226 /// - Panics if wiring is not built (call `.build()` first)
227 /// - Panics if `src >= input_dim` or `dest >= units`
228 /// - Panics if `polarity` is not +1 or -1
229 fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32);
230
231 /// Returns the total number of internal synapses (non-zero entries in adjacency matrix).
232 ///
233 /// This is a measure of model complexity/sparsity. Lower values indicate sparser networks.
234 fn synapse_count(&self) -> usize {
235 self.adjacency_matrix().mapv(|x| x.abs() as usize).sum()
236 }
237
238 /// Returns the total number of sensory synapses (input-to-neuron connections).
239 ///
240 /// Returns 0 if the wiring hasn't been built yet.
241 fn sensory_synapse_count(&self) -> usize {
242 self.sensory_adjacency_matrix()
243 .map(|m| m.mapv(|x| x.abs() as usize).sum())
244 .unwrap_or(0)
245 }
246
247 /// Returns `true` if this wiring requires external input (has sensory connections).
248 ///
249 /// Always returns `true` after `.build()` is called.
250 fn input_required(&self) -> bool {
251 self.sensory_adjacency_matrix().is_some()
252 }
253
254 /// Creates a serializable configuration for this wiring.
255 ///
256 /// Used for saving/loading models. See [`WiringConfig`] for details.
257 fn get_config(&self) -> WiringConfig;
258}
259
260/// Fully connected (dense) wiring structure.
261///
262/// Every neuron connects to every other neuron (and optionally to itself).
263/// This provides a baseline comparison for NCP's sparse connectivity.
264///
265/// # When to Use
266///
267/// - **Baseline comparison**: Compare NCP performance against dense networks
268/// - **Maximum expressiveness**: When sparsity is not a concern
269/// - **Debugging**: Simpler structure for testing
270///
271/// # Sparsity
272///
273/// `FullyConnected` has **no sparsity** - the adjacency matrix is fully populated.
274/// For a network with `N` units, this means `N²` internal synapses (or `N²-N` without self-connections).
275///
276/// # Example
277///
278/// ```rust
279/// use ncps::wirings::{FullyConnected, Wiring};
280///
281/// // Create a fully-connected wiring with 32 neurons, 8 outputs
282/// let mut wiring = FullyConnected::new(
283/// 32, // units (total neurons)
284/// Some(8), // output_dim (motor neurons)
285/// 42, // seed for reproducibility
286/// true, // self_connections allowed
287/// );
288///
289/// // Build with input dimension
290/// wiring.build(16);
291///
292/// // Check connectivity
293/// println!("Total synapses: {}", wiring.synapse_count()); // 32*32 = 1024
294/// println!("Sensory synapses: {}", wiring.sensory_synapse_count()); // 16*32 = 512
295/// ```
296///
297/// # Comparison with NCP
298///
299/// | Aspect | FullyConnected | NCP |
300/// |--------|----------------|-----|
301/// | Synapses | O(N²) | O(N) to O(N log N) |
302/// | Interpretability | Low | High |
303/// | Parameters | More | Fewer |
304/// | Structure | Single layer | 4-layer biological |
305#[derive(Clone, Debug)]
306pub struct FullyConnected {
307 units: usize,
308 output_dim: usize,
309 adjacency_matrix: ndarray::Array2<i32>,
310 sensory_adjacency_matrix: Option<ndarray::Array2<i32>>,
311 input_dim: Option<usize>,
312 self_connections: bool,
313 erev_init_seed: u64,
314}
315
316impl FullyConnected {
317 pub fn new(
318 units: usize,
319 output_dim: Option<usize>,
320 erev_init_seed: u64,
321 self_connections: bool,
322 ) -> Self {
323 let output_dim = output_dim.unwrap_or(units);
324 let mut adjacency_matrix = ndarray::Array2::zeros((units, units));
325 let mut rng = StdRng::seed_from_u64(erev_init_seed);
326
327 // Initialize synapses
328 for src in 0..units {
329 for dest in 0..units {
330 if src == dest && !self_connections {
331 continue;
332 }
333 // 2/3 chance of excitatory, 1/3 inhibitory
334 let polarity: i32 = if rand::random::<f64>() < 0.33 { -1 } else { 1 };
335 adjacency_matrix[[src, dest]] = polarity;
336 }
337 }
338
339 Self {
340 units,
341 output_dim,
342 adjacency_matrix,
343 sensory_adjacency_matrix: None,
344 input_dim: None,
345 self_connections,
346 erev_init_seed,
347 }
348 }
349
350 /// Get configuration for serialization
351 pub fn get_full_config(&self) -> WiringConfig {
352 WiringConfig {
353 units: self.units,
354 adjacency_matrix: Some(
355 self.adjacency_matrix
356 .outer_iter()
357 .map(|v| v.to_vec())
358 .collect(),
359 ),
360 sensory_adjacency_matrix: self
361 .sensory_adjacency_matrix
362 .as_ref()
363 .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
364 input_dim: self.input_dim,
365 output_dim: Some(self.output_dim),
366 // FullyConnected-specific fields
367 erev_init_seed: Some(self.erev_init_seed),
368 self_connections: Some(self.self_connections),
369 // Other fields not used by FullyConnected
370 num_inter_neurons: None,
371 num_command_neurons: None,
372 num_motor_neurons: None,
373 sensory_fanout: None,
374 inter_fanout: None,
375 recurrent_command_synapses: None,
376 motor_fanin: None,
377 seed: None,
378 sparsity_level: None,
379 random_seed: None,
380 }
381 }
382
383 pub fn from_config(config: WiringConfig) -> Self {
384 let units = config.units;
385 let adjacency_matrix = if let Some(matrix) = config.adjacency_matrix {
386 ndarray::Array2::from_shape_vec((units, units), matrix.into_iter().flatten().collect())
387 .expect("Invalid adjacency matrix shape")
388 } else {
389 ndarray::Array2::zeros((units, units))
390 };
391
392 let sensory_adjacency_matrix = config.sensory_adjacency_matrix.map(|matrix| {
393 let input_dim = config
394 .input_dim
395 .expect("Input dimension required when sensory matrix exists");
396 ndarray::Array2::from_shape_vec(
397 (input_dim, units),
398 matrix.into_iter().flatten().collect(),
399 )
400 .expect("Invalid sensory adjacency matrix shape")
401 });
402
403 Self {
404 units,
405 output_dim: config.output_dim.unwrap_or(units),
406 adjacency_matrix,
407 sensory_adjacency_matrix,
408 input_dim: config.input_dim,
409 self_connections: true,
410 erev_init_seed: 1111,
411 }
412 }
413}
414
415impl Wiring for FullyConnected {
416 fn units(&self) -> usize {
417 self.units
418 }
419
420 fn input_dim(&self) -> Option<usize> {
421 self.input_dim
422 }
423
424 fn output_dim(&self) -> Option<usize> {
425 Some(self.output_dim)
426 }
427
428 fn build(&mut self, input_dim: usize) {
429 if let Some(existing) = self.input_dim {
430 if existing != input_dim {
431 panic!(
432 "Conflicting input dimensions: expected {}, got {}",
433 existing, input_dim
434 );
435 }
436 return;
437 }
438
439 self.input_dim = Some(input_dim);
440 let mut sensory_matrix = ndarray::Array2::zeros((input_dim, self.units));
441 let mut rng = StdRng::seed_from_u64(self.erev_init_seed);
442
443 for src in 0..input_dim {
444 for dest in 0..self.units {
445 let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
446 sensory_matrix[[src, dest]] = polarity;
447 }
448 }
449 self.sensory_adjacency_matrix = Some(sensory_matrix);
450 }
451
452 fn adjacency_matrix(&self) -> &ndarray::Array2<i32> {
453 &self.adjacency_matrix
454 }
455
456 fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>> {
457 self.sensory_adjacency_matrix.as_ref()
458 }
459
460 fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
461 if src >= self.units || dest >= self.units {
462 panic!(
463 "Invalid synapse: src={}, dest={}, units={}",
464 src, dest, self.units
465 );
466 }
467 if ![-1, 1].contains(&polarity) {
468 panic!("Polarity must be -1 or 1, got {}", polarity);
469 }
470 self.adjacency_matrix[[src, dest]] = polarity;
471 }
472
473 fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
474 let input_dim = self
475 .input_dim
476 .expect("Must build wiring before adding sensory synapses");
477 if src >= input_dim || dest >= self.units {
478 panic!(
479 "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
480 src, dest, input_dim, self.units
481 );
482 }
483 if ![-1, 1].contains(&polarity) {
484 panic!("Polarity must be -1 or 1, got {}", polarity);
485 }
486 self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
487 }
488
489 fn get_config(&self) -> WiringConfig {
490 WiringConfig {
491 units: self.units,
492 adjacency_matrix: Some(
493 self.adjacency_matrix
494 .outer_iter()
495 .map(|v| v.to_vec())
496 .collect(),
497 ),
498 sensory_adjacency_matrix: self
499 .sensory_adjacency_matrix
500 .as_ref()
501 .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
502 input_dim: self.input_dim,
503 output_dim: Some(self.output_dim),
504 // FullyConnected-specific fields
505 erev_init_seed: Some(self.erev_init_seed),
506 self_connections: Some(self.self_connections),
507 // Other fields not used by FullyConnected
508 num_inter_neurons: None,
509 num_command_neurons: None,
510 num_motor_neurons: None,
511 sensory_fanout: None,
512 inter_fanout: None,
513 recurrent_command_synapses: None,
514 motor_fanin: None,
515 seed: None,
516 sparsity_level: None,
517 random_seed: None,
518 }
519 }
520}