Skip to main content

ncps_rust/wirings/
base.rs

1use rand::prelude::*;
2use super::WiringConfig;
3
4/// Base trait for wiring configurations in Neural Circuit Policies
5pub trait Wiring: Send + Sync {
6    /// Returns the number of neurons in this wiring
7    fn units(&self) -> usize;
8
9    /// Returns the input dimension (number of sensory neurons)
10    fn input_dim(&self) -> Option<usize>;
11
12    /// Returns the output dimension (number of motor neurons)
13    fn output_dim(&self) -> Option<usize>;
14
15    /// Returns the number of layers in this wiring
16    fn num_layers(&self) -> usize {
17        1
18    }
19
20    /// Returns the neuron IDs for a specific layer
21    fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
22        if layer_id == 0 {
23            (0..self.units()).collect()
24        } else {
25            vec![]
26        }
27    }
28
29    /// Check if the wiring has been built (input dimension is set)
30    fn is_built(&self) -> bool {
31        self.input_dim().is_some()
32    }
33
34    /// Build the wiring with the given input dimension
35    fn build(&mut self, input_dim: usize);
36
37    /// Get type of a neuron (motor, inter, command, etc.)
38    fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
39        let output_dim = self.output_dim().unwrap_or(0);
40        if neuron_id < output_dim {
41            "motor"
42        } else {
43            "inter"
44        }
45    }
46
47    /// Returns the adjacency matrix representing synapses between neurons
48    fn adjacency_matrix(&self) -> &ndarray::Array2<i32>;
49
50    /// Returns the sensory adjacency matrix (synapses from inputs to neurons)
51    fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>>;
52
53    /// Initialize the adjacency matrix (erev)
54    fn erev_initializer(&self) -> ndarray::Array2<i32> {
55        self.adjacency_matrix().clone()
56    }
57
58    /// Initialize the sensory adjacency matrix (erev)
59    fn sensory_erev_initializer(&self) -> Option<ndarray::Array2<i32>> {
60        self.sensory_adjacency_matrix().map(|m| m.clone())
61    }
62
63    /// Add a synapse between neurons
64    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32);
65
66    /// Add a sensory synapse from input feature to neuron
67    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32);
68
69    /// Count total internal synapses
70    fn synapse_count(&self) -> usize {
71        self.adjacency_matrix().mapv(|x| x.abs() as usize).sum()
72    }
73
74    /// Count total sensory synapses
75    fn sensory_synapse_count(&self) -> usize {
76        self.sensory_adjacency_matrix()
77            .map(|m| m.mapv(|x| x.abs() as usize).sum())
78            .unwrap_or(0)
79    }
80
81    fn input_required(&self) -> bool {
82        self.sensory_adjacency_matrix().is_some()
83    }
84
85    /// Create a serialization config for this wiring
86    fn get_config(&self) -> WiringConfig;
87}
88
89/// Fully connected wiring structure
90#[derive(Clone, Debug)]
91pub struct FullyConnected {
92    units: usize,
93    output_dim: usize,
94    adjacency_matrix: ndarray::Array2<i32>,
95    sensory_adjacency_matrix: Option<ndarray::Array2<i32>>,
96    input_dim: Option<usize>,
97    self_connections: bool,
98    erev_init_seed: u64,
99}
100
101impl FullyConnected {
102    pub fn new(
103        units: usize,
104        output_dim: Option<usize>,
105        erev_init_seed: u64,
106        self_connections: bool,
107    ) -> Self {
108        let output_dim = output_dim.unwrap_or(units);
109        let mut adjacency_matrix = ndarray::Array2::zeros((units, units));
110        let mut rng = StdRng::seed_from_u64(erev_init_seed);
111
112        // Initialize synapses
113        for src in 0..units {
114            for dest in 0..units {
115                if src == dest && !self_connections {
116                    continue;
117                }
118                // 2/3 chance of excitatory, 1/3 inhibitory
119                let polarity: i32 = if rand::random::<f64>() < 0.33 { -1 } else { 1 };
120                adjacency_matrix[[src, dest]] = polarity;
121            }
122        }
123
124        Self {
125            units,
126            output_dim,
127            adjacency_matrix,
128            sensory_adjacency_matrix: None,
129            input_dim: None,
130            self_connections,
131            erev_init_seed,
132        }
133    }
134
135    /// Get configuration for serialization
136    pub fn get_full_config(&self) -> WiringConfig {
137        WiringConfig {
138            units: self.units,
139            adjacency_matrix: Some(
140                self.adjacency_matrix
141                    .outer_iter()
142                    .map(|v| v.to_vec())
143                    .collect(),
144            ),
145            sensory_adjacency_matrix: self
146                .sensory_adjacency_matrix
147                .as_ref()
148                .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
149            input_dim: self.input_dim,
150            output_dim: Some(self.output_dim),
151            // FullyConnected-specific fields
152            erev_init_seed: Some(self.erev_init_seed),
153            self_connections: Some(self.self_connections),
154            // Other fields not used by FullyConnected
155            num_inter_neurons: None,
156            num_command_neurons: None,
157            num_motor_neurons: None,
158            sensory_fanout: None,
159            inter_fanout: None,
160            recurrent_command_synapses: None,
161            motor_fanin: None,
162            seed: None,
163            sparsity_level: None,
164            random_seed: None,
165        }
166    }
167
168    pub fn from_config(config: WiringConfig) -> Self {
169        let units = config.units;
170        let adjacency_matrix = if let Some(matrix) = config.adjacency_matrix {
171            ndarray::Array2::from_shape_vec((units, units), matrix.into_iter().flatten().collect())
172                .expect("Invalid adjacency matrix shape")
173        } else {
174            ndarray::Array2::zeros((units, units))
175        };
176
177        let sensory_adjacency_matrix = config.sensory_adjacency_matrix.map(|matrix| {
178            let input_dim = config
179                .input_dim
180                .expect("Input dimension required when sensory matrix exists");
181            ndarray::Array2::from_shape_vec(
182                (input_dim, units),
183                matrix.into_iter().flatten().collect(),
184            )
185            .expect("Invalid sensory adjacency matrix shape")
186        });
187
188        Self {
189            units,
190            output_dim: config.output_dim.unwrap_or(units),
191            adjacency_matrix,
192            sensory_adjacency_matrix,
193            input_dim: config.input_dim,
194            self_connections: true,
195            erev_init_seed: 1111,
196        }
197    }
198}
199
200impl Wiring for FullyConnected {
201    fn units(&self) -> usize {
202        self.units
203    }
204
205    fn input_dim(&self) -> Option<usize> {
206        self.input_dim
207    }
208
209    fn output_dim(&self) -> Option<usize> {
210        Some(self.output_dim)
211    }
212
213    fn build(&mut self, input_dim: usize) {
214        if let Some(existing) = self.input_dim {
215            if existing != input_dim {
216                panic!(
217                    "Conflicting input dimensions: expected {}, got {}",
218                    existing, input_dim
219                );
220            }
221            return;
222        }
223
224        self.input_dim = Some(input_dim);
225        let mut sensory_matrix = ndarray::Array2::zeros((input_dim, self.units));
226        let mut rng = StdRng::seed_from_u64(self.erev_init_seed);
227
228        for src in 0..input_dim {
229            for dest in 0..self.units {
230                let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
231                sensory_matrix[[src, dest]] = polarity;
232            }
233        }
234        self.sensory_adjacency_matrix = Some(sensory_matrix);
235    }
236
237    fn adjacency_matrix(&self) -> &ndarray::Array2<i32> {
238        &self.adjacency_matrix
239    }
240
241    fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>> {
242        self.sensory_adjacency_matrix.as_ref()
243    }
244
245    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
246        if src >= self.units || dest >= self.units {
247            panic!(
248                "Invalid synapse: src={}, dest={}, units={}",
249                src, dest, self.units
250            );
251        }
252        if ![-1, 1].contains(&polarity) {
253            panic!("Polarity must be -1 or 1, got {}", polarity);
254        }
255        self.adjacency_matrix[[src, dest]] = polarity;
256    }
257
258    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
259        let input_dim = self
260            .input_dim
261            .expect("Must build wiring before adding sensory synapses");
262        if src >= input_dim || dest >= self.units {
263            panic!(
264                "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
265                src, dest, input_dim, self.units
266            );
267        }
268        if ![-1, 1].contains(&polarity) {
269            panic!("Polarity must be -1 or 1, got {}", polarity);
270        }
271        self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
272    }
273
274    fn get_config(&self) -> WiringConfig {
275        WiringConfig {
276            units: self.units,
277            adjacency_matrix: Some(
278                self.adjacency_matrix
279                    .outer_iter()
280                    .map(|v| v.to_vec())
281                    .collect(),
282            ),
283            sensory_adjacency_matrix: self
284                .sensory_adjacency_matrix
285                .as_ref()
286                .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
287            input_dim: self.input_dim,
288            output_dim: Some(self.output_dim),
289            // FullyConnected-specific fields
290            erev_init_seed: Some(self.erev_init_seed),
291            self_connections: Some(self.self_connections),
292            // Other fields not used by FullyConnected
293            num_inter_neurons: None,
294            num_command_neurons: None,
295            num_motor_neurons: None,
296            sensory_fanout: None,
297            inter_fanout: None,
298            recurrent_command_synapses: None,
299            motor_fanin: None,
300            seed: None,
301            sparsity_level: None,
302            random_seed: None,
303        }
304    }
305}