1use super::base::Wiring;
2use super::WiringConfig;
3use ndarray::Array2;
4use rand::prelude::*;
5
6#[derive(Clone, Debug)]
8pub struct Random {
9 units: usize,
10 output_dim: usize,
11 adjacency_matrix: Array2<i32>,
12 sensory_adjacency_matrix: Option<Array2<i32>>,
13 input_dim: Option<usize>,
14 sparsity_level: f64,
15 random_seed: u64,
16}
17
18impl Random {
19 pub fn new(
20 units: usize,
21 output_dim: Option<usize>,
22 sparsity_level: f64,
23 random_seed: u64,
24 ) -> Self {
25 if sparsity_level < 0.0 || sparsity_level >= 1.0 {
26 panic!(
27 "Sparsity level must be in range [0, 1), got {}",
28 sparsity_level
29 );
30 }
31
32 let output_dim = output_dim.unwrap_or(units);
33 let mut adjacency_matrix = Array2::zeros((units, units));
34 let mut rng = StdRng::seed_from_u64(random_seed);
35
36 let total_possible = units * units;
38 let num_synapses = (total_possible as f64 * (1.0 - sparsity_level)).round() as usize;
39
40 let mut all_synapses: Vec<(usize, usize)> = Vec::with_capacity(total_possible);
42 for src in 0..units {
43 for dest in 0..units {
44 all_synapses.push((src, dest));
45 }
46 }
47
48 let selected: Vec<_> = all_synapses
50 .choose_multiple(&mut rng, num_synapses)
51 .cloned()
52 .collect();
53
54 for (src, dest) in selected {
55 let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
56 adjacency_matrix[[src, dest]] = polarity;
57 }
58
59 Self {
60 units,
61 output_dim,
62 adjacency_matrix,
63 sensory_adjacency_matrix: None,
64 input_dim: None,
65 sparsity_level,
66 random_seed,
67 }
68 }
69
70 pub fn from_config(config: WiringConfig) -> Self {
71 Self::new(
72 config.units,
73 config.output_dim,
74 config.sparsity_level.unwrap_or(0.5),
75 config.random_seed.unwrap_or(1111),
76 )
77 }
78}
79
80impl Wiring for Random {
81 fn units(&self) -> usize {
82 self.units
83 }
84
85 fn input_dim(&self) -> Option<usize> {
86 self.input_dim
87 }
88
89 fn output_dim(&self) -> Option<usize> {
90 Some(self.output_dim)
91 }
92
93 fn build(&mut self, input_dim: usize) {
94 if let Some(existing) = self.input_dim {
95 if existing != input_dim {
96 panic!(
97 "Conflicting input dimensions: expected {}, got {}",
98 existing, input_dim
99 );
100 }
101 return;
102 }
103
104 self.input_dim = Some(input_dim);
105 let mut sensory_matrix = Array2::zeros((input_dim, self.units));
106 let mut rng = StdRng::seed_from_u64(self.random_seed);
107
108 let total_possible = input_dim * self.units;
109 let num_sensory_synapses =
110 (total_possible as f64 * (1.0 - self.sparsity_level)).round() as usize;
111
112 let mut all_sensory_synapses: Vec<(usize, usize)> = Vec::with_capacity(total_possible);
113 for src in 0..input_dim {
114 for dest in 0..self.units {
115 all_sensory_synapses.push((src, dest));
116 }
117 }
118
119 let selected: Vec<_> = all_sensory_synapses
120 .choose_multiple(&mut rng, num_sensory_synapses)
121 .cloned()
122 .collect();
123
124 for (src, dest) in selected {
125 let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
126 sensory_matrix[[src, dest]] = polarity;
127 }
128
129 self.sensory_adjacency_matrix = Some(sensory_matrix);
130 }
131
132 fn adjacency_matrix(&self) -> &Array2<i32> {
133 &self.adjacency_matrix
134 }
135
136 fn sensory_adjacency_matrix(&self) -> Option<&Array2<i32>> {
137 self.sensory_adjacency_matrix.as_ref()
138 }
139
140 fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
141 if src >= self.units || dest >= self.units {
142 panic!(
143 "Invalid synapse: src={}, dest={}, units={}",
144 src, dest, self.units
145 );
146 }
147 if ![-1, 1].contains(&polarity) {
148 panic!("Polarity must be -1 or 1, got {}", polarity);
149 }
150 self.adjacency_matrix[[src, dest]] = polarity;
151 }
152
153 fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
154 let input_dim = self
155 .input_dim
156 .expect("Must build wiring before adding sensory synapses");
157 if src >= input_dim || dest >= self.units {
158 panic!(
159 "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
160 src, dest, input_dim, self.units
161 );
162 }
163 if ![-1, 1].contains(&polarity) {
164 panic!("Polarity must be -1 or 1, got {}", polarity);
165 }
166 self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
167 }
168
169 fn get_config(&self) -> WiringConfig {
170 WiringConfig {
171 units: self.units,
172 adjacency_matrix: Some(
173 self.adjacency_matrix
174 .outer_iter()
175 .map(|v| v.to_vec())
176 .collect(),
177 ),
178 sensory_adjacency_matrix: self
179 .sensory_adjacency_matrix
180 .as_ref()
181 .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
182 input_dim: self.input_dim,
183 output_dim: Some(self.output_dim),
184 sparsity_level: Some(self.sparsity_level),
186 random_seed: Some(self.random_seed),
187 erev_init_seed: None,
189 self_connections: None,
190 num_inter_neurons: None,
191 num_command_neurons: None,
192 num_motor_neurons: None,
193 sensory_fanout: None,
194 inter_fanout: None,
195 recurrent_command_synapses: None,
196 motor_fanin: None,
197 seed: None,
198 }
199 }
200}