1use super::base::Wiring;
2use super::WiringConfig;
3use ndarray::Array2;
4use rand::prelude::*;
5
6#[derive(Clone, Debug)]
61pub struct Random {
62 units: usize,
63 output_dim: usize,
64 adjacency_matrix: Array2<i32>,
65 sensory_adjacency_matrix: Option<Array2<i32>>,
66 input_dim: Option<usize>,
67 sparsity_level: f64,
68 random_seed: u64,
69}
70
71impl Random {
72 pub fn new(
73 units: usize,
74 output_dim: Option<usize>,
75 sparsity_level: f64,
76 random_seed: u64,
77 ) -> Self {
78 if sparsity_level < 0.0 || sparsity_level >= 1.0 {
79 panic!(
80 "Sparsity level must be in range [0, 1), got {}",
81 sparsity_level
82 );
83 }
84
85 let output_dim = output_dim.unwrap_or(units);
86 let mut adjacency_matrix = Array2::zeros((units, units));
87 let mut rng = StdRng::seed_from_u64(random_seed);
88
89 let total_possible = units * units;
91 let num_synapses = (total_possible as f64 * (1.0 - sparsity_level)).round() as usize;
92
93 let mut all_synapses: Vec<(usize, usize)> = Vec::with_capacity(total_possible);
95 for src in 0..units {
96 for dest in 0..units {
97 all_synapses.push((src, dest));
98 }
99 }
100
101 let selected: Vec<_> = all_synapses
103 .choose_multiple(&mut rng, num_synapses)
104 .cloned()
105 .collect();
106
107 for (src, dest) in selected {
108 let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
109 adjacency_matrix[[src, dest]] = polarity;
110 }
111
112 Self {
113 units,
114 output_dim,
115 adjacency_matrix,
116 sensory_adjacency_matrix: None,
117 input_dim: None,
118 sparsity_level,
119 random_seed,
120 }
121 }
122
123 pub fn from_config(config: WiringConfig) -> Self {
124 Self::new(
125 config.units,
126 config.output_dim,
127 config.sparsity_level.unwrap_or(0.5),
128 config.random_seed.unwrap_or(1111),
129 )
130 }
131}
132
133impl Wiring for Random {
134 fn units(&self) -> usize {
135 self.units
136 }
137
138 fn input_dim(&self) -> Option<usize> {
139 self.input_dim
140 }
141
142 fn output_dim(&self) -> Option<usize> {
143 Some(self.output_dim)
144 }
145
146 fn build(&mut self, input_dim: usize) {
147 if let Some(existing) = self.input_dim {
148 if existing != input_dim {
149 panic!(
150 "Conflicting input dimensions: expected {}, got {}",
151 existing, input_dim
152 );
153 }
154 return;
155 }
156
157 self.input_dim = Some(input_dim);
158 let mut sensory_matrix = Array2::zeros((input_dim, self.units));
159 let mut rng = StdRng::seed_from_u64(self.random_seed);
160
161 let total_possible = input_dim * self.units;
162 let num_sensory_synapses =
163 (total_possible as f64 * (1.0 - self.sparsity_level)).round() as usize;
164
165 let mut all_sensory_synapses: Vec<(usize, usize)> = Vec::with_capacity(total_possible);
166 for src in 0..input_dim {
167 for dest in 0..self.units {
168 all_sensory_synapses.push((src, dest));
169 }
170 }
171
172 let selected: Vec<_> = all_sensory_synapses
173 .choose_multiple(&mut rng, num_sensory_synapses)
174 .cloned()
175 .collect();
176
177 for (src, dest) in selected {
178 let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
179 sensory_matrix[[src, dest]] = polarity;
180 }
181
182 self.sensory_adjacency_matrix = Some(sensory_matrix);
183 }
184
185 fn adjacency_matrix(&self) -> &Array2<i32> {
186 &self.adjacency_matrix
187 }
188
189 fn sensory_adjacency_matrix(&self) -> Option<&Array2<i32>> {
190 self.sensory_adjacency_matrix.as_ref()
191 }
192
193 fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
194 if src >= self.units || dest >= self.units {
195 panic!(
196 "Invalid synapse: src={}, dest={}, units={}",
197 src, dest, self.units
198 );
199 }
200 if ![-1, 1].contains(&polarity) {
201 panic!("Polarity must be -1 or 1, got {}", polarity);
202 }
203 self.adjacency_matrix[[src, dest]] = polarity;
204 }
205
206 fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
207 let input_dim = self
208 .input_dim
209 .expect("Must build wiring before adding sensory synapses");
210 if src >= input_dim || dest >= self.units {
211 panic!(
212 "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
213 src, dest, input_dim, self.units
214 );
215 }
216 if ![-1, 1].contains(&polarity) {
217 panic!("Polarity must be -1 or 1, got {}", polarity);
218 }
219 self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
220 }
221
222 fn get_config(&self) -> WiringConfig {
223 WiringConfig {
224 units: self.units,
225 adjacency_matrix: Some(
226 self.adjacency_matrix
227 .outer_iter()
228 .map(|v| v.to_vec())
229 .collect(),
230 ),
231 sensory_adjacency_matrix: self
232 .sensory_adjacency_matrix
233 .as_ref()
234 .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
235 input_dim: self.input_dim,
236 output_dim: Some(self.output_dim),
237 sparsity_level: Some(self.sparsity_level),
239 random_seed: Some(self.random_seed),
240 erev_init_seed: None,
242 self_connections: None,
243 num_inter_neurons: None,
244 num_command_neurons: None,
245 num_motor_neurons: None,
246 sensory_fanout: None,
247 inter_fanout: None,
248 recurrent_command_synapses: None,
249 motor_fanin: None,
250 seed: None,
251 }
252 }
253}