Skip to main content

ncps_rust/wirings/
ncp.rs

1use super::base::Wiring;
2use super::WiringConfig;
3use ndarray::Array2;
4use rand::prelude::*;
5
6/// Neural Circuit Policy wiring structure
7/// Implements a 4-layer architecture: sensories -> inter -> command -> motor
8#[derive(Clone, Debug)]
9pub struct NCP {
10    units: usize,
11    adjacency_matrix: Array2<i32>,
12    sensory_adjacency_matrix: Option<Array2<i32>>,
13    input_dim: Option<usize>,
14    num_inter_neurons: usize,
15    num_command_neurons: usize,
16    num_motor_neurons: usize,
17    sensory_fanout: usize,
18    inter_fanout: usize,
19    recurrent_command_synapses: usize,
20    motor_fanin: usize,
21    motor_neurons: Vec<usize>,
22    command_neurons: Vec<usize>,
23    inter_neurons: Vec<usize>,
24    sensory_neurons: Vec<usize>,
25    rng: StdRng,
26}
27
28impl NCP {
29    pub fn new(
30        inter_neurons: usize,
31        command_neurons: usize,
32        motor_neurons: usize,
33        sensory_fanout: usize,
34        inter_fanout: usize,
35        recurrent_command_synapses: usize,
36        motor_fanin: usize,
37        seed: u64,
38    ) -> Self {
39        let units = inter_neurons + command_neurons + motor_neurons;
40
41        // Validate parameters
42        if motor_fanin > command_neurons {
43            panic!(
44                "Motor fanin {} exceeds number of command neurons {}",
45                motor_fanin, command_neurons
46            );
47        }
48        if sensory_fanout > inter_neurons {
49            panic!(
50                "Sensory fanout {} exceeds number of inter neurons {}",
51                sensory_fanout, inter_neurons
52            );
53        }
54        if inter_fanout > command_neurons {
55            panic!(
56                "Inter fanout {} exceeds number of command neurons {}",
57                inter_fanout, command_neurons
58            );
59        }
60
61        // Neuron IDs: [0..motor ... command ... inter]
62        let motor_neuron_ids: Vec<usize> = (0..motor_neurons).collect();
63        let command_neuron_ids: Vec<usize> =
64            (motor_neurons..motor_neurons + command_neurons).collect();
65        let inter_neuron_ids: Vec<usize> = (motor_neurons + command_neurons..units).collect();
66
67        let adjacency_matrix = Array2::zeros((units, units));
68        let rng = StdRng::seed_from_u64(seed);
69
70        Self {
71            units,
72            adjacency_matrix,
73            sensory_adjacency_matrix: None,
74            input_dim: None,
75            num_inter_neurons: inter_neurons,
76            num_command_neurons: command_neurons,
77            num_motor_neurons: motor_neurons,
78            sensory_fanout,
79            inter_fanout,
80            recurrent_command_synapses,
81            motor_fanin,
82            motor_neurons: motor_neuron_ids,
83            command_neurons: command_neuron_ids,
84            inter_neurons: inter_neuron_ids,
85            sensory_neurons: vec![],
86            rng,
87        }
88    }
89
90    fn build_sensory_to_inter_layer(&mut self) {
91        let input_dim = self.input_dim.unwrap();
92        self.sensory_neurons = (0..input_dim).collect();
93
94        // Clone the neuron list to avoid borrow issues
95        let inter_neurons = self.inter_neurons.clone();
96        let sensory_neurons = self.sensory_neurons.clone();
97        let mut unreachable_inter: Vec<usize> = inter_neurons.clone();
98
99        // Connect each sensory neuron to exactly sensory_fanout inter neurons
100        for &src in &sensory_neurons {
101            let selected: Vec<_> = inter_neurons
102                .choose_multiple(&mut self.rng, self.sensory_fanout)
103                .cloned()
104                .collect();
105
106            for &dest in &selected {
107                if let Some(pos) = unreachable_inter.iter().position(|&x| x == dest) {
108                    unreachable_inter.remove(pos);
109                }
110                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
111                self.add_sensory_synapse(src, dest, polarity);
112            }
113        }
114
115        // Connect any unreachable inter neurons
116        let mean_inter_fanin = (input_dim * self.sensory_fanout / self.num_inter_neurons)
117            .max(1)
118            .min(input_dim);
119
120        for &dest in &unreachable_inter {
121            let selected: Vec<_> = sensory_neurons
122                .choose_multiple(&mut self.rng, mean_inter_fanin)
123                .cloned()
124                .collect();
125
126            for &src in &selected {
127                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
128                self.add_sensory_synapse(src, dest, polarity);
129            }
130        }
131    }
132
133    fn build_inter_to_command_layer(&mut self) {
134        // Clone the neuron lists to avoid borrow issues
135        let inter_neurons = self.inter_neurons.clone();
136        let command_neurons = self.command_neurons.clone();
137        let mut unreachable_command: Vec<usize> = command_neurons.clone();
138
139        // Connect inter neurons to command neurons
140        for &src in &inter_neurons {
141            let selected: Vec<_> = command_neurons
142                .choose_multiple(&mut self.rng, self.inter_fanout)
143                .cloned()
144                .collect();
145
146            for &dest in &selected {
147                if let Some(pos) = unreachable_command.iter().position(|&x| x == dest) {
148                    unreachable_command.remove(pos);
149                }
150                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
151                self.add_synapse(src, dest, polarity);
152            }
153        }
154
155        // Connect any unreachable command neurons
156        let mean_command_fanin = (self.num_inter_neurons * self.inter_fanout
157            / self.num_command_neurons)
158            .max(1)
159            .min(self.num_inter_neurons);
160
161        for &dest in &unreachable_command {
162            let selected: Vec<_> = inter_neurons
163                .choose_multiple(&mut self.rng, mean_command_fanin)
164                .cloned()
165                .collect();
166
167            for &src in &selected {
168                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
169                self.add_synapse(src, dest, polarity);
170            }
171        }
172    }
173
174    fn build_recurrent_command_layer(&mut self) {
175        for _ in 0..self.recurrent_command_synapses {
176            let src = *self.command_neurons.choose(&mut self.rng).unwrap();
177            let dest = *self.command_neurons.choose(&mut self.rng).unwrap();
178            let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
179            self.add_synapse(src, dest, polarity);
180        }
181    }
182
183    fn build_command_to_motor_layer(&mut self) {
184        // Clone the neuron lists to avoid borrow issues
185        let motor_neurons = self.motor_neurons.clone();
186        let command_neurons = self.command_neurons.clone();
187        let mut unreachable_command: Vec<usize> = command_neurons.clone();
188
189        // Connect command neurons to motor neurons
190        for &dest in &motor_neurons {
191            let selected: Vec<_> = command_neurons
192                .choose_multiple(&mut self.rng, self.motor_fanin)
193                .cloned()
194                .collect();
195
196            for &src in &selected {
197                if let Some(pos) = unreachable_command.iter().position(|&x| x == src) {
198                    unreachable_command.remove(pos);
199                }
200                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
201                self.add_synapse(src, dest, polarity);
202            }
203        }
204
205        // Connect any unreachable command neurons
206        let mean_command_fanout = (self.num_motor_neurons * self.motor_fanin
207            / self.num_command_neurons)
208            .max(1)
209            .min(self.num_motor_neurons);
210
211        for &src in &unreachable_command {
212            let selected: Vec<_> = motor_neurons
213                .choose_multiple(&mut self.rng, mean_command_fanout)
214                .cloned()
215                .collect();
216
217            for &dest in &selected {
218                let polarity: i32 = if self.rng.gen::<bool>() { 1 } else { -1 };
219                self.add_synapse(src, dest, polarity);
220            }
221        }
222    }
223
224    pub fn from_config(config: WiringConfig) -> Self {
225        // Parse config to reconstruct NCP
226        let units = config.units;
227        let adjacency_matrix = if let Some(matrix) = config.adjacency_matrix {
228            Array2::from_shape_vec((units, units), matrix.into_iter().flatten().collect())
229                .expect("Invalid adjacency matrix shape")
230        } else {
231            Array2::zeros((units, units))
232        };
233
234        let sensory_adjacency_matrix = config.sensory_adjacency_matrix.map(|matrix| {
235            let input_dim = config
236                .input_dim
237                .expect("Input dimension required when sensory matrix exists");
238            Array2::from_shape_vec((input_dim, units), matrix.into_iter().flatten().collect())
239                .expect("Invalid sensory adjacency matrix shape")
240        });
241
242        // This would need additional info stored in config to reconstruct properly
243        // For now, create a basic NCP structure
244        let output_dim = config.output_dim.unwrap_or(1);
245        let inter_and_command = units - output_dim;
246        let command_neurons = (inter_and_command as f64 * 0.4).ceil() as usize;
247        let inter_neurons = inter_and_command - command_neurons;
248
249        NCP::new(
250            inter_neurons,
251            command_neurons,
252            output_dim,
253            6,     // Default sensory_fanout
254            6,     // Default inter_fanout
255            4,     // Default recurrent_command_synapses
256            6,     // Default motor_fanin
257            22222, // Default seed
258        )
259    }
260}
261
262impl Wiring for NCP {
263    fn units(&self) -> usize {
264        self.units
265    }
266
267    fn input_dim(&self) -> Option<usize> {
268        self.input_dim
269    }
270
271    fn output_dim(&self) -> Option<usize> {
272        Some(self.num_motor_neurons)
273    }
274
275    fn num_layers(&self) -> usize {
276        3
277    }
278
279    fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
280        match layer_id {
281            0 => self.inter_neurons.clone(),
282            1 => self.command_neurons.clone(),
283            2 => self.motor_neurons.clone(),
284            _ => panic!("Unknown layer {}", layer_id),
285        }
286    }
287
288    fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
289        if neuron_id < self.num_motor_neurons {
290            "motor"
291        } else if neuron_id < self.num_motor_neurons + self.num_command_neurons {
292            "command"
293        } else {
294            "inter"
295        }
296    }
297
298    fn build(&mut self, input_dim: usize) {
299        if let Some(existing) = self.input_dim {
300            if existing != input_dim {
301                panic!(
302                    "Conflicting input dimensions: expected {}, got {}",
303                    existing, input_dim
304                );
305            }
306            return;
307        }
308
309        self.input_dim = Some(input_dim);
310        self.sensory_adjacency_matrix = Some(Array2::zeros((input_dim, self.units)));
311
312        self.build_sensory_to_inter_layer();
313        self.build_inter_to_command_layer();
314        self.build_recurrent_command_layer();
315        self.build_command_to_motor_layer();
316    }
317
318    fn adjacency_matrix(&self) -> &Array2<i32> {
319        &self.adjacency_matrix
320    }
321
322    fn sensory_adjacency_matrix(&self) -> Option<&Array2<i32>> {
323        self.sensory_adjacency_matrix.as_ref()
324    }
325
326    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
327        if src >= self.units || dest >= self.units {
328            panic!(
329                "Invalid synapse: src={}, dest={}, units={}",
330                src, dest, self.units
331            );
332        }
333        if ![-1, 1].contains(&polarity) {
334            panic!("Polarity must be -1 or 1, got {}", polarity);
335        }
336        self.adjacency_matrix[[src, dest]] = polarity;
337    }
338
339    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
340        let input_dim = self
341            .input_dim
342            .expect("Must build wiring before adding sensory synapses");
343        if src >= input_dim || dest >= self.units {
344            panic!(
345                "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
346                src, dest, input_dim, self.units
347            );
348        }
349        if ![-1, 1].contains(&polarity) {
350            panic!("Polarity must be -1 or 1, got {}", polarity);
351        }
352        self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
353    }
354
355    fn get_config(&self) -> WiringConfig {
356        WiringConfig {
357            units: self.units,
358            adjacency_matrix: Some(
359                self.adjacency_matrix
360                    .outer_iter()
361                    .map(|v| v.to_vec())
362                    .collect(),
363            ),
364            sensory_adjacency_matrix: self
365                .sensory_adjacency_matrix
366                .as_ref()
367                .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
368            input_dim: self.input_dim,
369            output_dim: Some(self.num_motor_neurons),
370            // NCP-specific fields
371            num_inter_neurons: Some(self.num_inter_neurons),
372            num_command_neurons: Some(self.num_command_neurons),
373            num_motor_neurons: Some(self.num_motor_neurons),
374            sensory_fanout: Some(self.sensory_fanout),
375            inter_fanout: Some(self.inter_fanout),
376            recurrent_command_synapses: Some(self.recurrent_command_synapses),
377            motor_fanin: Some(self.motor_fanin),
378            seed: None, // NCP uses rng internally
379            // Other fields not used by NCP
380            erev_init_seed: None,
381            self_connections: None,
382            sparsity_level: None,
383            random_seed: None,
384        }
385    }
386}
387
388/// AutoNCP provides an easier way to create NCP wiring
389#[derive(Clone, Debug)]
390pub struct AutoNCP {
391    ncp: NCP,
392    output_size: usize,
393    sparsity_level: f64,
394    seed: u64,
395}
396
397impl AutoNCP {
398    pub fn new(units: usize, output_size: usize, sparsity_level: f64, seed: u64) -> Self {
399        if output_size >= units - 2 {
400            panic!(
401                "Output size {} must be less than units-2 ({})",
402                output_size,
403                units - 2
404            );
405        }
406        if sparsity_level < 0.0 || sparsity_level > 0.9 {
407            panic!(
408                "Sparsity level must be between 0.0 and 0.9, got {}",
409                sparsity_level
410            );
411        }
412
413        let density_level = 1.0 - sparsity_level;
414        let inter_and_command_neurons = units - output_size;
415        let command_neurons = ((inter_and_command_neurons as f64 * 0.4).ceil() as usize).max(1);
416        let inter_neurons = inter_and_command_neurons - command_neurons;
417
418        let sensory_fanout = ((inter_neurons as f64 * density_level).ceil() as usize).max(1);
419        let inter_fanout = ((command_neurons as f64 * density_level).ceil() as usize).max(1);
420        let recurrent_command_synapses =
421            ((command_neurons as f64 * density_level * 2.0).ceil() as usize).max(1);
422        let motor_fanin = ((command_neurons as f64 * density_level).ceil() as usize).max(1);
423
424        let ncp = NCP::new(
425            inter_neurons,
426            command_neurons,
427            output_size,
428            sensory_fanout,
429            inter_fanout,
430            recurrent_command_synapses,
431            motor_fanin,
432            seed,
433        );
434
435        Self {
436            ncp,
437            output_size,
438            sparsity_level,
439            seed,
440        }
441    }
442}
443
444impl Wiring for AutoNCP {
445    fn units(&self) -> usize {
446        self.ncp.units()
447    }
448
449    fn input_dim(&self) -> Option<usize> {
450        self.ncp.input_dim()
451    }
452
453    fn output_dim(&self) -> Option<usize> {
454        Some(self.output_size)
455    }
456
457    fn num_layers(&self) -> usize {
458        self.ncp.num_layers()
459    }
460
461    fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
462        self.ncp.get_neurons_of_layer(layer_id)
463    }
464
465    fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
466        self.ncp.get_type_of_neuron(neuron_id)
467    }
468
469    fn build(&mut self, input_dim: usize) {
470        self.ncp.build(input_dim)
471    }
472
473    fn is_built(&self) -> bool {
474        self.ncp.is_built()
475    }
476
477    fn adjacency_matrix(&self) -> &Array2<i32> {
478        self.ncp.adjacency_matrix()
479    }
480
481    fn sensory_adjacency_matrix(&self) -> Option<&Array2<i32>> {
482        self.ncp.sensory_adjacency_matrix()
483    }
484
485    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
486        self.ncp.add_synapse(src, dest, polarity)
487    }
488
489    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
490        self.ncp.add_sensory_synapse(src, dest, polarity)
491    }
492
493    fn get_config(&self) -> WiringConfig {
494        // Get the underlying NCP config and add AutoNCP-specific fields
495        let mut config = self.ncp.get_config();
496        config.output_dim = Some(self.output_size);
497        config.sparsity_level = Some(self.sparsity_level);
498        config.seed = Some(self.seed);
499        config
500    }
501}