1use rand::prelude::*;
2use super::WiringConfig;
3
4pub trait Wiring: Send + Sync {
6 fn units(&self) -> usize;
8
9 fn input_dim(&self) -> Option<usize>;
11
12 fn output_dim(&self) -> Option<usize>;
14
15 fn num_layers(&self) -> usize {
17 1
18 }
19
20 fn get_neurons_of_layer(&self, layer_id: usize) -> Vec<usize> {
22 if layer_id == 0 {
23 (0..self.units()).collect()
24 } else {
25 vec![]
26 }
27 }
28
29 fn is_built(&self) -> bool {
31 self.input_dim().is_some()
32 }
33
34 fn build(&mut self, input_dim: usize);
36
37 fn get_type_of_neuron(&self, neuron_id: usize) -> &'static str {
39 let output_dim = self.output_dim().unwrap_or(0);
40 if neuron_id < output_dim {
41 "motor"
42 } else {
43 "inter"
44 }
45 }
46
47 fn adjacency_matrix(&self) -> &ndarray::Array2<i32>;
49
50 fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>>;
52
53 fn erev_initializer(&self) -> ndarray::Array2<i32> {
55 self.adjacency_matrix().clone()
56 }
57
58 fn sensory_erev_initializer(&self) -> Option<ndarray::Array2<i32>> {
60 self.sensory_adjacency_matrix().map(|m| m.clone())
61 }
62
63 fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32);
65
66 fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32);
68
69 fn synapse_count(&self) -> usize {
71 self.adjacency_matrix().mapv(|x| x.abs() as usize).sum()
72 }
73
74 fn sensory_synapse_count(&self) -> usize {
76 self.sensory_adjacency_matrix()
77 .map(|m| m.mapv(|x| x.abs() as usize).sum())
78 .unwrap_or(0)
79 }
80
81 fn input_required(&self) -> bool {
82 self.sensory_adjacency_matrix().is_some()
83 }
84
85 fn get_config(&self) -> WiringConfig;
87}
88
89#[derive(Clone, Debug)]
91pub struct FullyConnected {
92 units: usize,
93 output_dim: usize,
94 adjacency_matrix: ndarray::Array2<i32>,
95 sensory_adjacency_matrix: Option<ndarray::Array2<i32>>,
96 input_dim: Option<usize>,
97 self_connections: bool,
98 erev_init_seed: u64,
99}
100
101impl FullyConnected {
102 pub fn new(
103 units: usize,
104 output_dim: Option<usize>,
105 erev_init_seed: u64,
106 self_connections: bool,
107 ) -> Self {
108 let output_dim = output_dim.unwrap_or(units);
109 let mut adjacency_matrix = ndarray::Array2::zeros((units, units));
110 let mut rng = StdRng::seed_from_u64(erev_init_seed);
111
112 for src in 0..units {
114 for dest in 0..units {
115 if src == dest && !self_connections {
116 continue;
117 }
118 let polarity: i32 = if rand::random::<f64>() < 0.33 { -1 } else { 1 };
120 adjacency_matrix[[src, dest]] = polarity;
121 }
122 }
123
124 Self {
125 units,
126 output_dim,
127 adjacency_matrix,
128 sensory_adjacency_matrix: None,
129 input_dim: None,
130 self_connections,
131 erev_init_seed,
132 }
133 }
134
135 pub fn get_full_config(&self) -> WiringConfig {
137 WiringConfig {
138 units: self.units,
139 adjacency_matrix: Some(
140 self.adjacency_matrix
141 .outer_iter()
142 .map(|v| v.to_vec())
143 .collect(),
144 ),
145 sensory_adjacency_matrix: self
146 .sensory_adjacency_matrix
147 .as_ref()
148 .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
149 input_dim: self.input_dim,
150 output_dim: Some(self.output_dim),
151 erev_init_seed: Some(self.erev_init_seed),
153 self_connections: Some(self.self_connections),
154 num_inter_neurons: None,
156 num_command_neurons: None,
157 num_motor_neurons: None,
158 sensory_fanout: None,
159 inter_fanout: None,
160 recurrent_command_synapses: None,
161 motor_fanin: None,
162 seed: None,
163 sparsity_level: None,
164 random_seed: None,
165 }
166 }
167
168 pub fn from_config(config: WiringConfig) -> Self {
169 let units = config.units;
170 let adjacency_matrix = if let Some(matrix) = config.adjacency_matrix {
171 ndarray::Array2::from_shape_vec((units, units), matrix.into_iter().flatten().collect())
172 .expect("Invalid adjacency matrix shape")
173 } else {
174 ndarray::Array2::zeros((units, units))
175 };
176
177 let sensory_adjacency_matrix = config.sensory_adjacency_matrix.map(|matrix| {
178 let input_dim = config
179 .input_dim
180 .expect("Input dimension required when sensory matrix exists");
181 ndarray::Array2::from_shape_vec(
182 (input_dim, units),
183 matrix.into_iter().flatten().collect(),
184 )
185 .expect("Invalid sensory adjacency matrix shape")
186 });
187
188 Self {
189 units,
190 output_dim: config.output_dim.unwrap_or(units),
191 adjacency_matrix,
192 sensory_adjacency_matrix,
193 input_dim: config.input_dim,
194 self_connections: true,
195 erev_init_seed: 1111,
196 }
197 }
198}
199
200impl Wiring for FullyConnected {
201 fn units(&self) -> usize {
202 self.units
203 }
204
205 fn input_dim(&self) -> Option<usize> {
206 self.input_dim
207 }
208
209 fn output_dim(&self) -> Option<usize> {
210 Some(self.output_dim)
211 }
212
213 fn build(&mut self, input_dim: usize) {
214 if let Some(existing) = self.input_dim {
215 if existing != input_dim {
216 panic!(
217 "Conflicting input dimensions: expected {}, got {}",
218 existing, input_dim
219 );
220 }
221 return;
222 }
223
224 self.input_dim = Some(input_dim);
225 let mut sensory_matrix = ndarray::Array2::zeros((input_dim, self.units));
226 let mut rng = StdRng::seed_from_u64(self.erev_init_seed);
227
228 for src in 0..input_dim {
229 for dest in 0..self.units {
230 let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
231 sensory_matrix[[src, dest]] = polarity;
232 }
233 }
234 self.sensory_adjacency_matrix = Some(sensory_matrix);
235 }
236
237 fn adjacency_matrix(&self) -> &ndarray::Array2<i32> {
238 &self.adjacency_matrix
239 }
240
241 fn sensory_adjacency_matrix(&self) -> Option<&ndarray::Array2<i32>> {
242 self.sensory_adjacency_matrix.as_ref()
243 }
244
245 fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
246 if src >= self.units || dest >= self.units {
247 panic!(
248 "Invalid synapse: src={}, dest={}, units={}",
249 src, dest, self.units
250 );
251 }
252 if ![-1, 1].contains(&polarity) {
253 panic!("Polarity must be -1 or 1, got {}", polarity);
254 }
255 self.adjacency_matrix[[src, dest]] = polarity;
256 }
257
258 fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
259 let input_dim = self
260 .input_dim
261 .expect("Must build wiring before adding sensory synapses");
262 if src >= input_dim || dest >= self.units {
263 panic!(
264 "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
265 src, dest, input_dim, self.units
266 );
267 }
268 if ![-1, 1].contains(&polarity) {
269 panic!("Polarity must be -1 or 1, got {}", polarity);
270 }
271 self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
272 }
273
274 fn get_config(&self) -> WiringConfig {
275 WiringConfig {
276 units: self.units,
277 adjacency_matrix: Some(
278 self.adjacency_matrix
279 .outer_iter()
280 .map(|v| v.to_vec())
281 .collect(),
282 ),
283 sensory_adjacency_matrix: self
284 .sensory_adjacency_matrix
285 .as_ref()
286 .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
287 input_dim: self.input_dim,
288 output_dim: Some(self.output_dim),
289 erev_init_seed: Some(self.erev_init_seed),
291 self_connections: Some(self.self_connections),
292 num_inter_neurons: None,
294 num_command_neurons: None,
295 num_motor_neurons: None,
296 sensory_fanout: None,
297 inter_fanout: None,
298 recurrent_command_synapses: None,
299 motor_fanin: None,
300 seed: None,
301 sparsity_level: None,
302 random_seed: None,
303 }
304 }
305}