1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
//! Wired CfC Cell Implementation
//!
//! Multi-layer CfC cell that respects NCP wiring structure.
//! Creates separate CfC cells for each layer of the wiring, following the
//! connectivity patterns defined by the adjacency matrices.
use crate::cells::{CfCCell, CfcMode};
use crate::wirings::Wiring;
use burn::module::Module;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use ndarray::Array2;
/// Wired CfC Cell - Multi-layer CfC respecting NCP wiring structure
///
/// This cell creates separate CfC cells for each layer of the wiring,
/// with appropriate sparsity masks derived from the adjacency matrices.
#[derive(Module, Debug)]
pub struct WiredCfCCell<B: Backend> {
/// The layers of CfC cells, one per wiring layer
#[module(child_list)]
layers: Vec<CfCCell<B>>,
/// Total number of neurons (state size)
#[module(skip)]
state_size: usize,
/// Motor (output) size
#[module(skip)]
motor_size: usize,
/// Sensory (input) size
#[module(skip)]
sensory_size: usize,
/// Layer sizes for state partitioning
#[module(skip)]
layer_sizes: Vec<usize>,
}
impl<B: Backend> WiredCfCCell<B> {
/// Create a new WiredCfCCell with a given wiring
///
/// # Arguments
/// * `wiring` - The wiring configuration (must be built)
/// * `device` - The device for tensor operations
/// * `mode` - The CfC operating mode (default, pure, or no_gate)
pub fn new(wiring: &dyn Wiring, device: &B::Device, mode: CfcMode) -> Self {
if !wiring.is_built() {
panic!(
"Wiring error! Unknown number of input features. \
Please build the wiring first by calling wiring.build(input_size)."
);
}
let num_layers = wiring.num_layers();
let input_dim = wiring.input_dim().unwrap();
let state_size = wiring.units();
let motor_size = wiring.output_dim().unwrap_or(state_size);
let mut layers: Vec<CfCCell<B>> = Vec::with_capacity(num_layers);
// For fully_connected-like case, we create one CfC cell per layer
// Each layer's CfC takes as input: previous layer output (or sensory input)
// and its own hidden state (which is the layer's neurons)
// For the wiring-based CfC, we need to create sparsity masks that constrain
// the connections based on the wiring adjacency matrix
for l in 0..num_layers {
let hidden_units = wiring.get_neurons_of_layer(l);
let num_hidden = hidden_units.len();
// The input to this CfC cell is:
// - For layer 0: sensory inputs + layer 0 hidden state
// - For layer N: layer N-1 output + layer N hidden state
// But CfCCell already handles concatenation internally, so we just need
// to tell it the input size (previous layer output size or sensory size)
let prev_layer_size = if l == 0 {
input_dim
} else {
wiring.get_neurons_of_layer(l - 1).len()
};
// Build input sparsity mask based on wiring connections
// The mask should have shape: [prev_layer_size, num_hidden]
// Extended with identity for recurrent connections: [prev_layer_size + num_hidden, num_hidden]
let input_sparsity = if l == 0 {
// First layer: use sensory adjacency matrix
let sensory_matrix = wiring
.sensory_adjacency_matrix()
.expect("Sensory adjacency matrix required for first layer");
// Extract columns for this layer's neurons
let mut mask = Array2::zeros((input_dim, num_hidden));
for (i, &neuron_id) in hidden_units.iter().enumerate() {
let col = sensory_matrix.column(neuron_id);
for (row, &val) in col.iter().enumerate() {
mask[[row, i]] = val.abs() as f32;
}
}
mask
} else {
// Subsequent layers: use adjacency matrix from previous layer
let adj_matrix = wiring.adjacency_matrix();
let prev_layer_neurons = wiring.get_neurons_of_layer(l - 1);
// Create mask: [prev_layer_size x current_layer_size]
let mut mask = Array2::zeros((prev_layer_neurons.len(), num_hidden));
for (i, ¤t_neuron) in hidden_units.iter().enumerate() {
for (j, &prev_neuron) in prev_layer_neurons.iter().enumerate() {
mask[[j, i]] = adj_matrix[[prev_neuron, current_neuron]].abs() as f32;
}
}
mask
};
// Extend mask with identity matrix for recurrent connections
// The extended mask should be: [prev_layer_size + num_hidden, num_hidden]
let mut extended_mask =
Array2::zeros((input_sparsity.nrows() + num_hidden, num_hidden));
// Copy input sparsity to top portion
for i in 0..input_sparsity.nrows() {
for j in 0..num_hidden {
extended_mask[[i, j]] = input_sparsity[[i, j]];
}
}
// Set identity matrix in bottom portion (recurrent connections)
for i in 0..num_hidden {
extended_mask[[input_sparsity.nrows() + i, i]] = 1.0;
}
// Create CfC cell:
// - input_size should be the previous layer's output dimension (or sensory size)
// - hidden_size is this layer's number of neurons
// CfCCell will internally concatenate [input, hx], so it creates
// a Linear layer with input_size=input_size+hidden_size
let cell = CfCCell::new(prev_layer_size, num_hidden, device)
.with_mode(mode)
.with_sparsity_mask(extended_mask, device);
layers.push(cell);
}
let layer_sizes: Vec<usize> = (0..num_layers)
.map(|l| wiring.get_neurons_of_layer(l).len())
.collect();
Self {
layers,
state_size,
motor_size,
sensory_size: input_dim,
layer_sizes,
}
}
/// Create a new WiredCfCCell with default mode
pub fn with_default_mode(wiring: &dyn Wiring, device: &B::Device) -> Self {
Self::new(wiring, device, CfcMode::Default)
}
/// Get the total state size (sum of all layer neurons)
pub fn state_size(&self) -> usize {
self.state_size
}
/// Get the motor (output) size
pub fn motor_size(&self) -> usize {
self.motor_size
}
/// Get the number of layers
pub fn num_layers(&self) -> usize {
self.layer_sizes.len()
}
/// Get the sizes of each layer
pub fn layer_sizes(&self) -> &[usize] {
&self.layer_sizes
}
/// Get the sensory (input) size
pub fn sensory_size(&self) -> usize {
self.sensory_size
}
/// Get the output size (alias for motor_size)
pub fn output_size(&self) -> usize {
self.motor_size()
}
/// Perform a forward pass through the wired CfC cell
///
/// # Arguments
/// * `input` - Input tensor of shape [batch_size, sensory_size]
/// * `hx` - Hidden state tensor of shape [batch_size, state_size]
/// * `ts` - Time step (scalar)
///
/// # Returns
/// * `(output, new_hidden)` - Output is motor neurons, new_hidden is full state
pub fn forward(
&self,
input: Tensor<B, 2>,
hx: Tensor<B, 2>,
ts: f32,
) -> (Tensor<B, 2>, Tensor<B, 2>) {
// Split hx into layer states using narrow
let mut h_states: Vec<Tensor<B, 2>> = Vec::with_capacity(self.num_layers());
let mut start_idx = 0;
for &layer_size in &self.layer_sizes {
// Use narrow to extract slice [batch, start:start+size]
let layer_state = hx.clone().narrow(1, start_idx, layer_size);
h_states.push(layer_state);
start_idx += layer_size;
}
// Forward through each layer
let mut new_h_states: Vec<Tensor<B, 2>> = Vec::with_capacity(self.num_layers());
let mut layer_input = input;
for (i, layer) in self.layers.iter().enumerate() {
let h_state = h_states[i].clone();
let (new_h, _) = layer.forward(layer_input, h_state, ts);
layer_input = new_h.clone();
new_h_states.push(new_h);
}
// Concatenate new states
let new_hx = Tensor::cat(new_h_states, 1);
// For FullyConnected wiring: output_dim might differ from units
// The output neurons are the first motor_size neurons of the state
// For NCP wiring: motor neurons are naturally the first layer
let output = if self.motor_size != self.state_size {
// Need to narrow to get only motor neurons
layer_input.narrow(1, 0, self.motor_size)
} else {
layer_input
};
(output, new_hx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wirings::{AutoNCP, FullyConnected, NCP};
use burn::backend::NdArray;
use burn::tensor::backend::Backend as BurnBackend;
type TestBackend = NdArray<f32>;
type TestDevice = <TestBackend as BurnBackend>::Device;
fn get_test_device() -> TestDevice {
Default::default()
}
fn create_wired_cell_with_ncp() -> WiredCfCCell<TestBackend> {
let device = get_test_device();
let mut wiring = AutoNCP::new(32, 8, 0.5, 22222);
wiring.build(16);
WiredCfCCell::new(&wiring, &device, CfcMode::Default)
}
#[test]
fn test_wired_cfc_creation() {
let cell = create_wired_cell_with_ncp();
assert_eq!(cell.state_size(), 32);
assert_eq!(cell.motor_size(), 8);
assert_eq!(cell.num_layers(), 3);
}
#[test]
fn test_wired_cfc_layer_sizes() {
let cell = create_wired_cell_with_ncp();
let sizes = cell.layer_sizes();
// Should have 3 layers
assert_eq!(sizes.len(), 3);
// Total should equal state_size
let total: usize = sizes.iter().sum();
assert_eq!(total, cell.state_size());
}
#[test]
fn test_wired_cfc_forward() {
let device = get_test_device();
let cell = create_wired_cell_with_ncp();
let batch_size = 4;
let input = Tensor::<TestBackend, 2>::zeros([batch_size, 16], &device);
let hx = Tensor::<TestBackend, 2>::zeros([batch_size, 32], &device);
let (output, new_hidden) = cell.forward(input, hx, 1.0);
// Output should be motor_size
assert_eq!(output.dims(), [batch_size, 8]);
// New hidden should preserve full state
assert_eq!(new_hidden.dims(), [batch_size, 32]);
}
#[test]
fn test_wired_cfc_state_partitioning() {
let device = get_test_device();
let cell = create_wired_cell_with_ncp();
// Create state with different values for each layer
let layer_sizes = cell.layer_sizes().to_vec();
let hx_parts: Vec<Tensor<TestBackend, 2>> = layer_sizes
.iter()
.enumerate()
.map(|(i, &size)| Tensor::<TestBackend, 2>::full([2, size], (i + 1) as f32, &device))
.collect();
let hx = Tensor::cat(hx_parts, 1);
let input = Tensor::<TestBackend, 2>::zeros([2, 16], &device);
let (output, new_hidden) = cell.forward(input, hx, 1.0);
// Verify state was processed correctly
assert_eq!(new_hidden.dims(), [2, 32]);
assert_eq!(output.dims(), [2, 8]);
}
#[test]
fn test_wired_cfc_with_different_wirings() {
let device = get_test_device();
// Test with manually configured NCP
let mut wiring = NCP::new(10, 8, 5, 6, 6, 4, 6, 22222);
wiring.build(10);
let cell = WiredCfCCell::<TestBackend>::new(&wiring, &device, CfcMode::Default);
assert_eq!(cell.state_size(), 23); // 10 + 8 + 5
assert_eq!(cell.num_layers(), 3);
}
#[test]
fn test_wired_cfc_information_flow() {
let device = get_test_device();
let cell = create_wired_cell_with_ncp();
// Test that information flows from sensory through all layers
let input1 = Tensor::<TestBackend, 2>::zeros([1, 16], &device);
let input2 = Tensor::<TestBackend, 2>::ones([1, 16], &device);
let hx = Tensor::<TestBackend, 2>::zeros([1, 32], &device);
let (out1, _) = cell.forward(input1, hx.clone(), 1.0);
let (out2, _) = cell.forward(input2, hx, 1.0);
let diff = (out1 - out2).abs().sum().into_scalar();
assert!(
diff > 0.0,
"Different inputs should produce different outputs"
);
}
#[test]
fn test_wired_cfc_with_fully_connected() {
let device = get_test_device();
// Test with FullyConnected wiring
let mut wiring = FullyConnected::new(20, Some(5), 1234, true);
wiring.build(10);
let cell = WiredCfCCell::<TestBackend>::new(&wiring, &device, CfcMode::Default);
assert_eq!(cell.state_size(), 20);
assert_eq!(cell.motor_size(), 5);
// FullyConnected has only 1 layer
assert_eq!(cell.num_layers(), 1);
// Test forward pass
let input = Tensor::<TestBackend, 2>::zeros([2, 10], &device);
let hx = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
let (output, new_hidden) = cell.forward(input, hx, 1.0);
assert_eq!(output.dims(), [2, 5]);
assert_eq!(new_hidden.dims(), [2, 20]);
}
}