Skip to main content

ncps_rust/wirings/
random.rs

1use super::base::Wiring;
2use super::WiringConfig;
3use ndarray::Array2;
4use rand::prelude::*;
5
6/// Random sparsity wiring structure
7#[derive(Clone, Debug)]
8pub struct Random {
9    units: usize,
10    output_dim: usize,
11    adjacency_matrix: Array2<i32>,
12    sensory_adjacency_matrix: Option<Array2<i32>>,
13    input_dim: Option<usize>,
14    sparsity_level: f64,
15    random_seed: u64,
16}
17
18impl Random {
19    pub fn new(
20        units: usize,
21        output_dim: Option<usize>,
22        sparsity_level: f64,
23        random_seed: u64,
24    ) -> Self {
25        if sparsity_level < 0.0 || sparsity_level >= 1.0 {
26            panic!(
27                "Sparsity level must be in range [0, 1), got {}",
28                sparsity_level
29            );
30        }
31
32        let output_dim = output_dim.unwrap_or(units);
33        let mut adjacency_matrix = Array2::zeros((units, units));
34        let mut rng = StdRng::seed_from_u64(random_seed);
35
36        // Calculate number of synapses
37        let total_possible = units * units;
38        let num_synapses = (total_possible as f64 * (1.0 - sparsity_level)).round() as usize;
39
40        // Create all possible synapse pairs
41        let mut all_synapses: Vec<(usize, usize)> = Vec::with_capacity(total_possible);
42        for src in 0..units {
43            for dest in 0..units {
44                all_synapses.push((src, dest));
45            }
46        }
47
48        // Randomly select synapses
49        let selected: Vec<_> = all_synapses
50            .choose_multiple(&mut rng, num_synapses)
51            .cloned()
52            .collect();
53
54        for (src, dest) in selected {
55            let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
56            adjacency_matrix[[src, dest]] = polarity;
57        }
58
59        Self {
60            units,
61            output_dim,
62            adjacency_matrix,
63            sensory_adjacency_matrix: None,
64            input_dim: None,
65            sparsity_level,
66            random_seed,
67        }
68    }
69
70    pub fn from_config(config: WiringConfig) -> Self {
71        Self::new(
72            config.units,
73            config.output_dim,
74            config.sparsity_level.unwrap_or(0.5),
75            config.random_seed.unwrap_or(1111),
76        )
77    }
78}
79
80impl Wiring for Random {
81    fn units(&self) -> usize {
82        self.units
83    }
84
85    fn input_dim(&self) -> Option<usize> {
86        self.input_dim
87    }
88
89    fn output_dim(&self) -> Option<usize> {
90        Some(self.output_dim)
91    }
92
93    fn build(&mut self, input_dim: usize) {
94        if let Some(existing) = self.input_dim {
95            if existing != input_dim {
96                panic!(
97                    "Conflicting input dimensions: expected {}, got {}",
98                    existing, input_dim
99                );
100            }
101            return;
102        }
103
104        self.input_dim = Some(input_dim);
105        let mut sensory_matrix = Array2::zeros((input_dim, self.units));
106        let mut rng = StdRng::seed_from_u64(self.random_seed);
107
108        let total_possible = input_dim * self.units;
109        let num_sensory_synapses =
110            (total_possible as f64 * (1.0 - self.sparsity_level)).round() as usize;
111
112        let mut all_sensory_synapses: Vec<(usize, usize)> = Vec::with_capacity(total_possible);
113        for src in 0..input_dim {
114            for dest in 0..self.units {
115                all_sensory_synapses.push((src, dest));
116            }
117        }
118
119        let selected: Vec<_> = all_sensory_synapses
120            .choose_multiple(&mut rng, num_sensory_synapses)
121            .cloned()
122            .collect();
123
124        for (src, dest) in selected {
125            let polarity: i32 = if rng.gen::<f64>() < 0.33 { -1 } else { 1 };
126            sensory_matrix[[src, dest]] = polarity;
127        }
128
129        self.sensory_adjacency_matrix = Some(sensory_matrix);
130    }
131
132    fn adjacency_matrix(&self) -> &Array2<i32> {
133        &self.adjacency_matrix
134    }
135
136    fn sensory_adjacency_matrix(&self) -> Option<&Array2<i32>> {
137        self.sensory_adjacency_matrix.as_ref()
138    }
139
140    fn add_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
141        if src >= self.units || dest >= self.units {
142            panic!(
143                "Invalid synapse: src={}, dest={}, units={}",
144                src, dest, self.units
145            );
146        }
147        if ![-1, 1].contains(&polarity) {
148            panic!("Polarity must be -1 or 1, got {}", polarity);
149        }
150        self.adjacency_matrix[[src, dest]] = polarity;
151    }
152
153    fn add_sensory_synapse(&mut self, src: usize, dest: usize, polarity: i32) {
154        let input_dim = self
155            .input_dim
156            .expect("Must build wiring before adding sensory synapses");
157        if src >= input_dim || dest >= self.units {
158            panic!(
159                "Invalid sensory synapse: src={}, dest={}, input_dim={}, units={}",
160                src, dest, input_dim, self.units
161            );
162        }
163        if ![-1, 1].contains(&polarity) {
164            panic!("Polarity must be -1 or 1, got {}", polarity);
165        }
166        self.sensory_adjacency_matrix.as_mut().unwrap()[[src, dest]] = polarity;
167    }
168
169    fn get_config(&self) -> WiringConfig {
170        WiringConfig {
171            units: self.units,
172            adjacency_matrix: Some(
173                self.adjacency_matrix
174                    .outer_iter()
175                    .map(|v| v.to_vec())
176                    .collect(),
177            ),
178            sensory_adjacency_matrix: self
179                .sensory_adjacency_matrix
180                .as_ref()
181                .map(|m| m.outer_iter().map(|v| v.to_vec()).collect()),
182            input_dim: self.input_dim,
183            output_dim: Some(self.output_dim),
184            // Random-specific fields
185            sparsity_level: Some(self.sparsity_level),
186            random_seed: Some(self.random_seed),
187            // Other fields not used by Random
188            erev_init_seed: None,
189            self_connections: None,
190            num_inter_neurons: None,
191            num_command_neurons: None,
192            num_motor_neurons: None,
193            sensory_fanout: None,
194            inter_fanout: None,
195            recurrent_command_synapses: None,
196            motor_fanin: None,
197            seed: None,
198        }
199    }
200}